From fb3f11a0ad40597710cce529db7130447b78cb36 Mon Sep 17 00:00:00 2001 From: megemini Date: Thu, 12 Dec 2024 22:39:24 +0800 Subject: [PATCH] [Fix] s2t test --- examples/ted_en_zh/st0/README.md | 2 +- paddlespeech/s2t/exps/u2_st/bin/test.py | 3 --- paddlespeech/s2t/frontend/featurizer/text_featurizer.py | 4 ++++ 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/ted_en_zh/st0/README.md b/examples/ted_en_zh/st0/README.md index 112d63c71..4c08e0fe1 100644 --- a/examples/ted_en_zh/st0/README.md +++ b/examples/ted_en_zh/st0/README.md @@ -127,7 +127,7 @@ source path.h bash ./local/data.sh CUDA_VISIBLE_DEVICES= ./local/train.sh conf/transformer_mtl_noam.yaml transformer_mtl_noam avg.sh latest exp/transformer_mtl_noam/checkpoints 5 -CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer_mtl_noam.yaml exp/transformer_mtl_noam/checkpoints/avg_5 +CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer_mtl_noam.yaml conf/tuning/decode.yaml exp/transformer_mtl_noam/checkpoints/avg_5 ``` The performance of the released models are shown below: ### Transformer diff --git a/paddlespeech/s2t/exps/u2_st/bin/test.py b/paddlespeech/s2t/exps/u2_st/bin/test.py index 30a903ceb..a2e37e84d 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/test.py +++ b/paddlespeech/s2t/exps/u2_st/bin/test.py @@ -34,9 +34,6 @@ def main(config, args): if __name__ == "__main__": parser = default_argument_parser() - # save asr result to - parser.add_argument( - "--result_file", type=str, help="path of save the asr result") args = parser.parse_args() print_arguments(args, globals()) diff --git a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py index 7623d0b87..0db0d63b9 100644 --- a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py +++ b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py @@ -115,6 +115,10 @@ class TextFeaturizer(): """ assert self.vocab_path_or_list, "toidx need vocab path or vocab list" tokens = [] + # unwrap `idxs`` like `[[1,2,3]]` + if idxs and isinstance(idxs[0], (list, tuple)) and len(idxs) == 1: + idxs = idxs[0] + for idx in idxs: if idx == self.eos_id: break