diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index 51b72209d..4588def0b 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -40,7 +40,7 @@ class U2Infer(): self.preprocess_conf = config.preprocess_config self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) - + self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0) self.text_feature = TextFeaturizer( unit_type=config.unit_type, vocab=config.vocab_filepath, @@ -90,7 +90,7 @@ class U2Infer(): decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, simulate_streaming=decode_config.simulate_streaming, - reverse_weight=self.config.model_conf.reverse_weight) + reverse_weight=self.reverse_weight) rsl = result_transcripts[0][0] utt = Path(self.audio_file).name logger.info(f"hyp: {utt} {result_transcripts[0][0]}") diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 99a0434d5..84b7be323 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -316,7 +316,7 @@ class U2Tester(U2Trainer): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.vocab_list = self.text_feature.vocab_list - self.reverse_weight = getattr(config, 'reverse_weight', '0.0') + self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0) def id2token(self, texts, texts_len, text_feature): """ ord() id to chr() chr """ diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 84c0e5b5e..e54a7afbc 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -689,24 +689,24 @@ class U2BaseModel(ASRInterface, nn.Layer): """ return self.ctc.log_softmax(xs) - @jit.to_static + # @jit.to_static def is_bidirectional_decoder(self) -> bool: """ Returns: - torch.Tensor: decoder output + paddle.Tensor: decoder output """ if hasattr(self.decoder, 'right_decoder'): return True else: return False - @jit.to_static + # @jit.to_static def forward_attention_decoder( self, hyps: paddle.Tensor, hyps_lens: paddle.Tensor, encoder_out: paddle.Tensor, - reverse_weight: float=0, ) -> paddle.Tensor: + reverse_weight: float=0.0, ) -> paddle.Tensor: """ Export interface for c++ call, forward decoder with multiple hypothesis from ctc prefix beam search and one encoder output Args: @@ -783,7 +783,7 @@ class U2BaseModel(ASRInterface, nn.Layer): # >>> tensor([[3, 2, 1], # >>> [4, 8, 9], # >>> [2, eos, eos]]) - r_hyps = torch.concat([hyps[:, 0:1], r_hyps], axis=1) + r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1) # >>> r_hyps # >>> tensor([[sos, 3, 2, 1], # >>> [sos, 4, 8, 9], @@ -791,7 +791,7 @@ class U2BaseModel(ASRInterface, nn.Layer): decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) - r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1) return decoder_out, r_decoder_out @paddle.no_grad() diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index 2052a19e1..3b1a7f23d 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -363,9 +363,8 @@ class BiTransformerDecoder(BatchScorerInterface, nn.Layer): memory: encoded memory, float32 (batch, maxlen_in, feat) memory_mask: encoded memory mask, (batch, 1, maxlen_in) tgt: input token ids, int64 (batch, maxlen_out) - tgt_mask: input token mask, (batch, maxlen_out) - dtype=torch.uint8 in PyTorch 1.2- - dtype=torch.bool in PyTorch 1.2+ (include 1.2) + tgt_mask: input token mask, (batch, maxlen_out, maxlen_out) + dtype=paddle.bool cache: cached output list of (batch, max_time_out-1, size) Returns: y, cache: NN output value and cache per `self.decoders`.