parent
2ea00755f7
commit
927c60a5c1
@ -1,143 +0,0 @@
|
|||||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
format ref/hyp file for `utt text` format to compute CER/WER/MER.
|
|
||||||
|
|
||||||
norm:
|
|
||||||
BAC009S0764W0196 明确了发展目标和重点任务
|
|
||||||
BAC009S0764W0186 实现我国房地产市场的平稳运行
|
|
||||||
|
|
||||||
|
|
||||||
sclite:
|
|
||||||
加大对结构机械化环境和收集谈控机制力度(BAC009S0906W0240.wav)
|
|
||||||
河南省新乡市丰秋县刘光镇政府东五零左右(BAC009S0770W0441.wav)
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import jsonlines
|
|
||||||
|
|
||||||
from paddlespeech.utils.argparse import print_arguments
|
|
||||||
|
|
||||||
|
|
||||||
def transform_hyp(origin, trans, trans_sclite):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
origin: The input json file which contains the model output
|
|
||||||
trans: The output file for caculate CER/WER
|
|
||||||
trans_sclite: The output file for caculate CER/WER using sclite
|
|
||||||
"""
|
|
||||||
input_dict = {}
|
|
||||||
|
|
||||||
with open(origin, "r+", encoding="utf8") as f:
|
|
||||||
for item in jsonlines.Reader(f):
|
|
||||||
input_dict[item["utt"]] = item["hyps"][0]
|
|
||||||
|
|
||||||
if trans:
|
|
||||||
with open(trans, "w+", encoding="utf8") as f:
|
|
||||||
for key in input_dict.keys():
|
|
||||||
f.write(key + " " + input_dict[key] + "\n")
|
|
||||||
print(f"transform_hyp output: {trans}")
|
|
||||||
|
|
||||||
if trans_sclite:
|
|
||||||
with open(trans_sclite, "w+") as f:
|
|
||||||
for key in input_dict.keys():
|
|
||||||
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
|
|
||||||
f.write(line)
|
|
||||||
print(f"transform_hyp output: {trans_sclite}")
|
|
||||||
|
|
||||||
|
|
||||||
def transform_ref(origin, trans, trans_sclite):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
origin: The input json file which contains the model output
|
|
||||||
trans: The output file for caculate CER/WER
|
|
||||||
trans_sclite: The output file for caculate CER/WER using sclite
|
|
||||||
"""
|
|
||||||
input_dict = {}
|
|
||||||
|
|
||||||
with open(origin, "r", encoding="utf8") as f:
|
|
||||||
for item in jsonlines.Reader(f):
|
|
||||||
input_dict[item["utt"]] = item["text"]
|
|
||||||
|
|
||||||
if trans:
|
|
||||||
with open(trans, "w", encoding="utf8") as f:
|
|
||||||
for key in input_dict.keys():
|
|
||||||
f.write(key + " " + input_dict[key] + "\n")
|
|
||||||
print(f"transform_hyp output: {trans}")
|
|
||||||
|
|
||||||
if trans_sclite:
|
|
||||||
with open(trans_sclite, "w") as f:
|
|
||||||
for key in input_dict.keys():
|
|
||||||
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
|
|
||||||
f.write(line)
|
|
||||||
print(f"transform_hyp output: {trans_sclite}")
|
|
||||||
|
|
||||||
|
|
||||||
def define_argparse():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
prog='format ref/hyp file for compute CER/WER', add_help=True)
|
|
||||||
parser.add_argument(
|
|
||||||
'--origin_hyp', type=str, default="", help='origin hyp file')
|
|
||||||
parser.add_argument(
|
|
||||||
'--trans_hyp',
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help='hyp file for caculating CER/WER')
|
|
||||||
parser.add_argument(
|
|
||||||
'--trans_hyp_sclite',
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help='hyp file for caculating CER/WER by sclite')
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--origin_ref', type=str, default="", help='origin ref file')
|
|
||||||
parser.add_argument(
|
|
||||||
'--trans_ref',
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help='ref file for caculating CER/WER')
|
|
||||||
parser.add_argument(
|
|
||||||
'--trans_ref_sclite',
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help='ref file for caculating CER/WER by sclite')
|
|
||||||
parser_args = parser.parse_args()
|
|
||||||
return parser_args
|
|
||||||
|
|
||||||
|
|
||||||
def format_result(origin_hyp="",
|
|
||||||
trans_hyp="",
|
|
||||||
trans_hyp_sclite="",
|
|
||||||
origin_ref="",
|
|
||||||
trans_ref="",
|
|
||||||
trans_ref_sclite=""):
|
|
||||||
|
|
||||||
if origin_hyp:
|
|
||||||
transform_hyp(
|
|
||||||
origin=origin_hyp, trans=trans_hyp, trans_sclite=trans_hyp_sclite)
|
|
||||||
|
|
||||||
if origin_ref:
|
|
||||||
transform_ref(
|
|
||||||
origin=origin_ref, trans=trans_ref, trans_sclite=trans_ref_sclite)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = define_argparse()
|
|
||||||
print_arguments(args, globals())
|
|
||||||
|
|
||||||
format_result(**vars(args))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Loading…
Reference in new issue