From 1a56a6e42bccedee0285d8a22205d802878bab92 Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Tue, 20 Sep 2022 03:42:07 +0000 Subject: [PATCH 1/6] add bitransformer decoder, test=asr --- paddlespeech/audio/utils/tensor_utils.py | 41 ++++-- paddlespeech/s2t/exps/u2/bin/test_wav.py | 3 +- paddlespeech/s2t/exps/u2/model.py | 9 +- paddlespeech/s2t/models/u2/u2.py | 152 ++++++++++++++++++++--- paddlespeech/s2t/modules/decoder.py | 128 ++++++++++++++++++- 5 files changed, 302 insertions(+), 31 deletions(-) diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py index 16f60810..ac86757b 100644 --- a/paddlespeech/audio/utils/tensor_utils.py +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -31,7 +31,6 @@ def has_tensor(val): return True elif isinstance(val, dict): for k, v in val.items(): - print(k) if has_tensor(v): return True else: @@ -143,14 +142,15 @@ def add_sos_eos(ys_pad: paddle.Tensor, sos: int, eos: int, [ 7, 8, 9, 11, -1, -1]]) """ # TODO(Hui Zhang): using comment code, - #_sos = paddle.to_tensor( - # [sos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) - #_eos = paddle.to_tensor( - # [eos], dtype=paddle.long, stop_gradient=True, place=ys_pad.place) - #ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys - #ys_in = [paddle.cat([_sos, y], dim=0) for y in ys] - #ys_out = [paddle.cat([y, _eos], dim=0) for y in ys] - #return pad_sequence(ys_in, padding_value=eos), pad_sequence(ys_out, padding_value=ignore_id) + # _sos = paddle.to_tensor( + # [sos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place) + # _eos = paddle.to_tensor( + # [eos], dtype=ys_pad.dtype, stop_gradient=True, place=ys_pad.place) + # ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys + # ys_in = [paddle.concat([_sos, y], axis=0) for y in ys] + # ys_out = [paddle.concat([y, _eos], axis=0) for y in ys] + # return pad_sequence(ys_in, padding_value=eos).transpose([1,0]), pad_sequence(ys_out, padding_value=ignore_id).transpose([1,0]) + B = ys_pad.shape[0] _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos _eos = paddle.ones([B, 1], dtype=ys_pad.dtype) * eos @@ -190,3 +190,26 @@ def th_accuracy(pad_outputs: paddle.Tensor, # denominator = paddle.sum(mask) denominator = paddle.sum(mask.type_as(pad_targets)) return float(numerator) / float(denominator) + + +def reverse_pad_list(ys_pad: paddle.Tensor, + ys_lens: paddle.Tensor, + pad_value: float=-1.0) -> paddle.Tensor: + """Reverse padding for the list of tensors. + Args: + ys_pad (tensor): The padded tensor (B, Tokenmax). + ys_lens (tensor): The lens of token seqs (B) + pad_value (int): Value for padding. + Returns: + Tensor: Padded tensor (B, Tokenmax). + Examples: + >>> x + tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) + >>> pad_list(x, 0) + tensor([[4, 3, 2, 1], + [7, 6, 5, 0], + [9, 8, 0, 0]]) + """ + r_ys_pad = pad_sequence([(paddle.flip(y.int()[:i], [0])) + for y, i in zip(ys_pad, ys_lens)], True, pad_value) + return r_ys_pad diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index 887ec7a6..51b72209 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -89,7 +89,8 @@ class U2Infer(): ctc_weight=decode_config.ctc_weight, decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, - simulate_streaming=decode_config.simulate_streaming) + simulate_streaming=decode_config.simulate_streaming, + reverse_weight=self.config.model_conf.reverse_weight) rsl = result_transcripts[0][0] utt = Path(self.audio_file).name logger.info(f"hyp: {utt} {result_transcripts[0][0]}") diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index db60083b..a7ccba48 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -250,10 +250,12 @@ class U2Trainer(Trainer): model_conf.output_dim = self.train_loader.vocab_size else: model_conf.input_dim = self.test_loader.feat_dim - model_conf.output_dim = self.test_loader.vocab_size + model_conf.output_dim = 5538 model = U2Model.from_config(model_conf) - + # params = model.state_dict() + # paddle.save(params, 'for_torch/test.pdparams') + # exit() if self.parallel: model = paddle.DataParallel(model) @@ -350,7 +352,8 @@ class U2Tester(U2Trainer): ctc_weight=decode_config.ctc_weight, decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, - simulate_streaming=decode_config.simulate_streaming) + simulate_streaming=decode_config.simulate_streaming, + reverse_weight=self.config.model_conf.reverse_weight) decode_time = time.time() - start_time for utt, target, result, rec_tids in zip( diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 813e1e52..84c0e5b5 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -31,6 +31,7 @@ from paddle import nn from paddlespeech.audio.utils.tensor_utils import add_sos_eos from paddlespeech.audio.utils.tensor_utils import pad_sequence +from paddlespeech.audio.utils.tensor_utils import reverse_pad_list from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer from paddlespeech.s2t.frontend.utility import IGNORE_ID @@ -38,6 +39,7 @@ from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.models.asr_interface import ASRInterface from paddlespeech.s2t.modules.cmvn import GlobalCMVN from paddlespeech.s2t.modules.ctc import CTCDecoderBase +from paddlespeech.s2t.modules.decoder import BiTransformerDecoder from paddlespeech.s2t.modules.decoder import TransformerDecoder from paddlespeech.s2t.modules.encoder import ConformerEncoder from paddlespeech.s2t.modules.encoder import TransformerEncoder @@ -69,6 +71,7 @@ class U2BaseModel(ASRInterface, nn.Layer): ctc: CTCDecoderBase, ctc_weight: float=0.5, ignore_id: int=IGNORE_ID, + reverse_weight: float=0.0, lsm_weight: float=0.0, length_normalized_loss: bool=False, **kwargs): @@ -82,6 +85,7 @@ class U2BaseModel(ASRInterface, nn.Layer): self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight + self.reverse_weight = reverse_weight self.encoder = encoder self.decoder = decoder @@ -171,12 +175,21 @@ class U2BaseModel(ASRInterface, nn.Layer): self.ignore_id) ys_in_lens = ys_pad_lens + 1 + r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id)) + r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos, + self.ignore_id) # 1. Forward decoder - decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, - ys_in_lens) + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, ys_in_pad, ys_in_lens, r_ys_in_pad, + self.reverse_weight) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) + r_loss_att = paddle.to_tensor(0.0) + if self.reverse_weight > 0.0: + r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad) + loss_att = loss_att * (1 - self.reverse_weight + ) + r_loss_att * self.reverse_weight acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, @@ -359,6 +372,7 @@ class U2BaseModel(ASRInterface, nn.Layer): # Let's assume B = batch_size # encoder_out: (B, maxlen, encoder_dim) # encoder_mask: (B, 1, Tmax) + encoder_out, encoder_mask = self._forward_encoder( speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks, simulate_streaming) @@ -500,7 +514,8 @@ class U2BaseModel(ASRInterface, nn.Layer): decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, ctc_weight: float=0.0, - simulate_streaming: bool=False, ) -> List[int]: + simulate_streaming: bool=False, + reverse_weight: float=0.0, ) -> List[int]: """ Apply attention rescoring decoding, CTC prefix beam search is applied first to get nbest, then we resoring the nbest on attention decoder with corresponding encoder out @@ -520,6 +535,9 @@ class U2BaseModel(ASRInterface, nn.Layer): """ assert speech.shape[0] == speech_lengths.shape[0] assert decoding_chunk_size != 0 + if reverse_weight > 0.0: + # decoder should be a bitransformer decoder if reverse_weight > 0.0 + assert hasattr(self.decoder, 'right_decoder') device = speech.place batch_size = speech.shape[0] # For attention rescoring we only support batch_size=1 @@ -541,6 +559,7 @@ class U2BaseModel(ASRInterface, nn.Layer): hyp_content, place=device, dtype=paddle.long) hyp_list.append(hyp_content) hyps_pad = pad_sequence(hyp_list, True, self.ignore_id) + ori_hyps_pad = hyps_pad hyps_lens = paddle.to_tensor( [len(hyp[0]) for hyp in hyps], place=device, dtype=paddle.long) # (beam_size,) @@ -550,13 +569,24 @@ class U2BaseModel(ASRInterface, nn.Layer): encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) - decoder_out, _ = self.decoder( - encoder_out, encoder_mask, hyps_pad, - hyps_lens) # (beam_size, max_hyps_len, vocab_size) + + # used for right to left decoder + r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens - 1, + self.ignore_id) + r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, + self.ignore_id) + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, + reverse_weight) # (beam_size, max_hyps_len, vocab_size) # ctc score in ln domain decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) decoder_out = decoder_out.numpy() + # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a + # conventional transformer decoder. + r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1) + r_decoder_out = r_decoder_out.numpy() + # Only use decoder score for rescoring best_score = -float('inf') best_index = 0 @@ -567,6 +597,12 @@ class U2BaseModel(ASRInterface, nn.Layer): score += decoder_out[i][j][w] # last decoder output token is `eos`, for laste decoder input token. score += decoder_out[i][len(hyp[0])][self.eos] + if reverse_weight > 0: + r_score = 0.0 + for j, w in enumerate(hyp[0]): + r_score += r_decoder_out[i][len(hyp[0]) - j - 1][w] + r_score += r_decoder_out[i][len(hyp[0])][self.eos] + score = score * (1 - reverse_weight) + r_score * reverse_weight # add ctc score (which in ln domain) score += hyp[1] * ctc_weight if score > best_score: @@ -653,12 +689,24 @@ class U2BaseModel(ASRInterface, nn.Layer): """ return self.ctc.log_softmax(xs) + @jit.to_static + def is_bidirectional_decoder(self) -> bool: + """ + Returns: + torch.Tensor: decoder output + """ + if hasattr(self.decoder, 'right_decoder'): + return True + else: + return False + @jit.to_static def forward_attention_decoder( self, hyps: paddle.Tensor, hyps_lens: paddle.Tensor, - encoder_out: paddle.Tensor, ) -> paddle.Tensor: + encoder_out: paddle.Tensor, + reverse_weight: float=0, ) -> paddle.Tensor: """ Export interface for c++ call, forward decoder with multiple hypothesis from ctc prefix beam search and one encoder output Args: @@ -676,11 +724,75 @@ class U2BaseModel(ASRInterface, nn.Layer): # (B, 1, T) encoder_mask = paddle.ones( [num_hyps, 1, encoder_out.shape[1]], dtype=paddle.bool) + + # input for right to left decoder + # this hyps_lens has count token, we need minus it. + r_hyps_lens = hyps_lens - 1 + # this hyps has included token, so it should be + # convert the original hyps. + r_hyps = hyps[:, 1:] # (num_hyps, max_hyps_len, vocab_size) + + # Equal to: + # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) + # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) + max_len = paddle.max(r_hyps_lens) + index_range = paddle.arange(0, max_len, 1) + seq_len_expand = r_hyps_lens.unsqueeze(1) + seq_mask = seq_len_expand > index_range # (beam, max_len) + + index = (seq_len_expand - 1) - index_range # (beam, max_len) + # >>> index + # >>> tensor([[ 2, 1, 0], + # >>> [ 2, 1, 0], + # >>> [ 0, -1, -2]]) + index = index * seq_mask + + # >>> index + # >>> tensor([[2, 1, 0], + # >>> [2, 1, 0], + # >>> [0, 0, 0]]) + def paddle_gather(x, dim, index): + index_shape = index.shape + index_flatten = index.flatten() + if dim < 0: + dim = len(x.shape) + dim + nd_index = [] + for k in range(len(x.shape)): + if k == dim: + nd_index.append(index_flatten) + else: + reshape_shape = [1] * len(x.shape) + reshape_shape[k] = x.shape[k] + x_arange = paddle.arange(x.shape[k], dtype=index.dtype) + x_arange = x_arange.reshape(reshape_shape) + dim_index = paddle.expand(x_arange, index_shape).flatten() + nd_index.append(dim_index) + ind2 = paddle.transpose(paddle.stack(nd_index), + [1, 0]).astype("int64") + paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape) + return paddle_out + + r_hyps = paddle_gather(r_hyps, 1, index) + # >>> r_hyps + # >>> tensor([[3, 2, 1], + # >>> [4, 8, 9], + # >>> [2, 2, 2]]) + r_hyps = paddle.where(seq_mask, r_hyps, self.eos) + # >>> r_hyps + # >>> tensor([[3, 2, 1], + # >>> [4, 8, 9], + # >>> [2, eos, eos]]) + r_hyps = torch.concat([hyps[:, 0:1], r_hyps], axis=1) + # >>> r_hyps + # >>> tensor([[sos, 3, 2, 1], + # >>> [sos, 4, 8, 9], + # >>> [sos, 2, eos, eos]]) decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, - hyps_lens) + hyps_lens, r_hyps, reverse_weight) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) - return decoder_out + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + return decoder_out, r_decoder_out @paddle.no_grad() def decode(self, @@ -692,7 +804,8 @@ class U2BaseModel(ASRInterface, nn.Layer): ctc_weight: float=0.0, decoding_chunk_size: int=-1, num_decoding_left_chunks: int=-1, - simulate_streaming: bool=False): + simulate_streaming: bool=False, + reverse_weight: float=0.0): """u2 decoding. Args: @@ -801,7 +914,6 @@ class U2Model(U2DecodeModel): with DefaultInitializerContext(init_type): vocab_size, encoder, decoder, ctc = U2Model._init_from_config( configs) - super().__init__( vocab_size=vocab_size, encoder=encoder, @@ -851,10 +963,20 @@ class U2Model(U2DecodeModel): raise ValueError(f"not support encoder type:{encoder_type}") # decoder - decoder = TransformerDecoder(vocab_size, - encoder.output_size(), - **configs['decoder_conf']) - + decoder_type = configs.get('decoder', 'transformer') + logger.debug(f"U2 Decoder type: {decoder_type}") + if decoder_type == 'transformer': + decoder = TransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + elif decoder_type == 'bitransformer': + assert 0.0 < configs['model_conf']['reverse_weight'] < 1.0 + assert configs['decoder_conf']['r_num_blocks'] > 0 + decoder = BiTransformerDecoder(vocab_size, + encoder.output_size(), + **configs['decoder_conf']) + else: + raise ValueError(f"not support decoder type:{decoder_type}") # ctc decoder and ctc loss model_conf = configs.get('model_conf', dict()) dropout_rate = model_conf.get('ctc_dropout_rate', 0.0) diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index ccc8482d..2052a19e 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -35,7 +35,6 @@ from paddlespeech.s2t.modules.mask import make_xs_mask from paddlespeech.s2t.modules.mask import subsequent_mask from paddlespeech.s2t.modules.positionwise_feed_forward import PositionwiseFeedForward from paddlespeech.s2t.utils.log import Log - logger = Log(__name__).getlog() __all__ = ["TransformerDecoder"] @@ -116,13 +115,19 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): memory: paddle.Tensor, memory_mask: paddle.Tensor, ys_in_pad: paddle.Tensor, - ys_in_lens: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor]: + ys_in_lens: paddle.Tensor, + r_ys_in_pad: paddle.Tensor=paddle.empty([0]), + reverse_weight: float=0.0) -> Tuple[paddle.Tensor, paddle.Tensor]: """Forward decoder. Args: memory: encoded memory, float32 (batch, maxlen_in, feat) memory_mask: encoder memory mask, (batch, 1, maxlen_in) ys_in_pad: padded input token ids, int64 (batch, maxlen_out) ys_in_lens: input lengths of this batch (batch) + r_ys_in_pad: not used in transformer decoder, in order to unify api + with bidirectional decoder + reverse_weight: not used in transformer decoder, in order to unify + api with bidirectional decode Returns: (tuple): tuple containing: x: decoded token score before softmax (batch, maxlen_out, vocab_size) @@ -151,7 +156,7 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): # TODO(Hui Zhang): reduce_sum not support bool type # olens = tgt_mask.sum(1) olens = tgt_mask.astype(paddle.int).sum(1) - return x, olens + return x, paddle.to_tensor(0.0), olens def forward_one_step( self, @@ -251,3 +256,120 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] return logp, state_list + + +class BiTransformerDecoder(BatchScorerInterface, nn.Layer): + """Base class of Transfomer decoder module. + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the hidden units number of position-wise feedforward + num_blocks: the number of decoder blocks + r_num_blocks: the number of right to left decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: + True: use layer_norm before each sub-block of a layer. + False: use layer_norm after each sub-block of a layer. + concat_after: whether to concat attention layer's input and output + True: x -> x + linear(concat(x, att(x))) + False: x -> x + att(x) + """ + + def __init__(self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int=4, + linear_units: int=2048, + num_blocks: int=6, + r_num_blocks: int=0, + dropout_rate: float=0.1, + positional_dropout_rate: float=0.1, + self_attention_dropout_rate: float=0.0, + src_attention_dropout_rate: float=0.0, + input_layer: str="embed", + use_output_layer: bool=True, + normalize_before: bool=True, + concat_after: bool=False, + max_len: int=5000): + + assert check_argument_types() + + nn.Layer.__init__(self) + self.left_decoder = TransformerDecoder( + vocab_size, encoder_output_size, attention_heads, linear_units, + num_blocks, dropout_rate, positional_dropout_rate, + self_attention_dropout_rate, src_attention_dropout_rate, + input_layer, use_output_layer, normalize_before, concat_after, + max_len) + + self.right_decoder = TransformerDecoder( + vocab_size, encoder_output_size, attention_heads, linear_units, + r_num_blocks, dropout_rate, positional_dropout_rate, + self_attention_dropout_rate, src_attention_dropout_rate, + input_layer, use_output_layer, normalize_before, concat_after, + max_len) + + def forward( + self, + memory: paddle.Tensor, + memory_mask: paddle.Tensor, + ys_in_pad: paddle.Tensor, + ys_in_lens: paddle.Tensor, + r_ys_in_pad: paddle.Tensor, + reverse_weight: float=0.0, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Forward decoder. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoder memory mask, (batch, 1, maxlen_in) + ys_in_pad: padded input token ids, int64 (batch, maxlen_out) + ys_in_lens: input lengths of this batch (batch) + r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out), + used for right to left decoder + reverse_weight: used for right to left decoder + Returns: + (tuple): tuple containing: + x: decoded token score before softmax (batch, maxlen_out, + vocab_size) if use_output_layer is True, + r_x: x: decoded token score (right to left decoder) + before softmax (batch, maxlen_out, vocab_size) + if use_output_layer is True, + olens: (batch, ) + """ + l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad, + ys_in_lens) + r_x = paddle.to_tensor(0.0) + if reverse_weight > 0.0: + r_x, _, olens = self.right_decoder(memory, memory_mask, r_ys_in_pad, + ys_in_lens) + return l_x, r_x, olens + + def forward_one_step( + self, + memory: paddle.Tensor, + memory_mask: paddle.Tensor, + tgt: paddle.Tensor, + tgt_mask: paddle.Tensor, + cache: Optional[List[paddle.Tensor]]=None, + ) -> Tuple[paddle.Tensor, List[paddle.Tensor]]: + """Forward one step. + This is only used for decoding. + Args: + memory: encoded memory, float32 (batch, maxlen_in, feat) + memory_mask: encoded memory mask, (batch, 1, maxlen_in) + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + return self.left_decoder.forward_one_step(memory, memory_mask, tgt, + tgt_mask, cache) From ecbf324286c55125e5fd2712c16bedc22f1e51c9 Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Tue, 20 Sep 2022 05:28:02 +0000 Subject: [PATCH 2/6] support bitransformer decoder, test=asr --- paddlespeech/server/engine/asr/online/python/asr_engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddlespeech/server/engine/asr/online/python/asr_engine.py b/paddlespeech/server/engine/asr/online/python/asr_engine.py index 87d88ee6..4c7c4b37 100644 --- a/paddlespeech/server/engine/asr/online/python/asr_engine.py +++ b/paddlespeech/server/engine/asr/online/python/asr_engine.py @@ -613,7 +613,8 @@ class PaddleASRConnectionHanddler: encoder_out = self.encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) - decoder_out, _ = self.model.decoder( + + decoder_out, _, _ = self.model.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens) # (beam_size, max_hyps_len, vocab_size) # ctc score in ln domain From 455379b88eb6654917b8fb691b1c7750a1e3234a Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Tue, 20 Sep 2022 09:07:28 +0000 Subject: [PATCH 3/6] support bitransformer decoder --- paddlespeech/s2t/exps/u2/model.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index a7ccba48..99a0434d 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -250,12 +250,9 @@ class U2Trainer(Trainer): model_conf.output_dim = self.train_loader.vocab_size else: model_conf.input_dim = self.test_loader.feat_dim - model_conf.output_dim = 5538 + model_conf.output_dim = self.test_loader.vocab_size model = U2Model.from_config(model_conf) - # params = model.state_dict() - # paddle.save(params, 'for_torch/test.pdparams') - # exit() if self.parallel: model = paddle.DataParallel(model) @@ -319,6 +316,7 @@ class U2Tester(U2Trainer): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.vocab_list = self.text_feature.vocab_list + self.reverse_weight = getattr(config, 'reverse_weight', '0.0') def id2token(self, texts, texts_len, text_feature): """ ord() id to chr() chr """ @@ -353,7 +351,7 @@ class U2Tester(U2Trainer): decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, simulate_streaming=decode_config.simulate_streaming, - reverse_weight=self.config.model_conf.reverse_weight) + reverse_weight=self.reverse_weight) decode_time = time.time() - start_time for utt, target, result, rec_tids in zip( From 0a95689461c9337074ebeeb1bc015a82caf23bd3 Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Tue, 20 Sep 2022 09:36:59 +0000 Subject: [PATCH 4/6] support bitransformer decoder --- paddlespeech/s2t/exps/u2/bin/test_wav.py | 4 ++-- paddlespeech/s2t/exps/u2/model.py | 2 +- paddlespeech/s2t/models/u2/u2.py | 12 ++++++------ paddlespeech/s2t/modules/decoder.py | 5 ++--- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/paddlespeech/s2t/exps/u2/bin/test_wav.py b/paddlespeech/s2t/exps/u2/bin/test_wav.py index 51b72209..4588def0 100644 --- a/paddlespeech/s2t/exps/u2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/u2/bin/test_wav.py @@ -40,7 +40,7 @@ class U2Infer(): self.preprocess_conf = config.preprocess_config self.preprocess_args = {"train": False} self.preprocessing = Transformation(self.preprocess_conf) - + self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0) self.text_feature = TextFeaturizer( unit_type=config.unit_type, vocab=config.vocab_filepath, @@ -90,7 +90,7 @@ class U2Infer(): decoding_chunk_size=decode_config.decoding_chunk_size, num_decoding_left_chunks=decode_config.num_decoding_left_chunks, simulate_streaming=decode_config.simulate_streaming, - reverse_weight=self.config.model_conf.reverse_weight) + reverse_weight=self.reverse_weight) rsl = result_transcripts[0][0] utt = Path(self.audio_file).name logger.info(f"hyp: {utt} {result_transcripts[0][0]}") diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 99a0434d..84b7be32 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -316,7 +316,7 @@ class U2Tester(U2Trainer): vocab=self.config.vocab_filepath, spm_model_prefix=self.config.spm_model_prefix) self.vocab_list = self.text_feature.vocab_list - self.reverse_weight = getattr(config, 'reverse_weight', '0.0') + self.reverse_weight = getattr(config.model_conf, 'reverse_weight', 0.0) def id2token(self, texts, texts_len, text_feature): """ ord() id to chr() chr """ diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 84c0e5b5..e54a7afb 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -689,24 +689,24 @@ class U2BaseModel(ASRInterface, nn.Layer): """ return self.ctc.log_softmax(xs) - @jit.to_static + # @jit.to_static def is_bidirectional_decoder(self) -> bool: """ Returns: - torch.Tensor: decoder output + paddle.Tensor: decoder output """ if hasattr(self.decoder, 'right_decoder'): return True else: return False - @jit.to_static + # @jit.to_static def forward_attention_decoder( self, hyps: paddle.Tensor, hyps_lens: paddle.Tensor, encoder_out: paddle.Tensor, - reverse_weight: float=0, ) -> paddle.Tensor: + reverse_weight: float=0.0, ) -> paddle.Tensor: """ Export interface for c++ call, forward decoder with multiple hypothesis from ctc prefix beam search and one encoder output Args: @@ -783,7 +783,7 @@ class U2BaseModel(ASRInterface, nn.Layer): # >>> tensor([[3, 2, 1], # >>> [4, 8, 9], # >>> [2, eos, eos]]) - r_hyps = torch.concat([hyps[:, 0:1], r_hyps], axis=1) + r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1) # >>> r_hyps # >>> tensor([[sos, 3, 2, 1], # >>> [sos, 4, 8, 9], @@ -791,7 +791,7 @@ class U2BaseModel(ASRInterface, nn.Layer): decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) - r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1) return decoder_out, r_decoder_out @paddle.no_grad() diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index 2052a19e..3b1a7f23 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -363,9 +363,8 @@ class BiTransformerDecoder(BatchScorerInterface, nn.Layer): memory: encoded memory, float32 (batch, maxlen_in, feat) memory_mask: encoded memory mask, (batch, 1, maxlen_in) tgt: input token ids, int64 (batch, maxlen_out) - tgt_mask: input token mask, (batch, maxlen_out) - dtype=torch.uint8 in PyTorch 1.2- - dtype=torch.bool in PyTorch 1.2+ (include 1.2) + tgt_mask: input token mask, (batch, maxlen_out, maxlen_out) + dtype=paddle.bool cache: cached output list of (batch, max_time_out-1, size) Returns: y, cache: NN output value and cache per `self.decoders`. From 027535dec19aea8aa1ee5685fe348fbf0c181757 Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Tue, 20 Sep 2022 13:02:39 +0000 Subject: [PATCH 5/6] support bitransformer decoder, test=asr --- paddlespeech/audio/utils/tensor_utils.py | 79 ++++++++++++++++++++++++ paddlespeech/s2t/exps/u2/model.py | 1 + paddlespeech/s2t/models/u2/u2.py | 73 +++------------------- 3 files changed, 89 insertions(+), 64 deletions(-) diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py index ac86757b..44dcb52e 100644 --- a/paddlespeech/audio/utils/tensor_utils.py +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -213,3 +213,82 @@ def reverse_pad_list(ys_pad: paddle.Tensor, r_ys_pad = pad_sequence([(paddle.flip(y.int()[:i], [0])) for y, i in zip(ys_pad, ys_lens)], True, pad_value) return r_ys_pad + + +def st_reverse_pad_list(ys_pad: paddle.Tensor, + ys_lens: paddle.Tensor, + sos: float, + eos: float) -> paddle.Tensor: + """Reverse padding for the list of tensors. + Args: + ys_pad (tensor): The padded tensor (B, Tokenmax). + ys_lens (tensor): The lens of token seqs (B) + Returns: + Tensor: Padded tensor (B, Tokenmax). + Examples: + >>> x + tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) + >>> pad_list(x, 0) + tensor([[4, 3, 2, 1], + [7, 6, 5, 0], + [9, 8, 0, 0]]) + """ + # Equal to: + # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) + # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) + B = ys_pad.shape[0] + _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos + max_len = paddle.max(ys_lens) + index_range = paddle.arange(0, max_len, 1) + seq_len_expand = ys_lens.unsqueeze(1) + seq_mask = seq_len_expand > index_range # (beam, max_len) + + index = (seq_len_expand - 1) - index_range # (beam, max_len) + # >>> index + # >>> tensor([[ 2, 1, 0], + # >>> [ 2, 1, 0], + # >>> [ 0, -1, -2]]) + index = index * seq_mask + + # >>> index + # >>> tensor([[2, 1, 0], + # >>> [2, 1, 0], + # >>> [0, 0, 0]]) + def paddle_gather(x, dim, index): + index_shape = index.shape + index_flatten = index.flatten() + if dim < 0: + dim = len(x.shape) + dim + nd_index = [] + for k in range(len(x.shape)): + if k == dim: + nd_index.append(index_flatten) + else: + reshape_shape = [1] * len(x.shape) + reshape_shape[k] = x.shape[k] + x_arange = paddle.arange(x.shape[k], dtype=index.dtype) + x_arange = x_arange.reshape(reshape_shape) + dim_index = paddle.expand(x_arange, index_shape).flatten() + nd_index.append(dim_index) + ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64") + paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape) + return paddle_out + + r_hyps = paddle_gather(ys_pad, 1, index) + # >>> r_hyps + # >>> tensor([[3, 2, 1], + # >>> [4, 8, 9], + # >>> [2, 2, 2]]) + r_hyps = paddle.where(seq_mask, r_hyps, eos) + # >>> r_hyps + # >>> tensor([[3, 2, 1], + # >>> [4, 8, 9], + # >>> [2, eos, eos]]) + + r_hyps = paddle.cat([_sos, r_hyps], dim=1) + # r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1) + # >>> r_hyps + # >>> tensor([[sos, 3, 2, 1], + # >>> [sos, 4, 8, 9], + # >>> [sos, 2, eos, eos]]) + return r_hyps diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 84b7be32..a13a6385 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -341,6 +341,7 @@ class U2Tester(U2Trainer): start_time = time.time() target_transcripts = self.id2token(texts, texts_len, self.text_feature) + result_transcripts, result_tokenids = self.model.decode( audio, audio_len, diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index e54a7afb..0a3e03b7 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -32,6 +32,7 @@ from paddle import nn from paddlespeech.audio.utils.tensor_utils import add_sos_eos from paddlespeech.audio.utils.tensor_utils import pad_sequence from paddlespeech.audio.utils.tensor_utils import reverse_pad_list +from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer from paddlespeech.s2t.frontend.utility import IGNORE_ID @@ -565,16 +566,12 @@ class U2BaseModel(ASRInterface, nn.Layer): dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining - encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) - # used for right to left decoder - r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens - 1, - self.ignore_id) - r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, - self.ignore_id) + r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1, self.sos, + self.eos) decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, reverse_weight) # (beam_size, max_hyps_len, vocab_size) @@ -733,63 +730,10 @@ class U2BaseModel(ASRInterface, nn.Layer): r_hyps = hyps[:, 1:] # (num_hyps, max_hyps_len, vocab_size) - # Equal to: - # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) - # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) - max_len = paddle.max(r_hyps_lens) - index_range = paddle.arange(0, max_len, 1) - seq_len_expand = r_hyps_lens.unsqueeze(1) - seq_mask = seq_len_expand > index_range # (beam, max_len) - - index = (seq_len_expand - 1) - index_range # (beam, max_len) - # >>> index - # >>> tensor([[ 2, 1, 0], - # >>> [ 2, 1, 0], - # >>> [ 0, -1, -2]]) - index = index * seq_mask - - # >>> index - # >>> tensor([[2, 1, 0], - # >>> [2, 1, 0], - # >>> [0, 0, 0]]) - def paddle_gather(x, dim, index): - index_shape = index.shape - index_flatten = index.flatten() - if dim < 0: - dim = len(x.shape) + dim - nd_index = [] - for k in range(len(x.shape)): - if k == dim: - nd_index.append(index_flatten) - else: - reshape_shape = [1] * len(x.shape) - reshape_shape[k] = x.shape[k] - x_arange = paddle.arange(x.shape[k], dtype=index.dtype) - x_arange = x_arange.reshape(reshape_shape) - dim_index = paddle.expand(x_arange, index_shape).flatten() - nd_index.append(dim_index) - ind2 = paddle.transpose(paddle.stack(nd_index), - [1, 0]).astype("int64") - paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape) - return paddle_out - - r_hyps = paddle_gather(r_hyps, 1, index) - # >>> r_hyps - # >>> tensor([[3, 2, 1], - # >>> [4, 8, 9], - # >>> [2, 2, 2]]) - r_hyps = paddle.where(seq_mask, r_hyps, self.eos) - # >>> r_hyps - # >>> tensor([[3, 2, 1], - # >>> [4, 8, 9], - # >>> [2, eos, eos]]) - r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1) - # >>> r_hyps - # >>> tensor([[sos, 3, 2, 1], - # >>> [sos, 4, 8, 9], - # >>> [sos, 2, eos, eos]]) - decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, - hyps_lens, r_hyps, reverse_weight) + r_hyps = st_reverse_pad_list(r_hyps, r_hyps_lens, self.sos, self.eos) + + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1) return decoder_out, r_decoder_out @@ -877,7 +821,8 @@ class U2BaseModel(ASRInterface, nn.Layer): decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks, ctc_weight=ctc_weight, - simulate_streaming=simulate_streaming) + simulate_streaming=simulate_streaming, + reverse_weight=reverse_weight) hyps = [hyp] else: raise ValueError(f"Not support decoding method: {decoding_method}") From d3e59375912bfe9f28dd49e52b5a883ac9a6e2fe Mon Sep 17 00:00:00 2001 From: tianhao zhang <15600919271@163.com> Date: Wed, 21 Sep 2022 12:24:52 +0000 Subject: [PATCH 6/6] support bitransformer decoder --- paddlespeech/audio/utils/tensor_utils.py | 3 ++- paddlespeech/s2t/io/dataloader.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py index 44dcb52e..e9008f17 100644 --- a/paddlespeech/audio/utils/tensor_utils.py +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -237,7 +237,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor, # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) B = ys_pad.shape[0] - _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos + _sos = paddle.full([B, 1], sos, dtype=ys_pad.dtype) max_len = paddle.max(ys_lens) index_range = paddle.arange(0, max_len, 1) seq_len_expand = ys_lens.unsqueeze(1) @@ -279,6 +279,7 @@ def st_reverse_pad_list(ys_pad: paddle.Tensor, # >>> tensor([[3, 2, 1], # >>> [4, 8, 9], # >>> [2, 2, 2]]) + eos = paddle.full([1], eos, dtype=r_hyps.dtype) r_hyps = paddle.where(seq_mask, r_hyps, eos) # >>> r_hyps # >>> tensor([[3, 2, 1], diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index 4cc8274f..5ba891c3 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -361,7 +361,7 @@ class DataLoaderFactory(): elif mode == 'valid': config['manifest'] = config.dev_manifest config['train_mode'] = False - elif model == 'test' or mode == 'align': + elif mode == 'test' or mode == 'align': config['manifest'] = config.test_manifest config['train_mode'] = False config['dither'] = 0.0