From 260752aa2a3284a37c06b88da2fef3b6d0118280 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 19 Sep 2022 14:10:16 +0000 Subject: [PATCH] using forward_attention_decoder --- paddlespeech/s2t/exps/u2/bin/test_wav.py | 8 +++----- paddlespeech/s2t/models/u2/u2.py | 14 ++++++-------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index c04e3ae4..a55a1eca 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -69,8 +69,7 @@ class U2Infer(): with paddle.no_grad(): # read audio, sample_rate = soundfile.read( - self.audio_file, dtype="int16", always_2d=True) - + self.audio_file, dtype="int16", always_2d=True) audio = audio[:, 0] logger.info(f"audio shape: {audio.shape}") @@ -78,11 +77,10 @@ class U2Infer(): feat = self.preprocessing(audio, **self.preprocess_args) logger.info(f"feat shape: {feat.shape}") - np.savetxt("feat.transform.txt", feat) - ilen = paddle.to_tensor(feat.shape[0]) - xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) + xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) decode_config = self.config.decode + logger.debug(f"decode cfg: {decode_config}") result_transcripts = self.model.decode( xs, ilen, diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index d7b8630a..b4ec6b03 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -545,17 +545,11 @@ class U2BaseModel(ASRInterface, nn.Layer): [len(hyp[0]) for hyp in hyps], place=device, dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) + logger.debug(f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}") hyps_lens = hyps_lens + 1 # Add at begining - encoder_out = encoder_out.repeat(beam_size, 1, 1) - encoder_mask = paddle.ones( - (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) - decoder_out, _ = self.decoder( - encoder_out, encoder_mask, hyps_pad, - hyps_lens) # (beam_size, max_hyps_len, vocab_size) # ctc score in ln domain - decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) - decoder_out = decoder_out.numpy() + decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens, encoder_out) # Only use decoder score for rescoring best_score = -float('inf') @@ -567,11 +561,15 @@ class U2BaseModel(ASRInterface, nn.Layer): score += decoder_out[i][j][w] # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.eos] + logger.debug(f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}") + # add ctc score (which in ln domain) score += hyp[1] * ctc_weight if score > best_score: best_score = score best_index = i + + logger.debug(f"result: {hyps[best_index]}") return hyps[best_index][0] @jit.to_static(property=True)