using forward_attention_decoder

pull/2212/head
Hui Zhang 2 years ago
parent 0d7d87120b
commit 260752aa2a

@ -69,8 +69,7 @@ class U2Infer():
with paddle.no_grad():
# read
audio, sample_rate = soundfile.read(
self.audio_file, dtype="int16", always_2d=True)
self.audio_file, dtype="int16", always_2d=True)
audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}")
@ -78,11 +77,10 @@ class U2Infer():
feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}")
np.savetxt("feat.transform.txt", feat)
ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0)
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode
logger.debug(f"decode cfg: {decode_config}")
result_transcripts = self.model.decode(
xs,
ilen,

@ -545,17 +545,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
[len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
logger.debug(f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}")
hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
decoder_out = decoder_out.numpy()
decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens, encoder_out)
# Only use decoder score for rescoring
best_score = -float('inf')
@ -567,11 +561,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos]
logger.debug(f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}")
# add ctc score (which in ln domain)
score += hyp[1] * ctc_weight
if score > best_score:
best_score = score
best_index = i
logger.debug(f"result: {hyps[best_index]}")
return hyps[best_index][0]
@jit.to_static(property=True)

Loading…
Cancel
Save