[Fix] s2t test

pull/3950/head
megemini 9 months ago
parent f3eb9508de
commit fb3f11a0ad

@ -127,7 +127,7 @@ source path.h
bash ./local/data.sh
CUDA_VISIBLE_DEVICES= ./local/train.sh conf/transformer_mtl_noam.yaml transformer_mtl_noam
avg.sh latest exp/transformer_mtl_noam/checkpoints 5
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer_mtl_noam.yaml exp/transformer_mtl_noam/checkpoints/avg_5
CUDA_VISIBLE_DEVICES= ./local/test.sh conf/transformer_mtl_noam.yaml conf/tuning/decode.yaml exp/transformer_mtl_noam/checkpoints/avg_5
```
The performance of the released models are shown below:
### Transformer

@ -34,9 +34,6 @@ def main(config, args):
if __name__ == "__main__":
parser = default_argument_parser()
# save asr result to
parser.add_argument(
"--result_file", type=str, help="path of save the asr result")
args = parser.parse_args()
print_arguments(args, globals())

@ -115,6 +115,10 @@ class TextFeaturizer():
"""
assert self.vocab_path_or_list, "toidx need vocab path or vocab list"
tokens = []
# unwrap `idxs`` like `[[1,2,3]]`
if idxs and isinstance(idxs[0], (list, tuple)) and len(idxs) == 1:
idxs = idxs[0]
for idx in idxs:
if idx == self.eos_id:
break

Loading…
Cancel
Save