add word reward into beam search.

pull/983/head
Junkun 3 years ago
parent e5edc83a43
commit 7c8843448c

@ -99,6 +99,7 @@ decoding:
alpha: 2.5 alpha: 2.5
beta: 0.3 beta: 0.3
beam_size: 10 beam_size: 10
word_reward: 0.7
cutoff_prob: 1.0 cutoff_prob: 1.0
cutoff_top_n: 0 cutoff_top_n: 0
num_proc_bsearch: 8 num_proc_bsearch: 8

@ -441,10 +441,7 @@ class U2STTester(U2STTrainer):
"".join(chr(t) for t in text[:text_len]) "".join(chr(t) for t in text[:text_len])
for text, text_len in zip(texts, texts_len) for text, text_len in zip(texts, texts_len)
] ]
# from IPython import embed
# import os
# embed()
# os._exit(0)
hyps = self.model.decode( hyps = self.model.decode(
audio, audio,
audio_len, audio_len,
@ -458,6 +455,7 @@ class U2STTester(U2STTrainer):
cutoff_top_n=cfg.cutoff_top_n, cutoff_top_n=cfg.cutoff_top_n,
num_processes=cfg.num_proc_bsearch, num_processes=cfg.num_proc_bsearch,
ctc_weight=cfg.ctc_weight, ctc_weight=cfg.ctc_weight,
word_reward=cfg.word_reward,
decoding_chunk_size=cfg.decoding_chunk_size, decoding_chunk_size=cfg.decoding_chunk_size,
num_decoding_left_chunks=cfg.num_decoding_left_chunks, num_decoding_left_chunks=cfg.num_decoding_left_chunks,
simulate_streaming=cfg.simulate_streaming) simulate_streaming=cfg.simulate_streaming)

@ -315,6 +315,7 @@ class U2STBaseModel(nn.Layer):
speech: paddle.Tensor, speech: paddle.Tensor,
speech_lengths: paddle.Tensor, speech_lengths: paddle.Tensor,
beam_size: int=10, beam_size: int=10,
word_reward: float=0.0,
decoding_chunk_size: int=-1, decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1, num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False, ) -> paddle.Tensor: simulate_streaming: bool=False, ) -> paddle.Tensor:
@ -378,6 +379,7 @@ class U2STBaseModel(nn.Layer):
# 2.2 First beam prune: select topk best prob at current time # 2.2 First beam prune: select topk best prob at current time
top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N)
top_k_logp += word_reward
top_k_logp = mask_finished_scores(top_k_logp, end_flag) top_k_logp = mask_finished_scores(top_k_logp, end_flag)
top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos)
@ -528,6 +530,7 @@ class U2STBaseModel(nn.Layer):
cutoff_top_n: int, cutoff_top_n: int,
num_processes: int, num_processes: int,
ctc_weight: float=0.0, ctc_weight: float=0.0,
word_reward: float=0.0,
decoding_chunk_size: int=-1, decoding_chunk_size: int=-1,
num_decoding_left_chunks: int=-1, num_decoding_left_chunks: int=-1,
simulate_streaming: bool=False): simulate_streaming: bool=False):
@ -569,6 +572,7 @@ class U2STBaseModel(nn.Layer):
feats, feats,
feats_lengths, feats_lengths,
beam_size=beam_size, beam_size=beam_size,
word_reward=word_reward,
decoding_chunk_size=decoding_chunk_size, decoding_chunk_size=decoding_chunk_size,
num_decoding_left_chunks=num_decoding_left_chunks, num_decoding_left_chunks=num_decoding_left_chunks,
simulate_streaming=simulate_streaming) simulate_streaming=simulate_streaming)

Loading…
Cancel
Save