fix forward attention decoder caller

pull/2425/head
Hui Zhang 2 years ago
parent 309c8d70d9
commit 00b2c1c8fb

@ -79,7 +79,7 @@ class U2Infer():
ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode
logger.debug(f"decode cfg: {decode_config}")
logger.info(f"decode cfg: {decode_config}")
result_transcripts = self.model.decode(
xs,
ilen,

@ -565,18 +565,18 @@ class U2BaseModel(ASRInterface, nn.Layer):
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
logger.debug(
logger.info(
f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}")
hyps_lens = hyps_lens + 1 # Add <sos> at begining
# ctc score in ln domain
# (beam_size, max_hyps_len, vocab_size)
decoder_out, r_decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens,
encoder_out,reverse_weight )
encoder_out, reverse_weight)
decoder_out = decoder_out.numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder.
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
r_decoder_out = r_decoder_out.numpy()
# Only use decoder score for rescoring
@ -590,15 +590,16 @@ class U2BaseModel(ASRInterface, nn.Layer):
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos]
logger.debug(
f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}"
)
logger.info(f"hyp {i} len {len(hyp[0])} l2r score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}")
if reverse_weight > 0:
r_score = 0.0
for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
logger.info(f"hyp {i} len {len(hyp[0])} r2l score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}")
score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score (which in ln domain)
@ -607,7 +608,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
best_score = score
best_index = i
logger.debug(f"result: {hyps[best_index]}")
logger.info(f"result: {hyps[best_index]}")
return hyps[best_index][0]
@jit.to_static(property=True)

@ -343,7 +343,7 @@ class BiTransformerDecoder(BatchScorerInterface, nn.Layer):
"""
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
ys_in_lens)
r_x = paddle.to_tensor(0.0)
r_x = paddle.zeros([1])
if reverse_weight > 0.0:
r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad,
ys_in_lens)

Loading…
Cancel
Save