|
|
@ -315,6 +315,7 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
speech: paddle.Tensor,
|
|
|
|
speech: paddle.Tensor,
|
|
|
|
speech_lengths: paddle.Tensor,
|
|
|
|
speech_lengths: paddle.Tensor,
|
|
|
|
beam_size: int=10,
|
|
|
|
beam_size: int=10,
|
|
|
|
|
|
|
|
word_reward: float=0.0,
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
simulate_streaming: bool=False, ) -> paddle.Tensor:
|
|
|
|
simulate_streaming: bool=False, ) -> paddle.Tensor:
|
|
|
@ -378,6 +379,7 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
# 2.2 First beam prune: select topk best prob at current time
|
|
|
|
# 2.2 First beam prune: select topk best prob at current time
|
|
|
|
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
|
|
|
|
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
|
|
|
|
|
|
|
|
top_k_logp += word_reward
|
|
|
|
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
|
|
|
|
top_k_logp = mask_finished_scores(top_k_logp, end_flag)
|
|
|
|
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
|
|
|
|
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
|
|
|
|
|
|
|
|
|
|
|
@ -528,6 +530,7 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
cutoff_top_n: int,
|
|
|
|
cutoff_top_n: int,
|
|
|
|
num_processes: int,
|
|
|
|
num_processes: int,
|
|
|
|
ctc_weight: float=0.0,
|
|
|
|
ctc_weight: float=0.0,
|
|
|
|
|
|
|
|
word_reward: float=0.0,
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
decoding_chunk_size: int=-1,
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
num_decoding_left_chunks: int=-1,
|
|
|
|
simulate_streaming: bool=False):
|
|
|
|
simulate_streaming: bool=False):
|
|
|
@ -569,6 +572,7 @@ class U2STBaseModel(nn.Layer):
|
|
|
|
feats,
|
|
|
|
feats,
|
|
|
|
feats_lengths,
|
|
|
|
feats_lengths,
|
|
|
|
beam_size=beam_size,
|
|
|
|
beam_size=beam_size,
|
|
|
|
|
|
|
|
word_reward=word_reward,
|
|
|
|
decoding_chunk_size=decoding_chunk_size,
|
|
|
|
decoding_chunk_size=decoding_chunk_size,
|
|
|
|
num_decoding_left_chunks=num_decoding_left_chunks,
|
|
|
|
num_decoding_left_chunks=num_decoding_left_chunks,
|
|
|
|
simulate_streaming=simulate_streaming)
|
|
|
|
simulate_streaming=simulate_streaming)
|
|
|
|