From 00b2c1c8fb4fc81e723e8580cbc7ed6059378680 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 21 Sep 2022 07:50:02 +0000 Subject: [PATCH] fix forward attention decoder caller --- paddlespeech/s2t/exps/u2/bin/test_wav.py | 2 +- paddlespeech/s2t/models/u2/u2.py | 15 ++++++++------- paddlespeech/s2t/modules/decoder.py | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index 9446884f..31890cb1 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -79,7 +79,7 @@ class U2Infer(): ilen = paddle.to_tensor(feat.shape[0]) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) decode_config = self.config.decode - logger.debug(f"decode cfg: {decode_config}") + logger.info(f"decode cfg: {decode_config}") result_transcripts = self.model.decode( xs, ilen, diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 1681bf1d..7609b71e 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -565,18 +565,18 @@ class U2BaseModel(ASRInterface, nn.Layer): [len(hyp[0]) for hyp in hyps], place=device, dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) - logger.debug( + logger.info( f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}") hyps_lens = hyps_lens + 1 # Add at begining # ctc score in ln domain # (beam_size, max_hyps_len, vocab_size) decoder_out, r_decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens, - encoder_out,reverse_weight ) + encoder_out, reverse_weight) + decoder_out = decoder_out.numpy() # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a # conventional transformer decoder. - r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1) r_decoder_out = r_decoder_out.numpy() # Only use decoder score for rescoring @@ -590,15 +590,16 @@ class U2BaseModel(ASRInterface, nn.Layer): # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.eos] - logger.debug( - f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}" - ) + logger.info(f"hyp {i} len {len(hyp[0])} l2r score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}") if reverse_weight > 0: r_score = 0.0 for j, w in enumerate(hyp[0]): r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] r_score += r_decoder_out[i][len(hyp[0])][self.eos] + + logger.info(f"hyp {i} len {len(hyp[0])} r2l score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}") + score = score * (1 - reverse_weight) + r_score * reverse_weight # add ctc score (which in ln domain) @@ -607,7 +608,7 @@ class U2BaseModel(ASRInterface, nn.Layer): best_score = score best_index = i - logger.debug(f"result: {hyps[best_index]}") + logger.info(f"result: {hyps[best_index]}") return hyps[best_index][0] @jit.to_static(property=True) diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index 3b1a7f23..03b637b7 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -343,7 +343,7 @@ class BiTransformerDecoder(BatchScorerInterface, nn.Layer): """ l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, ys_in_lens) - r_x = paddle.to_tensor(0.0) + r_x = paddle.zeros([1]) if reverse_weight > 0.0: r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad, ys_in_lens)