|
|
|
@ -565,18 +565,18 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
[len(hyp[0]) for hyp in hyps], place=device,
|
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
|
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
|
|
|
|
|
logger.debug(
|
|
|
|
|
logger.info(
|
|
|
|
|
f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}")
|
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
|
|
|
|
|
|
# ctc score in ln domain
|
|
|
|
|
# (beam_size, max_hyps_len, vocab_size)
|
|
|
|
|
decoder_out, r_decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens,
|
|
|
|
|
encoder_out,reverse_weight )
|
|
|
|
|
encoder_out, reverse_weight)
|
|
|
|
|
|
|
|
|
|
decoder_out = decoder_out.numpy()
|
|
|
|
|
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
|
|
|
|
|
# conventional transformer decoder.
|
|
|
|
|
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
|
|
|
|
|
r_decoder_out = r_decoder_out.numpy()
|
|
|
|
|
|
|
|
|
|
# Only use decoder score for rescoring
|
|
|
|
@ -590,15 +590,16 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
# last decoder output token is `eos`, for laste decoder input token.
|
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"hyp {i} len {len(hyp[0])} l2r score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}")
|
|
|
|
|
|
|
|
|
|
if reverse_weight > 0:
|
|
|
|
|
r_score = 0.0
|
|
|
|
|
for j, w in enumerate(hyp[0]):
|
|
|
|
|
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
|
|
|
|
|
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
|
|
|
|
|
|
logger.info(f"hyp {i} len {len(hyp[0])} r2l score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}")
|
|
|
|
|
|
|
|
|
|
score = score * (1 - reverse_weight) + r_score * reverse_weight
|
|
|
|
|
|
|
|
|
|
# add ctc score (which in ln domain)
|
|
|
|
@ -607,7 +608,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
best_score = score
|
|
|
|
|
best_index = i
|
|
|
|
|
|
|
|
|
|
logger.debug(f"result: {hyps[best_index]}")
|
|
|
|
|
logger.info(f"result: {hyps[best_index]}")
|
|
|
|
|
return hyps[best_index][0]
|
|
|
|
|
|
|
|
|
|
@jit.to_static(property=True)
|
|
|
|
|