|
|
@ -545,17 +545,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
[len(hyp[0]) for hyp in hyps], place=device,
|
|
|
|
[len(hyp[0]) for hyp in hyps], place=device,
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
dtype=paddle.long) # (beam_size,)
|
|
|
|
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
|
|
|
|
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
|
|
|
|
|
|
|
|
logger.debug(f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}")
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
hyps_lens = hyps_lens + 1 # Add <sos> at begining
|
|
|
|
|
|
|
|
|
|
|
|
encoder_out = encoder_out.repeat(beam_size, 1, 1)
|
|
|
|
|
|
|
|
encoder_mask = paddle.ones(
|
|
|
|
|
|
|
|
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
|
|
|
|
|
|
|
|
decoder_out, _ = self.decoder(
|
|
|
|
|
|
|
|
encoder_out, encoder_mask, hyps_pad,
|
|
|
|
|
|
|
|
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
|
|
|
|
|
|
|
|
# ctc score in ln domain
|
|
|
|
# ctc score in ln domain
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
|
|
|
|
decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens, encoder_out)
|
|
|
|
decoder_out = decoder_out.numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Only use decoder score for rescoring
|
|
|
|
# Only use decoder score for rescoring
|
|
|
|
best_score = -float('inf')
|
|
|
|
best_score = -float('inf')
|
|
|
@ -567,11 +561,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
score += decoder_out[i][j][w]
|
|
|
|
score += decoder_out[i][j][w]
|
|
|
|
# last decoder output token is `eos`, for laste decoder input token.
|
|
|
|
# last decoder output token is `eos`, for laste decoder input token.
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
|
|
|
|
logger.debug(f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}")
|
|
|
|
|
|
|
|
|
|
|
|
# add ctc score (which in ln domain)
|
|
|
|
# add ctc score (which in ln domain)
|
|
|
|
score += hyp[1] * ctc_weight
|
|
|
|
score += hyp[1] * ctc_weight
|
|
|
|
if score > best_score:
|
|
|
|
if score > best_score:
|
|
|
|
best_score = score
|
|
|
|
best_score = score
|
|
|
|
best_index = i
|
|
|
|
best_index = i
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.debug(f"result: {hyps[best_index]}")
|
|
|
|
return hyps[best_index][0]
|
|
|
|
return hyps[best_index][0]
|
|
|
|
|
|
|
|
|
|
|
|
@jit.to_static(property=True)
|
|
|
|
@jit.to_static(property=True)
|
|
|
|