diff --git a/examples/aishell/asr0/local/test.sh b/examples/aishell/asr0/local/test.sh index 463593ef..363dbf0a 100755 --- a/examples/aishell/asr0/local/test.sh +++ b/examples/aishell/asr0/local/test.sh @@ -5,6 +5,8 @@ if [ $# != 4 ];then exit -1 fi +stage=0 +stop_stage=100 ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." @@ -19,18 +21,45 @@ if [ $? -ne 0 ]; then exit 1 fi -python3 -u ${BIN_DIR}/test.py \ ---ngpu ${ngpu} \ ---config ${config_path} \ ---decode_cfg ${decode_config_path} \ ---result_file ${ckpt_prefix}.rsl \ ---checkpoint_path ${ckpt_prefix} \ ---model_type ${model_type} +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # format the reference test file + python utils/format_rsl.py \ + --origin_ref data/manifest.test.raw \ + --trans_ref data/manifest.test.text -if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 + python3 -u ${BIN_DIR}/test.py \ + --ngpu ${ngpu} \ + --config ${config_path} \ + --decode_cfg ${decode_config_path} \ + --result_file ${ckpt_prefix}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --model_type ${model_type} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi + + # format the hyp file + python utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.rsl \ + --trans_hyp ${ckpt_prefix}.rsl.text + + python utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test.text ${ckpt_prefix}.rsl.text > ${ckpt_prefix}.error fi +if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then + python utils/format_rsl.py \ + --origin_ref data/manifest.test.raw \ + --trans_ref_sclite data/manifest.test.text.sclite + + python utils/format_rsl.py \ + --origin_hyp ${ckpt_prefix}.rsl \ + --trans_hyp_sclite ${ckpt_prefix}.rsl.text.sclite + + mkdir -p ${ckpt_prefix}_sclite + sclite -i wsj -r data/manifest.test.text.sclite -h ${ckpt_prefix}.rsl.text.sclite -e utf-8 -o all -O ${ckpt_prefix}_sclite -c NOASCII +fi exit 0 diff --git a/examples/aishell/asr1/local/test.sh b/examples/aishell/asr1/local/test.sh index 65b884e5..a88feeed 100755 --- a/examples/aishell/asr1/local/test.sh +++ b/examples/aishell/asr1/local/test.sh @@ -5,6 +5,8 @@ if [ $# != 3 ];then exit -1 fi +stage=0 +stop_stage=100 ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." @@ -24,49 +26,86 @@ fi #fi -for type in attention ctc_greedy_search; do - echo "decoding ${type}" - if [ ${chunk_mode} == true ];then - # stream decoding only support batchsize=1 +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # format the reference test file + python utils/format_rsl.py \ + --origin_ref data/manifest.test.raw \ + --trans_ref data/manifest.test.text + + for type in attention ctc_greedy_search; do + echo "decoding ${type}" + if [ ${chunk_mode} == true ];then + # stream decoding only support batchsize=1 + batch_size=1 + else + batch_size=64 + fi + output_dir=${ckpt_prefix} + mkdir -p ${output_dir} + python3 -u ${BIN_DIR}/test.py \ + --ngpu ${ngpu} \ + --config ${config_path} \ + --decode_cfg ${decode_config_path} \ + --result_file ${output_dir}/${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decode.decoding_method ${type} \ + --opts decode.decode_batch_size ${batch_size} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + + fi + # format the hyp file + python utils/format_rsl.py \ + --origin_hyp ${output_dir}/${type}.rsl \ + --trans_hyp ${output_dir}/${type}.rsl.text + python utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test.text ${output_dir}/${type}.rsl.text > ${output_dir}/${type}.error + + done + + for type in ctc_prefix_beam_search attention_rescoring; do + echo "decoding ${type}" batch_size=1 - else - batch_size=64 - fi - output_dir=${ckpt_prefix} - mkdir -p ${output_dir} - python3 -u ${BIN_DIR}/test.py \ - --ngpu ${ngpu} \ - --config ${config_path} \ - --decode_cfg ${decode_config_path} \ - --result_file ${output_dir}/${type}.rsl \ - --checkpoint_path ${ckpt_prefix} \ - --opts decode.decoding_method ${type} \ - --opts decode.decode_batch_size ${batch_size} - - if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 - fi -done - -for type in ctc_prefix_beam_search attention_rescoring; do - echo "decoding ${type}" - batch_size=1 + output_dir=${ckpt_prefix} + mkdir -p ${output_dir} + python3 -u ${BIN_DIR}/test.py \ + --ngpu ${ngpu} \ + --config ${config_path} \ + --decode_cfg ${decode_config_path} \ + --result_file ${output_dir}/${type}.rsl \ + --checkpoint_path ${ckpt_prefix} \ + --opts decode.decoding_method ${type} \ + --opts decode.decode_batch_size ${batch_size} + + if [ $? -ne 0 ]; then + echo "Failed in evaluation!" + exit 1 + fi + python utils/format_rsl.py \ + --origin_hyp ${output_dir}/${type}.rsl + --trans_hyp ${output_dir}/${type}.rsl.text + python utils/compute-wer.py --char=1 --v=1 \ + data/manifest.test.text ${output_dir}/${type}.rsl.text > ${output_dir}/${type}.error + done +fi + +if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then + # format the reference test file for sclite + python utils/format_rsl.py \ + --origin_ref data/manifest.test.raw \ + --trans_ref_sclite data/manifest.test.text.sclite + output_dir=${ckpt_prefix} - mkdir -p ${output_dir} - python3 -u ${BIN_DIR}/test.py \ - --ngpu ${ngpu} \ - --config ${config_path} \ - --decode_cfg ${decode_config_path} \ - --result_file ${output_dir}/${type}.rsl \ - --checkpoint_path ${ckpt_prefix} \ - --opts decode.decoding_method ${type} \ - --opts decode.decode_batch_size ${batch_size} - - if [ $? -ne 0 ]; then - echo "Failed in evaluation!" - exit 1 - fi -done + for type in attention ctc_greedy_search ctc_prefix_beam_search attention_rescoring; do + python utils/format_rsl.py \ + --origin_hyp ${output_dir}/${type}.rsl + --trans_hyp_sclite ${output_dir}/${type}.rsl.text.sclite + + mkdir -p ${output_dir}/${type}_sclite + sclite -i wsj -r data/manifest.test.text.sclite -h ${output_dir}/${type}.rsl.text.sclite -e utf-8 -o all -O ${output_dir}/${type}_sclite -c NOASCII + done +fi exit 0 diff --git a/examples/aishell/asr1/run.sh b/examples/aishell/asr1/run.sh index c54dae9c..cb781b20 100644 --- a/examples/aishell/asr1/run.sh +++ b/examples/aishell/asr1/run.sh @@ -7,7 +7,7 @@ stage=0 stop_stage=50 conf_path=conf/conformer.yaml decode_conf_path=conf/tuning/decode.yaml -avg_num=20 +avg_num=30 audio_file=data/demo_01_03.wav source ${MAIN_ROOT}/utils/parse_options.sh || exit 1; diff --git a/paddlespeech/s2t/exps/deepspeech2/model.py b/paddlespeech/s2t/exps/deepspeech2/model.py index 3e9ede76..3c2eaab7 100644 --- a/paddlespeech/s2t/exps/deepspeech2/model.py +++ b/paddlespeech/s2t/exps/deepspeech2/model.py @@ -278,7 +278,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): len_refs += len_ref num_ins += 1 if fout: - fout.write({"utt": utt, "ref": target, "hyp": result}) + fout.write({"utt": utt, "refs": [target], "hyps": [result]}) logger.info(f"Utt: {utt}") logger.info(f"Ref: {target}") logger.info(f"Hyp: {result}") diff --git a/utils/compute-wer.py b/utils/compute-wer.py index 2d7cc8e1..b3dbf225 100755 --- a/utils/compute-wer.py +++ b/utils/compute-wer.py @@ -1,67 +1,61 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # CopyRight WeNet Apache-2.0 License + +import re, sys, unicodedata import codecs -import sys -import unicodedata remove_tag = True -spacelist = [' ', '\t', '\r', '\n'] -puncts = [ - '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', - '《', '》' -] - - -def characterize(string): - res = [] - i = 0 - while i < len(string): - char = string[i] - if char in puncts: - i += 1 - continue - cat1 = unicodedata.category(char) - #https://unicodebook.readthedocs.io/unicode.html#unicode-categories - if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned - i += 1 - continue - if cat1 == 'Lo': # letter-other - res.append(char) - i += 1 - else: - # some input looks like: , we want to separate it to two words. - sep = ' ' - if char == '<': - sep = '>' - j = i + 1 - while j < len(string): - c = string[j] - if ord(c) >= 128 or (c in spacelist) or (c == sep): - break - j += 1 - if j < len(string) and string[j] == '>': - j += 1 - res.append(string[i:j]) - i = j - return res +spacelist= [' ', '\t', '\r', '\n'] +puncts = ['!', ',', '?', + '、', '。', '!', ',', ';', '?', + ':', '「', '」', '︰', '『', '』', '《', '》'] +def characterize(string) : + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': sep = '>' + j = i+1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c==sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res def stripoff_tags(x): - if not x: - return '' - chars = [] - i = 0 - T = len(x) - while i < T: - if x[i] == '<': - while i < T and x[i] != '>': - i += 1 - i += 1 - else: - chars.append(x[i]) - i += 1 - return ''.join(chars) + if not x: return '' + chars = [] + i = 0; T=len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) def normalize(sentence, ignore_words, cs, split=None): @@ -71,487 +65,436 @@ def normalize(sentence, ignore_words, cs, split=None): for token in sentence: x = token if not cs: - x = x.upper() + x = x.upper() if x in ignore_words: - continue + continue if remove_tag: - x = stripoff_tags(x) + x = stripoff_tags(x) if not x: - continue + continue if split and x in split: - new_sentence += split[x] + new_sentence += split[x] else: - new_sentence.append(x) + new_sentence.append(x) return new_sentence - -class Calculator: - def __init__(self): - self.data = {} - self.space = [] - self.cost = {} - self.cost['cor'] = 0 - self.cost['sub'] = 1 - self.cost['del'] = 1 - self.cost['ins'] = 1 - - def calculate(self, lab, rec): - # Initialization - lab.insert(0, '') - rec.insert(0, '') - while len(self.space) < len(lab): - self.space.append([]) - for row in self.space: - for element in row: - element['dist'] = 0 - element['error'] = 'non' - while len(row) < len(rec): - row.append({'dist': 0, 'error': 'non'}) - for i in range(len(lab)): - self.space[i][0]['dist'] = i - self.space[i][0]['error'] = 'del' - for j in range(len(rec)): - self.space[0][j]['dist'] = j - self.space[0][j]['error'] = 'ins' - self.space[0][0]['error'] = 'non' - for token in lab: - if token not in self.data and len(token) > 0: - self.data[token] = { - 'all': 0, - 'cor': 0, - 'sub': 0, - 'ins': 0, - 'del': 0 - } - for token in rec: - if token not in self.data and len(token) > 0: - self.data[token] = { - 'all': 0, - 'cor': 0, - 'sub': 0, - 'ins': 0, - 'del': 0 - } - # Computing edit distance - for i, lab_token in enumerate(lab): - for j, rec_token in enumerate(rec): - if i == 0 or j == 0: - continue - min_dist = sys.maxsize - min_error = 'none' - dist = self.space[i - 1][j]['dist'] + self.cost['del'] - error = 'del' - if dist < min_dist: - min_dist = dist - min_error = error - dist = self.space[i][j - 1]['dist'] + self.cost['ins'] - error = 'ins' - if dist < min_dist: - min_dist = dist - min_error = error - if lab_token == rec_token: - dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] - error = 'cor' - else: - dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] - error = 'sub' - if dist < min_dist: - min_dist = dist - min_error = error - self.space[i][j]['dist'] = min_dist - self.space[i][j]['error'] = min_error - # Tracing back - result = { - 'lab': [], - 'rec': [], - 'all': 0, - 'cor': 0, - 'sub': 0, - 'ins': 0, - 'del': 0 - } - i = len(lab) - 1 - j = len(rec) - 1 - while True: - if self.space[i][j]['error'] == 'cor': # correct - if len(lab[i]) > 0: - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 - result['all'] = result['all'] + 1 - result['cor'] = result['cor'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'sub': # substitution - if len(lab[i]) > 0: - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 - result['all'] = result['all'] + 1 - result['sub'] = result['sub'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'del': # deletion - if len(lab[i]) > 0: - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 - result['all'] = result['all'] + 1 - result['del'] = result['del'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, "") - i = i - 1 - elif self.space[i][j]['error'] == 'ins': # insertion - if len(rec[j]) > 0: - self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 - result['ins'] = result['ins'] + 1 - result['lab'].insert(0, "") - result['rec'].insert(0, rec[j]) - j = j - 1 - elif self.space[i][j]['error'] == 'non': # starting point - break - else: # shouldn't reach here - print( - 'this should not happen , i = {i} , j = {j} , error = {error}'. - format(i=i, j=j, error=self.space[i][j]['error'])) - return result - - def overall(self): - result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} - for token in self.data: - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - - def cluster(self, data): - result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} - for token in data: - if token in self.data: - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - - def keys(self): - return list(self.data.keys()) - +class Calculator : + def __init__(self) : + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + def calculate(self, lab, rec) : + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab) : + self.space.append([]) + for row in self.space : + for element in row : + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec) : + row.append({'dist' : 0, 'error' : 'non'}) + for i in range(len(lab)) : + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)) : + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} + for token in rec : + if token not in self.data and len(token) > 0 : + self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} + # Computing edit distance + for i, lab_token in enumerate(lab) : + for j, rec_token in enumerate(rec) : + if i == 0 or j == 0 : + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i-1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist : + min_dist = dist + min_error = error + dist = self.space[i][j-1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist : + min_dist = dist + min_error = error + if lab_token == rec_token : + dist = self.space[i-1][j-1]['dist'] + self.cost['cor'] + error = 'cor' + else : + dist = self.space[i-1][j-1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist : + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + i = len(lab) - 1 + j = len(rec) - 1 + while True : + if self.space[i][j]['error'] == 'cor' : # correct + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub' : # substitution + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del' : # deletion + if len(lab[i]) > 0 : + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins' : # insertion + if len(rec[j]) > 0 : + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non' : # starting point + break + else : # shouldn't reach here + print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error'])) + return result + def overall(self) : + result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + for token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + def cluster(self, data) : + result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} + for token in data : + if token in self.data : + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + def keys(self) : + return list(self.data.keys()) def width(string): - return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) +def default_cluster(word) : + unicode_names = [ unicodedata.name(char) for char in word ] + for i in reversed(range(len(unicode_names))) : + if unicode_names[i].startswith('DIGIT') : # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')) : + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')) : + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else : + return 'Other' + if len(unicode_names) == 0 : + return 'Other' + if len(unicode_names) == 1 : + return unicode_names[0] + for i in range(len(unicode_names)-1) : + if unicode_names[i] != unicode_names[i+1] : + return 'Other' + return unicode_names[0] -def default_cluster(word): - unicode_names = [unicodedata.name(char) for char in word] - for i in reversed(range(len(unicode_names))): - if unicode_names[i].startswith('DIGIT'): # 1 - unicode_names[i] = 'Number' # 'DIGIT' - elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or - unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): - # 明 / 郎 - unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' - elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or - unicode_names[i].startswith('LATIN SMALL LETTER')): - # A / a - unicode_names[i] = 'English' # 'LATIN LETTER' - elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め - unicode_names[i] = 'Japanese' # 'GANA LETTER' - elif (unicode_names[i].startswith('AMPERSAND') or - unicode_names[i].startswith('APOSTROPHE') or - unicode_names[i].startswith('COMMERCIAL AT') or - unicode_names[i].startswith('DEGREE CELSIUS') or - unicode_names[i].startswith('EQUALS SIGN') or - unicode_names[i].startswith('FULL STOP') or - unicode_names[i].startswith('HYPHEN-MINUS') or - unicode_names[i].startswith('LOW LINE') or - unicode_names[i].startswith('NUMBER SIGN') or - unicode_names[i].startswith('PLUS SIGN') or - unicode_names[i].startswith('SEMICOLON')): - # & / ' / @ / ℃ / = / . / - / _ / # / + / ; - del unicode_names[i] - else: - return 'Other' - if len(unicode_names) == 0: - return 'Other' - if len(unicode_names) == 1: - return unicode_names[0] - for i in range(len(unicode_names) - 1): - if unicode_names[i] != unicode_names[i + 1]: - return 'Other' - return unicode_names[0] - - -def usage(): - print( - "compute-wer.py : compute word error rate (WER) and align recognition results and references." - ) - print( - " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" - ) - +def usage() : + print("compute-wer.py : compute word error rate (WER) and align recognition results and references.") + print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") if __name__ == '__main__': - if len(sys.argv) == 1: - usage() - sys.exit(0) - calculator = Calculator() - cluster_file = '' - ignore_words = set() - tochar = False - verbose = 1 - padding_symbol = ' ' - case_sensitive = False - max_words_per_line = sys.maxsize - split = None - while len(sys.argv) > 3: - a = '--maxw=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):] - del sys.argv[1] - max_words_per_line = int(b) - continue - a = '--rt=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - remove_tag = (b == 'true') or (b != '0') - continue - a = '--cs=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - case_sensitive = (b == 'true') or (b != '0') - continue - a = '--cluster=' - if sys.argv[1].startswith(a): - cluster_file = sys.argv[1][len(a):] - del sys.argv[1] - continue - a = '--splitfile=' - if sys.argv[1].startswith(a): - split_file = sys.argv[1][len(a):] - del sys.argv[1] - split = dict() - with codecs.open(split_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - words = line.strip().split() - if len(words) >= 2: - split[words[0]] = words[1:] - continue - a = '--ig=' - if sys.argv[1].startswith(a): - ignore_file = sys.argv[1][len(a):] - del sys.argv[1] - with codecs.open(ignore_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - line = line.strip() - if len(line) > 0: - ignore_words.add(line) - continue - a = '--char=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - tochar = (b == 'true') or (b != '0') - continue - a = '--v=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - verbose = 0 - try: - verbose = int(b) - except Exception as e: - if b == 'true' or b != '0': - verbose = 1 - continue - a = '--padding-symbol=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - if b == 'space': - padding_symbol = ' ' - elif b == 'underline': - padding_symbol = '_' - continue - if True or sys.argv[1].startswith('-'): - #ignore invalid switch - del sys.argv[1] - continue + if len(sys.argv) == 1 : + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose= 1 + padding_symbol= ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose=0 + try: + verbose=int(b) + except: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol= ' ' + elif b == 'underline': + padding_symbol= '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue - if not case_sensitive: - ig = set([w.upper() for w in ignore_words]) - ignore_words = ig + if not case_sensitive: + ig=set([w.upper() for w in ignore_words]) + ignore_words = ig - default_clusters = {} - default_words = {} + default_clusters = {} + default_words = {} - ref_file = sys.argv[1] - hyp_file = sys.argv[2] - rec_set = {} - if split and not case_sensitive: - newsplit = dict() - for w in split: - words = split[w] - for i in range(len(words)): - words[i] = words[i].upper() - newsplit[w.upper()] = words - split = newsplit + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit - with codecs.open(hyp_file, 'r', 'utf-8') as fh: - for line in fh: - if tochar: - array = characterize(line) - else: - array = line.strip().split() - if len(array) == 0: - continue - fid = array[0] - rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, - split) - - # compute error rate on the interaction of reference file and hyp file - for line in open(ref_file, 'r', encoding='utf-8'): + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: if tochar: array = characterize(line) else: - array = line.rstrip('\n').split() - if len(array) == 0: - continue + array = line.strip().split() + if len(array)==0: continue fid = array[0] - if fid not in rec_set: - continue - lab = normalize(array[1:], ignore_words, case_sensitive, split) - rec = rec_set[fid] - if verbose: - print('\nutt: %s' % fid) + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) - for word in rec + lab: - if word not in default_words: - default_cluster_name = default_cluster(word) - if default_cluster_name not in default_clusters: - default_clusters[default_cluster_name] = {} - if word not in default_clusters[default_cluster_name]: - default_clusters[default_cluster_name][word] = 1 - default_words[word] = default_cluster_name + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8') : + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array)==0: continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) - result = calculator.calculate(lab, rec) - if verbose: - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('WER: %4.2f %%' % wer, end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], - result['ins'])) - space = {} - space['lab'] = [] - space['rec'] = [] - for idx in range(len(result['lab'])): - len_lab = width(result['lab'][idx]) - len_rec = width(result['rec'][idx]) - length = max(len_lab, len_rec) - space['lab'].append(length - len_lab) - space['rec'].append(length - len_rec) - upper_lab = len(result['lab']) - upper_rec = len(result['rec']) - lab1, rec1 = 0, 0 - while lab1 < upper_lab or rec1 < upper_rec: - if verbose > 1: - print('lab(%s):' % fid.encode('utf-8'), end=' ') - else: - print('lab:', end=' ') - lab2 = min(upper_lab, lab1 + max_words_per_line) - for idx in range(lab1, lab2): - token = result['lab'][idx] - print('{token}'.format(token=token), end='') - for n in range(space['lab'][idx]): - print(padding_symbol, end='') - print(' ', end='') - print() - if verbose > 1: - print('rec(%s):' % fid.encode('utf-8'), end=' ') - else: - print('rec:', end=' ') - rec2 = min(upper_rec, rec1 + max_words_per_line) - for idx in range(rec1, rec2): - token = result['rec'][idx] - print('{token}'.format(token=token), end='') - for n in range(space['rec'][idx]): - print(padding_symbol, end='') - print(' ', end='') - print('\n', end='\n') - lab1 = lab2 - rec1 = rec2 + for word in rec + lab : + if word not in default_words : + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters : + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name] : + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + result = calculator.calculate(lab, rec) if verbose: - print( - '===========================================================================' - ) - print() - - result = calculator.overall() - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : wer = 0.0 - print('Overall -> %4.2f %%' % wer, end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], - result['ins'])) - if not verbose: - print() + print('WER: %4.2f %%' % wer, end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])) : + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length-len_lab) + space['rec'].append(length-len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end = ' ') + else: + print('lab:', end = ' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token = token), end = '') + for n in range(space['lab'][idx]) : + print(padding_symbol, end = '') + print(' ',end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end = ' ') + else: + print('rec:', end = ' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token = token), end = '') + for n in range(space['rec'][idx]) : + print(padding_symbol, end = '') + print(' ',end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 - if verbose: - for cluster_id in default_clusters: - result = calculator.cluster( - [k for k in default_clusters[cluster_id]]) - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], - result['ins'])) - if len(cluster_file) > 0: # compute separated WERs for word clusters - cluster_id = '' - cluster = [] - for line in open(cluster_file, 'r', encoding='utf-8'): - for token in line.decode('utf-8').rstrip('\n').split(): - # end of cluster reached, like - if token[0:2] == '' and \ - token.lstrip('') == cluster_id : - result = calculator.cluster(cluster) - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], - result['del'], result['ins'])) - cluster_id = '' - cluster = [] - # begin of cluster reached, like - elif token[0] == '<' and token[len(token) - 1] == '>' and \ - cluster_id == '' : - cluster_id = token.lstrip('<').rstrip('>') - cluster = [] - # general terms, like WEATHER / CAR / ... - else: - cluster.append(token) - print() - print( - '===========================================================================' - ) + if verbose: + print('===========================================================================') + print() + + result = calculator.overall() + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('Overall -> %4.2f %%' % wer, end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters : + result = calculator.cluster([ k for k in default_clusters[cluster_id] ]) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + if len(cluster_file) > 0 : # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8') : + for token in line.decode('utf-8').rstrip('\n').split() : + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0 : + wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] + else : + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token)-1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else : + cluster.append(token) + print() + print('===========================================================================') \ No newline at end of file diff --git a/utils/format_rsl.py b/utils/format_rsl.py new file mode 100644 index 00000000..d5bc0017 --- /dev/null +++ b/utils/format_rsl.py @@ -0,0 +1,90 @@ +import os +import argparse +import jsonlines + + +def trans_hyp(origin_hyp, + trans_hyp = None, + trans_hyp_sclite = None): + """ + Args: + origin_hyp: The input json file which contains the model output + trans_hyp: The output file for caculate CER/WER + trans_hyp_sclite: The output file for caculate CER/WER using sclite + """ + input_dict = {} + + with open(origin_hyp, "r+", encoding="utf8") as f: + for item in jsonlines.Reader(f): + input_dict[item["utt"]] = item["hyps"][0] + if trans_hyp is not None: + with open(trans_hyp, "w+", encoding="utf8") as f: + for key in input_dict.keys(): + f.write(key + " " + input_dict[key] + "\n") + if trans_hyp_sclite is not None: + with open(trans_hyp_sclite, "w+") as f: + for key in input_dict.keys(): + line = input_dict[key] + "(" + key + ".wav" +")" + "\n" + f.write(line) + +def trans_ref(origin_ref, + trans_ref = None, + trans_ref_sclite = None): + """ + Args: + origin_hyp: The input json file which contains the model output + trans_hyp: The output file for caculate CER/WER + trans_hyp_sclite: The output file for caculate CER/WER using sclite + """ + input_dict = {} + + with open(origin_ref, "r", encoding="utf8") as f: + for item in jsonlines.Reader(f): + input_dict[item["utt"]] = item["text"] + if trans_ref is not None: + with open(trans_ref, "w", encoding="utf8") as f: + for key in input_dict.keys(): + f.write(key + " " + input_dict[key] + "\n") + + if trans_ref_sclite is not None: + with open(trans_ref_sclite, "w") as f: + for key in input_dict.keys(): + line = input_dict[key] + "(" + key + ".wav" +")" + "\n" + f.write(line) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog='format hyp file for compute CER/WER', add_help=True) + parser.add_argument( + '--origin_hyp', + type=str, + default = None, + help='origin hyp file') + parser.add_argument( + '--trans_hyp', type=str, default = None, help='hyp file for caculating CER/WER') + parser.add_argument( + '--trans_hyp_sclite', type=str, default = None, help='hyp file for caculating CER/WER by sclite') + + parser.add_argument( + '--origin_ref', + type=str, + default = None, + help='origin ref file') + parser.add_argument( + '--trans_ref', type=str, default = None, help='ref file for caculating CER/WER') + parser.add_argument( + '--trans_ref_sclite', type=str, default = None, help='ref file for caculating CER/WER by sclite') + parser_args = parser.parse_args() + + if parser_args.origin_hyp is not None: + trans_hyp( + origin_hyp = parser_args.origin_hyp, + trans_hyp = parser_args.trans_hyp, + trans_hyp_sclite = parser_args.trans_hyp_sclite, ) + + if parser_args.origin_ref is not None: + trans_ref( + origin_ref = parser_args.origin_ref, + trans_ref = parser_args.trans_ref, + trans_ref_sclite = parser_args.trans_ref_sclite, )