From b4c2f3bae3d158442fc47ea6e27dc2f024919c83 Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 18 Dec 2024 14:32:37 +0800 Subject: [PATCH] =?UTF-8?q?[Hackathon=207th]=20=E4=BF=AE=E5=A4=8D=20`s2t`?= =?UTF-8?q?=20=E7=A4=BA=E4=BE=8B=E9=94=99=E8=AF=AF=20(#3950)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Fix] s2t * [Fix] s2t test --- examples/ted_en_zh/st0/README.md | 2 +- paddlespeech/s2t/exps/u2_st/bin/test.py | 3 --- .../s2t/frontend/featurizer/text_featurizer.py | 4 ++++ paddlespeech/s2t/io/dataloader.py | 6 ++++++ paddlespeech/s2t/models/u2_st/u2_st.py | 8 ++++---- paddlespeech/s2t/modules/decoder.py | 16 ++++++++-------- 6 files changed, 23 insertions(+), 16 deletions(-) diff --git a/examples/ted_en_zh/st0/README.md b/examples/ted_en_zh/st0/README.md index 112d63c7..4c08e0fe 100644 --- a/examples/ted_en_zh/st0/README.md +++ b/examples/ted_en_zh/st0/README.md @@ -127,7 +127,7 @@ source path.h bash ./local/data.sh CUDA_VISIBLE_DEVICES= ./local/train.sh conf/transformer_mtl_noam.yaml transformer_mtl_noam avg.sh latest exp/transformer_mtl_noam/checkpoints 5 -CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer_mtl_noam.yaml exp/transformer_mtl_noam/checkpoints/avg_5 +CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer_mtl_noam.yaml conf/tuning/decode.yaml exp/transformer_mtl_noam/checkpoints/avg_5 ``` The performance of the released models are shown below: ### Transformer diff --git a/paddlespeech/s2t/exps/u2_st/bin/test.py b/paddlespeech/s2t/exps/u2_st/bin/test.py index 30a903ce..a2e37e84 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/test.py +++ b/paddlespeech/s2t/exps/u2_st/bin/test.py @@ -34,9 +34,6 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - # save asr result to - parser.add_argument( - "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) diff --git a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py index 7623d0b8..0db0d63b 100644 --- a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py +++ b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py @@ -115,6 +115,10 @@ class TextFeaturizer(): """ assert self.vocab_path_or_list, "toidx need vocab path or vocab list" tokens = [] + # unwrap `idxs`` like `[[1,2,3]]` + if idxs and isinstance(idxs[0], (list, tuple)) and len(idxs) == 1: + idxs = idxs[0] + for idx in idxs: if idx == self.eos_id: break diff --git a/paddlespeech/s2t/io/dataloader.py b/paddlespeech/s2t/io/dataloader.py index db6292f2..5065c31e 100644 --- a/paddlespeech/s2t/io/dataloader.py +++ b/paddlespeech/s2t/io/dataloader.py @@ -404,6 +404,12 @@ class DataLoaderFactory(): config['subsampling_factor'] = 1 config['num_encs'] = 1 config['shortest_first'] = False + config['minibatches'] = 0 + config['batch_count'] = 'auto' + config['batch_bins'] = 0 + config['batch_frames_in'] = 0 + config['batch_frames_out'] = 0 + config['batch_frames_inout'] = 0 elif mode == 'valid': config['manifest'] = config.dev_manifest config['train_mode'] = False diff --git a/paddlespeech/s2t/models/u2_st/u2_st.py b/paddlespeech/s2t/models/u2_st/u2_st.py index 339af4b7..3fe1d352 100644 --- a/paddlespeech/s2t/models/u2_st/u2_st.py +++ b/paddlespeech/s2t/models/u2_st/u2_st.py @@ -170,8 +170,8 @@ class U2STBaseModel(nn.Layer): ys_in_lens = ys_pad_lens + 1 # 1. Forward decoder - decoder_out, _ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad, - ys_in_lens) + decoder_out, *_ = self.st_decoder(encoder_out, encoder_mask, ys_in_pad, + ys_in_lens) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) @@ -203,8 +203,8 @@ class U2STBaseModel(nn.Layer): ys_in_lens = ys_pad_lens + 1 # 1. Forward decoder - decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, - ys_in_lens) + decoder_out, *_ = self.decoder(encoder_out, encoder_mask, ys_in_pad, + ys_in_lens) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) diff --git a/paddlespeech/s2t/modules/decoder.py b/paddlespeech/s2t/modules/decoder.py index 4ddf057b..1881a865 100644 --- a/paddlespeech/s2t/modules/decoder.py +++ b/paddlespeech/s2t/modules/decoder.py @@ -110,14 +110,14 @@ class TransformerDecoder(BatchScorerInterface, nn.Layer): concat_after=concat_after, ) for _ in range(num_blocks) ]) - def forward( - self, - memory: paddle.Tensor, - memory_mask: paddle.Tensor, - ys_in_pad: paddle.Tensor, - ys_in_lens: paddle.Tensor, - r_ys_in_pad: paddle.Tensor=paddle.empty([0]), - reverse_weight: float=0.0) -> Tuple[paddle.Tensor, paddle.Tensor]: + def forward(self, + memory: paddle.Tensor, + memory_mask: paddle.Tensor, + ys_in_pad: paddle.Tensor, + ys_in_lens: paddle.Tensor, + r_ys_in_pad: paddle.Tensor=paddle.empty([0]), + reverse_weight: float=0.0 + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Forward decoder. Args: memory: encoded memory, float32 (batch, maxlen_in, feat)