|
|
|
@ -213,14 +213,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
num_decoding_left_chunks=num_decoding_left_chunks
|
|
|
|
|
) # (B, maxlen, encoder_dim)
|
|
|
|
|
else:
|
|
|
|
|
print("offline decode from the asr")
|
|
|
|
|
encoder_out, encoder_mask = self.encoder(
|
|
|
|
|
speech,
|
|
|
|
|
speech_lengths,
|
|
|
|
|
decoding_chunk_size=decoding_chunk_size,
|
|
|
|
|
num_decoding_left_chunks=num_decoding_left_chunks
|
|
|
|
|
) # (B, maxlen, encoder_dim)
|
|
|
|
|
print("offline decode success")
|
|
|
|
|
return encoder_out, encoder_mask
|
|
|
|
|
|
|
|
|
|
def recognize(
|
|
|
|
@ -281,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
# TODO(Hui Zhang): if end_flag.sum() == running_size:
|
|
|
|
|
if end_flag.cast(paddle.int64).sum() == running_size:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 2.1 Forward decoder step
|
|
|
|
|
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
|
|
|
|
|
running_size, 1, 1).to(device) # (B*N, i, i)
|
|
|
|
|
# logp: (B*N, vocab)
|
|
|
|
|
logp, cache = self.decoder.forward_one_step(
|
|
|
|
|
encoder_out, encoder_mask, hyps, hyps_mask, cache)
|
|
|
|
|
|
|
|
|
|
# 2.2 First beam prune: select topk best prob at current time
|
|
|
|
|
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
|
|
|
|
|
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
|
|
|
|
@ -708,7 +705,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
List[List[int]]: transcripts.
|
|
|
|
|
"""
|
|
|
|
|
batch_size = feats.shape[0]
|
|
|
|
|
print("start to decode the audio feat")
|
|
|
|
|
if decoding_method in ['ctc_prefix_beam_search',
|
|
|
|
|
'attention_rescoring'] and batch_size > 1:
|
|
|
|
|
logger.error(
|
|
|
|
@ -716,7 +712,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
|
|
|
|
|
)
|
|
|
|
|
logger.error(f"current batch_size is {batch_size}")
|
|
|
|
|
sys.exit(1)
|
|
|
|
|
print(f"use the {decoding_method} to decode the audio feat")
|
|
|
|
|
if decoding_method == 'attention':
|
|
|
|
|
hyps = self.recognize(
|
|
|
|
|
feats,
|
|
|
|
|