diff --git a/examples/aishell/asr1/local/test.sh b/examples/aishell/asr1/local/test.sh index 26926b4a..8487e990 100755 --- a/examples/aishell/asr1/local/test.sh +++ b/examples/aishell/asr1/local/test.sh @@ -1,15 +1,21 @@ #!/bin/bash -if [ $# != 3 ];then - echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" - exit -1 -fi +set -e stage=0 stop_stage=100 + +source utils/parse_options.sh || exit 1; + ngpu=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') echo "using $ngpu gpus..." + +if [ $# != 3 ];then + echo "usage: ${0} config_path decode_config_path ckpt_path_prefix" + exit -1 +fi + config_path=$1 decode_config_path=$2 ckpt_prefix=$3 @@ -92,6 +98,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then fi if [ ${stage} -le 101 ] && [ ${stop_stage} -ge 101 ]; then + echo "using sclite to compute cer..." # format the reference test file for sclite python utils/format_rsl.py \ --origin_ref data/manifest.test.raw \ diff --git a/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py b/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py index ba178567..5d914a43 100644 --- a/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py +++ b/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py @@ -28,6 +28,7 @@ import soundfile from paddlespeech.dataset.download import download from paddlespeech.dataset.download import unpack +from paddlespeech.utils.argparse import print_arguments DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') @@ -139,7 +140,7 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path, subset): def main(): - print(f"args: {args}") + print_arguments(args, globals()) if args.target_dir.startswith('~'): args.target_dir = os.path.expanduser(args.target_dir) diff --git a/paddlespeech/dataset/aishell/aishell.py b/paddlespeech/dataset/aishell/aishell.py index fa90aa67..7ea4d676 100644 --- a/paddlespeech/dataset/aishell/aishell.py +++ b/paddlespeech/dataset/aishell/aishell.py @@ -28,6 +28,7 @@ import soundfile from paddlespeech.dataset.download import download from paddlespeech.dataset.download import unpack +from paddlespeech.utils.argparse import print_arguments DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') @@ -205,7 +206,7 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path=None, check=False): def main(): - print(f"args: {args}") + print_arguments(args, globals()) if args.target_dir.startswith('~'): args.target_dir = os.path.expanduser(args.target_dir) diff --git a/paddlespeech/dataset/s2t/__init__.py b/paddlespeech/dataset/s2t/__init__.py new file mode 100644 index 00000000..27ea9e77 --- /dev/null +++ b/paddlespeech/dataset/s2t/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# s2t utils binaries. +from .avg_model import main as avg_ckpts_main +from .build_vocab import main as build_vocab_main +from .compute_mean_std import main as compute_mean_std_main +from .compute_wer import main as compute_wer_main +from .format_data import main as format_data_main +from .format_rsl import main as format_rsl_main diff --git a/paddlespeech/dataset/s2t/avg_model.py b/paddlespeech/dataset/s2t/avg_model.py new file mode 100755 index 00000000..c5753b72 --- /dev/null +++ b/paddlespeech/dataset/s2t/avg_model.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob +import json +import os + +import numpy as np +import paddle + + +def define_argparse(): + parser = argparse.ArgumentParser(description='average model') + parser.add_argument('--dst_model', required=True, help='averaged model') + parser.add_argument( + '--ckpt_dir', required=True, help='ckpt model dir for average') + parser.add_argument( + '--val_best', action="store_true", help='averaged model') + parser.add_argument( + '--num', default=5, type=int, help='nums for averaged model') + parser.add_argument( + '--min_epoch', + default=0, + type=int, + help='min epoch used for averaging model') + parser.add_argument( + '--max_epoch', + default=65536, # Big enough + type=int, + help='max epoch used for averaging model') + + args = parser.parse_args() + return args + + +def average_checkpoints(dst_model="", + ckpt_dir="", + val_best=True, + num=5, + min_epoch=0, + max_epoch=65536): + paddle.set_device('cpu') + + val_scores = [] + jsons = glob.glob(f'{ckpt_dir}/[!train]*.json') + jsons = sorted(jsons, key=os.path.getmtime, reverse=True) + for y in jsons: + with open(y, 'r') as f: + dic_json = json.load(f) + loss = dic_json['val_loss'] + epoch = dic_json['epoch'] + if epoch >= min_epoch and epoch <= max_epoch: + val_scores.append((epoch, loss)) + assert val_scores, f"Not find any valid checkpoints: {val_scores}" + val_scores = np.array(val_scores) + + if val_best: + sort_idx = np.argsort(val_scores[:, 1]) + sorted_val_scores = val_scores[sort_idx] + else: + sorted_val_scores = val_scores + + beat_val_scores = sorted_val_scores[:num, 1] + selected_epochs = sorted_val_scores[:num, 0].astype(np.int64) + avg_val_score = np.mean(beat_val_scores) + print("selected val scores = " + str(beat_val_scores)) + print("selected epochs = " + str(selected_epochs)) + print("averaged val score = " + str(avg_val_score)) + + path_list = [ + ckpt_dir + '/{}.pdparams'.format(int(epoch)) + for epoch in sorted_val_scores[:num, 0] + ] + print(path_list) + + avg = None + num = args.num + assert num == len(path_list) + for path in path_list: + print(f'Processing {path}') + states = paddle.load(path) + if avg is None: + avg = states + else: + for k in avg.keys(): + avg[k] += states[k] + # average + for k in avg.keys(): + if avg[k] is not None: + avg[k] /= num + + paddle.save(avg, args.dst_model) + print(f'Saving to {args.dst_model}') + + meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' + with open(meta_path, 'w') as f: + data = json.dumps({ + "mode": 'val_best' if args.val_best else 'latest', + "avg_ckpt": args.dst_model, + "val_loss_mean": avg_val_score, + "ckpts": path_list, + "epochs": selected_epochs.tolist(), + "val_losses": beat_val_scores.tolist(), + }) + f.write(data + "\n") + + +def main(): + args = define_argparse() + average_checkpoints(args) + + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/build_vocab.py b/paddlespeech/dataset/s2t/build_vocab.py new file mode 100755 index 00000000..dd5f6208 --- /dev/null +++ b/paddlespeech/dataset/s2t/build_vocab.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Build vocabulary from manifest files. +Each item in vocabulary file is a character. +""" +import argparse +import functools +import os +import tempfile +from collections import Counter + +import jsonlines + +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.frontend.utility import BLANK +from paddlespeech.s2t.frontend.utility import SOS +from paddlespeech.s2t.frontend.utility import SPACE +from paddlespeech.s2t.frontend.utility import UNK +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments + + +def count_manifest(counter, text_feature, manifest_path): + manifest_jsons = [] + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + manifest_jsons.append(json_data) + + for line_json in manifest_jsons: + if isinstance(line_json['text'], str): + tokens = text_feature.tokenize( + line_json['text'], replace_space=False) + + counter.update(tokens) + else: + assert isinstance(line_json['text'], list) + for text in line_json['text']: + tokens = text_feature.tokenize(text, replace_space=False) + counter.update(tokens) + + +def dump_text_manifest(fileobj, manifest_path, key='text'): + manifest_jsons = [] + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + manifest_jsons.append(json_data) + + for line_json in manifest_jsons: + if isinstance(line_json[key], str): + fileobj.write(line_json[key] + "\n") + else: + assert isinstance(line_json[key], list) + for line in line_json[key]: + fileobj.write(line + "\n") + + +def build_vocab(manifest_paths="", + vocab_path="examples/librispeech/data/vocab.txt", + unit_type="char", + count_threshold=0, + text_keys='text', + spm_mode="unigram", + spm_vocab_size=0, + spm_model_prefix="", + spm_character_coverage=0.9995): + fout = open(vocab_path, 'w', encoding='utf-8') + fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC + fout.write(UNK + '\n') # must be 1 + + if unit_type == 'spm': + # tools/spm_train --input=$wave_data/lang_char/input.txt + # --vocab_size=${nbpe} --model_type=${bpemode} + # --model_prefix=${bpemodel} --input_sentence_size=100000000 + import sentencepiece as spm + + fp = tempfile.NamedTemporaryFile(mode='w', delete=False) + for manifest_path in manifest_paths: + _text_keys = [text_keys] if type( + text_keys) is not list else text_keys + for text_key in _text_keys: + dump_text_manifest(fp, manifest_path, key=text_key) + fp.close() + # train + spm.SentencePieceTrainer.Train( + input=fp.name, + vocab_size=spm_vocab_size, + model_type=spm_mode, + model_prefix=spm_model_prefix, + input_sentence_size=100000000, + character_coverage=spm_character_coverage) + os.unlink(fp.name) + + # encode + text_feature = TextFeaturizer(unit_type, "", spm_model_prefix) + counter = Counter() + + for manifest_path in manifest_paths: + count_manifest(counter, text_feature, manifest_path) + + count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) + tokens = [] + for token, count in count_sorted: + if count < count_threshold: + break + # replace space by `` + token = SPACE if token == ' ' else token + tokens.append(token) + + tokens = sorted(tokens) + for token in tokens: + fout.write(token + '\n') + + fout.write(SOS + "\n") # + fout.close() + + +def define_argparse(): + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + + # yapf: disable + add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") + add_arg('count_threshold', int, 0, + "Truncation threshold for char/word counts.Default 0, no truncate.") + add_arg('vocab_path', str, + 'examples/librispeech/data/vocab.txt', + "Filepath to write the vocabulary.") + add_arg('manifest_paths', str, + None, + "Filepaths of manifests for building vocabulary. " + "You can provide multiple manifest files.", + nargs='+', + required=True) + add_arg('text_keys', str, + 'text', + "keys of the text in manifest for building vocabulary. " + "You can provide multiple k.", + nargs='+') + # bpe + add_arg('spm_vocab_size', int, 0, "Vocab size for spm.") + add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") + add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm") + add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols") + # yapf: disable + + args = parser.parse_args() + return args + +def main(): + args = define_argparse() + print_arguments(args, globals()) + build_vocab(**vars(args)) + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/compute_mean_std.py b/paddlespeech/dataset/s2t/compute_mean_std.py new file mode 100755 index 00000000..8762ee57 --- /dev/null +++ b/paddlespeech/dataset/s2t/compute_mean_std.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compute mean and std for feature normalizer, and save to file.""" +import argparse +import functools + +from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline +from paddlespeech.s2t.frontend.featurizer.audio_featurizer import AudioFeaturizer +from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments + + +def compute_cmvn(manifest_path="data/librispeech/manifest.train", + output_path="data/librispeech/mean_std.npz", + num_samples=2000, + num_workers=0, + spectrum_type="linear", + feat_dim=13, + delta_delta=False, + stride_ms=10, + window_ms=20, + sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): + + augmentation_pipeline = AugmentationPipeline('{}') + audio_featurizer = AudioFeaturizer( + spectrum_type=spectrum_type, + feat_dim=feat_dim, + delta_delta=delta_delta, + stride_ms=float(stride_ms), + window_ms=float(window_ms), + n_fft=None, + max_freq=None, + target_sample_rate=sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB, + dither=0.0) + + def augment_and_featurize(audio_segment): + augmentation_pipeline.transform_audio(audio_segment) + return audio_featurizer.featurize(audio_segment) + + normalizer = FeatureNormalizer( + mean_std_filepath=None, + manifest_path=manifest_path, + featurize_func=augment_and_featurize, + num_samples=num_samples, + num_workers=num_workers) + normalizer.write_to_file(output_path) + + +def define_argparse(): + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + + # yapf: disable + add_arg('manifest_path', str, + 'data/librispeech/manifest.train', + "Filepath of manifest to compute normalizer's mean and stddev.") + + add_arg('output_path', str, + 'data/librispeech/mean_std.npz', + "Filepath of write mean and stddev to (.npz).") + add_arg('num_samples', int, 2000, "# of samples to for statistics.") + add_arg('num_workers', + default=0, + type=int, + help='num of subprocess workers for processing') + + + add_arg('spectrum_type', str, + 'linear', + "Audio feature type. Options: linear, mfcc, fbank.", + choices=['linear', 'mfcc', 'fbank']) + add_arg('feat_dim', int, 13, "Audio feature dim.") + add_arg('delta_delta', bool, False, "Audio feature with delta delta.") + add_arg('stride_ms', int, 10, "stride length in ms.") + add_arg('window_ms', int, 20, "stride length in ms.") + add_arg('sample_rate', int, 16000, "target sample rate.") + add_arg('use_dB_normalization', bool, True, "do dB normalization.") + add_arg('target_dB', int, -20, "target dB.") + # yapf: disable + + args = parser.parse_args() + return args + +def main(): + args = define_argparse() + print_arguments(args, globals()) + compute_cmvn(**vars(args)) + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/compute_wer.py b/paddlespeech/dataset/s2t/compute_wer.py new file mode 100755 index 00000000..5711c725 --- /dev/null +++ b/paddlespeech/dataset/s2t/compute_wer.py @@ -0,0 +1,558 @@ +# Copyright 2021 Mobvoi Inc. All Rights Reserved. +# flake8: noqa +import codecs +import re +import sys +import unicodedata + +remove_tag = True +spacelist = [' ', '\t', '\r', '\n'] +puncts = [ + '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', + '《', '》' +] + + +def characterize(string): + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': sep = '>' + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res + + +def stripoff_tags(x): + if not x: return '' + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) + + +def normalize(sentence, ignore_words, cs, split=None): + """ sentence, ignore_words are both in unicode + """ + new_sentence = [] + for token in sentence: + x = token + if not cs: + x = x.upper() + if x in ignore_words: + continue + if remove_tag: + x = stripoff_tags(x) + if not x: + continue + if split and x in split: + new_sentence += split[x] + else: + new_sentence.append(x) + return new_sentence + + +class Calculator: + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec): + row.append({'dist': 0, 'error': 'non'}) + for i in range(len(lab)): + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)): + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = { + 'lab': [], + 'rec': [], + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]['error'] == 'cor': # correct + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub': # substitution + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del': # deletion + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins': # insertion + if len(rec[j]) > 0: + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non': # starting point + break + else: # shouldn't reach here + print( + 'this should not happen , i = {i} , j = {j} , error = {error}'. + format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data: + if token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self): + return list(self.data.keys()) + + +def width(string): + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + + +def default_cluster(word): + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))): + if unicode_names[i].startswith('DIGIT'): # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')): + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')): + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else: + return 'Other' + if len(unicode_names) == 0: + return 'Other' + if len(unicode_names) == 1: + return unicode_names[0] + for i in range(len(unicode_names) - 1): + if unicode_names[i] != unicode_names[i + 1]: + return 'Other' + return unicode_names[0] + + +def usage(): + print( + "compute-wer.py : compute word error rate (WER) and align recognition results and references." + ) + print( + " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" + ) + + +def main(): + # python utils/compute-wer.py --char=1 --v=1 ref hyp > rsl.error + if len(sys.argv) == 1: + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose = 1 + padding_symbol = ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose = 0 + try: + verbose = int(b) + except: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol = ' ' + elif b == 'underline': + padding_symbol = '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue + + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig + + default_clusters = {} + default_words = {} + + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit + + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, + split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8'): + if tochar: + array = characterize(line) + else: + array = line.rstrip('\n').split() + if len(array) == 0: continue + fid = array[0] + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) + + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name + + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('WER: %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])): + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 + + if verbose: + print( + '===========================================================================' + ) + print() + + result = calculator.overall() + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('Overall -> %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters: + result = calculator.cluster( + [k for k in default_clusters[cluster_id]]) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if len(cluster_file) > 0: # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8'): + for token in line.decode('utf-8').rstrip('\n').split(): + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token)-1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else: + cluster.append(token) + print() + print( + '===========================================================================' + ) + + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/format_data.py b/paddlespeech/dataset/s2t/format_data.py new file mode 100755 index 00000000..dcff66ea --- /dev/null +++ b/paddlespeech/dataset/s2t/format_data.py @@ -0,0 +1,154 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""format manifest with more metadata.""" +import argparse +import functools +import json + +import jsonlines + +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.frontend.utility import load_cmvn +from paddlespeech.s2t.io.utility import feat_type +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments + + +def define_argparse(): + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + # yapf: disable + add_arg('manifest_paths', str, + None, + "Filepaths of manifests for building vocabulary. " + "You can provide multiple manifest files.", + nargs='+', + required=True) + add_arg('output_path', str, None, "filepath of formated manifest.", required=True) + add_arg('cmvn_path', str, + 'examples/librispeech/data/mean_std.json', + "Filepath of cmvn.") + add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") + add_arg('vocab_path', str, + 'examples/librispeech/data/vocab.txt', + "Filepath of the vocabulary.") + # bpe + add_arg('spm_model_prefix', str, None, + "spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm") + + # yapf: disable + args = parser.parse_args() + return args + +def format_data( + manifest_paths="", + output_path="", + cmvn_path="examples/librispeech/data/mean_std.json", + unit_type="char", + vocab_path="examples/librispeech/data/vocab.txt", + spm_model_prefix=""): + + fout = open(output_path, 'w', encoding='utf-8') + + # get feat dim + filetype = cmvn_path.split(".")[-1] + mean, istd = load_cmvn(cmvn_path, filetype=filetype) + feat_dim = mean.shape[0] #(D) + print(f"Feature dim: {feat_dim}") + + text_feature = TextFeaturizer(unit_type, vocab_path, spm_model_prefix) + vocab_size = text_feature.vocab_size + print(f"Vocab size: {vocab_size}") + + # josnline like this + # { + # "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}], + # "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}], + # "utt2spk": "111-2222", + # "utt": "111-2222-333" + # } + count = 0 + for manifest_path in manifest_paths: + with jsonlines.open(str(manifest_path), 'r') as reader: + manifest_jsons = list(reader) + + for line_json in manifest_jsons: + output_json = { + "input": [], + "output": [], + 'utt': line_json['utt'], + 'utt2spk': line_json.get('utt2spk', 'global'), + } + + # output + line = line_json['text'] + if isinstance(line, str): + # only one target + tokens = text_feature.tokenize(line) + tokenids = text_feature.featurize(line) + output_json['output'].append({ + 'name': 'target1', + 'shape': (len(tokenids), vocab_size), + 'text': line, + 'token': ' '.join(tokens), + 'tokenid': ' '.join(map(str, tokenids)), + }) + else: + # isinstance(line, list), multi target in one vocab + for i, item in enumerate(line, 1): + tokens = text_feature.tokenize(item) + tokenids = text_feature.featurize(item) + output_json['output'].append({ + 'name': f'target{i}', + 'shape': (len(tokenids), vocab_size), + 'text': item, + 'token': ' '.join(tokens), + 'tokenid': ' '.join(map(str, tokenids)), + }) + + # input + line = line_json['feat'] + if isinstance(line, str): + # only one input + feat_shape = line_json['feat_shape'] + assert isinstance(feat_shape, (list, tuple)), type(feat_shape) + filetype = feat_type(line) + if filetype == 'sound': + feat_shape.append(feat_dim) + else: # kaldi + raise NotImplementedError('no support kaldi feat now!') + + output_json['input'].append({ + "name": "input1", + "shape": feat_shape, + "feat": line, + "filetype": filetype, + }) + else: + # isinstance(line, list), multi input + raise NotImplementedError("not support multi input now!") + + fout.write(json.dumps(output_json) + '\n') + count += 1 + + print(f"{manifest_paths} Examples number: {count}") + fout.close() + +def main(): + args = define_argparse() + print_arguments(args, globals()) + format_data(**vars(args)) + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/format_rsl.py b/paddlespeech/dataset/s2t/format_rsl.py new file mode 100644 index 00000000..0a58e7e6 --- /dev/null +++ b/paddlespeech/dataset/s2t/format_rsl.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +format ref/hyp file for `utt text` format to compute CER/WER/MER. + +norm: +BAC009S0764W0196 明确了发展目标和重点任务 +BAC009S0764W0186 实现我国房地产市场的平稳运行 + + +sclite: +加大对结构机械化环境和收集谈控机制力度(BAC009S0906W0240.wav) +河南省新乡市丰秋县刘光镇政府东五零左右(BAC009S0770W0441.wav) +""" +import argparse + +import jsonlines + +from paddlespeech.utils.argparse import print_arguments + + +def transform_hyp(origin, trans, trans_sclite): + """ + Args: + origin: The input json file which contains the model output + trans: The output file for caculate CER/WER + trans_sclite: The output file for caculate CER/WER using sclite + """ + input_dict = {} + + with open(origin, "r+", encoding="utf8") as f: + for item in jsonlines.Reader(f): + input_dict[item["utt"]] = item["hyps"][0] + + if trans: + with open(trans, "w+", encoding="utf8") as f: + for key in input_dict.keys(): + f.write(key + " " + input_dict[key] + "\n") + print(f"transform_hyp output: {trans}") + + if trans_sclite: + with open(trans_sclite, "w+") as f: + for key in input_dict.keys(): + line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" + f.write(line) + print(f"transform_hyp output: {trans_sclite}") + + +def transform_ref(origin, trans, trans_sclite): + """ + Args: + origin: The input json file which contains the model output + trans: The output file for caculate CER/WER + trans_sclite: The output file for caculate CER/WER using sclite + """ + input_dict = {} + + with open(origin, "r", encoding="utf8") as f: + for item in jsonlines.Reader(f): + input_dict[item["utt"]] = item["text"] + + if trans: + with open(trans, "w", encoding="utf8") as f: + for key in input_dict.keys(): + f.write(key + " " + input_dict[key] + "\n") + print(f"transform_hyp output: {trans}") + + if trans_sclite: + with open(trans_sclite, "w") as f: + for key in input_dict.keys(): + line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" + f.write(line) + print(f"transform_hyp output: {trans_sclite}") + + +def define_argparse(): + parser = argparse.ArgumentParser( + prog='format ref/hyp file for compute CER/WER', add_help=True) + parser.add_argument( + '--origin_hyp', type=str, default="", help='origin hyp file') + parser.add_argument( + '--trans_hyp', + type=str, + default="", + help='hyp file for caculating CER/WER') + parser.add_argument( + '--trans_hyp_sclite', + type=str, + default="", + help='hyp file for caculating CER/WER by sclite') + + parser.add_argument( + '--origin_ref', type=str, default="", help='origin ref file') + parser.add_argument( + '--trans_ref', + type=str, + default="", + help='ref file for caculating CER/WER') + parser.add_argument( + '--trans_ref_sclite', + type=str, + default="", + help='ref file for caculating CER/WER by sclite') + parser_args = parser.parse_args() + return parser_args + + +def format_result(origin_hyp="", + trans_hyp="", + trans_hyp_sclite="", + origin_ref="", + trans_ref="", + trans_ref_sclite=""): + + if origin_hyp: + transform_hyp( + origin=origin_hyp, trans=trans_hyp, trans_sclite=trans_hyp_sclite) + + if origin_ref: + transform_ref( + origin=origin_ref, trans=trans_ref, trans_sclite=trans_ref_sclite) + + +def main(): + args = define_argparse() + print_arguments(args, globals()) + + format_result(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py index 5755a5f1..f6b1ed09 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py @@ -28,8 +28,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.socket_server import AsrRequestHandler from paddlespeech.s2t.utils.socket_server import AsrTCPServer from paddlespeech.s2t.utils.socket_server import warm_up_test -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments def init_predictor(args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py index 0d0b4f21..fc57399d 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py @@ -26,8 +26,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.socket_server import AsrRequestHandler from paddlespeech.s2t.utils.socket_server import AsrTCPServer from paddlespeech.s2t.utils.socket_server import warm_up_test -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments def start_server(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/export.py b/paddlespeech/s2t/exps/deepspeech2/bin/export.py index 8acd46df..07228e98 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test.py b/paddlespeech/s2t/exps/deepspeech2/bin/test.py index 030168a9..a8e20ff9 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py index d7a9402b..1e07aa80 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2ExportTester as ExportTester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py index 66ea29d0..32a583b6 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/test_wav.py @@ -27,8 +27,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils import mp_tools from paddlespeech.s2t.utils.checkpoint import Checkpoint from paddlespeech.s2t.utils.log import Log -from paddlespeech.s2t.utils.utility import print_arguments from paddlespeech.s2t.utils.utility import UpdateConfig +from paddlespeech.utils.argparse import print_arguments logger = Log(__name__).getlog() diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/train.py b/paddlespeech/s2t/exps/deepspeech2/bin/train.py index 2c9942f9..1340aaa3 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/train.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/train.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.deepspeech2.model import DeepSpeech2Trainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/alignment.py b/paddlespeech/s2t/exps/u2/bin/alignment.py index e3390feb..cc294038 100644 --- a/paddlespeech/s2t/exps/u2/bin/alignment.py +++ b/paddlespeech/s2t/exps/u2/bin/alignment.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/export.py b/paddlespeech/s2t/exps/u2/bin/export.py index 592b1237..4725e5e1 100644 --- a/paddlespeech/s2t/exps/u2/bin/export.py +++ b/paddlespeech/s2t/exps/u2/bin/export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/test.py b/paddlespeech/s2t/exps/u2/bin/test.py index b13fd0d3..43eeff63 100644 --- a/paddlespeech/s2t/exps/u2/bin/test.py +++ b/paddlespeech/s2t/exps/u2/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Tester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2/bin/train.py b/paddlespeech/s2t/exps/u2/bin/train.py index dc3a87c1..a0f50328 100644 --- a/paddlespeech/s2t/exps/u2/bin/train.py +++ b/paddlespeech/s2t/exps/u2/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2.model import U2Trainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments # from paddlespeech.s2t.exps.u2.trainer import U2Trainer as Trainer diff --git a/paddlespeech/s2t/exps/u2_kaldi/bin/test.py b/paddlespeech/s2t/exps/u2_kaldi/bin/test.py index 422483b9..4137537e 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/bin/test.py +++ b/paddlespeech/s2t/exps/u2_kaldi/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments model_test_alias = { "u2": "paddlespeech.s2t.exps.u2.model:U2Tester", diff --git a/paddlespeech/s2t/exps/u2_kaldi/bin/train.py b/paddlespeech/s2t/exps/u2_kaldi/bin/train.py index b11da715..011aabac 100644 --- a/paddlespeech/s2t/exps/u2_kaldi/bin/train.py +++ b/paddlespeech/s2t/exps/u2_kaldi/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.dynamic_import import dynamic_import -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments model_train_alias = { "u2": "paddlespeech.s2t.exps.u2.model:U2Trainer", diff --git a/paddlespeech/s2t/exps/u2_st/bin/export.py b/paddlespeech/s2t/exps/u2_st/bin/export.py index c641152f..a2a7424c 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/export.py +++ b/paddlespeech/s2t/exps/u2_st/bin/export.py @@ -16,7 +16,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2_st/bin/test.py b/paddlespeech/s2t/exps/u2_st/bin/test.py index c07c95bd..30a903ce 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/test.py +++ b/paddlespeech/s2t/exps/u2_st/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2_st.model import U2STTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/u2_st/bin/train.py b/paddlespeech/s2t/exps/u2_st/bin/train.py index 574942e5..b36a0af4 100644 --- a/paddlespeech/s2t/exps/u2_st/bin/train.py +++ b/paddlespeech/s2t/exps/u2_st/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.u2_st.model import U2STTrainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/test.py b/paddlespeech/s2t/exps/wav2vec2/bin/test.py index a376651d..c17cee0f 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/test.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/test.py @@ -18,7 +18,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTester as Tester from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/exps/wav2vec2/bin/train.py b/paddlespeech/s2t/exps/wav2vec2/bin/train.py index 29e7ef55..0c37f796 100644 --- a/paddlespeech/s2t/exps/wav2vec2/bin/train.py +++ b/paddlespeech/s2t/exps/wav2vec2/bin/train.py @@ -19,7 +19,7 @@ from yacs.config import CfgNode from paddlespeech.s2t.exps.wav2vec2.model import Wav2Vec2ASRTrainer as Trainer from paddlespeech.s2t.training.cli import default_argument_parser -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import print_arguments def main_sp(config, args): diff --git a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py index 982c6b8f..7623d0b8 100644 --- a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py +++ b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py @@ -48,13 +48,16 @@ class TextFeaturizer(): self.unit_type = unit_type self.unk = UNK self.maskctc = maskctc + self.vocab_path_or_list = vocab - if vocab: + if self.vocab_path_or_list: self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file( vocab, maskctc) self.vocab_size = len(self.vocab_list) else: - logger.warning("TextFeaturizer: not have vocab file or vocab list.") + logger.warning( + "TextFeaturizer: not have vocab file or vocab list. Only Tokenizer can use, can not convert to token idx" + ) if unit_type == 'spm': spm_model = spm_model_prefix + '.model' @@ -62,6 +65,7 @@ class TextFeaturizer(): self.sp.Load(spm_model) def tokenize(self, text, replace_space=True): + """tokenizer split text into text tokens""" if self.unit_type == 'char': tokens = self.char_tokenize(text, replace_space) elif self.unit_type == 'word': @@ -71,6 +75,7 @@ class TextFeaturizer(): return tokens def detokenize(self, tokens): + """tokenizer convert text tokens back to text""" if self.unit_type == 'char': text = self.char_detokenize(tokens) elif self.unit_type == 'word': @@ -88,6 +93,7 @@ class TextFeaturizer(): Returns: List[int]: List of token indices. """ + assert self.vocab_path_or_list, "toidx need vocab path or vocab list" tokens = self.tokenize(text) ids = [] for token in tokens: @@ -107,6 +113,7 @@ class TextFeaturizer(): Returns: str: Text. """ + assert self.vocab_path_or_list, "toidx need vocab path or vocab list" tokens = [] for idx in idxs: if idx == self.eos_id: @@ -127,10 +134,10 @@ class TextFeaturizer(): """ text = text.strip() if replace_space: - text_list = [SPACE if item == " " else item for item in list(text)] + tokens = [SPACE if item == " " else item for item in list(text)] else: - text_list = list(text) - return text_list + tokens = list(text) + return tokens def char_detokenize(self, tokens): """Character detokenizer. diff --git a/paddlespeech/s2t/utils/utility.py b/paddlespeech/s2t/utils/utility.py index d7e7c6ca..5655ec3f 100644 --- a/paddlespeech/s2t/utils/utility.py +++ b/paddlespeech/s2t/utils/utility.py @@ -29,10 +29,7 @@ from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() -__all__ = [ - "all_version", "UpdateConfig", "seed_all", 'print_arguments', - 'add_arguments', "log_add" -] +__all__ = ["all_version", "UpdateConfig", "seed_all", "log_add"] def all_version(): @@ -60,51 +57,6 @@ def seed_all(seed: int=20210329): paddle.seed(seed) -def print_arguments(args, info=None): - """Print argparse's arguments. - - Usage: - - .. code-block:: python - - parser = argparse.ArgumentParser() - parser.add_argument("name", default="Jonh", type=str, help="User name.") - args = parser.parse_args() - print_arguments(args) - - :param args: Input argparse.Namespace for printing. - :type args: argparse.Namespace - """ - filename = "" - if info: - filename = info["__file__"] - filename = os.path.basename(filename) - print(f"----------- {filename} Arguments -----------") - for arg, value in sorted(vars(args).items()): - print("%s: %s" % (arg, value)) - print("-----------------------------------------------------------") - - -def add_arguments(argname, type, default, help, argparser, **kwargs): - """Add argparse's argument. - - Usage: - - .. code-block:: python - - parser = argparse.ArgumentParser() - add_argument("name", str, "Jonh", "User name.", parser) - args = parser.parse_args() - """ - type = distutils.util.strtobool if type == bool else type - argparser.add_argument( - "--" + argname, - default=default, - type=type, - help=help + ' Default: %(default)s.', - **kwargs) - - def log_add(args: List[int]) -> float: """Stable log add diff --git a/paddlespeech/utils/argparse.py b/paddlespeech/utils/argparse.py index 4df75c5a..aad3801e 100644 --- a/paddlespeech/utils/argparse.py +++ b/paddlespeech/utils/argparse.py @@ -16,6 +16,8 @@ import os import sys from typing import Text +import distutils + __all__ = ["print_arguments", "add_arguments", "get_commandline_args"] diff --git a/utils/avg_model.py b/utils/avg_model.py index 6ee16408..039ea626 100755 --- a/utils/avg_model.py +++ b/utils/avg_model.py @@ -12,105 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse -import glob -import json -import os - -import numpy as np -import paddle - - -def main(args): - paddle.set_device('cpu') - - val_scores = [] - beat_val_scores = None - selected_epochs = None - - jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') - jsons = sorted(jsons, key=os.path.getmtime, reverse=True) - for y in jsons: - with open(y, 'r') as f: - dic_json = json.load(f) - loss = dic_json['val_loss'] - epoch = dic_json['epoch'] - if epoch >= args.min_epoch and epoch <= args.max_epoch: - val_scores.append((epoch, loss)) - val_scores = np.array(val_scores) - - if args.val_best: - sort_idx = np.argsort(val_scores[:, 1]) - sorted_val_scores = val_scores[sort_idx] - else: - sorted_val_scores = val_scores - - beat_val_scores = sorted_val_scores[:args.num, 1] - selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) - avg_val_score = np.mean(beat_val_scores) - print("selected val scores = " + str(beat_val_scores)) - print("selected epochs = " + str(selected_epochs)) - print("averaged val score = " + str(avg_val_score)) - - path_list = [ - args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) - for epoch in sorted_val_scores[:args.num, 0] - ] - print(path_list) - - avg = None - num = args.num - assert num == len(path_list) - for path in path_list: - print(f'Processing {path}') - states = paddle.load(path) - if avg is None: - avg = states - else: - for k in avg.keys(): - avg[k] += states[k] - # average - for k in avg.keys(): - if avg[k] is not None: - avg[k] /= num - - paddle.save(avg, args.dst_model) - print(f'Saving to {args.dst_model}') - - meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' - with open(meta_path, 'w') as f: - data = json.dumps({ - "mode": 'val_best' if args.val_best else 'latest', - "avg_ckpt": args.dst_model, - "val_loss_mean": avg_val_score, - "ckpts": path_list, - "epochs": selected_epochs.tolist(), - "val_losses": beat_val_scores.tolist(), - }) - f.write(data + "\n") - +from paddlespeech.dataset.s2t import avg_ckpts_main if __name__ == '__main__': - parser = argparse.ArgumentParser(description='average model') - parser.add_argument('--dst_model', required=True, help='averaged model') - parser.add_argument( - '--ckpt_dir', required=True, help='ckpt model dir for average') - parser.add_argument( - '--val_best', action="store_true", help='averaged model') - parser.add_argument( - '--num', default=5, type=int, help='nums for averaged model') - parser.add_argument( - '--min_epoch', - default=0, - type=int, - help='min epoch used for averaging model') - parser.add_argument( - '--max_epoch', - default=65536, # Big enough - type=int, - help='max epoch used for averaging model') - - args = parser.parse_args() - print(args) - - main(args) + avg_ckpts_main() diff --git a/utils/build_vocab.py b/utils/build_vocab.py index e364e821..9b29dfa5 100755 --- a/utils/build_vocab.py +++ b/utils/build_vocab.py @@ -15,134 +15,7 @@ """Build vocabulary from manifest files. Each item in vocabulary file is a character. """ -import argparse -import functools -import os -import tempfile -from collections import Counter - -import jsonlines - -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.frontend.utility import BLANK -from paddlespeech.s2t.frontend.utility import SOS -from paddlespeech.s2t.frontend.utility import SPACE -from paddlespeech.s2t.frontend.utility import UNK -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments - -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") -add_arg('count_threshold', int, 0, - "Truncation threshold for char/word counts.Default 0, no truncate.") -add_arg('vocab_path', str, - 'examples/librispeech/data/vocab.txt', - "Filepath to write the vocabulary.") -add_arg('manifest_paths', str, - None, - "Filepaths of manifests for building vocabulary. " - "You can provide multiple manifest files.", - nargs='+', - required=True) -add_arg('text_keys', str, - 'text', - "keys of the text in manifest for building vocabulary. " - "You can provide multiple k.", - nargs='+') -# bpe -add_arg('spm_vocab_size', int, 0, "Vocab size for spm.") -add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") -add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm") -add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols") - -# yapf: disable -args = parser.parse_args() - - -def count_manifest(counter, text_feature, manifest_path): - manifest_jsons = [] - with jsonlines.open(manifest_path, 'r') as reader: - for json_data in reader: - manifest_jsons.append(json_data) - - for line_json in manifest_jsons: - if isinstance(line_json['text'], str): - line = text_feature.tokenize(line_json['text'], replace_space=False) - counter.update(line) - else: - assert isinstance(line_json['text'], list) - for text in line_json['text']: - line = text_feature.tokenize(text, replace_space=False) - counter.update(line) - -def dump_text_manifest(fileobj, manifest_path, key='text'): - manifest_jsons = [] - with jsonlines.open(manifest_path, 'r') as reader: - for json_data in reader: - manifest_jsons.append(json_data) - - for line_json in manifest_jsons: - if isinstance(line_json[key], str): - fileobj.write(line_json[key] + "\n") - else: - assert isinstance(line_json[key], list) - for line in line_json[key]: - fileobj.write(line + "\n") - -def main(): - print_arguments(args, globals()) - - fout = open(args.vocab_path, 'w', encoding='utf-8') - fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC - fout.write(UNK + '\n') # must be 1 - - if args.unit_type == 'spm': - # tools/spm_train --input=$wave_data/lang_char/input.txt - # --vocab_size=${nbpe} --model_type=${bpemode} - # --model_prefix=${bpemodel} --input_sentence_size=100000000 - import sentencepiece as spm - - fp = tempfile.NamedTemporaryFile(mode='w', delete=False) - for manifest_path in args.manifest_paths: - text_keys = [args.text_keys] if type(args.text_keys) is not list else args.text_keys - for text_key in text_keys: - dump_text_manifest(fp, manifest_path, key=text_key) - fp.close() - # train - spm.SentencePieceTrainer.Train( - input=fp.name, - vocab_size=args.spm_vocab_size, - model_type=args.spm_mode, - model_prefix=args.spm_model_prefix, - input_sentence_size=100000000, - character_coverage=args.spm_character_coverage) - os.unlink(fp.name) - - # encode - text_feature = TextFeaturizer(args.unit_type, "", args.spm_model_prefix) - counter = Counter() - - for manifest_path in args.manifest_paths: - count_manifest(counter, text_feature, manifest_path) - - count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) - tokens = [] - for token, count in count_sorted: - if count < args.count_threshold: - break - # replace space by `` - token = SPACE if token == ' ' else token - tokens.append(token) - - tokens = sorted(tokens) - for token in tokens: - fout.write(token + '\n') - - fout.write(SOS + "\n") # - fout.close() - +from paddlespeech.dataset.s2t import build_vocab_main if __name__ == '__main__': - main() + build_vocab_main() diff --git a/utils/compute-wer.py b/utils/compute-wer.py index 98bb24a7..1fa77216 100755 --- a/utils/compute-wer.py +++ b/utils/compute-wer.py @@ -1,554 +1,5 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- # Copyright 2021 Mobvoi Inc. All Rights Reserved. -import codecs -import re -import sys -import unicodedata - -remove_tag = True -spacelist = [' ', '\t', '\r', '\n'] -puncts = [ - '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', - '《', '》' -] - - -def characterize(string): - res = [] - i = 0 - while i < len(string): - char = string[i] - if char in puncts: - i += 1 - continue - cat1 = unicodedata.category(char) - #https://unicodebook.readthedocs.io/unicode.html#unicode-categories - if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned - i += 1 - continue - if cat1 == 'Lo': # letter-other - res.append(char) - i += 1 - else: - # some input looks like: , we want to separate it to two words. - sep = ' ' - if char == '<': sep = '>' - j = i + 1 - while j < len(string): - c = string[j] - if ord(c) >= 128 or (c in spacelist) or (c == sep): - break - j += 1 - if j < len(string) and string[j] == '>': - j += 1 - res.append(string[i:j]) - i = j - return res - - -def stripoff_tags(x): - if not x: return '' - chars = [] - i = 0 - T = len(x) - while i < T: - if x[i] == '<': - while i < T and x[i] != '>': - i += 1 - i += 1 - else: - chars.append(x[i]) - i += 1 - return ''.join(chars) - - -def normalize(sentence, ignore_words, cs, split=None): - """ sentence, ignore_words are both in unicode - """ - new_sentence = [] - for token in sentence: - x = token - if not cs: - x = x.upper() - if x in ignore_words: - continue - if remove_tag: - x = stripoff_tags(x) - if not x: - continue - if split and x in split: - new_sentence += split[x] - else: - new_sentence.append(x) - return new_sentence - - -class Calculator: - def __init__(self): - self.data = {} - self.space = [] - self.cost = {} - self.cost['cor'] = 0 - self.cost['sub'] = 1 - self.cost['del'] = 1 - self.cost['ins'] = 1 - - def calculate(self, lab, rec): - # Initialization - lab.insert(0, '') - rec.insert(0, '') - while len(self.space) < len(lab): - self.space.append([]) - for row in self.space: - for element in row: - element['dist'] = 0 - element['error'] = 'non' - while len(row) < len(rec): - row.append({'dist': 0, 'error': 'non'}) - for i in range(len(lab)): - self.space[i][0]['dist'] = i - self.space[i][0]['error'] = 'del' - for j in range(len(rec)): - self.space[0][j]['dist'] = j - self.space[0][j]['error'] = 'ins' - self.space[0][0]['error'] = 'non' - for token in lab: - if token not in self.data and len(token) > 0: - self.data[token] = { - 'all': 0, - 'cor': 0, - 'sub': 0, - 'ins': 0, - 'del': 0 - } - for token in rec: - if token not in self.data and len(token) > 0: - self.data[token] = { - 'all': 0, - 'cor': 0, - 'sub': 0, - 'ins': 0, - 'del': 0 - } - # Computing edit distance - for i, lab_token in enumerate(lab): - for j, rec_token in enumerate(rec): - if i == 0 or j == 0: - continue - min_dist = sys.maxsize - min_error = 'none' - dist = self.space[i - 1][j]['dist'] + self.cost['del'] - error = 'del' - if dist < min_dist: - min_dist = dist - min_error = error - dist = self.space[i][j - 1]['dist'] + self.cost['ins'] - error = 'ins' - if dist < min_dist: - min_dist = dist - min_error = error - if lab_token == rec_token: - dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] - error = 'cor' - else: - dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] - error = 'sub' - if dist < min_dist: - min_dist = dist - min_error = error - self.space[i][j]['dist'] = min_dist - self.space[i][j]['error'] = min_error - # Tracing back - result = { - 'lab': [], - 'rec': [], - 'all': 0, - 'cor': 0, - 'sub': 0, - 'ins': 0, - 'del': 0 - } - i = len(lab) - 1 - j = len(rec) - 1 - while True: - if self.space[i][j]['error'] == 'cor': # correct - if len(lab[i]) > 0: - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 - result['all'] = result['all'] + 1 - result['cor'] = result['cor'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'sub': # substitution - if len(lab[i]) > 0: - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 - result['all'] = result['all'] + 1 - result['sub'] = result['sub'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'del': # deletion - if len(lab[i]) > 0: - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 - result['all'] = result['all'] + 1 - result['del'] = result['del'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, "") - i = i - 1 - elif self.space[i][j]['error'] == 'ins': # insertion - if len(rec[j]) > 0: - self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 - result['ins'] = result['ins'] + 1 - result['lab'].insert(0, "") - result['rec'].insert(0, rec[j]) - j = j - 1 - elif self.space[i][j]['error'] == 'non': # starting point - break - else: # shouldn't reach here - print( - 'this should not happen , i = {i} , j = {j} , error = {error}'. - format(i=i, j=j, error=self.space[i][j]['error'])) - return result - - def overall(self): - result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} - for token in self.data: - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - - def cluster(self, data): - result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} - for token in data: - if token in self.data: - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - - def keys(self): - return list(self.data.keys()) - - -def width(string): - return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) - - -def default_cluster(word): - unicode_names = [unicodedata.name(char) for char in word] - for i in reversed(range(len(unicode_names))): - if unicode_names[i].startswith('DIGIT'): # 1 - unicode_names[i] = 'Number' # 'DIGIT' - elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or - unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): - # 明 / 郎 - unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' - elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or - unicode_names[i].startswith('LATIN SMALL LETTER')): - # A / a - unicode_names[i] = 'English' # 'LATIN LETTER' - elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め - unicode_names[i] = 'Japanese' # 'GANA LETTER' - elif (unicode_names[i].startswith('AMPERSAND') or - unicode_names[i].startswith('APOSTROPHE') or - unicode_names[i].startswith('COMMERCIAL AT') or - unicode_names[i].startswith('DEGREE CELSIUS') or - unicode_names[i].startswith('EQUALS SIGN') or - unicode_names[i].startswith('FULL STOP') or - unicode_names[i].startswith('HYPHEN-MINUS') or - unicode_names[i].startswith('LOW LINE') or - unicode_names[i].startswith('NUMBER SIGN') or - unicode_names[i].startswith('PLUS SIGN') or - unicode_names[i].startswith('SEMICOLON')): - # & / ' / @ / ℃ / = / . / - / _ / # / + / ; - del unicode_names[i] - else: - return 'Other' - if len(unicode_names) == 0: - return 'Other' - if len(unicode_names) == 1: - return unicode_names[0] - for i in range(len(unicode_names) - 1): - if unicode_names[i] != unicode_names[i + 1]: - return 'Other' - return unicode_names[0] - - -def usage(): - print( - "compute-wer.py : compute word error rate (WER) and align recognition results and references." - ) - print( - " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" - ) - +from paddlespeech.dataset.s2t import compute_wer_main if __name__ == '__main__': - if len(sys.argv) == 1: - usage() - sys.exit(0) - calculator = Calculator() - cluster_file = '' - ignore_words = set() - tochar = False - verbose = 1 - padding_symbol = ' ' - case_sensitive = False - max_words_per_line = sys.maxsize - split = None - while len(sys.argv) > 3: - a = '--maxw=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):] - del sys.argv[1] - max_words_per_line = int(b) - continue - a = '--rt=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - remove_tag = (b == 'true') or (b != '0') - continue - a = '--cs=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - case_sensitive = (b == 'true') or (b != '0') - continue - a = '--cluster=' - if sys.argv[1].startswith(a): - cluster_file = sys.argv[1][len(a):] - del sys.argv[1] - continue - a = '--splitfile=' - if sys.argv[1].startswith(a): - split_file = sys.argv[1][len(a):] - del sys.argv[1] - split = dict() - with codecs.open(split_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - words = line.strip().split() - if len(words) >= 2: - split[words[0]] = words[1:] - continue - a = '--ig=' - if sys.argv[1].startswith(a): - ignore_file = sys.argv[1][len(a):] - del sys.argv[1] - with codecs.open(ignore_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - line = line.strip() - if len(line) > 0: - ignore_words.add(line) - continue - a = '--char=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - tochar = (b == 'true') or (b != '0') - continue - a = '--v=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - verbose = 0 - try: - verbose = int(b) - except: - if b == 'true' or b != '0': - verbose = 1 - continue - a = '--padding-symbol=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - if b == 'space': - padding_symbol = ' ' - elif b == 'underline': - padding_symbol = '_' - continue - if True or sys.argv[1].startswith('-'): - #ignore invalid switch - del sys.argv[1] - continue - - if not case_sensitive: - ig = set([w.upper() for w in ignore_words]) - ignore_words = ig - - default_clusters = {} - default_words = {} - - ref_file = sys.argv[1] - hyp_file = sys.argv[2] - rec_set = {} - if split and not case_sensitive: - newsplit = dict() - for w in split: - words = split[w] - for i in range(len(words)): - words[i] = words[i].upper() - newsplit[w.upper()] = words - split = newsplit - - with codecs.open(hyp_file, 'r', 'utf-8') as fh: - for line in fh: - if tochar: - array = characterize(line) - else: - array = line.strip().split() - if len(array) == 0: continue - fid = array[0] - rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, - split) - - # compute error rate on the interaction of reference file and hyp file - for line in open(ref_file, 'r', encoding='utf-8'): - if tochar: - array = characterize(line) - else: - array = line.rstrip('\n').split() - if len(array) == 0: continue - fid = array[0] - if fid not in rec_set: - continue - lab = normalize(array[1:], ignore_words, case_sensitive, split) - rec = rec_set[fid] - if verbose: - print('\nutt: %s' % fid) - - for word in rec + lab: - if word not in default_words: - default_cluster_name = default_cluster(word) - if default_cluster_name not in default_clusters: - default_clusters[default_cluster_name] = {} - if word not in default_clusters[default_cluster_name]: - default_clusters[default_cluster_name][word] = 1 - default_words[word] = default_cluster_name - - result = calculator.calculate(lab, rec) - if verbose: - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('WER: %4.2f %%' % wer, end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], - result['ins'])) - space = {} - space['lab'] = [] - space['rec'] = [] - for idx in range(len(result['lab'])): - len_lab = width(result['lab'][idx]) - len_rec = width(result['rec'][idx]) - length = max(len_lab, len_rec) - space['lab'].append(length - len_lab) - space['rec'].append(length - len_rec) - upper_lab = len(result['lab']) - upper_rec = len(result['rec']) - lab1, rec1 = 0, 0 - while lab1 < upper_lab or rec1 < upper_rec: - if verbose > 1: - print('lab(%s):' % fid.encode('utf-8'), end=' ') - else: - print('lab:', end=' ') - lab2 = min(upper_lab, lab1 + max_words_per_line) - for idx in range(lab1, lab2): - token = result['lab'][idx] - print('{token}'.format(token=token), end='') - for n in range(space['lab'][idx]): - print(padding_symbol, end='') - print(' ', end='') - print() - if verbose > 1: - print('rec(%s):' % fid.encode('utf-8'), end=' ') - else: - print('rec:', end=' ') - rec2 = min(upper_rec, rec1 + max_words_per_line) - for idx in range(rec1, rec2): - token = result['rec'][idx] - print('{token}'.format(token=token), end='') - for n in range(space['rec'][idx]): - print(padding_symbol, end='') - print(' ', end='') - print('\n', end='\n') - lab1 = lab2 - rec1 = rec2 - - if verbose: - print( - '===========================================================================' - ) - print() - - result = calculator.overall() - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('Overall -> %4.2f %%' % wer, end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], - result['ins'])) - if not verbose: - print() - - if verbose: - for cluster_id in default_clusters: - result = calculator.cluster( - [k for k in default_clusters[cluster_id]]) - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], - result['ins'])) - if len(cluster_file) > 0: # compute separated WERs for word clusters - cluster_id = '' - cluster = [] - for line in open(cluster_file, 'r', encoding='utf-8'): - for token in line.decode('utf-8').rstrip('\n').split(): - # end of cluster reached, like - if token[0:2] == '' and \ - token.lstrip('') == cluster_id : - result = calculator.cluster(cluster) - if result['all'] != 0: - wer = float(result['ins'] + result['sub'] + result[ - 'del']) * 100.0 / result['all'] - else: - wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], - result['del'], result['ins'])) - cluster_id = '' - cluster = [] - # begin of cluster reached, like - elif token[0] == '<' and token[len(token)-1] == '>' and \ - cluster_id == '' : - cluster_id = token.lstrip('<').rstrip('>') - cluster = [] - # general terms, like WEATHER / CAR / ... - else: - cluster.append(token) - print() - print( - '===========================================================================' - ) + compute_wer_main() diff --git a/utils/compute_mean_std.py b/utils/compute_mean_std.py index e47554dc..6e3fc0db 100755 --- a/utils/compute_mean_std.py +++ b/utils/compute_mean_std.py @@ -13,75 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Compute mean and std for feature normalizer, and save to file.""" -import argparse -import functools - -from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline -from paddlespeech.s2t.frontend.featurizer.audio_featurizer import AudioFeaturizer -from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments - -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('num_samples', int, 2000, "# of samples to for statistics.") - -add_arg('spectrum_type', str, - 'linear', - "Audio feature type. Options: linear, mfcc, fbank.", - choices=['linear', 'mfcc', 'fbank']) -add_arg('feat_dim', int, 13, "Audio feature dim.") -add_arg('delta_delta', bool, False, "Audio feature with delta delta.") -add_arg('stride_ms', int, 10, "stride length in ms.") -add_arg('window_ms', int, 20, "stride length in ms.") -add_arg('sample_rate', int, 16000, "target sample rate.") -add_arg('use_dB_normalization', bool, True, "do dB normalization.") -add_arg('target_dB', int, -20, "target dB.") - -add_arg('manifest_path', str, - 'data/librispeech/manifest.train', - "Filepath of manifest to compute normalizer's mean and stddev.") -add_arg('num_workers', - default=0, - type=int, - help='num of subprocess workers for processing') -add_arg('output_path', str, - 'data/librispeech/mean_std.npz', - "Filepath of write mean and stddev to (.npz).") -# yapf: disable -args = parser.parse_args() - - -def main(): - print_arguments(args, globals()) - - augmentation_pipeline = AugmentationPipeline('{}') - audio_featurizer = AudioFeaturizer( - spectrum_type=args.spectrum_type, - feat_dim=args.feat_dim, - delta_delta=args.delta_delta, - stride_ms=float(args.stride_ms), - window_ms=float(args.window_ms), - n_fft=None, - max_freq=None, - target_sample_rate=args.sample_rate, - use_dB_normalization=args.use_dB_normalization, - target_dB=args.target_dB, - dither=0.0) - - def augment_and_featurize(audio_segment): - augmentation_pipeline.transform_audio(audio_segment) - return audio_featurizer.featurize(audio_segment) - - normalizer = FeatureNormalizer( - mean_std_filepath=None, - manifest_path=args.manifest_path, - featurize_func=augment_and_featurize, - num_samples=args.num_samples, - num_workers=args.num_workers) - normalizer.write_to_file(args.output_path) - +from paddlespeech.dataset.s2t import compute_mean_std_main if __name__ == '__main__': - main() + compute_mean_std_main() diff --git a/utils/format_data.py b/utils/format_data.py index 6db2a1bb..574cb735 100755 --- a/utils/format_data.py +++ b/utils/format_data.py @@ -13,130 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """format manifest with more metadata.""" -import argparse -import functools -import json - -import jsonlines - -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.frontend.utility import load_cmvn -from paddlespeech.s2t.io.utility import feat_type -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments - -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('cmvn_path', str, - 'examples/librispeech/data/mean_std.json', - "Filepath of cmvn.") -add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") -add_arg('vocab_path', str, - 'examples/librispeech/data/vocab.txt', - "Filepath of the vocabulary.") -add_arg('manifest_paths', str, - None, - "Filepaths of manifests for building vocabulary. " - "You can provide multiple manifest files.", - nargs='+', - required=True) -# bpe -add_arg('spm_model_prefix', str, None, - "spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm") -add_arg('output_path', str, None, "filepath of formated manifest.", required=True) -# yapf: disable -args = parser.parse_args() - - -def main(): - print_arguments(args, globals()) - fout = open(args.output_path, 'w', encoding='utf-8') - - # get feat dim - filetype = args.cmvn_path.split(".")[-1] - mean, istd = load_cmvn(args.cmvn_path, filetype=filetype) - feat_dim = mean.shape[0] #(D) - print(f"Feature dim: {feat_dim}") - - text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) - vocab_size = text_feature.vocab_size - print(f"Vocab size: {vocab_size}") - - # josnline like this - # { - # "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}], - # "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}], - # "utt2spk": "111-2222", - # "utt": "111-2222-333" - # } - count = 0 - for manifest_path in args.manifest_paths: - with jsonlines.open(str(manifest_path), 'r') as reader: - manifest_jsons = list(reader) - - for line_json in manifest_jsons: - output_json = { - "input": [], - "output": [], - 'utt': line_json['utt'], - 'utt2spk': line_json.get('utt2spk', 'global'), - } - - # output - line = line_json['text'] - if isinstance(line, str): - # only one target - tokens = text_feature.tokenize(line) - tokenids = text_feature.featurize(line) - output_json['output'].append({ - 'name': 'target1', - 'shape': (len(tokenids), vocab_size), - 'text': line, - 'token': ' '.join(tokens), - 'tokenid': ' '.join(map(str, tokenids)), - }) - else: - # isinstance(line, list), multi target in one vocab - for i, item in enumerate(line, 1): - tokens = text_feature.tokenize(item) - tokenids = text_feature.featurize(item) - output_json['output'].append({ - 'name': f'target{i}', - 'shape': (len(tokenids), vocab_size), - 'text': item, - 'token': ' '.join(tokens), - 'tokenid': ' '.join(map(str, tokenids)), - }) - - # input - line = line_json['feat'] - if isinstance(line, str): - # only one input - feat_shape = line_json['feat_shape'] - assert isinstance(feat_shape, (list, tuple)), type(feat_shape) - filetype = feat_type(line) - if filetype == 'sound': - feat_shape.append(feat_dim) - else: # kaldi - raise NotImplementedError('no support kaldi feat now!') - - output_json['input'].append({ - "name": "input1", - "shape": feat_shape, - "feat": line, - "filetype": filetype, - }) - else: - # isinstance(line, list), multi input - raise NotImplementedError("not support multi input now!") - - fout.write(json.dumps(output_json) + '\n') - count += 1 - - print(f"{args.manifest_paths} Examples number: {count}") - fout.close() - +from paddlespeech.dataset.s2t import format_data_main if __name__ == '__main__': - main() + format_data_main() diff --git a/utils/format_rsl.py b/utils/format_rsl.py index 8230416c..a6845a67 100644 --- a/utils/format_rsl.py +++ b/utils/format_rsl.py @@ -11,96 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import argparse +from paddlespeech.dataset.s2t import format_rsl_main -import jsonlines - - -def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None): - """ - Args: - origin_hyp: The input json file which contains the model output - trans_hyp: The output file for caculate CER/WER - trans_hyp_sclite: The output file for caculate CER/WER using sclite - """ - input_dict = {} - - with open(origin_hyp, "r+", encoding="utf8") as f: - for item in jsonlines.Reader(f): - input_dict[item["utt"]] = item["hyps"][0] - if trans_hyp is not None: - with open(trans_hyp, "w+", encoding="utf8") as f: - for key in input_dict.keys(): - f.write(key + " " + input_dict[key] + "\n") - if trans_hyp_sclite is not None: - with open(trans_hyp_sclite, "w+") as f: - for key in input_dict.keys(): - line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" - f.write(line) - - -def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None): - """ - Args: - origin_hyp: The input json file which contains the model output - trans_hyp: The output file for caculate CER/WER - trans_hyp_sclite: The output file for caculate CER/WER using sclite - """ - input_dict = {} - - with open(origin_ref, "r", encoding="utf8") as f: - for item in jsonlines.Reader(f): - input_dict[item["utt"]] = item["text"] - if trans_ref is not None: - with open(trans_ref, "w", encoding="utf8") as f: - for key in input_dict.keys(): - f.write(key + " " + input_dict[key] + "\n") - - if trans_ref_sclite is not None: - with open(trans_ref_sclite, "w") as f: - for key in input_dict.keys(): - line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" - f.write(line) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog='format hyp file for compute CER/WER', add_help=True) - parser.add_argument( - '--origin_hyp', type=str, default=None, help='origin hyp file') - parser.add_argument( - '--trans_hyp', - type=str, - default=None, - help='hyp file for caculating CER/WER') - parser.add_argument( - '--trans_hyp_sclite', - type=str, - default=None, - help='hyp file for caculating CER/WER by sclite') - - parser.add_argument( - '--origin_ref', type=str, default=None, help='origin ref file') - parser.add_argument( - '--trans_ref', - type=str, - default=None, - help='ref file for caculating CER/WER') - parser.add_argument( - '--trans_ref_sclite', - type=str, - default=None, - help='ref file for caculating CER/WER by sclite') - parser_args = parser.parse_args() - - if parser_args.origin_hyp is not None: - trans_hyp( - origin_hyp=parser_args.origin_hyp, - trans_hyp=parser_args.trans_hyp, - trans_hyp_sclite=parser_args.trans_hyp_sclite, ) - - if parser_args.origin_ref is not None: - trans_ref( - origin_ref=parser_args.origin_ref, - trans_ref=parser_args.trans_ref, - trans_ref_sclite=parser_args.trans_ref_sclite, ) +if __name__ == '__main__': + format_rsl_main() diff --git a/utils/format_triplet_data.py b/utils/format_triplet_data.py index 44ff4527..e9a0cf54 100755 --- a/utils/format_triplet_data.py +++ b/utils/format_triplet_data.py @@ -22,8 +22,8 @@ import jsonlines from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.io.utility import feat_type -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser)