fix forward attention decoder caller

pull/2425/head
Hui Zhang 2 years ago
parent 309c8d70d9
commit 00b2c1c8fb

@ -79,7 +79,7 @@ class U2Infer():
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode decode_config = self.config.decode
logger.debug(f"decode cfg: {decode_config}") logger.info(f"decode cfg: {decode_config}")
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
xs, xs,
ilen, ilen,

@ -565,18 +565,18 @@ class U2BaseModel(ASRInterface, nn.Layer):
[len(hyp[0]) for hyp in hyps], place=device, [len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
logger.debug( logger.info(
f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}") f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}")
hyps_lens = hyps_lens + 1 # Add <sos> at begining hyps_lens = hyps_lens + 1 # Add <sos> at begining
# ctc score in ln domain # ctc score in ln domain
# (beam_size, max_hyps_len, vocab_size) # (beam_size, max_hyps_len, vocab_size)
decoder_out, r_decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens, decoder_out, r_decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens,
encoder_out,reverse_weight ) encoder_out, reverse_weight)
decoder_out = decoder_out.numpy()
# r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a
# conventional transformer decoder. # conventional transformer decoder.
r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1)
r_decoder_out = r_decoder_out.numpy() r_decoder_out = r_decoder_out.numpy()
# Only use decoder score for rescoring # Only use decoder score for rescoring
@ -590,15 +590,16 @@ class U2BaseModel(ASRInterface, nn.Layer):
# last decoder output token is `eos`, for laste decoder input token. # last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos] score += decoder_out[i][len(hyp[0])][self.eos]
logger.debug( logger.info(f"hyp {i} len {len(hyp[0])} l2r score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}")
f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}"
)
if reverse_weight > 0: if reverse_weight > 0:
r_score = 0.0 r_score = 0.0
for j, w in enumerate(hyp[0]): for j, w in enumerate(hyp[0]):
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
r_score += r_decoder_out[i][len(hyp[0])][self.eos] r_score += r_decoder_out[i][len(hyp[0])][self.eos]
logger.info(f"hyp {i} len {len(hyp[0])} r2l score: {score} ctc_score: {hyp[1]} reverse_weight: {reverse_weight}")
score = score * (1 - reverse_weight) + r_score * reverse_weight score = score * (1 - reverse_weight) + r_score * reverse_weight
# add ctc score (which in ln domain) # add ctc score (which in ln domain)
@ -607,7 +608,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
best_score = score best_score = score
best_index = i best_index = i
logger.debug(f"result: {hyps[best_index]}") logger.info(f"result: {hyps[best_index]}")
return hyps[best_index][0] return hyps[best_index][0]
@jit.to_static(property=True) @jit.to_static(property=True)

@ -343,7 +343,7 @@ class BiTransformerDecoder(BatchScorerInterface, nn.Layer):
""" """
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
ys_in_lens) ys_in_lens)
r_x = paddle.to_tensor(0.0) r_x = paddle.zeros([1])
if reverse_weight > 0.0: if reverse_weight > 0.0:
r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad, r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad,
ys_in_lens) ys_in_lens)

Loading…
Cancel
Save