From 43aad7a01873489daeb5f620f50b1ef3012702d9 Mon Sep 17 00:00:00 2001 From: Junkun Date: Fri, 14 Jan 2022 13:59:53 -0800 Subject: [PATCH 1/3] beam search with optimality guarantees --- paddlespeech/s2t/exps/u2_st/model.py | 4 +- paddlespeech/s2t/models/u2_st/u2_st.py | 145 ++++++++++++------------- 2 files changed, 75 insertions(+), 74 deletions(-) diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index b03ca38b6..b642e9337 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -285,7 +285,7 @@ class U2STTrainer(Trainer): subsampling_factor=1, load_aux_output=load_transcript, num_encs=1, - dist_sampler=True) + dist_sampler=False) logger.info("Setup train/valid Dataloader!") else: # test dataset, return raw text @@ -408,6 +408,7 @@ class U2STTester(U2STTrainer): decoding_method=decode_cfg.decoding_method, beam_size=decode_cfg.beam_size, word_reward=decode_cfg.word_reward, + maxlen_ratio=decode_cfg.maxlen_ratio, decoding_chunk_size=decode_cfg.decoding_chunk_size, num_decoding_left_chunks=decode_cfg.num_decoding_left_chunks, simulate_streaming=decode_cfg.simulate_streaming) @@ -435,6 +436,7 @@ class U2STTester(U2STTrainer): decoding_method=decode_cfg.decoding_method, beam_size=decode_cfg.beam_size, word_reward=decode_cfg.word_reward, + maxlen_ratio=decode_cfg.maxlen_ratio, decoding_chunk_size=decode_cfg.decoding_chunk_size, num_decoding_left_chunks=decode_cfg.num_decoding_left_chunks, simulate_streaming=decode_cfg.simulate_streaming) diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 79ca423f8..211813f63 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -264,14 +264,17 @@ class U2STBaseModel(nn.Layer): speech_lengths: paddle.Tensor, beam_size: int=10, word_reward: float=0.0, + maxlen_ratio: float=0.5, decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, simulate_streaming: bool=False, ) -> paddle.Tensor: - """ Apply beam search on attention decoder + """ Apply beam search on attention decoder with length penalty Args: speech (paddle.Tensor): (batch, max_len, feat_dim) speech_length (paddle.Tensor): (batch, ) beam_size (int): beam size for beam search + word_reward (float): word reward used in beam search + maxlen_ratio (float): max length ratio to bound the length of translated text decoding_chunk_size (int): decoding chunk for dynamic chunk trained model. <0: for decoding, use full chunk. @@ -284,90 +287,84 @@ class U2STBaseModel(nn.Layer): """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 + assert speech.shape[0] == 1 device = speech.place - batch_size = speech.shape[0] # Let's assume B = batch_size and N = beam_size - # 1. Encoder + # 1. Encoder and init hypothesis encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) - maxlen = encoder_out.shape[1] - encoder_dim = encoder_out.shape[2] - running_size = batch_size * beam_size - encoder_out = encoder_out.unsqueeze(1).repeat(1, beam_size, 1, 1).view( - running_size, maxlen, encoder_dim) # (B*N, maxlen, encoder_dim) - encoder_mask = encoder_mask.unsqueeze(1).repeat( - 1, beam_size, 1, 1).view(running_size, 1, - maxlen) # (B*N, 1, max_len) - - hyps = paddle.ones( - [running_size, 1], dtype=paddle.long).fill_(self.sos) # (B*N, 1) - # log scale score - scores = paddle.to_tensor( - [0.0] + [-float('inf')] * (beam_size - 1), dtype=paddle.float) - scores = scores.to(device).repeat(batch_size).unsqueeze(1).to( - device) # (B*N, 1) - end_flag = paddle.zeros_like(scores, dtype=paddle.bool) # (B*N, 1) - cache: Optional[List[paddle.Tensor]] = None + + maxlen = max(int(encoder_out.shape[1] * maxlen_ratio), 5) + + hyp = {"score": 0.0, "yseq": [self.sos], "cache": None} + hyps = [hyp] + ended_hyps = [] + cur_best_score = -float("inf") + cache = None + # 2. Decoder forward step by step for i in range(1, maxlen + 1): - # Stop if all batch and all beam produce eos - # TODO(Hui Zhang): if end_flag.sum() == running_size: - if end_flag.cast(paddle.int64).sum() == running_size: - break - - # 2.1 Forward decoder step - hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( - running_size, 1, 1).to(device) # (B*N, i, i) - # logp: (B*N, vocab) + ys = paddle.ones((len(hyps), i), dtype=paddle.long) + + if hyps[0]["cache"] is not None: + cache = [paddle.ones((len(hyps), i-1, hyps[0]["cache"][0].shape[-1]), dtype=paddle.float32) for _ in range(len(hyps[0]["cache"]))] + for j, hyp in enumerate(hyps): + ys[j, :] = paddle.to_tensor(hyp["yseq"]) + if hyps[0]["cache"] is not None: + for k in range(len(cache)): + cache[k][j] = hyps[j]["cache"][k] + ys_mask = subsequent_mask(i).unsqueeze(0).to(device) + logp, cache = self.st_decoder.forward_one_step( - encoder_out, encoder_mask, hyps, hyps_mask, cache) - - # 2.2 First beam prune: select topk best prob at current time - top_k_logp, top_k_index = logp.topk(beam_size) # (B*N, N) - top_k_logp += word_reward - top_k_logp = mask_finished_scores(top_k_logp, end_flag) - top_k_index = mask_finished_preds(top_k_index, end_flag, self.eos) - - # 2.3 Seconde beam prune: select topk score with history - scores = scores + top_k_logp # (B*N, N), broadcast add - scores = scores.view(batch_size, beam_size * beam_size) # (B, N*N) - scores, offset_k_index = scores.topk(k=beam_size) # (B, N) - scores = scores.view(-1, 1) # (B*N, 1) - - # 2.4. Compute base index in top_k_index, - # regard top_k_index as (B*N*N),regard offset_k_index as (B*N), - # then find offset_k_index in top_k_index - base_k_index = paddle.arange(batch_size).view(-1, 1).repeat( - 1, beam_size) # (B, N) - base_k_index = base_k_index * beam_size * beam_size - best_k_index = base_k_index.view(-1) + offset_k_index.view( - -1) # (B*N) - - # 2.5 Update best hyps - best_k_pred = paddle.index_select( - top_k_index.view(-1), index=best_k_index, axis=0) # (B*N) - best_hyps_index = best_k_index // beam_size - last_best_k_hyps = paddle.index_select( - hyps, index=best_hyps_index, axis=0) # (B*N, i) - hyps = paddle.cat( - (last_best_k_hyps, best_k_pred.view(-1, 1)), - dim=1) # (B*N, i+1) - - # 2.6 Update end flag - end_flag = paddle.eq(hyps[:, -1], self.eos).view(-1, 1) + encoder_out.repeat(len(hyps), 1, 1), encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache) + + hyps_best_kept = [] + for j, hyp in enumerate(hyps): + top_k_logp, top_k_index = logp[j : j + 1].topk(beam_size) + + for b in range(beam_size): + new_hyp = {} + new_hyp["score"] = hyp["score"] + float(top_k_logp[0, b]) + new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) + new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] + new_hyp["yseq"][len(hyp["yseq"])] = int(top_k_index[0, b]) + new_hyp["cache"] = [cache_[j] for cache_ in cache] + # will be (2 x beam) hyps at most + hyps_best_kept.append(new_hyp) + + hyps_best_kept = sorted( + hyps_best_kept, key=lambda x: -x["score"])[:beam_size] + + # sort and get nbest + hyps = hyps_best_kept + if i == maxlen: + for hyp in hyps: + hyp["yseq"].append(self.eos) + + # finalize the ended hypotheses with word reward (by length) + remained_hyps = [] + for hyp in hyps: + if hyp["yseq"][-1] == self.eos: + hyp["score"] += (i - 1) * word_reward + cur_best_score = max(cur_best_score, hyp["score"]) + ended_hyps.append(hyp) + else: + # stop while guarantee the optimality + if hyp["score"] + maxlen * word_reward > cur_best_score: + remained_hyps.append(hyp) + + # stop predition when there is no unended hypothesis + if not remained_hyps: + break + hyps = remained_hyps # 3. Select best of best - scores = scores.view(batch_size, beam_size) - # TODO: length normalization - best_index = paddle.argmax(scores, axis=-1).long() # (B) - best_hyps_index = best_index + paddle.arange( - batch_size, dtype=paddle.long) * beam_size - best_hyps = paddle.index_select(hyps, index=best_hyps_index, axis=0) - best_hyps = best_hyps[:, 1:] - return best_hyps + best_hyp = max(ended_hyps, key=lambda x: x["score"]) + + return paddle.to_tensor([best_hyp["yseq"][1:]]) # @jit.to_static def subsampling_rate(self) -> int: @@ -472,6 +469,7 @@ class U2STBaseModel(nn.Layer): decoding_method: str, beam_size: int, word_reward: float=0.0, + maxlen_ratio: float=0.5, decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, simulate_streaming: bool=False): @@ -507,6 +505,7 @@ class U2STBaseModel(nn.Layer): feats_lengths, beam_size=beam_size, word_reward=word_reward, + maxlen_ratio=maxlen_ratio, decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks, simulate_streaming=simulate_streaming) From f866059b744958b76e87fd342bc6dbbb000fec91 Mon Sep 17 00:00:00 2001 From: Junkun Date: Fri, 14 Jan 2022 14:58:23 -0800 Subject: [PATCH 2/3] config and formalize --- .../ted_en_zh/st0/conf/tuning/decode.yaml | 3 ++- .../ted_en_zh/st1/conf/tuning/decode.yaml | 3 ++- paddlespeech/s2t/models/u2_st/u2_st.py | 26 ++++++++++++------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/examples/ted_en_zh/st0/conf/tuning/decode.yaml b/examples/ted_en_zh/st0/conf/tuning/decode.yaml index ed081cf4a..7606ee35f 100644 --- a/examples/ted_en_zh/st0/conf/tuning/decode.yaml +++ b/examples/ted_en_zh/st0/conf/tuning/decode.yaml @@ -1,8 +1,9 @@ -batch_size: 5 +batch_size: 1 error_rate_type: char-bleu decoding_method: fullsentence # 'fullsentence', 'simultaneous' beam_size: 10 word_reward: 0.7 +maxlen_ratio: 0.3 decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. diff --git a/examples/ted_en_zh/st1/conf/tuning/decode.yaml b/examples/ted_en_zh/st1/conf/tuning/decode.yaml index d6104dbce..9f00dd764 100644 --- a/examples/ted_en_zh/st1/conf/tuning/decode.yaml +++ b/examples/ted_en_zh/st1/conf/tuning/decode.yaml @@ -1,9 +1,10 @@ -batch_size: 5 +batch_size: 1 error_rate_type: char-bleu decoding_method: fullsentence # 'fullsentence', 'simultaneous' beam_size: 10 word_reward: 0.7 +maxlen_ratio: 0.3 decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 211813f63..f92268eb7 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -308,28 +308,34 @@ class U2STBaseModel(nn.Layer): # 2. Decoder forward step by step for i in range(1, maxlen + 1): ys = paddle.ones((len(hyps), i), dtype=paddle.long) - + if hyps[0]["cache"] is not None: - cache = [paddle.ones((len(hyps), i-1, hyps[0]["cache"][0].shape[-1]), dtype=paddle.float32) for _ in range(len(hyps[0]["cache"]))] + cache = [ + paddle.ones( + (len(hyps), i - 1, hyp_cache.shape[-1]), + dtype=paddle.float32) + for hyp_cache in hyps[0]["cache"] + ] for j, hyp in enumerate(hyps): ys[j, :] = paddle.to_tensor(hyp["yseq"]) if hyps[0]["cache"] is not None: for k in range(len(cache)): cache[k][j] = hyps[j]["cache"][k] ys_mask = subsequent_mask(i).unsqueeze(0).to(device) - + logp, cache = self.st_decoder.forward_one_step( - encoder_out.repeat(len(hyps), 1, 1), encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache) + encoder_out.repeat(len(hyps), 1, 1), + encoder_mask.repeat(len(hyps), 1, 1), ys, ys_mask, cache) hyps_best_kept = [] for j, hyp in enumerate(hyps): - top_k_logp, top_k_index = logp[j : j + 1].topk(beam_size) + top_k_logp, top_k_index = logp[j:j + 1].topk(beam_size) for b in range(beam_size): new_hyp = {} new_hyp["score"] = hyp["score"] + float(top_k_logp[0, b]) new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) - new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] + new_hyp["yseq"][:len(hyp["yseq"])] = hyp["yseq"] new_hyp["yseq"][len(hyp["yseq"])] = int(top_k_index[0, b]) new_hyp["cache"] = [cache_[j] for cache_ in cache] # will be (2 x beam) hyps at most @@ -337,13 +343,13 @@ class U2STBaseModel(nn.Layer): hyps_best_kept = sorted( hyps_best_kept, key=lambda x: -x["score"])[:beam_size] - + # sort and get nbest hyps = hyps_best_kept if i == maxlen: for hyp in hyps: hyp["yseq"].append(self.eos) - + # finalize the ended hypotheses with word reward (by length) remained_hyps = [] for hyp in hyps: @@ -355,7 +361,7 @@ class U2STBaseModel(nn.Layer): # stop while guarantee the optimality if hyp["score"] + maxlen * word_reward > cur_best_score: remained_hyps.append(hyp) - + # stop predition when there is no unended hypothesis if not remained_hyps: break @@ -364,7 +370,7 @@ class U2STBaseModel(nn.Layer): # 3. Select best of best best_hyp = max(ended_hyps, key=lambda x: x["score"]) - return paddle.to_tensor([best_hyp["yseq"][1:]]) + return paddle.to_tensor([best_hyp["yseq"][1:]]) # @jit.to_static def subsampling_rate(self) -> int: From 44408e5211b8c1457351c273d959591b518d8aeb Mon Sep 17 00:00:00 2001 From: Junkun Date: Fri, 14 Jan 2022 16:16:43 -0800 Subject: [PATCH 3/3] sync the variable name to others --- examples/ted_en_zh/st0/conf/tuning/decode.yaml | 2 +- examples/ted_en_zh/st1/conf/tuning/decode.yaml | 2 +- paddlespeech/s2t/exps/u2_st/model.py | 4 ++-- paddlespeech/s2t/models/u2_st/u2_st.py | 10 +++++----- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/ted_en_zh/st0/conf/tuning/decode.yaml b/examples/ted_en_zh/st0/conf/tuning/decode.yaml index 7606ee35f..7d8d1daf1 100644 --- a/examples/ted_en_zh/st0/conf/tuning/decode.yaml +++ b/examples/ted_en_zh/st0/conf/tuning/decode.yaml @@ -3,7 +3,7 @@ error_rate_type: char-bleu decoding_method: fullsentence # 'fullsentence', 'simultaneous' beam_size: 10 word_reward: 0.7 -maxlen_ratio: 0.3 +maxlenratio: 0.3 decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. diff --git a/examples/ted_en_zh/st1/conf/tuning/decode.yaml b/examples/ted_en_zh/st1/conf/tuning/decode.yaml index 9f00dd764..4f10acf74 100644 --- a/examples/ted_en_zh/st1/conf/tuning/decode.yaml +++ b/examples/ted_en_zh/st1/conf/tuning/decode.yaml @@ -4,7 +4,7 @@ error_rate_type: char-bleu decoding_method: fullsentence # 'fullsentence', 'simultaneous' beam_size: 10 word_reward: 0.7 -maxlen_ratio: 0.3 +maxlenratio: 0.3 decoding_chunk_size: -1 # decoding chunk size. Defaults to -1. # <0: for decoding, use full chunk. # >0: for decoding, use fixed chunk size as set. diff --git a/paddlespeech/s2t/exps/u2_st/model.py b/paddlespeech/s2t/exps/u2_st/model.py index b642e9337..6a32eda77 100644 --- a/paddlespeech/s2t/exps/u2_st/model.py +++ b/paddlespeech/s2t/exps/u2_st/model.py @@ -408,7 +408,7 @@ class U2STTester(U2STTrainer): decoding_method=decode_cfg.decoding_method, beam_size=decode_cfg.beam_size, word_reward=decode_cfg.word_reward, - maxlen_ratio=decode_cfg.maxlen_ratio, + maxlenratio=decode_cfg.maxlenratio, decoding_chunk_size=decode_cfg.decoding_chunk_size, num_decoding_left_chunks=decode_cfg.num_decoding_left_chunks, simulate_streaming=decode_cfg.simulate_streaming) @@ -436,7 +436,7 @@ class U2STTester(U2STTrainer): decoding_method=decode_cfg.decoding_method, beam_size=decode_cfg.beam_size, word_reward=decode_cfg.word_reward, - maxlen_ratio=decode_cfg.maxlen_ratio, + maxlenratio=decode_cfg.maxlenratio, decoding_chunk_size=decode_cfg.decoding_chunk_size, num_decoding_left_chunks=decode_cfg.num_decoding_left_chunks, simulate_streaming=decode_cfg.simulate_streaming) diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index f92268eb7..bc76de7ad 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -264,7 +264,7 @@ class U2STBaseModel(nn.Layer): speech_lengths: paddle.Tensor, beam_size: int=10, word_reward: float=0.0, - maxlen_ratio: float=0.5, + maxlenratio: float=0.5, decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, simulate_streaming: bool=False, ) -> paddle.Tensor: @@ -274,7 +274,7 @@ class U2STBaseModel(nn.Layer): speech_length (paddle.Tensor): (batch, ) beam_size (int): beam size for beam search word_reward (float): word reward used in beam search - maxlen_ratio (float): max length ratio to bound the length of translated text + maxlenratio (float): max length ratio to bound the length of translated text decoding_chunk_size (int): decoding chunk for dynamic chunk trained model. <0: for decoding, use full chunk. @@ -297,7 +297,7 @@ class U2STBaseModel(nn.Layer): num_decoding_left_chunks, simulate_streaming) # (B, maxlen, encoder_dim) - maxlen = max(int(encoder_out.shape[1] * maxlen_ratio), 5) + maxlen = max(int(encoder_out.shape[1] * maxlenratio), 5) hyp = {"score": 0.0, "yseq": [self.sos], "cache": None} hyps = [hyp] @@ -475,7 +475,7 @@ class U2STBaseModel(nn.Layer): decoding_method: str, beam_size: int, word_reward: float=0.0, - maxlen_ratio: float=0.5, + maxlenratio: float=0.5, decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, simulate_streaming: bool=False): @@ -511,7 +511,7 @@ class U2STBaseModel(nn.Layer): feats_lengths, beam_size=beam_size, word_reward=word_reward, - maxlen_ratio=maxlen_ratio, + maxlenratio=maxlenratio, decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks, simulate_streaming=simulate_streaming)