using forward_attention_decoder

pull/2212/head
Hui Zhang 2 years ago
parent 0d7d87120b
commit 260752aa2a

@ -70,7 +70,6 @@ class U2Infer():
# read # read
audio, sample_rate = soundfile.read( audio, sample_rate = soundfile.read(
self.audio_file, dtype="int16", always_2d=True) self.audio_file, dtype="int16", always_2d=True)
audio = audio[:, 0] audio = audio[:, 0]
logger.info(f"audio shape: {audio.shape}") logger.info(f"audio shape: {audio.shape}")
@ -78,11 +77,10 @@ class U2Infer():
feat = self.preprocessing(audio, **self.preprocess_args) feat = self.preprocessing(audio, **self.preprocess_args)
logger.info(f"feat shape: {feat.shape}") logger.info(f"feat shape: {feat.shape}")
np.savetxt("feat.transform.txt", feat)
ilen = paddle.to_tensor(feat.shape[0]) ilen = paddle.to_tensor(feat.shape[0])
xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(axis=0) xs = paddle.to_tensor(feat, dtype='float32').unsqueeze(0)
decode_config = self.config.decode decode_config = self.config.decode
logger.debug(f"decode cfg: {decode_config}")
result_transcripts = self.model.decode( result_transcripts = self.model.decode(
xs, xs,
ilen, ilen,

@ -545,17 +545,11 @@ class U2BaseModel(ASRInterface, nn.Layer):
[len(hyp[0]) for hyp in hyps], place=device, [len(hyp[0]) for hyp in hyps], place=device,
dtype=paddle.long) # (beam_size,) dtype=paddle.long) # (beam_size,)
hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id)
logger.debug(f"hyps pad: {hyps_pad} {self.sos} {self.eos} {self.ignore_id}")
hyps_lens = hyps_lens + 1 # Add <sos> at begining hyps_lens = hyps_lens + 1 # Add <sos> at begining
encoder_out = encoder_out.repeat(beam_size, 1, 1)
encoder_mask = paddle.ones(
(beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool)
decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps_pad,
hyps_lens) # (beam_size, max_hyps_len, vocab_size)
# ctc score in ln domain # ctc score in ln domain
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = self.forward_attention_decoder(hyps_pad, hyps_lens, encoder_out)
decoder_out = decoder_out.numpy()
# Only use decoder score for rescoring # Only use decoder score for rescoring
best_score = -float('inf') best_score = -float('inf')
@ -567,11 +561,15 @@ class U2BaseModel(ASRInterface, nn.Layer):
score += decoder_out[i][j][w] score += decoder_out[i][j][w]
# last decoder output token is `eos`, for laste decoder input token. # last decoder output token is `eos`, for laste decoder input token.
score += decoder_out[i][len(hyp[0])][self.eos] score += decoder_out[i][len(hyp[0])][self.eos]
logger.debug(f"hyp {i} len {len(hyp[0])} l2r rescore_score: {score} ctc_score: {hyp[1]}")
# add ctc score (which in ln domain) # add ctc score (which in ln domain)
score += hyp[1] * ctc_weight score += hyp[1] * ctc_weight
if score > best_score: if score > best_score:
best_score = score best_score = score
best_index = i best_index = i
logger.debug(f"result: {hyps[best_index]}")
return hyps[best_index][0] return hyps[best_index][0]
@jit.to_static(property=True) @jit.to_static(property=True)

Loading…
Cancel
Save