From 1325cd9b8ed0d2d12042cdd0aaad9a7087ded162 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 9 Aug 2017 16:21:44 +0800 Subject: [PATCH 1/2] Create 'tools' to hold tool scripts and add vocabulary dictionary building script. --- README.md | 6 +- tools/_init_paths.py | 16 +++++ tools/build_vocab.py | 63 +++++++++++++++++++ .../compute_mean_std.py | 1 + 4 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 tools/_init_paths.py create mode 100644 tools/build_vocab.py rename compute_mean_std.py => tools/compute_mean_std.py (99%) diff --git a/README.md b/README.md index 96fbb7d0..9d39903b 100644 --- a/README.md +++ b/README.md @@ -40,13 +40,13 @@ python datasets/librispeech/librispeech.py --help ### Preparing for Training ``` -python compute_mean_std.py +python tools/compute_mean_std.py ``` It will compute mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, and the mfcc feature is also supported. To train and infer based on mfcc feature, please generate this file by ``` -python compute_mean_std.py --specgram_type mfcc +python tools/compute_mean_std.py --specgram_type mfcc ``` and specify ```--specgram_type mfcc``` when running train.py, infer.py, evaluator.py or tune.py. @@ -54,7 +54,7 @@ and specify ```--specgram_type mfcc``` when running train.py, infer.py, evaluato More help for arguments: ``` -python compute_mean_std.py --help +python tools/compute_mean_std.py --help ``` ### Training diff --git a/tools/_init_paths.py b/tools/_init_paths.py new file mode 100644 index 00000000..3bb2fd19 --- /dev/null +++ b/tools/_init_paths.py @@ -0,0 +1,16 @@ +"""Set up paths for DS2""" + +import os.path +import sys + + +def add_path(path): + if path not in sys.path: + sys.path.insert(0, path) + + +this_dir = os.path.dirname(__file__) + +# Add project path to PYTHONPATH +proj_path = os.path.join(this_dir, '..') +add_path(proj_path) diff --git a/tools/build_vocab.py b/tools/build_vocab.py new file mode 100644 index 00000000..59be4031 --- /dev/null +++ b/tools/build_vocab.py @@ -0,0 +1,63 @@ +"""Build vocabulary dictionary from manifest files. + +Each item in vocabulary file is a character. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import codecs +import json +from collections import Counter +import os.path + +parser = argparse.ArgumentParser( + description='Build vocabulary dictionary from transcription texts.') +parser.add_argument( + "--manifest_paths", + type=str, + help="Manifest paths for building vocabulary dictionary." + "You can provide multiple manifest files.", + nargs='+', + required=True) +parser.add_argument( + "--count_threshold", + default=0, + type=int, + help="Characters whose count below the threshold will be truncated. " + "(default: %(default)s)") +parser.add_argument( + "--vocab_path", + default='datasets/vocab/zh_vocab.txt', + type=str, + help="Filepath to write vocabularies. (default: %(default)s)") +args = parser.parse_args() + + +def count_manifest(counter, manifest_path): + for json_line in codecs.open(manifest_path, 'r', 'utf-8'): + try: + json_data = json.loads(json_line) + except Exception as e: + raise Exception('Error parsing manifest: %s, %s' % \ + (manifest_path, e)) + text = json_data['text'] + for char in text: + counter.update(char) + + +def main(): + counter = Counter() + for manifest_path in args.manifest_paths: + count_manifest(counter, manifest_path) + + count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) + with codecs.open(args.vocab_path, 'w', 'utf-8') as fout: + for item_pair in count_sorted: + if item_pair[1] < args.count_threshold: break + fout.write(item_pair[0] + '\n') + + +if __name__ == '__main__': + main() diff --git a/compute_mean_std.py b/tools/compute_mean_std.py similarity index 99% rename from compute_mean_std.py rename to tools/compute_mean_std.py index 0cc84e73..da49eb4c 100644 --- a/compute_mean_std.py +++ b/tools/compute_mean_std.py @@ -4,6 +4,7 @@ from __future__ import division from __future__ import print_function import argparse +import _init_paths from data_utils.normalizer import FeatureNormalizer from data_utils.augmentor.augmentation import AugmentationPipeline from data_utils.featurizer.audio_featurizer import AudioFeaturizer From c2e6378a64b1526076e4fb99aa6f9228d25891c8 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 9 Aug 2017 23:03:30 +0800 Subject: [PATCH 2/2] Simplify codes and comments. --- tools/_init_paths.py | 3 +++ tools/build_vocab.py | 32 ++++++++++++++------------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/tools/_init_paths.py b/tools/_init_paths.py index 3bb2fd19..ddabb535 100644 --- a/tools/_init_paths.py +++ b/tools/_init_paths.py @@ -1,4 +1,7 @@ """Set up paths for DS2""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import os.path import sys diff --git a/tools/build_vocab.py b/tools/build_vocab.py index 59be4031..618f2498 100644 --- a/tools/build_vocab.py +++ b/tools/build_vocab.py @@ -1,4 +1,4 @@ -"""Build vocabulary dictionary from manifest files. +"""Build vocabulary from manifest files. Each item in vocabulary file is a character. """ @@ -11,13 +11,14 @@ import codecs import json from collections import Counter import os.path +import _init_paths +from data_utils import utils -parser = argparse.ArgumentParser( - description='Build vocabulary dictionary from transcription texts.') +parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--manifest_paths", type=str, - help="Manifest paths for building vocabulary dictionary." + help="Manifest paths for building vocabulary." "You can provide multiple manifest files.", nargs='+', required=True) @@ -25,25 +26,20 @@ parser.add_argument( "--count_threshold", default=0, type=int, - help="Characters whose count below the threshold will be truncated. " - "(default: %(default)s)") + help="Characters whose counts are below the threshold will be truncated. " + "(default: %(default)i)") parser.add_argument( "--vocab_path", default='datasets/vocab/zh_vocab.txt', type=str, - help="Filepath to write vocabularies. (default: %(default)s)") + help="File path to write the vocabulary. (default: %(default)s)") args = parser.parse_args() def count_manifest(counter, manifest_path): - for json_line in codecs.open(manifest_path, 'r', 'utf-8'): - try: - json_data = json.loads(json_line) - except Exception as e: - raise Exception('Error parsing manifest: %s, %s' % \ - (manifest_path, e)) - text = json_data['text'] - for char in text: + manifest_jsons = utils.read_manifest(manifest_path) + for line_json in manifest_jsons: + for char in line_json['text']: counter.update(char) @@ -54,9 +50,9 @@ def main(): count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) with codecs.open(args.vocab_path, 'w', 'utf-8') as fout: - for item_pair in count_sorted: - if item_pair[1] < args.count_threshold: break - fout.write(item_pair[0] + '\n') + for char, count in count_sorted: + if count < args.count_threshold: break + fout.write(char + '\n') if __name__ == '__main__':