|
|
@ -571,8 +571,8 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
# ctc score in ln domain
|
|
|
|
# ctc score in ln domain
|
|
|
|
# (beam_size, max_hyps_len, vocab_size)
|
|
|
|
# (beam_size, max_hyps_len, vocab_size)
|
|
|
|
decoder_out, r_decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens,
|
|
|
|
decoder_out, r_decoder_out = self.forward_attention_decoder(
|
|
|
|
encoder_out, reverse_weight)
|
|
|
|
hyps_pad, hyps_lens, encoder_out, reverse_weight)
|
|
|
|
|
|
|
|
|
|
|
|
decoder_out = decoder_out.numpy()
|
|
|
|
decoder_out = decoder_out.numpy()
|
|
|
|
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
|
|
|
|
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
|
|
|
@ -590,7 +590,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
# last decoder output token is `eos`, for laste decoder input token.
|
|
|
|
# last decoder output token is `eos`, for laste decoder input token.
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"hyp {i} len {len(hyp[0])} l2r score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}")
|
|
|
|
logger.debug(
|
|
|
|
|
|
|
|
f"hyp {i} len {len(hyp[0])} l2r score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if reverse_weight > 0:
|
|
|
|
if reverse_weight > 0:
|
|
|
|
r_score = 0.0
|
|
|
|
r_score = 0.0
|
|
|
@ -598,7 +600,9 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
|
|
|
|
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
|
|
|
|
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}")
|
|
|
|
logger.info(
|
|
|
|
|
|
|
|
f"hyp {i} len {len(hyp[0])} r2l score: {r_score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
score = score * (1 - reverse_weight) + r_score * reverse_weight
|
|
|
|
score = score * (1 - reverse_weight) + r_score * reverse_weight
|
|
|
|
|
|
|
|
|
|
|
@ -702,8 +706,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
return self.ctc.log_softmax(xs)
|
|
|
|
return self.ctc.log_softmax(xs)
|
|
|
|
|
|
|
|
|
|
|
|
# @jit.to_static
|
|
|
|
# @jit.to_static
|
|
|
|
def forward_attention_decoder(
|
|
|
|
def forward_attention_decoder(self,
|
|
|
|
self,
|
|
|
|
|
|
|
|
hyps: paddle.Tensor,
|
|
|
|
hyps: paddle.Tensor,
|
|
|
|
hyps_lens: paddle.Tensor,
|
|
|
|
hyps_lens: paddle.Tensor,
|
|
|
|
encoder_out: paddle.Tensor,
|
|
|
|
encoder_out: paddle.Tensor,
|
|
|
|