remove debug info, test=doc

pull/1704/head
xiongxinlei 2 years ago
parent 48fa84bee9
commit 9c03280ca6

@ -213,14 +213,12 @@ class U2BaseModel(ASRInterface, nn.Layer):
num_decoding_left_chunks=num_decoding_left_chunks num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim) ) # (B, maxlen, encoder_dim)
else: else:
print("offline decode from the asr")
encoder_out, encoder_mask = self.encoder( encoder_out, encoder_mask = self.encoder(
speech, speech,
speech_lengths, speech_lengths,
decoding_chunk_size=decoding_chunk_size, decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks num_decoding_left_chunks=num_decoding_left_chunks
) # (B, maxlen, encoder_dim) ) # (B, maxlen, encoder_dim)
print("offline decode success")
return encoder_out, encoder_mask return encoder_out, encoder_mask
def recognize( def recognize(
@ -281,14 +279,13 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size: # TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size: if end_flag.cast(paddle.int64).sum() == running_size:
break break
# 2.1 Forward decoder step # 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i) running_size, 1, 1).to(device) # (B*N, i, i)
# logp: (B*N, vocab) # logp: (B*N, vocab)
logp, cache = self.decoder.forward_one_step( logp, cache = self.decoder.forward_one_step(
encoder_out, encoder_mask, hyps, hyps_mask, cache) encoder_out, encoder_mask, hyps, hyps_mask, cache)
# 2.2 First beam prune: select topk best prob at current time # 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp = mask_finished_scores(top_k_logp, end_flag) top_k_logp = mask_finished_scores(top_k_logp, end_flag)
@ -708,7 +705,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
List[List[int]]: transcripts. List[List[int]]: transcripts.
""" """
batch_size = feats.shape[0] batch_size = feats.shape[0]
print("start to decode the audio feat")
if decoding_method in ['ctc_prefix_beam_search', if decoding_method in ['ctc_prefix_beam_search',
'attention_rescoring'] and batch_size > 1: 'attention_rescoring'] and batch_size > 1:
logger.error( logger.error(
@ -716,7 +712,6 @@ class U2BaseModel(ASRInterface, nn.Layer):
) )
logger.error(f"current batch_size is {batch_size}") logger.error(f"current batch_size is {batch_size}")
sys.exit(1) sys.exit(1)
print(f"use the {decoding_method} to decode the audio feat")
if decoding_method == 'attention': if decoding_method == 'attention':
hyps = self.recognize( hyps = self.recognize(
feats, feats,

Loading…
Cancel
Save