diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 2279812b..93c5d910 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -571,8 +571,8 @@ class U2BaseModel(ASRInterface, nn.Layer): # ctc score in ln domain # (beam_size, max_hyps_len, vocab_size) - decoder_out, r_decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens, - encoder_out, reverse_weight) + decoder_out, r_decoder_out = self.forward_attention_decoder( + hyps_pad, hyps_lens, encoder_out, reverse_weight) decoder_out = decoder_out.numpy() # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a @@ -590,7 +590,9 @@ class U2BaseModel(ASRInterface, nn.Layer): # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.eos] - logger.debug(f"hyp {i} len {len(hyp[0])} l2r score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}") + logger.debug( + f"hyp {i} len {len(hyp[0])} l2r score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}" + ) if reverse_weight > 0: r_score = 0.0 @@ -598,7 +600,9 @@ class U2BaseModel(ASRInterface, nn.Layer): r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] r_score += r_decoder_out[i][len(hyp[0])][self.eos] - logger.info(f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}") + logger.info( + f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}" + ) score = score * (1 - reverse_weight) + r_score * reverse_weight @@ -702,12 +706,11 @@ class U2BaseModel(ASRInterface, nn.Layer): return self.ctc.log_softmax(xs) # @jit.to_static - def forward_attention_decoder( - self, - hyps: paddle.Tensor, - hyps_lens: paddle.Tensor, - encoder_out: paddle.Tensor, - reverse_weight: float=0.0) -> paddle.Tensor: + def forward_attention_decoder(self, + hyps: paddle.Tensor, + hyps_lens: paddle.Tensor, + encoder_out: paddle.Tensor, + reverse_weight: float=0.0) -> paddle.Tensor: """ Export interface for c++ call, forward decoder with multiple hypothesis from ctc prefix beam search and one encoder output Args: