Merge pull request #191 from pkuyym/build_vocab

Add vocabulary dictionary building script
pull/2/head
Yang yaming 7 years ago committed by GitHub
commit 11afffc026

@ -40,13 +40,13 @@ python datasets/librispeech/librispeech.py --help
### Preparing for Training
```
python compute_mean_std.py
python tools/compute_mean_std.py
```
It will compute mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, and the mfcc feature is also supported. To train and infer based on mfcc feature, please generate this file by
```
python compute_mean_std.py --specgram_type mfcc
python tools/compute_mean_std.py --specgram_type mfcc
```
and specify ```--specgram_type mfcc``` when running train.py, infer.py, evaluator.py or tune.py.
@ -54,7 +54,7 @@ and specify ```--specgram_type mfcc``` when running train.py, infer.py, evaluato
More help for arguments:
```
python compute_mean_std.py --help
python tools/compute_mean_std.py --help
```
### Training

@ -0,0 +1,19 @@
"""Set up paths for DS2"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import sys
def add_path(path):
if path not in sys.path:
sys.path.insert(0, path)
this_dir = os.path.dirname(__file__)
# Add project path to PYTHONPATH
proj_path = os.path.join(this_dir, '..')
add_path(proj_path)

@ -0,0 +1,59 @@
"""Build vocabulary from manifest files.
Each item in vocabulary file is a character.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import codecs
import json
from collections import Counter
import os.path
import _init_paths
from data_utils import utils
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--manifest_paths",
type=str,
help="Manifest paths for building vocabulary."
"You can provide multiple manifest files.",
nargs='+',
required=True)
parser.add_argument(
"--count_threshold",
default=0,
type=int,
help="Characters whose counts are below the threshold will be truncated. "
"(default: %(default)i)")
parser.add_argument(
"--vocab_path",
default='datasets/vocab/zh_vocab.txt',
type=str,
help="File path to write the vocabulary. (default: %(default)s)")
args = parser.parse_args()
def count_manifest(counter, manifest_path):
manifest_jsons = utils.read_manifest(manifest_path)
for line_json in manifest_jsons:
for char in line_json['text']:
counter.update(char)
def main():
counter = Counter()
for manifest_path in args.manifest_paths:
count_manifest(counter, manifest_path)
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
with codecs.open(args.vocab_path, 'w', 'utf-8') as fout:
for char, count in count_sorted:
if count < args.count_threshold: break
fout.write(char + '\n')
if __name__ == '__main__':
main()

@ -4,6 +4,7 @@ from __future__ import division
from __future__ import print_function
import argparse
import _init_paths
from data_utils.normalizer import FeatureNormalizer
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
Loading…
Cancel
Save