diff --git a/examples/aishell/asr0/local/test.sh b/examples/aishell/asr0/local/test.sh index 463593ef..363dbf0a 100755 --- a/examples/aishell/asr0/local/test.sh +++ b/examples/aishell/asr0/local/test.sh @@ -5,6 +5,8 @@ if [ $# != 4 ];then exit -1 fi +stage=0 +stop_stage=100 ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." @@ -19,18 +21,45 @@ if [ $? -ne 0 ]; then exit 1 fi -python3 -u ${BIN_DIR}/test.py \ ---ngpu ${ngpu} \ ---config ${config_path} \ ---decode_cfg ${decode_config_path} \ ---result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} \ ---model_type ${model_type} +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # format the reference test file + python utils/format_rsl.py \ + --origin_ref data/manifest.test.raw \ + --trans_ref data/manifest.test.text -if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 + python3 -u ${BIN_DIR}/test.py \ + --ngpu ${ngpu} \ + --config ${config_path} \ + --decode_cfg ${decode_config_path} \ + --result_file ${ckpt_prefix}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --model_type ${model_type} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi + + # format the hyp file + python utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.rsl \ + --trans_hyp ${ckpt_prefix}.rsl.text + + python utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test.text ${ckpt_prefix}.rsl.text > ${ckpt_prefix}.error fi +if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then + python utils/format_rsl.py \ + --origin_ref data/manifest.test.raw \ + --trans_ref_sclite data/manifest.test.text.sclite + + python utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.rsl \ + --trans_hyp_sclite ${ckpt_prefix}.rsl.text.sclite + + mkdir -p ${ckpt_prefix}_sclite + sclite -i wsj -r data/manifest.test.text.sclite -h ${ckpt_prefix}.rsl.text.sclite -e utf-8 -o all -O ${ckpt_prefix}_sclite -c NOASCII +fi exit 0 diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index 3e9ede76..3c2eaab7 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -278,7 +278,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write({"utt": utt, "ref": target, "hyp": result}) + fout.write({"utt": utt, "refs": [target], "hyps": [result]}) logger.info(f"Utt: {utt}") logger.info(f"Ref: {target}") logger.info(f"Hyp: {result}")