fix format rsl

pull/3189/head
Hui Zhang 2 years ago
parent a72d37a838
commit 4b5be948ad

@ -1,15 +1,21 @@
#!/bin/bash #!/bin/bash
if [ $# != 3 ];then set -e
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1
fi
stage=0 stage=0
stop_stage=100 stop_stage=100
source utils/parse_options.sh || exit 1;
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..." echo "using $ngpu gpus..."
if [ $# != 3 ];then
echo "usage: ${0} config_path decode_config_path ckpt_path_prefix"
exit -1
fi
config_path=$1 config_path=$1
decode_config_path=$2 decode_config_path=$2
ckpt_prefix=$3 ckpt_prefix=$3
@ -92,6 +98,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi fi
if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
echo "using sclite to compute cer..."
# format the reference test file for sclite # format the reference test file for sclite
python utils/format_rsl.py \ python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \ --origin_ref data/manifest.test.raw \

@ -12,114 +12,130 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
format ref/hyp file for `utt text` format to compute CER/WER/MER format ref/hyp file for `utt text` format to compute CER/WER/MER.
norm:
BAC009S0764W0196 明确了发展目标和重点任务
BAC009S0764W0186 实现我国房地产市场的平稳运行
sclite:
加大对结构机械化环境和收集谈控机制力度(BAC009S0906W0240.wav)
河南省新乡市丰秋县刘光镇政府东五零左右(BAC009S0770W0441.wav)
""" """
import argparse import argparse
import jsonlines import jsonlines
from paddlespeech.utils.argparse import print_arguments
def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None): def transform_hyp(origin, trans, trans_sclite):
""" """
Args: Args:
origin_hyp: The input json file which contains the model output origin: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER trans: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite trans_sclite: The output file for caculate CER/WER using sclite
""" """
input_dict = {} input_dict = {}
with open(origin_hyp, "r+", encoding="utf8") as f: with open(origin, "r+", encoding="utf8") as f:
for item in jsonlines.Reader(f): for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["hyps"][0] input_dict[item["utt"]] = item["hyps"][0]
if trans_hyp is not None:
with open(trans_hyp, "w+", encoding="utf8") as f: if trans:
with open(trans, "w+", encoding="utf8") as f:
for key in input_dict.keys(): for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n") f.write(key + " " + input_dict[key] + "\n")
if trans_hyp_sclite is not None: print(f"transform_hyp output: {trans}")
with open(trans_hyp_sclite, "w+") as f:
if trans_sclite:
with open(trans_sclite, "w+") as f:
for key in input_dict.keys(): for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line) f.write(line)
print(f"transform_hyp output: {trans_sclite}")
def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None): def transform_ref(origin, trans, trans_sclite):
""" """
Args: Args:
origin_hyp: The input json file which contains the model output origin: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER trans: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite trans_sclite: The output file for caculate CER/WER using sclite
""" """
input_dict = {} input_dict = {}
with open(origin_ref, "r", encoding="utf8") as f: with open(origin, "r", encoding="utf8") as f:
for item in jsonlines.Reader(f): for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["text"] input_dict[item["utt"]] = item["text"]
if trans_ref is not None:
with open(trans_ref, "w", encoding="utf8") as f: if trans:
with open(trans, "w", encoding="utf8") as f:
for key in input_dict.keys(): for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n") f.write(key + " " + input_dict[key] + "\n")
print(f"transform_hyp output: {trans}")
if trans_ref_sclite is not None: if trans_sclite:
with open(trans_ref_sclite, "w") as f: with open(trans_sclite, "w") as f:
for key in input_dict.keys(): for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line) f.write(line)
print(f"transform_hyp output: {trans_sclite}")
def define_argparse(): def define_argparse():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog='format ref/hyp file for compute CER/WER', add_help=True) prog='format ref/hyp file for compute CER/WER', add_help=True)
parser.add_argument( parser.add_argument(
'--origin_hyp', type=str, default=None, help='origin hyp file') '--origin_hyp', type=str, default="", help='origin hyp file')
parser.add_argument( parser.add_argument(
'--trans_hyp', '--trans_hyp',
type=str, type=str,
default=None, default="",
help='hyp file for caculating CER/WER') help='hyp file for caculating CER/WER')
parser.add_argument( parser.add_argument(
'--trans_hyp_sclite', '--trans_hyp_sclite',
type=str, type=str,
default=None, default="",
help='hyp file for caculating CER/WER by sclite') help='hyp file for caculating CER/WER by sclite')
parser.add_argument( parser.add_argument(
'--origin_ref', type=str, default=None, help='origin ref file') '--origin_ref', type=str, default="", help='origin ref file')
parser.add_argument( parser.add_argument(
'--trans_ref', '--trans_ref',
type=str, type=str,
default=None, default="",
help='ref file for caculating CER/WER') help='ref file for caculating CER/WER')
parser.add_argument( parser.add_argument(
'--trans_ref_sclite', '--trans_ref_sclite',
type=str, type=str,
default=None, default="",
help='ref file for caculating CER/WER by sclite') help='ref file for caculating CER/WER by sclite')
parser_args = parser.parse_args() parser_args = parser.parse_args()
return parser_args return parser_args
def format_result(origin_hyp=None, def format_result(origin_hyp="",
trans_hyp=None, trans_hyp="",
trans_hyp_sclite=None, trans_hyp_sclite="",
origin_ref=None, origin_ref="",
trans_ref=None, trans_ref="",
trans_ref_sclite=None): trans_ref_sclite=""):
if origin_hyp is not None: if origin_hyp:
trans_hyp( transform_hyp(
origin_hyp=origin_hyp, origin=origin_hyp, trans=trans_hyp, trans_sclite=trans_hyp_sclite)
trans_hyp=trans_hyp,
trans_hyp_sclite=trans_hyp_sclite, )
if origin_ref is not None: if origin_ref:
trans_ref( transform_ref(
origin_ref=origin_ref, origin=origin_ref, trans=trans_ref, trans_sclite=trans_ref_sclite)
trans_ref=trans_ref,
trans_ref_sclite=trans_ref_sclite, )
def main(): def main():
args = define_argparse() args = define_argparse()
print_arguments(args, globals())
format_result(**vars(args)) format_result(**vars(args))

Loading…
Cancel
Save