From 4b5be948ad3b8815f2172f75502bf799452b9bd2 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 21 Apr 2023 08:59:20 +0000 Subject: [PATCH] fix format rsl --- examples/aishell/asr1/local/test.sh | 15 ++-- paddlespeech/dataset/s2t/format_rsl.py | 98 +++++++++++++++----------- 2 files changed, 68 insertions(+), 45 deletions(-) diff --git a/examples/aishell/asr1/local/test.sh b/examples/aishell/asr1/local/test.sh index 26926b4a9..8487e9904 100755 --- a/examples/aishell/asr1/local/test.sh +++ b/examples/aishell/asr1/local/test.sh @@ -1,15 +1,21 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" - exit -1 -fi +set -e stage=0 stop_stage=100 + +source utils/parse_options.sh || exit 1; + ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." + +if [ $# != 3 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" + exit -1 +fi + config_path=$1 decode_config_path=$2 ckpt_prefix=$3 @@ -92,6 +98,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then fi if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then + echo "using sclite to compute cer..." # format the reference test file for sclite python utils/format_rsl.py \ --origin_ref data/manifest.test.raw \ diff --git a/paddlespeech/dataset/s2t/format_rsl.py b/paddlespeech/dataset/s2t/format_rsl.py index 640a72021..0a58e7e68 100644 --- a/paddlespeech/dataset/s2t/format_rsl.py +++ b/paddlespeech/dataset/s2t/format_rsl.py @@ -12,114 +12,130 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -format ref/hyp file for `utt text` format to compute CER/WER/MER +format ref/hyp file for `utt text` format to compute CER/WER/MER. + +norm: +BAC009S0764W0196 明确了发展目标和重点任务 +BAC009S0764W0186 实现我国房地产市场的平稳运行 + + +sclite: +加大对结构机械化环境和收集谈控机制力度(BAC009S0906W0240.wav) +河南省新乡市丰秋县刘光镇政府东五零左右(BAC009S0770W0441.wav) """ import argparse import jsonlines +from paddlespeech.utils.argparse import print_arguments + -def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None): +def transform_hyp(origin, trans, trans_sclite): """ Args: - origin_hyp: The input json file which contains the model output - trans_hyp: The output file for caculate CER/WER - trans_hyp_sclite: The output file for caculate CER/WER using sclite + origin: The input json file which contains the model output + trans: The output file for caculate CER/WER + trans_sclite: The output file for caculate CER/WER using sclite """ input_dict = {} - with open(origin_hyp, "r+", encoding="utf8") as f: + with open(origin, "r+", encoding="utf8") as f: for item in jsonlines.Reader(f): input_dict[item["utt"]] = item["hyps"][0] - if trans_hyp is not None: - with open(trans_hyp, "w+", encoding="utf8") as f: + + if trans: + with open(trans, "w+", encoding="utf8") as f: for key in input_dict.keys(): f.write(key + " " + input_dict[key] + "\n") - if trans_hyp_sclite is not None: - with open(trans_hyp_sclite, "w+") as f: + print(f"transform_hyp output: {trans}") + + if trans_sclite: + with open(trans_sclite, "w+") as f: for key in input_dict.keys(): line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" f.write(line) + print(f"transform_hyp output: {trans_sclite}") -def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None): +def transform_ref(origin, trans, trans_sclite): """ Args: - origin_hyp: The input json file which contains the model output - trans_hyp: The output file for caculate CER/WER - trans_hyp_sclite: The output file for caculate CER/WER using sclite + origin: The input json file which contains the model output + trans: The output file for caculate CER/WER + trans_sclite: The output file for caculate CER/WER using sclite """ input_dict = {} - with open(origin_ref, "r", encoding="utf8") as f: + with open(origin, "r", encoding="utf8") as f: for item in jsonlines.Reader(f): input_dict[item["utt"]] = item["text"] - if trans_ref is not None: - with open(trans_ref, "w", encoding="utf8") as f: + + if trans: + with open(trans, "w", encoding="utf8") as f: for key in input_dict.keys(): f.write(key + " " + input_dict[key] + "\n") + print(f"transform_hyp output: {trans}") - if trans_ref_sclite is not None: - with open(trans_ref_sclite, "w") as f: + if trans_sclite: + with open(trans_sclite, "w") as f: for key in input_dict.keys(): line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" f.write(line) + print(f"transform_hyp output: {trans_sclite}") def define_argparse(): parser = argparse.ArgumentParser( prog='format ref/hyp file for compute CER/WER', add_help=True) parser.add_argument( - '--origin_hyp', type=str, default=None, help='origin hyp file') + '--origin_hyp', type=str, default="", help='origin hyp file') parser.add_argument( '--trans_hyp', type=str, - default=None, + default="", help='hyp file for caculating CER/WER') parser.add_argument( '--trans_hyp_sclite', type=str, - default=None, + default="", help='hyp file for caculating CER/WER by sclite') parser.add_argument( - '--origin_ref', type=str, default=None, help='origin ref file') + '--origin_ref', type=str, default="", help='origin ref file') parser.add_argument( '--trans_ref', type=str, - default=None, + default="", help='ref file for caculating CER/WER') parser.add_argument( '--trans_ref_sclite', type=str, - default=None, + default="", help='ref file for caculating CER/WER by sclite') parser_args = parser.parse_args() return parser_args -def format_result(origin_hyp=None, - trans_hyp=None, - trans_hyp_sclite=None, - origin_ref=None, - trans_ref=None, - trans_ref_sclite=None): +def format_result(origin_hyp="", + trans_hyp="", + trans_hyp_sclite="", + origin_ref="", + trans_ref="", + trans_ref_sclite=""): - if origin_hyp is not None: - trans_hyp( - origin_hyp=origin_hyp, - trans_hyp=trans_hyp, - trans_hyp_sclite=trans_hyp_sclite, ) + if origin_hyp: + transform_hyp( + origin=origin_hyp, trans=trans_hyp, trans_sclite=trans_hyp_sclite) - if origin_ref is not None: - trans_ref( - origin_ref=origin_ref, - trans_ref=trans_ref, - trans_ref_sclite=trans_ref_sclite, ) + if origin_ref: + transform_ref( + origin=origin_ref, trans=trans_ref, trans_sclite=trans_ref_sclite) def main(): args = define_argparse() + print_arguments(args, globals()) + format_result(**vars(args))