remove debug info, test=doc

pull/1704/head
xiongxinlei 2 years ago
parent 48fa84bee9
commit 9c03280ca6

@ -213,14 +213,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
else:
print("offline decode from the asr")
encoder_out, encoder_mask = self.encoder(
speech,
speech_lengths,
decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
print("offline decode success")
return encoder_out, encoder_mask
def recognize(
@ -281,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
@ -708,7 +705,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
List[List[int]]: transcripts.
"""
batch_size = feats.shape[0]
print("start to decode the audio feat")
if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1:
logger.error(
@ -716,7 +712,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
)
logger.error(f"current batch_size is {batch_size}")
sys.exit(1)
print(f"use the {decoding_method} to decode the audio feat")
if decoding_method == 'attention':
hyps = self.recognize(
feats,

Loading…
Cancel
Save