diff --git a/examples/aishell/asr1/path.sh b/examples/aishell/asr1/path.sh index c6eed668e..449829109 100644 --- a/examples/aishell/asr1/path.sh +++ b/examples/aishell/asr1/path.sh @@ -27,5 +27,3 @@ export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!" [ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh - -unset GREP_OPTIONS diff --git a/paddlespeech/dataset/s2t/__init__.py b/paddlespeech/dataset/s2t/__init__.py index 3f546855e..27ea9e778 100644 --- a/paddlespeech/dataset/s2t/__init__.py +++ b/paddlespeech/dataset/s2t/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # s2t utils binaries. +from .avg_model import main as avg_ckpts_main from .build_vocab import main as build_vocab_main from .compute_mean_std import main as compute_mean_std_main +from .compute_wer import main as compute_wer_main from .format_data import main as format_data_main +from .format_rsl import main as format_rsl_main diff --git a/paddlespeech/dataset/s2t/avg_model.py b/paddlespeech/dataset/s2t/avg_model.py new file mode 100755 index 000000000..99111ccc1 --- /dev/null +++ b/paddlespeech/dataset/s2t/avg_model.py @@ -0,0 +1,127 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob +import json +import os + +import numpy as np +import paddle + + +def define_argparse(): + parser = argparse.ArgumentParser(description='average model') + parser.add_argument('--dst_model', required=True, help='averaged model') + parser.add_argument( + '--ckpt_dir', required=True, help='ckpt model dir for average') + parser.add_argument( + '--val_best', action="store_true", help='averaged model') + parser.add_argument( + '--num', default=5, type=int, help='nums for averaged model') + parser.add_argument( + '--min_epoch', + default=0, + type=int, + help='min epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') + + args = parser.parse_args() + return args + + +def average_checkpoints(dst_model="", + ckpt_dir="", + val_best=True, + num=5, + min_epoch=0, + max_epoch=65536): + paddle.set_device('cpu') + + val_scores = [] + beat_val_scores = None + selected_epochs = None + + jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') + jsons = sorted(jsons, key=os.path.getmtime, reverse=True) + for y in jsons: + with open(y, 'r') as f: + dic_json = json.load(f) + loss = dic_json['val_loss'] + epoch = dic_json['epoch'] + if epoch >= args.min_epoch and epoch <= args.max_epoch: + val_scores.append((epoch, loss)) + val_scores = np.array(val_scores) + + if args.val_best: + sort_idx = np.argsort(val_scores[:, 1]) + sorted_val_scores = val_scores[sort_idx] + else: + sorted_val_scores = val_scores + + beat_val_scores = sorted_val_scores[:args.num, 1] + selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) + avg_val_score = np.mean(beat_val_scores) + print("selected val scores = " + str(beat_val_scores)) + print("selected epochs = " + str(selected_epochs)) + print("averaged val score = " + str(avg_val_score)) + + path_list = [ + args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) + for epoch in sorted_val_scores[:args.num, 0] + ] + print(path_list) + + avg = None + num = args.num + assert num == len(path_list) + for path in path_list: + print(f'Processing {path}') + states = paddle.load(path) + if avg is None: + avg = states + else: + for k in avg.keys(): + avg[k] += states[k] + # average + for k in avg.keys(): + if avg[k] is not None: + avg[k] /= num + + paddle.save(avg, args.dst_model) + print(f'Saving to {args.dst_model}') + + meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' + with open(meta_path, 'w') as f: + data = json.dumps({ + "mode": 'val_best' if args.val_best else 'latest', + "avg_ckpt": args.dst_model, + "val_loss_mean": avg_val_score, + "ckpts": path_list, + "epochs": selected_epochs.tolist(), + "val_losses": beat_val_scores.tolist(), + }) + f.write(data + "\n") + + +def main(): + args = define_argparse() + average_checkpoints(args) + + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/compute_wer.py b/paddlespeech/dataset/s2t/compute_wer.py new file mode 100755 index 000000000..5711c725b --- /dev/null +++ b/paddlespeech/dataset/s2t/compute_wer.py @@ -0,0 +1,558 @@ +# Copyright 2021 Mobvoi Inc. All Rights Reserved. +# flake8: noqa +import codecs +import re +import sys +import unicodedata + +remove_tag = True +spacelist = [' ', '\t', '\r', '\n'] +puncts = [ + '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', + '《', '》' +] + + +def characterize(string): + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': sep = '>' + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res + + +def stripoff_tags(x): + if not x: return '' + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """ sentence, ignore_words are both in unicode + """ + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + else: + new_sentence.append(x) + return new_sentence + + +class Calculator: + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec): + row.append({'dist': 0, 'error': 'non'}) + for i in range(len(lab)): + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)): + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = { + 'lab': [], + 'rec': [], + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]['error'] == 'cor': # correct + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub': # substitution + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del': # deletion + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins': # insertion + if len(rec[j]) > 0: + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non': # starting point + break + else: # shouldn't reach here + print( + 'this should not happen , i = {i} , j = {j} , error = {error}'. + format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data: + if token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self): + return list(self.data.keys()) + + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + + +def default_cluster(word): + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))): + if unicode_names[i].startswith('DIGIT'): # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')): + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')): + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else: + return 'Other' + if len(unicode_names) == 0: + return 'Other' + if len(unicode_names) == 1: + return unicode_names[0] + for i in range(len(unicode_names) - 1): + if unicode_names[i] != unicode_names[i + 1]: + return 'Other' + return unicode_names[0] + + +def usage(): + print( + "compute-wer.py : compute word error rate (WER) and align recognition results and references." + ) + print( + " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" + ) + + +def main(): + # python utils/compute-wer.py --char=1 --v=1 ref hyp > rsl.error + if len(sys.argv) == 1: + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose = 1 + padding_symbol = ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose = 0 + try: + verbose = int(b) + except: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol = ' ' + elif b == 'underline': + padding_symbol = '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue + + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, + split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8'): + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array) == 0: continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) + + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('WER: %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])): + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + + if verbose: + print( + '===========================================================================' + ) + print() + + result = calculator.overall() + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('Overall -> %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters: + result = calculator.cluster( + [k for k in default_clusters[cluster_id]]) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if len(cluster_file) > 0: # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8'): + for token in line.decode('utf-8').rstrip('\n').split(): + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token)-1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else: + cluster.append(token) + print() + print( + '===========================================================================' + ) + + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/format_data.py b/paddlespeech/dataset/s2t/format_data.py index fae717dc5..dcff66eac 100755 --- a/paddlespeech/dataset/s2t/format_data.py +++ b/paddlespeech/dataset/s2t/format_data.py @@ -1,5 +1,4 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/paddlespeech/dataset/s2t/format_rsl.py b/paddlespeech/dataset/s2t/format_rsl.py new file mode 100644 index 000000000..640a72021 --- /dev/null +++ b/paddlespeech/dataset/s2t/format_rsl.py @@ -0,0 +1,127 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +format ref/hyp file for `utt text` format to compute CER/WER/MER +""" +import argparse + +import jsonlines + + +def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None): + """ + Args: + origin_hyp: The input json file which contains the model output + trans_hyp: The output file for caculate CER/WER + trans_hyp_sclite: The output file for caculate CER/WER using sclite + """ + input_dict = {} + + with open(origin_hyp, "r+", encoding="utf8") as f: + for item in jsonlines.Reader(f): + input_dict[item["utt"]] = item["hyps"][0] + if trans_hyp is not None: + with open(trans_hyp, "w+", encoding="utf8") as f: + for key in input_dict.keys(): + f.write(key + " " + input_dict[key] + "\n") + if trans_hyp_sclite is not None: + with open(trans_hyp_sclite, "w+") as f: + for key in input_dict.keys(): + line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" + f.write(line) + + +def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None): + """ + Args: + origin_hyp: The input json file which contains the model output + trans_hyp: The output file for caculate CER/WER + trans_hyp_sclite: The output file for caculate CER/WER using sclite + """ + input_dict = {} + + with open(origin_ref, "r", encoding="utf8") as f: + for item in jsonlines.Reader(f): + input_dict[item["utt"]] = item["text"] + if trans_ref is not None: + with open(trans_ref, "w", encoding="utf8") as f: + for key in input_dict.keys(): + f.write(key + " " + input_dict[key] + "\n") + + if trans_ref_sclite is not None: + with open(trans_ref_sclite, "w") as f: + for key in input_dict.keys(): + line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" + f.write(line) + + +def define_argparse(): + parser = argparse.ArgumentParser( + prog='format ref/hyp file for compute CER/WER', add_help=True) + parser.add_argument( + '--origin_hyp', type=str, default=None, help='origin hyp file') + parser.add_argument( + '--trans_hyp', + type=str, + default=None, + help='hyp file for caculating CER/WER') + parser.add_argument( + '--trans_hyp_sclite', + type=str, + default=None, + help='hyp file for caculating CER/WER by sclite') + + parser.add_argument( + '--origin_ref', type=str, default=None, help='origin ref file') + parser.add_argument( + '--trans_ref', + type=str, + default=None, + help='ref file for caculating CER/WER') + parser.add_argument( + '--trans_ref_sclite', + type=str, + default=None, + help='ref file for caculating CER/WER by sclite') + parser_args = parser.parse_args() + return parser_args + + +def format_result(origin_hyp=None, + trans_hyp=None, + trans_hyp_sclite=None, + origin_ref=None, + trans_ref=None, + trans_ref_sclite=None): + + if origin_hyp is not None: + trans_hyp( + origin_hyp=origin_hyp, + trans_hyp=trans_hyp, + trans_hyp_sclite=trans_hyp_sclite, ) + + if origin_ref is not None: + trans_ref( + origin_ref=origin_ref, + trans_ref=trans_ref, + trans_ref_sclite=trans_ref_sclite, ) + + +def main(): + args = define_argparse() + format_result(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/export.py b/paddlespeech/s2t/exps/deepspeech2/bin/export.py index 8acd46dfc..07228e98b 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test.py b/paddlespeech/s2t/exps/deepspeech2/bin/test.py index 030168a9a..a8e20ff93 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py index d7a9402b9..1e07aa800 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py index 66ea29d08..32a583b6a 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py @@ -27,8 +27,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils import mp_tools from paddlespeech.s2t.utils.checkpoint import Checkpoint from paddlespeech.s2t.utils.log import Log -from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.utils.argparse import print_arguments logger = Log(__name__).getlog() diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/train.py b/paddlespeech/s2t/exps/deepspeech2/bin/train.py index 2c9942f9b..1340aaa3b 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/train.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/train.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/alignment.py b/paddlespeech/s2t/exps/u2/bin/alignment.py index e3390feb1..cc2940388 100644 --- a/paddlespeech/s2t/exps/u2/bin/alignment.py +++ b/paddlespeech/s2t/exps/u2/bin/alignment.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/export.py b/paddlespeech/s2t/exps/u2/bin/export.py index 592b12379..4725e5e13 100644 --- a/paddlespeech/s2t/exps/u2/bin/export.py +++ b/paddlespeech/s2t/exps/u2/bin/export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/test.py b/paddlespeech/s2t/exps/u2/bin/test.py index b13fd0d3f..43eeff631 100644 --- a/paddlespeech/s2t/exps/u2/bin/test.py +++ b/paddlespeech/s2t/exps/u2/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/train.py b/paddlespeech/s2t/exps/u2/bin/train.py index dc3a87c16..a0f503288 100644 --- a/paddlespeech/s2t/exps/u2/bin/train.py +++ b/paddlespeech/s2t/exps/u2/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments # from paddlespeech.s2t.exps.u2.trainer import U2Trainer as Trainer diff --git a/paddlespeech/s2t/exps/u2_kaldi/bin/test.py b/paddlespeech/s2t/exps/u2_kaldi/bin/test.py index 422483b97..4137537e9 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/bin/test.py +++ b/paddlespeech/s2t/exps/u2_kaldi/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments model_test_alias = { "u2": "paddlespeech.s2t.exps.u2.model:U2Tester", diff --git a/paddlespeech/s2t/exps/u2_kaldi/bin/train.py b/paddlespeech/s2t/exps/u2_kaldi/bin/train.py index b11da7154..011aabac4 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/bin/train.py +++ b/paddlespeech/s2t/exps/u2_kaldi/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments model_train_alias = { "u2": "paddlespeech.s2t.exps.u2.model:U2Trainer", diff --git a/paddlespeech/s2t/exps/u2_st/bin/export.py b/paddlespeech/s2t/exps/u2_st/bin/export.py index c641152fe..a2a7424c1 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/export.py +++ b/paddlespeech/s2t/exps/u2_st/bin/export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2_st/bin/test.py b/paddlespeech/s2t/exps/u2_st/bin/test.py index c07c95bd5..30a903ceb 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/test.py +++ b/paddlespeech/s2t/exps/u2_st/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2_st/bin/train.py b/paddlespeech/s2t/exps/u2_st/bin/train.py index 574942e5a..b36a0af4d 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/train.py +++ b/paddlespeech/s2t/exps/u2_st/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2_st.model import U2STTrainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/test.py b/paddlespeech/s2t/exps/wav2vec2/bin/test.py index a376651df..c17cee0fd 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/test.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/train.py b/paddlespeech/s2t/exps/wav2vec2/bin/train.py index 29e7ef552..0c37f796a 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/train.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTrainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/utils/avg_model.py b/utils/avg_model.py index 6ee16408d..039ea6267 100755 --- a/utils/avg_model.py +++ b/utils/avg_model.py @@ -12,105 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse -import glob -import json -import os - -import numpy as np -import paddle - - -def main(args): - paddle.set_device('cpu') - - val_scores = [] - beat_val_scores = None - selected_epochs = None - - jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') - jsons = sorted(jsons, key=os.path.getmtime, reverse=True) - for y in jsons: - with open(y, 'r') as f: - dic_json = json.load(f) - loss = dic_json['val_loss'] - epoch = dic_json['epoch'] - if epoch >= args.min_epoch and epoch <= args.max_epoch: - val_scores.append((epoch, loss)) - val_scores = np.array(val_scores) - - if args.val_best: - sort_idx = np.argsort(val_scores[:, 1]) - sorted_val_scores = val_scores[sort_idx] - else: - sorted_val_scores = val_scores - - beat_val_scores = sorted_val_scores[:args.num, 1] - selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) - avg_val_score = np.mean(beat_val_scores) - print("selected val scores = " + str(beat_val_scores)) - print("selected epochs = " + str(selected_epochs)) - print("averaged val score = " + str(avg_val_score)) - - path_list = [ - args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) - for epoch in sorted_val_scores[:args.num, 0] - ] - print(path_list) - - avg = None - num = args.num - assert num == len(path_list) - for path in path_list: - print(f'Processing {path}') - states = paddle.load(path) - if avg is None: - avg = states - else: - for k in avg.keys(): - avg[k] += states[k] - # average - for k in avg.keys(): - if avg[k] is not None: - avg[k] /= num - - paddle.save(avg, args.dst_model) - print(f'Saving to {args.dst_model}') - - meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' - with open(meta_path, 'w') as f: - data = json.dumps({ - "mode": 'val_best' if args.val_best else 'latest', - "avg_ckpt": args.dst_model, - "val_loss_mean": avg_val_score, - "ckpts": path_list, - "epochs": selected_epochs.tolist(), - "val_losses": beat_val_scores.tolist(), - }) - f.write(data + "\n") - +from paddlespeech.dataset.s2t import avg_ckpts_main if __name__ == '__main__': - parser = argparse.ArgumentParser(description='average model') - parser.add_argument('--dst_model', required=True, help='averaged model') - parser.add_argument( - '--ckpt_dir', required=True, help='ckpt model dir for average') - parser.add_argument( - '--val_best', action="store_true", help='averaged model') - parser.add_argument( - '--num', default=5, type=int, help='nums for averaged model') - parser.add_argument( - '--min_epoch', - default=0, - type=int, - help='min epoch used for averaging model') - parser.add_argument( - '--max_epoch', - default=65536, # Big enough - type=int, - help='max epoch used for averaging model') - - args = parser.parse_args() - print(args) - - main(args) + avg_ckpts_main() diff --git a/utils/compute-wer.py b/utils/compute-wer.py index 98bb24a7e..1fa77216d 100755 --- a/utils/compute-wer.py +++ b/utils/compute-wer.py @@ -1,554 +1,5 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- # Copyright 2021 Mobvoi Inc. All Rights Reserved. -import codecs -import re -import sys -import unicodedata - -remove_tag = True -spacelist = [' ', '\t', '\r', '\n'] -puncts = [ - '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', - '《', '》' -] - - -def characterize(string): - res = [] - i = 0 - while i < len(string): - char = string[i] - if char in puncts: - i += 1 - continue - cat1 = unicodedata.category(char) - #https://unicodebook.readthedocs.io/unicode.html#unicode-categories - if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned - i += 1 - continue - if cat1 == 'Lo': # letter-other - res.append(char) - i += 1 - else: - # some input looks like: , we want to separate it to two words. - sep = ' ' - if char == '<': sep = '>' - j = i + 1 - while j < len(string): - c = string[j] - if ord(c) >= 128 or (c in spacelist) or (c == sep): - break - j += 1 - if j < len(string) and string[j] == '>': - j += 1 - res.append(string[i:j]) - i = j - return res - - -def stripoff_tags(x): - if not x: return '' - chars = [] - i = 0 - T = len(x) - while i < T: - if x[i] == '<': - while i < T and x[i] != '>': - i += 1 - i += 1 - else: - chars.append(x[i]) - i += 1 - return ''.join(chars) - - -def normalize(sentence, ignore_words, cs, split=None): - """ sentence, ignore_words are both in unicode - """ - new_sentence = [] - for token in sentence: - x = token - if not cs: - x = x.upper() - if x in ignore_words: - continue - if remove_tag: - x = stripoff_tags(x) - if not x: - continue - if split and x in split: - new_sentence += split[x] - else: - new_sentence.append(x) - return new_sentence - - -class Calculator: - def __init__(self): - self.data = {} - self.space = [] - self.cost = {} - self.cost['cor'] = 0 - self.cost['sub'] = 1 - self.cost['del'] = 1 - self.cost['ins'] = 1 - - def calculate(self, lab, rec): - # Initialization - lab.insert(0, '') - rec.insert(0, '') - while len(self.space) < len(lab): - self.space.append([]) - for row in self.space: - for element in row: - element['dist'] = 0 - element['error'] = 'non' - while len(row) < len(rec): - row.append({'dist': 0, 'error': 'non'}) - for i in range(len(lab)): - self.space[i][0]['dist'] = i - self.space[i][0]['error'] = 'del' - for j in range(len(rec)): - self.space[0][j]['dist'] = j - self.space[0][j]['error'] = 'ins' - self.space[0][0]['error'] = 'non' - for token in lab: - if token not in self.data and len(token) > 0: - self.data[token] = { - 'all': 0, - 'cor': 0, - 'sub': 0, - 'ins': 0, - 'del': 0 - } - for token in rec: - if token not in self.data and len(token) > 0: - self.data[token] = { - 'all': 0, - 'cor': 0, - 'sub': 0, - 'ins': 0, - 'del': 0 - } - # Computing edit distance - for i, lab_token in enumerate(lab): - for j, rec_token in enumerate(rec): - if i == 0 or j == 0: - continue - min_dist = sys.maxsize - min_error = 'none' - dist = self.space[i - 1][j]['dist'] + self.cost['del'] - error = 'del' - if dist < min_dist: - min_dist = dist - min_error = error - dist = self.space[i][j - 1]['dist'] + self.cost['ins'] - error = 'ins' - if dist < min_dist: - min_dist = dist - min_error = error - if lab_token == rec_token: - dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] - error = 'cor' - else: - dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] - error = 'sub' - if dist < min_dist: - min_dist = dist - min_error = error - self.space[i][j]['dist'] = min_dist - self.space[i][j]['error'] = min_error - # Tracing back - result = { - 'lab': [], - 'rec': [], - 'all': 0, - 'cor': 0, - 'sub': 0, - 'ins': 0, - 'del': 0 - } - i = len(lab) - 1 - j = len(rec) - 1 - while True: - if self.space[i][j]['error'] == 'cor': # correct - if len(lab[i]) > 0: - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 - result['all'] = result['all'] + 1 - result['cor'] = result['cor'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'sub': # substitution - if len(lab[i]) > 0: - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 - result['all'] = result['all'] + 1 - result['sub'] = result['sub'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'del': # deletion - if len(lab[i]) > 0: - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 - result['all'] = result['all'] + 1 - result['del'] = result['del'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, "") - i = i - 1 - elif self.space[i][j]['error'] == 'ins': # insertion - if len(rec[j]) > 0: - self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 - result['ins'] = result['ins'] + 1 - result['lab'].insert(0, "") - result['rec'].insert(0, rec[j]) - j = j - 1 - elif self.space[i][j]['error'] == 'non': # starting point - break - else: # shouldn't reach here - print( - 'this should not happen , i = {i} , j = {j} , error = {error}'. - format(i=i, j=j, error=self.space[i][j]['error'])) - return result - - def overall(self): - result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} - for token in self.data: - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - - def cluster(self, data): - result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} - for token in data: - if token in self.data: - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - - def keys(self): - return list(self.data.keys()) - - -def width(string): - return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) - - -def default_cluster(word): - unicode_names = [unicodedata.name(char) for char in word] - for i in reversed(range(len(unicode_names))): - if unicode_names[i].startswith('DIGIT'): # 1 - unicode_names[i] = 'Number' # 'DIGIT' - elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or - unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): - # 明 / 郎 - unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' - elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or - unicode_names[i].startswith('LATIN SMALL LETTER')): - # A / a - unicode_names[i] = 'English' # 'LATIN LETTER' - elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め - unicode_names[i] = 'Japanese' # 'GANA LETTER' - elif (unicode_names[i].startswith('AMPERSAND') or - unicode_names[i].startswith('APOSTROPHE') or - unicode_names[i].startswith('COMMERCIAL AT') or - unicode_names[i].startswith('DEGREE CELSIUS') or - unicode_names[i].startswith('EQUALS SIGN') or - unicode_names[i].startswith('FULL STOP') or - unicode_names[i].startswith('HYPHEN-MINUS') or - unicode_names[i].startswith('LOW LINE') or - unicode_names[i].startswith('NUMBER SIGN') or - unicode_names[i].startswith('PLUS SIGN') or - unicode_names[i].startswith('SEMICOLON')): - # & / ' / @ / ℃ / = / . / - / _ / # / + / ; - del unicode_names[i] - else: - return 'Other' - if len(unicode_names) == 0: - return 'Other' - if len(unicode_names) == 1: - return unicode_names[0] - for i in range(len(unicode_names) - 1): - if unicode_names[i] != unicode_names[i + 1]: - return 'Other' - return unicode_names[0] - - -def usage(): - print( - "compute-wer.py : compute word error rate (WER) and align recognition results and references." - ) - print( - " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" - ) - +from paddlespeech.dataset.s2t import compute_wer_main if __name__ == '__main__': - if len(sys.argv) == 1: - usage() - sys.exit(0) - calculator = Calculator() - cluster_file = '' - ignore_words = set() - tochar = False - verbose = 1 - padding_symbol = ' ' - case_sensitive = False - max_words_per_line = sys.maxsize - split = None - while len(sys.argv) > 3: - a = '--maxw=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):] - del sys.argv[1] - max_words_per_line = int(b) - continue - a = '--rt=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - remove_tag = (b == 'true') or (b != '0') - continue - a = '--cs=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - case_sensitive = (b == 'true') or (b != '0') - continue - a = '--cluster=' - if sys.argv[1].startswith(a): - cluster_file = sys.argv[1][len(a):] - del sys.argv[1] - continue - a = '--splitfile=' - if sys.argv[1].startswith(a): - split_file = sys.argv[1][len(a):] - del sys.argv[1] - split = dict() - with codecs.open(split_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - words = line.strip().split() - if len(words) >= 2: - split[words[0]] = words[1:] - continue - a = '--ig=' - if sys.argv[1].startswith(a): - ignore_file = sys.argv[1][len(a):] - del sys.argv[1] - with codecs.open(ignore_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - line = line.strip() - if len(line) > 0: - ignore_words.add(line) - continue - a = '--char=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - tochar = (b == 'true') or (b != '0') - continue - a = '--v=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - verbose = 0 - try: - verbose = int(b) - except: - if b == 'true' or b != '0': - verbose = 1 - continue - a = '--padding-symbol=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - if b == 'space': - padding_symbol = ' ' - elif b == 'underline': - padding_symbol = '_' - continue - if True or sys.argv[1].startswith('-'): - #ignore invalid switch - del sys.argv[1] - continue - - if not case_sensitive: - ig = set([w.upper() for w in ignore_words]) - ignore_words = ig - - default_clusters = {} - default_words = {} - - ref_file = sys.argv[1] - hyp_file = sys.argv[2] - rec_set = {} - if split and not case_sensitive: - newsplit = dict() - for w in split: - words = split[w] - for i in range(len(words)): - words[i] = words[i].upper() - newsplit[w.upper()] = words - split = newsplit - - with codecs.open(hyp_file, 'r', 'utf-8') as fh: - for line in fh: - if tochar: - array = characterize(line) - else: - array = line.strip().split() - if len(array) == 0: continue - fid = array[0] - rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, - split) - - # compute error rate on the interaction of reference file and hyp file - for line in open(ref_file, 'r', encoding='utf-8'): - if tochar: - array = characterize(line) - else: - array = line.rstrip('\n').split() - if len(array) == 0: continue - fid = array[0] - if fid not in rec_set: - continue - lab = normalize(array[1:], ignore_words, case_sensitive, split) - rec = rec_set[fid] - if verbose: - print('\nutt: %s' % fid) - - for word in rec + lab: - if word not in default_words: - default_cluster_name = default_cluster(word) - if default_cluster_name not in default_clusters: - default_clusters[default_cluster_name] = {} - if word not in default_clusters[default_cluster_name]: - default_clusters[default_cluster_name][word] = 1 - default_words[word] = default_cluster_name - - result = calculator.calculate(lab, rec) - if verbose: - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('WER: %4.2f %%' % wer, end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], - result['ins'])) - space = {} - space['lab'] = [] - space['rec'] = [] - for idx in range(len(result['lab'])): - len_lab = width(result['lab'][idx]) - len_rec = width(result['rec'][idx]) - length = max(len_lab, len_rec) - space['lab'].append(length - len_lab) - space['rec'].append(length - len_rec) - upper_lab = len(result['lab']) - upper_rec = len(result['rec']) - lab1, rec1 = 0, 0 - while lab1 < upper_lab or rec1 < upper_rec: - if verbose > 1: - print('lab(%s):' % fid.encode('utf-8'), end=' ') - else: - print('lab:', end=' ') - lab2 = min(upper_lab, lab1 + max_words_per_line) - for idx in range(lab1, lab2): - token = result['lab'][idx] - print('{token}'.format(token=token), end='') - for n in range(space['lab'][idx]): - print(padding_symbol, end='') - print(' ', end='') - print() - if verbose > 1: - print('rec(%s):' % fid.encode('utf-8'), end=' ') - else: - print('rec:', end=' ') - rec2 = min(upper_rec, rec1 + max_words_per_line) - for idx in range(rec1, rec2): - token = result['rec'][idx] - print('{token}'.format(token=token), end='') - for n in range(space['rec'][idx]): - print(padding_symbol, end='') - print(' ', end='') - print('\n', end='\n') - lab1 = lab2 - rec1 = rec2 - - if verbose: - print( - '===========================================================================' - ) - print() - - result = calculator.overall() - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('Overall -> %4.2f %%' % wer, end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], - result['ins'])) - if not verbose: - print() - - if verbose: - for cluster_id in default_clusters: - result = calculator.cluster( - [k for k in default_clusters[cluster_id]]) - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], - result['ins'])) - if len(cluster_file) > 0: # compute separated WERs for word clusters - cluster_id = '' - cluster = [] - for line in open(cluster_file, 'r', encoding='utf-8'): - for token in line.decode('utf-8').rstrip('\n').split(): - # end of cluster reached, like - if token[0:2] == '' and \ - token.lstrip('') == cluster_id : - result = calculator.cluster(cluster) - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], - result['del'], result['ins'])) - cluster_id = '' - cluster = [] - # begin of cluster reached, like - elif token[0] == '<' and token[len(token)-1] == '>' and \ - cluster_id == '' : - cluster_id = token.lstrip('<').rstrip('>') - cluster = [] - # general terms, like WEATHER / CAR / ... - else: - cluster.append(token) - print() - print( - '===========================================================================' - ) + compute_wer_main() diff --git a/utils/format_rsl.py b/utils/format_rsl.py index 8230416c4..a6845a671 100644 --- a/utils/format_rsl.py +++ b/utils/format_rsl.py @@ -11,96 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse +from paddlespeech.dataset.s2t import format_rsl_main -import jsonlines - - -def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None): - """ - Args: - origin_hyp: The input json file which contains the model output - trans_hyp: The output file for caculate CER/WER - trans_hyp_sclite: The output file for caculate CER/WER using sclite - """ - input_dict = {} - - with open(origin_hyp, "r+", encoding="utf8") as f: - for item in jsonlines.Reader(f): - input_dict[item["utt"]] = item["hyps"][0] - if trans_hyp is not None: - with open(trans_hyp, "w+", encoding="utf8") as f: - for key in input_dict.keys(): - f.write(key + " " + input_dict[key] + "\n") - if trans_hyp_sclite is not None: - with open(trans_hyp_sclite, "w+") as f: - for key in input_dict.keys(): - line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" - f.write(line) - - -def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None): - """ - Args: - origin_hyp: The input json file which contains the model output - trans_hyp: The output file for caculate CER/WER - trans_hyp_sclite: The output file for caculate CER/WER using sclite - """ - input_dict = {} - - with open(origin_ref, "r", encoding="utf8") as f: - for item in jsonlines.Reader(f): - input_dict[item["utt"]] = item["text"] - if trans_ref is not None: - with open(trans_ref, "w", encoding="utf8") as f: - for key in input_dict.keys(): - f.write(key + " " + input_dict[key] + "\n") - - if trans_ref_sclite is not None: - with open(trans_ref_sclite, "w") as f: - for key in input_dict.keys(): - line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" - f.write(line) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog='format hyp file for compute CER/WER', add_help=True) - parser.add_argument( - '--origin_hyp', type=str, default=None, help='origin hyp file') - parser.add_argument( - '--trans_hyp', - type=str, - default=None, - help='hyp file for caculating CER/WER') - parser.add_argument( - '--trans_hyp_sclite', - type=str, - default=None, - help='hyp file for caculating CER/WER by sclite') - - parser.add_argument( - '--origin_ref', type=str, default=None, help='origin ref file') - parser.add_argument( - '--trans_ref', - type=str, - default=None, - help='ref file for caculating CER/WER') - parser.add_argument( - '--trans_ref_sclite', - type=str, - default=None, - help='ref file for caculating CER/WER by sclite') - parser_args = parser.parse_args() - - if parser_args.origin_hyp is not None: - trans_hyp( - origin_hyp=parser_args.origin_hyp, - trans_hyp=parser_args.trans_hyp, - trans_hyp_sclite=parser_args.trans_hyp_sclite, ) - - if parser_args.origin_ref is not None: - trans_ref( - origin_ref=parser_args.origin_ref, - trans_ref=parser_args.trans_ref, - trans_ref_sclite=parser_args.trans_ref_sclite, ) +if __name__ == '__main__': + format_rsl_main()