From bac9e0b153ea94f19156345a6bc73153a2e5f7b3 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 21 Apr 2023 07:18:38 +0000 Subject: [PATCH] move s2t data preprocess into paddlespeech.dataset --- examples/aishell/asr1/path.sh | 2 + .../aidatatang_200zh/aidatatang_200zh.py | 3 +- paddlespeech/dataset/aishell/aishell.py | 3 +- paddlespeech/dataset/s2t/__init__.py | 17 ++ paddlespeech/dataset/s2t/build_vocab.py | 166 ++++++++++++++++++ paddlespeech/dataset/s2t/compute_mean_std.py | 106 +++++++++++ paddlespeech/dataset/s2t/format_data.py | 155 ++++++++++++++++ .../exps/deepspeech2/bin/deploy/runtime.py | 4 +- .../s2t/exps/deepspeech2/bin/deploy/server.py | 4 +- .../frontend/featurizer/text_featurizer.py | 17 +- paddlespeech/s2t/utils/utility.py | 50 +----- paddlespeech/utils/argparse.py | 2 + utils/build_vocab.py | 131 +------------- utils/compute_mean_std.py | 72 +------- utils/format_data.py | 127 +------------- utils/format_triplet_data.py | 4 +- 16 files changed, 477 insertions(+), 386 deletions(-) create mode 100644 paddlespeech/dataset/s2t/__init__.py create mode 100755 paddlespeech/dataset/s2t/build_vocab.py create mode 100755 paddlespeech/dataset/s2t/compute_mean_std.py create mode 100755 paddlespeech/dataset/s2t/format_data.py diff --git a/examples/aishell/asr1/path.sh b/examples/aishell/asr1/path.sh index 449829109..c6eed668e 100644 --- a/examples/aishell/asr1/path.sh +++ b/examples/aishell/asr1/path.sh @@ -27,3 +27,5 @@ export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!" [ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh + +unset GREP_OPTIONS diff --git a/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py b/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py index ba1785672..5d914a438 100644 --- a/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py +++ b/paddlespeech/dataset/aidatatang_200zh/aidatatang_200zh.py @@ -28,6 +28,7 @@ import soundfile from paddlespeech.dataset.download import download from paddlespeech.dataset.download import unpack +from paddlespeech.utils.argparse import print_arguments DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') @@ -139,7 +140,7 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path, subset): def main(): - print(f"args: {args}") + print_arguments(args, globals()) if args.target_dir.startswith('~'): args.target_dir = os.path.expanduser(args.target_dir) diff --git a/paddlespeech/dataset/aishell/aishell.py b/paddlespeech/dataset/aishell/aishell.py index fa90aa67b..7ea4d6766 100644 --- a/paddlespeech/dataset/aishell/aishell.py +++ b/paddlespeech/dataset/aishell/aishell.py @@ -28,6 +28,7 @@ import soundfile from paddlespeech.dataset.download import download from paddlespeech.dataset.download import unpack +from paddlespeech.utils.argparse import print_arguments DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') @@ -205,7 +206,7 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path=None, check=False): def main(): - print(f"args: {args}") + print_arguments(args, globals()) if args.target_dir.startswith('~'): args.target_dir = os.path.expanduser(args.target_dir) diff --git a/paddlespeech/dataset/s2t/__init__.py b/paddlespeech/dataset/s2t/__init__.py new file mode 100644 index 000000000..3f546855e --- /dev/null +++ b/paddlespeech/dataset/s2t/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# s2t utils binaries. +from .build_vocab import main as build_vocab_main +from .compute_mean_std import main as compute_mean_std_main +from .format_data import main as format_data_main diff --git a/paddlespeech/dataset/s2t/build_vocab.py b/paddlespeech/dataset/s2t/build_vocab.py new file mode 100755 index 000000000..dd5f62081 --- /dev/null +++ b/paddlespeech/dataset/s2t/build_vocab.py @@ -0,0 +1,166 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Build vocabulary from manifest files. +Each item in vocabulary file is a character. +""" +import argparse +import functools +import os +import tempfile +from collections import Counter + +import jsonlines + +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.frontend.utility import BLANK +from paddlespeech.s2t.frontend.utility import SOS +from paddlespeech.s2t.frontend.utility import SPACE +from paddlespeech.s2t.frontend.utility import UNK +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments + + +def count_manifest(counter, text_feature, manifest_path): + manifest_jsons = [] + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + manifest_jsons.append(json_data) + + for line_json in manifest_jsons: + if isinstance(line_json['text'], str): + tokens = text_feature.tokenize( + line_json['text'], replace_space=False) + + counter.update(tokens) + else: + assert isinstance(line_json['text'], list) + for text in line_json['text']: + tokens = text_feature.tokenize(text, replace_space=False) + counter.update(tokens) + + +def dump_text_manifest(fileobj, manifest_path, key='text'): + manifest_jsons = [] + with jsonlines.open(manifest_path, 'r') as reader: + for json_data in reader: + manifest_jsons.append(json_data) + + for line_json in manifest_jsons: + if isinstance(line_json[key], str): + fileobj.write(line_json[key] + "\n") + else: + assert isinstance(line_json[key], list) + for line in line_json[key]: + fileobj.write(line + "\n") + + +def build_vocab(manifest_paths="", + vocab_path="examples/librispeech/data/vocab.txt", + unit_type="char", + count_threshold=0, + text_keys='text', + spm_mode="unigram", + spm_vocab_size=0, + spm_model_prefix="", + spm_character_coverage=0.9995): + fout = open(vocab_path, 'w', encoding='utf-8') + fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC + fout.write(UNK + '\n') # must be 1 + + if unit_type == 'spm': + # tools/spm_train --input=$wave_data/lang_char/input.txt + # --vocab_size=${nbpe} --model_type=${bpemode} + # --model_prefix=${bpemodel} --input_sentence_size=100000000 + import sentencepiece as spm + + fp = tempfile.NamedTemporaryFile(mode='w', delete=False) + for manifest_path in manifest_paths: + _text_keys = [text_keys] if type( + text_keys) is not list else text_keys + for text_key in _text_keys: + dump_text_manifest(fp, manifest_path, key=text_key) + fp.close() + # train + spm.SentencePieceTrainer.Train( + input=fp.name, + vocab_size=spm_vocab_size, + model_type=spm_mode, + model_prefix=spm_model_prefix, + input_sentence_size=100000000, + character_coverage=spm_character_coverage) + os.unlink(fp.name) + + # encode + text_feature = TextFeaturizer(unit_type, "", spm_model_prefix) + counter = Counter() + + for manifest_path in manifest_paths: + count_manifest(counter, text_feature, manifest_path) + + count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) + tokens = [] + for token, count in count_sorted: + if count < count_threshold: + break + # replace space by `` + token = SPACE if token == ' ' else token + tokens.append(token) + + tokens = sorted(tokens) + for token in tokens: + fout.write(token + '\n') + + fout.write(SOS + "\n") # + fout.close() + + +def define_argparse(): + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + + # yapf: disable + add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") + add_arg('count_threshold', int, 0, + "Truncation threshold for char/word counts.Default 0, no truncate.") + add_arg('vocab_path', str, + 'examples/librispeech/data/vocab.txt', + "Filepath to write the vocabulary.") + add_arg('manifest_paths', str, + None, + "Filepaths of manifests for building vocabulary. " + "You can provide multiple manifest files.", + nargs='+', + required=True) + add_arg('text_keys', str, + 'text', + "keys of the text in manifest for building vocabulary. " + "You can provide multiple k.", + nargs='+') + # bpe + add_arg('spm_vocab_size', int, 0, "Vocab size for spm.") + add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") + add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm") + add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols") + # yapf: disable + + args = parser.parse_args() + return args + +def main(): + args = define_argparse() + print_arguments(args, globals()) + build_vocab(**vars(args)) + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/compute_mean_std.py b/paddlespeech/dataset/s2t/compute_mean_std.py new file mode 100755 index 000000000..8762ee57e --- /dev/null +++ b/paddlespeech/dataset/s2t/compute_mean_std.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compute mean and std for feature normalizer, and save to file.""" +import argparse +import functools + +from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline +from paddlespeech.s2t.frontend.featurizer.audio_featurizer import AudioFeaturizer +from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments + + +def compute_cmvn(manifest_path="data/librispeech/manifest.train", + output_path="data/librispeech/mean_std.npz", + num_samples=2000, + num_workers=0, + spectrum_type="linear", + feat_dim=13, + delta_delta=False, + stride_ms=10, + window_ms=20, + sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): + + augmentation_pipeline = AugmentationPipeline('{}') + audio_featurizer = AudioFeaturizer( + spectrum_type=spectrum_type, + feat_dim=feat_dim, + delta_delta=delta_delta, + stride_ms=float(stride_ms), + window_ms=float(window_ms), + n_fft=None, + max_freq=None, + target_sample_rate=sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB, + dither=0.0) + + def augment_and_featurize(audio_segment): + augmentation_pipeline.transform_audio(audio_segment) + return audio_featurizer.featurize(audio_segment) + + normalizer = FeatureNormalizer( + mean_std_filepath=None, + manifest_path=manifest_path, + featurize_func=augment_and_featurize, + num_samples=num_samples, + num_workers=num_workers) + normalizer.write_to_file(output_path) + + +def define_argparse(): + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + + # yapf: disable + add_arg('manifest_path', str, + 'data/librispeech/manifest.train', + "Filepath of manifest to compute normalizer's mean and stddev.") + + add_arg('output_path', str, + 'data/librispeech/mean_std.npz', + "Filepath of write mean and stddev to (.npz).") + add_arg('num_samples', int, 2000, "# of samples to for statistics.") + add_arg('num_workers', + default=0, + type=int, + help='num of subprocess workers for processing') + + + add_arg('spectrum_type', str, + 'linear', + "Audio feature type. Options: linear, mfcc, fbank.", + choices=['linear', 'mfcc', 'fbank']) + add_arg('feat_dim', int, 13, "Audio feature dim.") + add_arg('delta_delta', bool, False, "Audio feature with delta delta.") + add_arg('stride_ms', int, 10, "stride length in ms.") + add_arg('window_ms', int, 20, "stride length in ms.") + add_arg('sample_rate', int, 16000, "target sample rate.") + add_arg('use_dB_normalization', bool, True, "do dB normalization.") + add_arg('target_dB', int, -20, "target dB.") + # yapf: disable + + args = parser.parse_args() + return args + +def main(): + args = define_argparse() + print_arguments(args, globals()) + compute_cmvn(**vars(args)) + +if __name__ == '__main__': + main() diff --git a/paddlespeech/dataset/s2t/format_data.py b/paddlespeech/dataset/s2t/format_data.py new file mode 100755 index 000000000..fae717dc5 --- /dev/null +++ b/paddlespeech/dataset/s2t/format_data.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""format manifest with more metadata.""" +import argparse +import functools +import json + +import jsonlines + +from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer +from paddlespeech.s2t.frontend.utility import load_cmvn +from paddlespeech.s2t.io.utility import feat_type +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments + + +def define_argparse(): + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + # yapf: disable + add_arg('manifest_paths', str, + None, + "Filepaths of manifests for building vocabulary. " + "You can provide multiple manifest files.", + nargs='+', + required=True) + add_arg('output_path', str, None, "filepath of formated manifest.", required=True) + add_arg('cmvn_path', str, + 'examples/librispeech/data/mean_std.json', + "Filepath of cmvn.") + add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") + add_arg('vocab_path', str, + 'examples/librispeech/data/vocab.txt', + "Filepath of the vocabulary.") + # bpe + add_arg('spm_model_prefix', str, None, + "spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm") + + # yapf: disable + args = parser.parse_args() + return args + +def format_data( + manifest_paths="", + output_path="", + cmvn_path="examples/librispeech/data/mean_std.json", + unit_type="char", + vocab_path="examples/librispeech/data/vocab.txt", + spm_model_prefix=""): + + fout = open(output_path, 'w', encoding='utf-8') + + # get feat dim + filetype = cmvn_path.split(".")[-1] + mean, istd = load_cmvn(cmvn_path, filetype=filetype) + feat_dim = mean.shape[0] #(D) + print(f"Feature dim: {feat_dim}") + + text_feature = TextFeaturizer(unit_type, vocab_path, spm_model_prefix) + vocab_size = text_feature.vocab_size + print(f"Vocab size: {vocab_size}") + + # josnline like this + # { + # "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}], + # "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}], + # "utt2spk": "111-2222", + # "utt": "111-2222-333" + # } + count = 0 + for manifest_path in manifest_paths: + with jsonlines.open(str(manifest_path), 'r') as reader: + manifest_jsons = list(reader) + + for line_json in manifest_jsons: + output_json = { + "input": [], + "output": [], + 'utt': line_json['utt'], + 'utt2spk': line_json.get('utt2spk', 'global'), + } + + # output + line = line_json['text'] + if isinstance(line, str): + # only one target + tokens = text_feature.tokenize(line) + tokenids = text_feature.featurize(line) + output_json['output'].append({ + 'name': 'target1', + 'shape': (len(tokenids), vocab_size), + 'text': line, + 'token': ' '.join(tokens), + 'tokenid': ' '.join(map(str, tokenids)), + }) + else: + # isinstance(line, list), multi target in one vocab + for i, item in enumerate(line, 1): + tokens = text_feature.tokenize(item) + tokenids = text_feature.featurize(item) + output_json['output'].append({ + 'name': f'target{i}', + 'shape': (len(tokenids), vocab_size), + 'text': item, + 'token': ' '.join(tokens), + 'tokenid': ' '.join(map(str, tokenids)), + }) + + # input + line = line_json['feat'] + if isinstance(line, str): + # only one input + feat_shape = line_json['feat_shape'] + assert isinstance(feat_shape, (list, tuple)), type(feat_shape) + filetype = feat_type(line) + if filetype == 'sound': + feat_shape.append(feat_dim) + else: # kaldi + raise NotImplementedError('no support kaldi feat now!') + + output_json['input'].append({ + "name": "input1", + "shape": feat_shape, + "feat": line, + "filetype": filetype, + }) + else: + # isinstance(line, list), multi input + raise NotImplementedError("not support multi input now!") + + fout.write(json.dumps(output_json) + '\n') + count += 1 + + print(f"{manifest_paths} Examples number: {count}") + fout.close() + +def main(): + args = define_argparse() + print_arguments(args, globals()) + format_data(**vars(args)) + +if __name__ == '__main__': + main() diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py index 5755a5f10..f6b1ed096 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/runtime.py @@ -28,8 +28,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.socket_server import AsrRequestHandler from paddlespeech.s2t.utils.socket_server import AsrTCPServer from paddlespeech.s2t.utils.socket_server import warm_up_test -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments def init_predictor(args): diff --git a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py index 0d0b4f219..fc57399df 100644 --- a/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py +++ b/paddlespeech/s2t/exps/deepspeech2/bin/deploy/server.py @@ -26,8 +26,8 @@ from paddlespeech.s2t.training.cli import default_argument_parser from paddlespeech.s2t.utils.socket_server import AsrRequestHandler from paddlespeech.s2t.utils.socket_server import AsrTCPServer from paddlespeech.s2t.utils.socket_server import warm_up_test -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments def start_server(config, args): diff --git a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py index 982c6b8fe..7623d0b87 100644 --- a/paddlespeech/s2t/frontend/featurizer/text_featurizer.py +++ b/paddlespeech/s2t/frontend/featurizer/text_featurizer.py @@ -48,13 +48,16 @@ class TextFeaturizer(): self.unit_type = unit_type self.unk = UNK self.maskctc = maskctc + self.vocab_path_or_list = vocab - if vocab: + if self.vocab_path_or_list: self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file( vocab, maskctc) self.vocab_size = len(self.vocab_list) else: - logger.warning("TextFeaturizer: not have vocab file or vocab list.") + logger.warning( + "TextFeaturizer: not have vocab file or vocab list. Only Tokenizer can use, can not convert to token idx" + ) if unit_type == 'spm': spm_model = spm_model_prefix + '.model' @@ -62,6 +65,7 @@ class TextFeaturizer(): self.sp.Load(spm_model) def tokenize(self, text, replace_space=True): + """tokenizer split text into text tokens""" if self.unit_type == 'char': tokens = self.char_tokenize(text, replace_space) elif self.unit_type == 'word': @@ -71,6 +75,7 @@ class TextFeaturizer(): return tokens def detokenize(self, tokens): + """tokenizer convert text tokens back to text""" if self.unit_type == 'char': text = self.char_detokenize(tokens) elif self.unit_type == 'word': @@ -88,6 +93,7 @@ class TextFeaturizer(): Returns: List[int]: List of token indices. """ + assert self.vocab_path_or_list, "toidx need vocab path or vocab list" tokens = self.tokenize(text) ids = [] for token in tokens: @@ -107,6 +113,7 @@ class TextFeaturizer(): Returns: str: Text. """ + assert self.vocab_path_or_list, "toidx need vocab path or vocab list" tokens = [] for idx in idxs: if idx == self.eos_id: @@ -127,10 +134,10 @@ class TextFeaturizer(): """ text = text.strip() if replace_space: - text_list = [SPACE if item == " " else item for item in list(text)] + tokens = [SPACE if item == " " else item for item in list(text)] else: - text_list = list(text) - return text_list + tokens = list(text) + return tokens def char_detokenize(self, tokens): """Character detokenizer. diff --git a/paddlespeech/s2t/utils/utility.py b/paddlespeech/s2t/utils/utility.py index d7e7c6ca2..5655ec3fb 100644 --- a/paddlespeech/s2t/utils/utility.py +++ b/paddlespeech/s2t/utils/utility.py @@ -29,10 +29,7 @@ from paddlespeech.s2t.utils.log import Log logger = Log(__name__).getlog() -__all__ = [ - "all_version", "UpdateConfig", "seed_all", 'print_arguments', - 'add_arguments', "log_add" -] +__all__ = ["all_version", "UpdateConfig", "seed_all", "log_add"] def all_version(): @@ -60,51 +57,6 @@ def seed_all(seed: int=20210329): paddle.seed(seed) -def print_arguments(args, info=None): - """Print argparse's arguments. - - Usage: - - .. code-block:: python - - parser = argparse.ArgumentParser() - parser.add_argument("name", default="Jonh", type=str, help="User name.") - args = parser.parse_args() - print_arguments(args) - - :param args: Input argparse.Namespace for printing. - :type args: argparse.Namespace - """ - filename = "" - if info: - filename = info["__file__"] - filename = os.path.basename(filename) - print(f"----------- {filename} Arguments -----------") - for arg, value in sorted(vars(args).items()): - print("%s: %s" % (arg, value)) - print("-----------------------------------------------------------") - - -def add_arguments(argname, type, default, help, argparser, **kwargs): - """Add argparse's argument. - - Usage: - - .. code-block:: python - - parser = argparse.ArgumentParser() - add_argument("name", str, "Jonh", "User name.", parser) - args = parser.parse_args() - """ - type = distutils.util.strtobool if type == bool else type - argparser.add_argument( - "--" + argname, - default=default, - type=type, - help=help + ' Default: %(default)s.', - **kwargs) - - def log_add(args: List[int]) -> float: """Stable log add diff --git a/paddlespeech/utils/argparse.py b/paddlespeech/utils/argparse.py index 4df75c5ae..aad3801ea 100644 --- a/paddlespeech/utils/argparse.py +++ b/paddlespeech/utils/argparse.py @@ -16,6 +16,8 @@ import os import sys from typing import Text +import distutils + __all__ = ["print_arguments", "add_arguments", "get_commandline_args"] diff --git a/utils/build_vocab.py b/utils/build_vocab.py index e364e821e..9b29dfa54 100755 --- a/utils/build_vocab.py +++ b/utils/build_vocab.py @@ -15,134 +15,7 @@ """Build vocabulary from manifest files. Each item in vocabulary file is a character. """ -import argparse -import functools -import os -import tempfile -from collections import Counter - -import jsonlines - -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.frontend.utility import BLANK -from paddlespeech.s2t.frontend.utility import SOS -from paddlespeech.s2t.frontend.utility import SPACE -from paddlespeech.s2t.frontend.utility import UNK -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments - -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") -add_arg('count_threshold', int, 0, - "Truncation threshold for char/word counts.Default 0, no truncate.") -add_arg('vocab_path', str, - 'examples/librispeech/data/vocab.txt', - "Filepath to write the vocabulary.") -add_arg('manifest_paths', str, - None, - "Filepaths of manifests for building vocabulary. " - "You can provide multiple manifest files.", - nargs='+', - required=True) -add_arg('text_keys', str, - 'text', - "keys of the text in manifest for building vocabulary. " - "You can provide multiple k.", - nargs='+') -# bpe -add_arg('spm_vocab_size', int, 0, "Vocab size for spm.") -add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") -add_arg('spm_model_prefix', str, "", "spm_model_%(spm_mode)_%(count_threshold), spm model prefix, only need when `unit_type` is spm") -add_arg('spm_character_coverage', float, 0.9995, "character coverage to determine the minimum symbols") - -# yapf: disable -args = parser.parse_args() - - -def count_manifest(counter, text_feature, manifest_path): - manifest_jsons = [] - with jsonlines.open(manifest_path, 'r') as reader: - for json_data in reader: - manifest_jsons.append(json_data) - - for line_json in manifest_jsons: - if isinstance(line_json['text'], str): - line = text_feature.tokenize(line_json['text'], replace_space=False) - counter.update(line) - else: - assert isinstance(line_json['text'], list) - for text in line_json['text']: - line = text_feature.tokenize(text, replace_space=False) - counter.update(line) - -def dump_text_manifest(fileobj, manifest_path, key='text'): - manifest_jsons = [] - with jsonlines.open(manifest_path, 'r') as reader: - for json_data in reader: - manifest_jsons.append(json_data) - - for line_json in manifest_jsons: - if isinstance(line_json[key], str): - fileobj.write(line_json[key] + "\n") - else: - assert isinstance(line_json[key], list) - for line in line_json[key]: - fileobj.write(line + "\n") - -def main(): - print_arguments(args, globals()) - - fout = open(args.vocab_path, 'w', encoding='utf-8') - fout.write(BLANK + "\n") # 0 will be used for "blank" in CTC - fout.write(UNK + '\n') # must be 1 - - if args.unit_type == 'spm': - # tools/spm_train --input=$wave_data/lang_char/input.txt - # --vocab_size=${nbpe} --model_type=${bpemode} - # --model_prefix=${bpemodel} --input_sentence_size=100000000 - import sentencepiece as spm - - fp = tempfile.NamedTemporaryFile(mode='w', delete=False) - for manifest_path in args.manifest_paths: - text_keys = [args.text_keys] if type(args.text_keys) is not list else args.text_keys - for text_key in text_keys: - dump_text_manifest(fp, manifest_path, key=text_key) - fp.close() - # train - spm.SentencePieceTrainer.Train( - input=fp.name, - vocab_size=args.spm_vocab_size, - model_type=args.spm_mode, - model_prefix=args.spm_model_prefix, - input_sentence_size=100000000, - character_coverage=args.spm_character_coverage) - os.unlink(fp.name) - - # encode - text_feature = TextFeaturizer(args.unit_type, "", args.spm_model_prefix) - counter = Counter() - - for manifest_path in args.manifest_paths: - count_manifest(counter, text_feature, manifest_path) - - count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) - tokens = [] - for token, count in count_sorted: - if count < args.count_threshold: - break - # replace space by `` - token = SPACE if token == ' ' else token - tokens.append(token) - - tokens = sorted(tokens) - for token in tokens: - fout.write(token + '\n') - - fout.write(SOS + "\n") # - fout.close() - +from paddlespeech.dataset.s2t import build_vocab_main if __name__ == '__main__': - main() + build_vocab_main() diff --git a/utils/compute_mean_std.py b/utils/compute_mean_std.py index e47554dca..6e3fc0db2 100755 --- a/utils/compute_mean_std.py +++ b/utils/compute_mean_std.py @@ -13,75 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Compute mean and std for feature normalizer, and save to file.""" -import argparse -import functools - -from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline -from paddlespeech.s2t.frontend.featurizer.audio_featurizer import AudioFeaturizer -from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments - -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('num_samples', int, 2000, "# of samples to for statistics.") - -add_arg('spectrum_type', str, - 'linear', - "Audio feature type. Options: linear, mfcc, fbank.", - choices=['linear', 'mfcc', 'fbank']) -add_arg('feat_dim', int, 13, "Audio feature dim.") -add_arg('delta_delta', bool, False, "Audio feature with delta delta.") -add_arg('stride_ms', int, 10, "stride length in ms.") -add_arg('window_ms', int, 20, "stride length in ms.") -add_arg('sample_rate', int, 16000, "target sample rate.") -add_arg('use_dB_normalization', bool, True, "do dB normalization.") -add_arg('target_dB', int, -20, "target dB.") - -add_arg('manifest_path', str, - 'data/librispeech/manifest.train', - "Filepath of manifest to compute normalizer's mean and stddev.") -add_arg('num_workers', - default=0, - type=int, - help='num of subprocess workers for processing') -add_arg('output_path', str, - 'data/librispeech/mean_std.npz', - "Filepath of write mean and stddev to (.npz).") -# yapf: disable -args = parser.parse_args() - - -def main(): - print_arguments(args, globals()) - - augmentation_pipeline = AugmentationPipeline('{}') - audio_featurizer = AudioFeaturizer( - spectrum_type=args.spectrum_type, - feat_dim=args.feat_dim, - delta_delta=args.delta_delta, - stride_ms=float(args.stride_ms), - window_ms=float(args.window_ms), - n_fft=None, - max_freq=None, - target_sample_rate=args.sample_rate, - use_dB_normalization=args.use_dB_normalization, - target_dB=args.target_dB, - dither=0.0) - - def augment_and_featurize(audio_segment): - augmentation_pipeline.transform_audio(audio_segment) - return audio_featurizer.featurize(audio_segment) - - normalizer = FeatureNormalizer( - mean_std_filepath=None, - manifest_path=args.manifest_path, - featurize_func=augment_and_featurize, - num_samples=args.num_samples, - num_workers=args.num_workers) - normalizer.write_to_file(args.output_path) - +from paddlespeech.dataset.s2t import compute_mean_std_main if __name__ == '__main__': - main() + compute_mean_std_main() diff --git a/utils/format_data.py b/utils/format_data.py index 6db2a1bbb..574cb735b 100755 --- a/utils/format_data.py +++ b/utils/format_data.py @@ -13,130 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """format manifest with more metadata.""" -import argparse -import functools -import json - -import jsonlines - -from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer -from paddlespeech.s2t.frontend.utility import load_cmvn -from paddlespeech.s2t.io.utility import feat_type -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments - -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('cmvn_path', str, - 'examples/librispeech/data/mean_std.json', - "Filepath of cmvn.") -add_arg('unit_type', str, "char", "Unit type, e.g. char, word, spm") -add_arg('vocab_path', str, - 'examples/librispeech/data/vocab.txt', - "Filepath of the vocabulary.") -add_arg('manifest_paths', str, - None, - "Filepaths of manifests for building vocabulary. " - "You can provide multiple manifest files.", - nargs='+', - required=True) -# bpe -add_arg('spm_model_prefix', str, None, - "spm model prefix, spm_model_%(bpe_mode)_%(count_threshold), only need when `unit_type` is spm") -add_arg('output_path', str, None, "filepath of formated manifest.", required=True) -# yapf: disable -args = parser.parse_args() - - -def main(): - print_arguments(args, globals()) - fout = open(args.output_path, 'w', encoding='utf-8') - - # get feat dim - filetype = args.cmvn_path.split(".")[-1] - mean, istd = load_cmvn(args.cmvn_path, filetype=filetype) - feat_dim = mean.shape[0] #(D) - print(f"Feature dim: {feat_dim}") - - text_feature = TextFeaturizer(args.unit_type, args.vocab_path, args.spm_model_prefix) - vocab_size = text_feature.vocab_size - print(f"Vocab size: {vocab_size}") - - # josnline like this - # { - # "input": [{"name": "input1", "shape": (100, 83), "feat": "xxx.ark:123"}], - # "output": [{"name":"target1", "shape": (40, 5002), "text": "a b c de"}], - # "utt2spk": "111-2222", - # "utt": "111-2222-333" - # } - count = 0 - for manifest_path in args.manifest_paths: - with jsonlines.open(str(manifest_path), 'r') as reader: - manifest_jsons = list(reader) - - for line_json in manifest_jsons: - output_json = { - "input": [], - "output": [], - 'utt': line_json['utt'], - 'utt2spk': line_json.get('utt2spk', 'global'), - } - - # output - line = line_json['text'] - if isinstance(line, str): - # only one target - tokens = text_feature.tokenize(line) - tokenids = text_feature.featurize(line) - output_json['output'].append({ - 'name': 'target1', - 'shape': (len(tokenids), vocab_size), - 'text': line, - 'token': ' '.join(tokens), - 'tokenid': ' '.join(map(str, tokenids)), - }) - else: - # isinstance(line, list), multi target in one vocab - for i, item in enumerate(line, 1): - tokens = text_feature.tokenize(item) - tokenids = text_feature.featurize(item) - output_json['output'].append({ - 'name': f'target{i}', - 'shape': (len(tokenids), vocab_size), - 'text': item, - 'token': ' '.join(tokens), - 'tokenid': ' '.join(map(str, tokenids)), - }) - - # input - line = line_json['feat'] - if isinstance(line, str): - # only one input - feat_shape = line_json['feat_shape'] - assert isinstance(feat_shape, (list, tuple)), type(feat_shape) - filetype = feat_type(line) - if filetype == 'sound': - feat_shape.append(feat_dim) - else: # kaldi - raise NotImplementedError('no support kaldi feat now!') - - output_json['input'].append({ - "name": "input1", - "shape": feat_shape, - "feat": line, - "filetype": filetype, - }) - else: - # isinstance(line, list), multi input - raise NotImplementedError("not support multi input now!") - - fout.write(json.dumps(output_json) + '\n') - count += 1 - - print(f"{args.manifest_paths} Examples number: {count}") - fout.close() - +from paddlespeech.dataset.s2t import format_data_main if __name__ == '__main__': - main() + format_data_main() diff --git a/utils/format_triplet_data.py b/utils/format_triplet_data.py index 44ff4527c..e9a0cf54c 100755 --- a/utils/format_triplet_data.py +++ b/utils/format_triplet_data.py @@ -22,8 +22,8 @@ import jsonlines from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.utility import load_cmvn from paddlespeech.s2t.io.utility import feat_type -from paddlespeech.s2t.utils.utility import add_arguments -from paddlespeech.s2t.utils.utility import print_arguments +from paddlespeech.utils.argparse import add_arguments +from paddlespeech.utils.argparse import print_arguments parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser)