|
|
|
@ -507,16 +507,14 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
num_decoding_left_chunks, simulate_streaming)
|
|
|
|
|
return hyps[0][0]
|
|
|
|
|
|
|
|
|
|
def attention_rescoring(
|
|
|
|
|
self,
|
|
|
|
|
def attention_rescoring(self,
|
|
|
|
|
speech: paddle.Tensor,
|
|
|
|
|
speech_lengths: paddle.Tensor,
|
|
|
|
|
beam_size: int,
|
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
|
ctc_weight: float=0.0,
|
|
|
|
|
simulate_streaming: bool=False,
|
|
|
|
|
reverse_weight: float=0.0, ) -> List[int]:
|
|
|
|
|
simulate_streaming: bool=False) -> List[int]:
|
|
|
|
|
""" Apply attention rescoring decoding, CTC prefix beam search
|
|
|
|
|
is applied first to get nbest, then we resoring the nbest on
|
|
|
|
|
attention decoder with corresponding encoder out
|
|
|
|
@ -536,7 +534,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
"""
|
|
|
|
|
assert speech.shape[0] == speech_lengths.shape[0]
|
|
|
|
|
assert decoding_chunk_size != 0
|
|
|
|
|
if reverse_weight > 0.0:
|
|
|
|
|
if self.reverse_weight > 0.0:
|
|
|
|
|
# decoder should be a bitransformer decoder if reverse_weight > 0.0
|
|
|
|
|
assert hasattr(self.decoder, 'right_decoder')
|
|
|
|
|
device = speech.place
|
|
|
|
@ -574,7 +572,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
self.eos)
|
|
|
|
|
decoder_out, r_decoder_out, _ = self.decoder(
|
|
|
|
|
encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad,
|
|
|
|
|
reverse_weight) # (beam_size, max_hyps_len, vocab_size)
|
|
|
|
|
self.reverse_weight) # (beam_size, max_hyps_len, vocab_size)
|
|
|
|
|
# ctc score in ln domain
|
|
|
|
|
decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1)
|
|
|
|
|
decoder_out = decoder_out.numpy()
|
|
|
|
@ -594,12 +592,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
score += decoder_out[i][j][w]
|
|
|
|
|
# last decoder output token is `eos`, for laste decoder input token.
|
|
|
|
|
score += decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
|
if reverse_weight > 0:
|
|
|
|
|
if self.reverse_weight > 0:
|
|
|
|
|
r_score = 0.0
|
|
|
|
|
for j, w in enumerate(hyp[0]):
|
|
|
|
|
r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w]
|
|
|
|
|
r_score += r_decoder_out[i][len(hyp[0])][self.eos]
|
|
|
|
|
score = score * (1 - reverse_weight) + r_score * reverse_weight
|
|
|
|
|
score = score * (1 - self.reverse_weight
|
|
|
|
|
) + r_score * self.reverse_weight
|
|
|
|
|
# add ctc score (which in ln domain)
|
|
|
|
|
score += hyp[1] * ctc_weight
|
|
|
|
|
if score > best_score:
|
|
|
|
@ -748,8 +747,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
ctc_weight: float=0.0,
|
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
|
simulate_streaming: bool=False,
|
|
|
|
|
reverse_weight: float=0.0):
|
|
|
|
|
simulate_streaming: bool=False):
|
|
|
|
|
"""u2 decoding.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -821,8 +819,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
decoding_chunk_size=decoding_chunk_size,
|
|
|
|
|
num_decoding_left_chunks=num_decoding_left_chunks,
|
|
|
|
|
ctc_weight=ctc_weight,
|
|
|
|
|
simulate_streaming=simulate_streaming,
|
|
|
|
|
reverse_weight=reverse_weight)
|
|
|
|
|
simulate_streaming=simulate_streaming)
|
|
|
|
|
hyps = [hyp]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Not support decoding method: {decoding_method}")
|
|
|
|
|