diff --git a/paddlespeech/audio/utils/tensor_utils.py b/paddlespeech/audio/utils/tensor_utils.py index ac86757b..44dcb52e 100644 --- a/paddlespeech/audio/utils/tensor_utils.py +++ b/paddlespeech/audio/utils/tensor_utils.py @@ -213,3 +213,82 @@ def reverse_pad_list(ys_pad: paddle.Tensor, r_ys_pad = pad_sequence([(paddle.flip(y.int()[:i], [0])) for y, i in zip(ys_pad, ys_lens)], True, pad_value) return r_ys_pad + + +def st_reverse_pad_list(ys_pad: paddle.Tensor, + ys_lens: paddle.Tensor, + sos: float, + eos: float) -> paddle.Tensor: + """Reverse padding for the list of tensors. + Args: + ys_pad (tensor): The padded tensor (B, Tokenmax). + ys_lens (tensor): The lens of token seqs (B) + Returns: + Tensor: Padded tensor (B, Tokenmax). + Examples: + >>> x + tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) + >>> pad_list(x, 0) + tensor([[4, 3, 2, 1], + [7, 6, 5, 0], + [9, 8, 0, 0]]) + """ + # Equal to: + # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) + # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) + B = ys_pad.shape[0] + _sos = paddle.ones([B, 1], dtype=ys_pad.dtype) * sos + max_len = paddle.max(ys_lens) + index_range = paddle.arange(0, max_len, 1) + seq_len_expand = ys_lens.unsqueeze(1) + seq_mask = seq_len_expand > index_range # (beam, max_len) + + index = (seq_len_expand - 1) - index_range # (beam, max_len) + # >>> index + # >>> tensor([[ 2, 1, 0], + # >>> [ 2, 1, 0], + # >>> [ 0, -1, -2]]) + index = index * seq_mask + + # >>> index + # >>> tensor([[2, 1, 0], + # >>> [2, 1, 0], + # >>> [0, 0, 0]]) + def paddle_gather(x, dim, index): + index_shape = index.shape + index_flatten = index.flatten() + if dim < 0: + dim = len(x.shape) + dim + nd_index = [] + for k in range(len(x.shape)): + if k == dim: + nd_index.append(index_flatten) + else: + reshape_shape = [1] * len(x.shape) + reshape_shape[k] = x.shape[k] + x_arange = paddle.arange(x.shape[k], dtype=index.dtype) + x_arange = x_arange.reshape(reshape_shape) + dim_index = paddle.expand(x_arange, index_shape).flatten() + nd_index.append(dim_index) + ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64") + paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape) + return paddle_out + + r_hyps = paddle_gather(ys_pad, 1, index) + # >>> r_hyps + # >>> tensor([[3, 2, 1], + # >>> [4, 8, 9], + # >>> [2, 2, 2]]) + r_hyps = paddle.where(seq_mask, r_hyps, eos) + # >>> r_hyps + # >>> tensor([[3, 2, 1], + # >>> [4, 8, 9], + # >>> [2, eos, eos]]) + + r_hyps = paddle.cat([_sos, r_hyps], dim=1) + # r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1) + # >>> r_hyps + # >>> tensor([[sos, 3, 2, 1], + # >>> [sos, 4, 8, 9], + # >>> [sos, 2, eos, eos]]) + return r_hyps diff --git a/paddlespeech/s2t/exps/u2/model.py b/paddlespeech/s2t/exps/u2/model.py index 84b7be32..a13a6385 100644 --- a/paddlespeech/s2t/exps/u2/model.py +++ b/paddlespeech/s2t/exps/u2/model.py @@ -341,6 +341,7 @@ class U2Tester(U2Trainer): start_time = time.time() target_transcripts = self.id2token(texts, texts_len, self.text_feature) + result_transcripts, result_tokenids = self.model.decode( audio, audio_len, diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index e54a7afb..0a3e03b7 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -32,6 +32,7 @@ from paddle import nn from paddlespeech.audio.utils.tensor_utils import add_sos_eos from paddlespeech.audio.utils.tensor_utils import pad_sequence from paddlespeech.audio.utils.tensor_utils import reverse_pad_list +from paddlespeech.audio.utils.tensor_utils import st_reverse_pad_list from paddlespeech.audio.utils.tensor_utils import th_accuracy from paddlespeech.s2t.decoders.scorers.ctc import CTCPrefixScorer from paddlespeech.s2t.frontend.utility import IGNORE_ID @@ -565,16 +566,12 @@ class U2BaseModel(ASRInterface, nn.Layer): dtype=paddle.long) # (beam_size,) hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) hyps_lens = hyps_lens + 1 # Add at begining - encoder_out = encoder_out.repeat(beam_size, 1, 1) encoder_mask = paddle.ones( (beam_size, 1, encoder_out.shape[1]), dtype=paddle.bool) - # used for right to left decoder - r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens - 1, - self.ignore_id) - r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, - self.ignore_id) + r_hyps_pad = st_reverse_pad_list(ori_hyps_pad, hyps_lens - 1, self.sos, + self.eos) decoder_out, r_decoder_out, _ = self.decoder( encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, reverse_weight) # (beam_size, max_hyps_len, vocab_size) @@ -733,63 +730,10 @@ class U2BaseModel(ASRInterface, nn.Layer): r_hyps = hyps[:, 1:] # (num_hyps, max_hyps_len, vocab_size) - # Equal to: - # >>> r_hyps = reverse_pad_list(r_hyps, r_hyps_lens, float(self.ignore_id)) - # >>> r_hyps, _ = add_sos_eos(r_hyps, self.sos, self.eos, self.ignore_id) - max_len = paddle.max(r_hyps_lens) - index_range = paddle.arange(0, max_len, 1) - seq_len_expand = r_hyps_lens.unsqueeze(1) - seq_mask = seq_len_expand > index_range # (beam, max_len) - - index = (seq_len_expand - 1) - index_range # (beam, max_len) - # >>> index - # >>> tensor([[ 2, 1, 0], - # >>> [ 2, 1, 0], - # >>> [ 0, -1, -2]]) - index = index * seq_mask - - # >>> index - # >>> tensor([[2, 1, 0], - # >>> [2, 1, 0], - # >>> [0, 0, 0]]) - def paddle_gather(x, dim, index): - index_shape = index.shape - index_flatten = index.flatten() - if dim < 0: - dim = len(x.shape) + dim - nd_index = [] - for k in range(len(x.shape)): - if k == dim: - nd_index.append(index_flatten) - else: - reshape_shape = [1] * len(x.shape) - reshape_shape[k] = x.shape[k] - x_arange = paddle.arange(x.shape[k], dtype=index.dtype) - x_arange = x_arange.reshape(reshape_shape) - dim_index = paddle.expand(x_arange, index_shape).flatten() - nd_index.append(dim_index) - ind2 = paddle.transpose(paddle.stack(nd_index), - [1, 0]).astype("int64") - paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape) - return paddle_out - - r_hyps = paddle_gather(r_hyps, 1, index) - # >>> r_hyps - # >>> tensor([[3, 2, 1], - # >>> [4, 8, 9], - # >>> [2, 2, 2]]) - r_hyps = paddle.where(seq_mask, r_hyps, self.eos) - # >>> r_hyps - # >>> tensor([[3, 2, 1], - # >>> [4, 8, 9], - # >>> [2, eos, eos]]) - r_hyps = paddle.concat([hyps[:, 0:1], r_hyps], axis=1) - # >>> r_hyps - # >>> tensor([[sos, 3, 2, 1], - # >>> [sos, 4, 8, 9], - # >>> [sos, 2, eos, eos]]) - decoder_out, _ = self.decoder(encoder_out, encoder_mask, hyps, - hyps_lens, r_hyps, reverse_weight) + r_hyps = st_reverse_pad_list(r_hyps, r_hyps_lens, self.sos, self.eos) + + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight) decoder_out = paddle.nn.functional.log_softmax(decoder_out, axis=-1) r_decoder_out = paddle.nn.functional.log_softmax(r_decoder_out, axis=-1) return decoder_out, r_decoder_out @@ -877,7 +821,8 @@ class U2BaseModel(ASRInterface, nn.Layer): decoding_chunk_size=decoding_chunk_size, num_decoding_left_chunks=num_decoding_left_chunks, ctc_weight=ctc_weight, - simulate_streaming=simulate_streaming) + simulate_streaming=simulate_streaming, + reverse_weight=reverse_weight) hyps = [hyp] else: raise ValueError(f"Not support decoding method: {decoding_method}")