Merge pull request #1673 from Jackwaterveg/CER

[asr] Add new cer tools
pull/1735/head
Hui Zhang 2 years ago committed by GitHub
commit 2b8c08e3e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,6 +5,8 @@ if [ $# != 4 ];then
exit -1
fi
stage=0
stop_stage=100
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
@ -19,18 +21,45 @@ if [ $? -ne 0 ]; then
exit 1
fi
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# format the reference test file
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref data/manifest.test.text
if [ $? -ne 0 ]; then
python3 -u ${BIN_DIR}/test.py \
--ngpu ${ngpu} \
--config ${config_path} \
--decode_cfg ${decode_config_path} \
--result_file ${ckpt_prefix}.rsl \
--checkpoint_path ${ckpt_prefix} \
--model_type ${model_type}
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
# format the hyp file
python utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.rsl \
--trans_hyp ${ckpt_prefix}.rsl.text
python utils/compute-wer.py --char=1 --v=1 \
data/manifest.test.text ${ckpt_prefix}.rsl.text > ${ckpt_prefix}.error
fi
if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref_sclite data/manifest.test.text.sclite
python utils/format_rsl.py \
--origin_hyp ${ckpt_prefix}.rsl \
--trans_hyp_sclite ${ckpt_prefix}.rsl.text.sclite
mkdir -p ${ckpt_prefix}_sclite
sclite -i wsj -r data/manifest.test.text.sclite -h ${ckpt_prefix}.rsl.text.sclite -e utf-8 -o all -O ${ckpt_prefix}_sclite -c NOASCII
fi
exit 0

@ -5,6 +5,8 @@ if [ $# != 3 ];then
exit -1
fi
stage=0
stop_stage=100
ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
echo "using $ngpu gpus..."
@ -24,7 +26,13 @@ fi
#fi
for type in attention ctc_greedy_search; do
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# format the reference test file
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref data/manifest.test.text
for type in attention ctc_greedy_search; do
echo "decoding ${type}"
if [ ${chunk_mode} == true ];then
# stream decoding only support batchsize=1
@ -46,10 +54,18 @@ for type in attention ctc_greedy_search; do
if [ $? -ne 0 ]; then
echo "Failed in evaluation!"
exit 1
fi
done
# format the hyp file
python utils/format_rsl.py \
--origin_hyp ${output_dir}/${type}.rsl \
--trans_hyp ${output_dir}/${type}.rsl.text
python utils/compute-wer.py --char=1 --v=1 \
data/manifest.test.text ${output_dir}/${type}.rsl.text > ${output_dir}/${type}.error
done
for type in ctc_prefix_beam_search attention_rescoring; do
for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
output_dir=${ckpt_prefix}
@ -67,6 +83,29 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "Failed in evaluation!"
exit 1
fi
done
python utils/format_rsl.py \
--origin_hyp ${output_dir}/${type}.rsl
--trans_hyp ${output_dir}/${type}.rsl.text
python utils/compute-wer.py --char=1 --v=1 \
data/manifest.test.text ${output_dir}/${type}.rsl.text > ${output_dir}/${type}.error
done
fi
if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then
# format the reference test file for sclite
python utils/format_rsl.py \
--origin_ref data/manifest.test.raw \
--trans_ref_sclite data/manifest.test.text.sclite
output_dir=${ckpt_prefix}
for type in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do
python utils/format_rsl.py \
--origin_hyp ${output_dir}/${type}.rsl
--trans_hyp_sclite ${output_dir}/${type}.rsl.text.sclite
mkdir -p ${output_dir}/${type}_sclite
sclite -i wsj -r data/manifest.test.text.sclite -h ${output_dir}/${type}.rsl.text.sclite -e utf-8 -o all -O ${output_dir}/${type}_sclite -c NOASCII
done
fi
exit 0

@ -7,7 +7,7 @@ stage=0
stop_stage=50
conf_path=conf/conformer.yaml
decode_conf_path=conf/tuning/decode.yaml
avg_num=20
avg_num=30
audio_file=data/demo_01_03.wav
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1;

@ -278,7 +278,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
len_refs += len_ref
num_ins += 1
if fout:
fout.write({"utt": utt, "ref": target, "hyp": result})
fout.write({"utt": utt, "refs": [target], "hyps": [result]})
logger.info(f"Utt: {utt}")
logger.info(f"Ref: {target}")
logger.info(f"Hyp: {result}")

@ -1,19 +1,17 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# CopyRight WeNet Apache-2.0 License
import re, sys, unicodedata
import codecs
import sys
import unicodedata
remove_tag = True
spacelist = [' ', '\t', '\r', '\n']
puncts = [
'!', ',', '?', '', '', '', '', '', '', '', '', '', '', '', '',
'', ''
]
spacelist= [' ', '\t', '\r', '\n']
puncts = ['!', ',', '?',
'', '', '', '', '', '',
'', '', '', '', '', '', '', '']
def characterize(string):
def characterize(string) :
res = []
i = 0
while i < len(string):
@ -32,12 +30,11 @@ def characterize(string):
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' '
if char == '<':
sep = '>'
j = i + 1
if char == '<': sep = '>'
j = i+1
while j < len(string):
c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c == sep):
if ord(c) >= 128 or (c in spacelist) or (c==sep):
break
j += 1
if j < len(string) and string[j] == '>':
@ -46,13 +43,10 @@ def characterize(string):
i = j
return res
def stripoff_tags(x):
if not x:
return ''
if not x: return ''
chars = []
i = 0
T = len(x)
i = 0; T=len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
@ -84,9 +78,8 @@ def normalize(sentence, ignore_words, cs, split=None):
new_sentence.append(x)
return new_sentence
class Calculator:
def __init__(self):
class Calculator :
def __init__(self) :
self.data = {}
self.space = []
self.cost = {}
@ -94,87 +87,66 @@ class Calculator:
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec):
def calculate(self, lab, rec) :
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab):
while len(self.space) < len(lab) :
self.space.append([])
for row in self.space:
for element in row:
for row in self.space :
for element in row :
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec):
row.append({'dist': 0, 'error': 'non'})
for i in range(len(lab)):
while len(row) < len(rec) :
row.append({'dist' : 0, 'error' : 'non'})
for i in range(len(lab)) :
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)):
for j in range(len(rec)) :
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in lab :
if token not in self.data and len(token) > 0 :
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
for token in rec :
if token not in self.data and len(token) > 0 :
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
if i == 0 or j == 0:
for i, lab_token in enumerate(lab) :
for j, rec_token in enumerate(rec) :
if i == 0 or j == 0 :
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i - 1][j]['dist'] + self.cost['del']
dist = self.space[i-1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist:
if dist < min_dist :
min_dist = dist
min_error = error
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
dist = self.space[i][j-1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist:
if dist < min_dist :
min_dist = dist
min_error = error
if lab_token == rec_token:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
if lab_token == rec_token :
dist = self.space[i-1][j-1]['dist'] + self.cost['cor']
error = 'cor'
else:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
else :
dist = self.space[i-1][j-1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist:
if dist < min_dist :
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {
'lab': [],
'rec': [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
i = len(lab) - 1
j = len(rec) - 1
while True:
if self.space[i][j]['error'] == 'cor': # correct
if len(lab[i]) > 0:
while True :
if self.space[i][j]['error'] == 'cor' : # correct
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
@ -183,8 +155,8 @@ class Calculator:
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub': # substitution
if len(lab[i]) > 0:
elif self.space[i][j]['error'] == 'sub' : # substitution
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
@ -193,8 +165,8 @@ class Calculator:
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del': # deletion
if len(lab[i]) > 0:
elif self.space[i][j]['error'] == 'del' : # deletion
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
@ -202,64 +174,57 @@ class Calculator:
result['lab'].insert(0, lab[i])
result['rec'].insert(0, "")
i = i - 1
elif self.space[i][j]['error'] == 'ins': # insertion
if len(rec[j]) > 0:
elif self.space[i][j]['error'] == 'ins' : # insertion
if len(rec[j]) > 0 :
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, "")
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non': # starting point
elif self.space[i][j]['error'] == 'non' : # starting point
break
else: # shouldn't reach here
print(
'this should not happen , i = {i} , j = {j} , error = {error}'.
format(i=i, j=j, error=self.space[i][j]['error']))
else : # shouldn't reach here
print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error']))
return result
def overall(self):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in self.data:
def overall(self) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in self.data :
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in data:
if token in self.data:
def cluster(self, data) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in data :
if token in self.data :
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self):
def keys(self) :
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word):
unicode_names = [unicodedata.name(char) for char in word]
for i in reversed(range(len(unicode_names))):
if unicode_names[i].startswith('DIGIT'): # 1
def default_cluster(word) :
unicode_names = [ unicodedata.name(char) for char in word ]
for i in reversed(range(len(unicode_names))) :
if unicode_names[i].startswith('DIGIT') : # 1
unicode_names[i] = 'Number' # 'DIGIT'
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or
unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) :
# 明 / 郎
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or
unicode_names[i].startswith('LATIN SMALL LETTER')):
unicode_names[i].startswith('LATIN SMALL LETTER')) :
# A / a
unicode_names[i] = 'English' # 'LATIN LETTER'
elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め
unicode_names[i] = 'Japanese' # 'GANA LETTER'
elif (unicode_names[i].startswith('AMPERSAND') or
unicode_names[i].startswith('APOSTROPHE') or
@ -271,40 +236,34 @@ def default_cluster(word):
unicode_names[i].startswith('LOW LINE') or
unicode_names[i].startswith('NUMBER SIGN') or
unicode_names[i].startswith('PLUS SIGN') or
unicode_names[i].startswith('SEMICOLON')):
unicode_names[i].startswith('SEMICOLON')) :
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
else:
else :
return 'Other'
if len(unicode_names) == 0:
if len(unicode_names) == 0 :
return 'Other'
if len(unicode_names) == 1:
if len(unicode_names) == 1 :
return unicode_names[0]
for i in range(len(unicode_names) - 1):
if unicode_names[i] != unicode_names[i + 1]:
for i in range(len(unicode_names)-1) :
if unicode_names[i] != unicode_names[i+1] :
return 'Other'
return unicode_names[0]
def usage():
print(
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
print(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
def usage() :
print("compute-wer.py : compute word error rate (WER) and align recognition results and references.")
print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer")
if __name__ == '__main__':
if len(sys.argv) == 1:
if len(sys.argv) == 1 :
usage()
sys.exit(0)
calculator = Calculator()
cluster_file = ''
ignore_words = set()
tochar = False
verbose = 1
padding_symbol = ' '
verbose= 1
padding_symbol= ' '
case_sensitive = False
max_words_per_line = sys.maxsize
split = None
@ -363,10 +322,10 @@ if __name__ == '__main__':
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
verbose = 0
verbose=0
try:
verbose = int(b)
except Exception as e:
verbose=int(b)
except:
if b == 'true' or b != '0':
verbose = 1
continue
@ -375,9 +334,9 @@ if __name__ == '__main__':
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
if b == 'space':
padding_symbol = ' '
padding_symbol= ' '
elif b == 'underline':
padding_symbol = '_'
padding_symbol= '_'
continue
if True or sys.argv[1].startswith('-'):
#ignore invalid switch
@ -385,7 +344,7 @@ if __name__ == '__main__':
continue
if not case_sensitive:
ig = set([w.upper() for w in ignore_words])
ig=set([w.upper() for w in ignore_words])
ignore_words = ig
default_clusters = {}
@ -409,20 +368,17 @@ if __name__ == '__main__':
array = characterize(line)
else:
array = line.strip().split()
if len(array) == 0:
continue
if len(array)==0: continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
split)
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8'):
for line in open(ref_file, 'r', encoding='utf-8') :
if tochar:
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array) == 0:
continue
if len(array)==0: continue
fid = array[0]
if fid not in rec_set:
continue
@ -431,127 +387,114 @@ if __name__ == '__main__':
if verbose:
print('\nutt: %s' % fid)
for word in rec + lab:
if word not in default_words:
for word in rec + lab :
if word not in default_words :
default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters:
if default_cluster_name not in default_clusters :
default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name]:
if word not in default_clusters[default_cluster_name] :
default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name
result = calculator.calculate(lab, rec)
if verbose:
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('WER: %4.2f %%' % wer, end=' ')
print('WER: %4.2f %%' % wer, end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
space = {}
space['lab'] = []
space['rec'] = []
for idx in range(len(result['lab'])):
for idx in range(len(result['lab'])) :
len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx])
length = max(len_lab, len_rec)
space['lab'].append(length - len_lab)
space['rec'].append(length - len_rec)
space['lab'].append(length-len_lab)
space['rec'].append(length-len_rec)
upper_lab = len(result['lab'])
upper_rec = len(result['rec'])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end=' ')
print('lab(%s):' % fid.encode('utf-8'), end = ' ')
else:
print('lab:', end=' ')
print('lab:', end = ' ')
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result['lab'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['lab'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print('{token}'.format(token = token), end = '')
for n in range(space['lab'][idx]) :
print(padding_symbol, end = '')
print(' ',end='')
print()
if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end=' ')
print('rec(%s):' % fid.encode('utf-8'), end = ' ')
else:
print('rec:', end=' ')
print('rec:', end = ' ')
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result['rec'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['rec'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print('{token}'.format(token = token), end = '')
for n in range(space['rec'][idx]) :
print(padding_symbol, end = '')
print(' ',end='')
print('\n', end='\n')
lab1 = lab2
rec1 = rec2
if verbose:
print(
'==========================================================================='
)
print('===========================================================================')
print()
result = calculator.overall()
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('Overall -> %4.2f %%' % wer, end=' ')
print('Overall -> %4.2f %%' % wer, end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
if not verbose:
print()
if verbose:
for cluster_id in default_clusters:
result = calculator.cluster(
[k for k in default_clusters[cluster_id]])
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
for cluster_id in default_clusters :
result = calculator.cluster([ k for k in default_clusters[cluster_id] ])
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if len(cluster_file) > 0: # compute separated WERs for word clusters
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
if len(cluster_file) > 0 : # compute separated WERs for word clusters
cluster_id = ''
cluster = []
for line in open(cluster_file, 'r', encoding='utf-8'):
for token in line.decode('utf-8').rstrip('\n').split():
for line in open(cluster_file, 'r', encoding='utf-8') :
for token in line.decode('utf-8').rstrip('\n').split() :
# end of cluster reached, like </Keyword>
if token[0:2] == '</' and token[len(token) - 1] == '>' and \
if token[0:2] == '</' and token[len(token)-1] == '>' and \
token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster)
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'],
result['del'], result['ins']))
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
cluster_id = ''
cluster = []
# begin of cluster reached, like <Keyword>
elif token[0] == '<' and token[len(token) - 1] == '>' and \
elif token[0] == '<' and token[len(token)-1] == '>' and \
cluster_id == '' :
cluster_id = token.lstrip('<').rstrip('>')
cluster = []
# general terms, like WEATHER / CAR / ...
else:
else :
cluster.append(token)
print()
print(
'==========================================================================='
)
print('===========================================================================')

@ -0,0 +1,90 @@
import os
import argparse
import jsonlines
def trans_hyp(origin_hyp,
trans_hyp = None,
trans_hyp_sclite = None):
"""
Args:
origin_hyp: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite
"""
input_dict = {}
with open(origin_hyp, "r+", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["hyps"][0]
if trans_hyp is not None:
with open(trans_hyp, "w+", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
if trans_hyp_sclite is not None:
with open(trans_hyp_sclite, "w+") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n"
f.write(line)
def trans_ref(origin_ref,
trans_ref = None,
trans_ref_sclite = None):
"""
Args:
origin_hyp: The input json file which contains the model output
trans_hyp: The output file for caculate CER/WER
trans_hyp_sclite: The output file for caculate CER/WER using sclite
"""
input_dict = {}
with open(origin_ref, "r", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["text"]
if trans_ref is not None:
with open(trans_ref, "w", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
if trans_ref_sclite is not None:
with open(trans_ref_sclite, "w") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n"
f.write(line)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='format hyp file for compute CER/WER', add_help=True)
parser.add_argument(
'--origin_hyp',
type=str,
default = None,
help='origin hyp file')
parser.add_argument(
'--trans_hyp', type=str, default = None, help='hyp file for caculating CER/WER')
parser.add_argument(
'--trans_hyp_sclite', type=str, default = None, help='hyp file for caculating CER/WER by sclite')
parser.add_argument(
'--origin_ref',
type=str,
default = None,
help='origin ref file')
parser.add_argument(
'--trans_ref', type=str, default = None, help='ref file for caculating CER/WER')
parser.add_argument(
'--trans_ref_sclite', type=str, default = None, help='ref file for caculating CER/WER by sclite')
parser_args = parser.parse_args()
if parser_args.origin_hyp is not None:
trans_hyp(
origin_hyp = parser_args.origin_hyp,
trans_hyp = parser_args.trans_hyp,
trans_hyp_sclite = parser_args.trans_hyp_sclite, )
if parser_args.origin_ref is not None:
trans_ref(
origin_ref = parser_args.origin_ref,
trans_ref = parser_args.trans_ref,
trans_ref_sclite = parser_args.trans_ref_sclite, )
Loading…
Cancel
Save