Merge pull request #1673 from Jackwaterveg/CER

[asr] Add new cer tools
pull/1735/head
Hui Zhang 3 years ago committed by GitHub
commit 2b8c08e3e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,6 +5,8 @@ if [ $# != 4 ];then
exit -1
fi
stage=0
stop_stage=100
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
@ -19,6 +21,12 @@ if [ $? -ne 0 ]; then
exit 1
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# format the reference test file
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref data/manifest.test.text
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
@ -32,5 +40,26 @@ if [ $? -ne 0 ]; then
exit 1
fi
# format the hyp file
python utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.rsl \
--trans_hyp ${ckpt_prefix}.rsl.text
python utils/compute-wer.py --char=1 --v=1 \
data/manifest.test.text ${ckpt_prefix}.rsl.text > ${ckpt_prefix}.error
fi
if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref_sclite data/manifest.test.text.sclite
python utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.rsl \
--trans_hyp_sclite ${ckpt_prefix}.rsl.text.sclite
mkdir -p ${ckpt_prefix}_sclite
sclite -i wsj -r data/manifest.test.text.sclite -h ${ckpt_prefix}.rsl.text.sclite -e utf-8 -o all -O ${ckpt_prefix}_sclite -c NOASCII
fi
exit 0

@ -5,6 +5,8 @@ if [ $# != 3 ];then
exit -1
fi
stage=0
stop_stage=100
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
@ -24,6 +26,12 @@ fi
#fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# format the reference test file
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref data/manifest.test.text
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
if [ ${chunk_mode} == true ];then
@ -46,7 +54,15 @@ for type in attention ctc_greedy_search; do
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
# format the hyp file
python utils/format_rsl.py \
--origin_hyp ${output_dir}/${type}.rsl \
--trans_hyp ${output_dir}/${type}.rsl.text
python utils/compute-wer.py --char=1 --v=1 \
data/manifest.test.text ${output_dir}/${type}.rsl.text > ${output_dir}/${type}.error
done
for type in ctc_prefix_beam_search attention_rescoring; do
@ -67,6 +83,29 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "Failed in evaluation!"
exit 1
fi
python utils/format_rsl.py \
--origin_hyp ${output_dir}/${type}.rsl
--trans_hyp ${output_dir}/${type}.rsl.text
python utils/compute-wer.py --char=1 --v=1 \
data/manifest.test.text ${output_dir}/${type}.rsl.text > ${output_dir}/${type}.error
done
fi
if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
# format the reference test file for sclite
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref_sclite data/manifest.test.text.sclite
output_dir=${ckpt_prefix}
for type in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do
python utils/format_rsl.py \
--origin_hyp ${output_dir}/${type}.rsl
--trans_hyp_sclite ${output_dir}/${type}.rsl.text.sclite
mkdir -p ${output_dir}/${type}_sclite
sclite -i wsj -r data/manifest.test.text.sclite -h ${output_dir}/${type}.rsl.text.sclite -e utf-8 -o all -O ${output_dir}/${type}_sclite -c NOASCII
done
fi
exit 0

@ -7,7 +7,7 @@ stage=0
stop_stage=50
conf_path=conf/conformer.yaml
decode_conf_path=conf/tuning/decode.yaml
avg_num=20
avg_num=30
audio_file=data/demo_01_03.wav
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;

@ -278,7 +278,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write({"utt": utt, "ref": target, "hyp": result})
fout.write({"utt": utt, "refs": [target], "hyps": [result]})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")

@ -1,17 +1,15 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# CopyRight WeNet Apache-2.0 License
import re, sys, unicodedata
import codecs
import sys
import unicodedata
remove_tag = True
spacelist= [' ', '\t', '\r', '\n']
puncts = [
'!', ',', '?', '', '', '', '', '', '', '', '', '', '', '', '',
'', ''
]
puncts = ['!', ',', '?',
'', '', '', '', '', '',
'', '', '', '', '', '', '', '']
def characterize(string) :
res = []
@ -32,8 +30,7 @@ def characterize(string):
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' '
if char == '<':
sep = '>'
if char == '<': sep = '>'
j = i+1
while j < len(string):
c = string[j]
@ -46,13 +43,10 @@ def characterize(string):
i = j
return res
def stripoff_tags(x):
if not x:
return ''
if not x: return ''
chars = []
i = 0
T = len(x)
i = 0; T=len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
@ -84,7 +78,6 @@ def normalize(sentence, ignore_words, cs, split=None):
new_sentence.append(x)
return new_sentence
class Calculator :
def __init__(self) :
self.data = {}
@ -94,7 +87,6 @@ class Calculator:
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec) :
# Initialization
lab.insert(0, '')
@ -116,22 +108,10 @@ class Calculator:
self.space[0][0]['error'] = 'non'
for token in lab :
if token not in self.data and len(token) > 0 :
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
for token in rec :
if token not in self.data and len(token) > 0 :
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
# Computing edit distance
for i, lab_token in enumerate(lab) :
for j, rec_token in enumerate(rec) :
@ -161,15 +141,7 @@ class Calculator:
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {
'lab': [],
'rec': [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
i = len(lab) - 1
j = len(rec) - 1
while True :
@ -212,11 +184,8 @@ class Calculator:
elif self.space[i][j]['error'] == 'non' : # starting point
break
else : # shouldn't reach here
print(
'this should not happen , i = {i} , j = {j} , error = {error}'.
format(i=i, j=j, error=self.space[i][j]['error']))
print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error']))
return result
def overall(self) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in self.data :
@ -226,7 +195,6 @@ class Calculator:
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in data :
@ -237,15 +205,12 @@ class Calculator:
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self) :
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word) :
unicode_names = [ unicodedata.name(char) for char in word ]
for i in reversed(range(len(unicode_names))) :
@ -285,15 +250,9 @@ def default_cluster(word):
return 'Other'
return unicode_names[0]
def usage() :
print(
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
print(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
print("compute-wer.py : compute word error rate (WER) and align recognition results and references.")
print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer")
if __name__ == '__main__':
if len(sys.argv) == 1 :
@ -366,7 +325,7 @@ if __name__ == '__main__':
verbose=0
try:
verbose=int(b)
except Exception as e:
except:
if b == 'true' or b != '0':
verbose = 1
continue
@ -409,11 +368,9 @@ if __name__ == '__main__':
array = characterize(line)
else:
array = line.strip().split()
if len(array) == 0:
continue
if len(array)==0: continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
split)
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8') :
@ -421,8 +378,7 @@ if __name__ == '__main__':
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array) == 0:
continue
if len(array)==0: continue
fid = array[0]
if fid not in rec_set:
continue
@ -443,14 +399,12 @@ if __name__ == '__main__':
result = calculator.calculate(lab, rec)
if verbose:
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('WER: %4.2f %%' % wer, end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
space = {}
space['lab'] = []
space['rec'] = []
@ -492,37 +446,30 @@ if __name__ == '__main__':
rec1 = rec2
if verbose:
print(
'==========================================================================='
)
print('===========================================================================')
print()
result = calculator.overall()
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('Overall -> %4.2f %%' % wer, end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
if not verbose:
print()
if verbose:
for cluster_id in default_clusters :
result = calculator.cluster(
[k for k in default_clusters[cluster_id]])
result = calculator.cluster([ k for k in default_clusters[cluster_id] ])
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
if len(cluster_file) > 0 : # compute separated WERs for word clusters
cluster_id = ''
cluster = []
@ -533,14 +480,12 @@ if __name__ == '__main__':
token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster)
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'],
result['del'], result['ins']))
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
cluster_id = ''
cluster = []
# begin of cluster reached, like <Keyword>
@ -552,6 +497,4 @@ if __name__ == '__main__':
else :
cluster.append(token)
print()
print(
'==========================================================================='
)
print('===========================================================================')

@ -0,0 +1,90 @@
import os
import argparse
import jsonlines
def trans_hyp(origin_hyp,
trans_hyp = None,
trans_hyp_sclite = None):
"""
Args:
origin_hyp: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite
"""
input_dict = {}
with open(origin_hyp, "r+", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["hyps"][0]
if trans_hyp is not None:
with open(trans_hyp, "w+", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
if trans_hyp_sclite is not None:
with open(trans_hyp_sclite, "w+") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n"
f.write(line)
def trans_ref(origin_ref,
trans_ref = None,
trans_ref_sclite = None):
"""
Args:
origin_hyp: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite
"""
input_dict = {}
with open(origin_ref, "r", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["text"]
if trans_ref is not None:
with open(trans_ref, "w", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
if trans_ref_sclite is not None:
with open(trans_ref_sclite, "w") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n"
f.write(line)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='format hyp file for compute CER/WER', add_help=True)
parser.add_argument(
'--origin_hyp',
type=str,
default = None,
help='origin hyp file')
parser.add_argument(
'--trans_hyp', type=str, default = None, help='hyp file for caculating CER/WER')
parser.add_argument(
'--trans_hyp_sclite', type=str, default = None, help='hyp file for caculating CER/WER by sclite')
parser.add_argument(
'--origin_ref',
type=str,
default = None,
help='origin ref file')
parser.add_argument(
'--trans_ref', type=str, default = None, help='ref file for caculating CER/WER')
parser.add_argument(
'--trans_ref_sclite', type=str, default = None, help='ref file for caculating CER/WER by sclite')
parser_args = parser.parse_args()
if parser_args.origin_hyp is not None:
trans_hyp(
origin_hyp = parser_args.origin_hyp,
trans_hyp = parser_args.trans_hyp,
trans_hyp_sclite = parser_args.trans_hyp_sclite, )
if parser_args.origin_ref is not None:
trans_ref(
origin_ref = parser_args.origin_ref,
trans_ref = parser_args.trans_ref,
trans_ref_sclite = parser_args.trans_ref_sclite, )
Loading…
Cancel
Save