Merge branch 'develop' of https://github.com/PaddlePaddle/models into fix-81

pull/2/head
yangyaming 8 years ago
commit 397b2fc288

@ -18,9 +18,14 @@ For some machines, we also need to install libsndfile1. Details to be added.
``` ```
cd data cd data
python librispeech.py python librispeech.py
cat manifest.libri.train-* > manifest.libri.train-all
cd .. cd ..
``` ```
After running librispeech.py, we have several "manifest" json files named with a prefix `manifest.libri.`. A manifest file summarizes a speech data set, with each line containing the meta data (i.e. audio filepath, transcription text, audio duration) of each audio file within the data set, in json format.
By `cat manifest.libri.train-* > manifest.libri.train-all`, we simply merge the three seperate sample sets of LibriSpeech (train-clean-100, train-clean-360, train-other-500) into one training set. This is a simple way for merging different data sets.
More help for arguments: More help for arguments:
``` ```
@ -32,13 +37,13 @@ python librispeech.py --help
For GPU Training: For GPU Training:
``` ```
CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 --train_manifest_path ./data/manifest.libri.train-all
``` ```
For CPU Training: For CPU Training:
``` ```
python train.py --trainer_count 8 --use_gpu False python train.py --trainer_count 8 --use_gpu False -- train_manifest_path ./data/manifest.libri.train-all
``` ```
More help for arguments: More help for arguments:

@ -8,6 +8,7 @@ import json
import random import random
import soundfile import soundfile
import numpy as np import numpy as np
import itertools
import os import os
RANDOM_SEED = 0 RANDOM_SEED = 0
@ -62,6 +63,7 @@ class DataGenerator(object):
self.__stride_ms__ = stride_ms self.__stride_ms__ = stride_ms
self.__window_ms__ = window_ms self.__window_ms__ = window_ms
self.__max_frequency__ = max_frequency self.__max_frequency__ = max_frequency
self.__epoc__ = 0
self.__random__ = random.Random(RANDOM_SEED) self.__random__ = random.Random(RANDOM_SEED)
# load vocabulary (dictionary) # load vocabulary (dictionary)
self.__vocab_dict__, self.__vocab_list__ = \ self.__vocab_dict__, self.__vocab_list__ = \
@ -245,10 +247,42 @@ class DataGenerator(object):
new_batch.append((padded_audio, text)) new_batch.append((padded_audio, text))
return new_batch return new_batch
def instance_reader_creator(self, def __batch_shuffle__(self, manifest, batch_size):
manifest_path, """
sort_by_duration=True, The instances have different lengths and they cannot be
shuffle=False): combined into a single matrix multiplication. It usually
sorts the training examples by length and combines only
similarly-sized instances into minibatches, pads with
silence when necessary so that all instances in a batch
have the same length. This batch shuffle fuction is used
to make similarly-sized instances into minibatches and
make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_size.
4. Shuffle the minibatches.
:param manifest: manifest file.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:return: batch shuffled mainifest.
:rtype: list
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self.__random__.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self.__random__.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest
def instance_reader_creator(self, manifest):
""" """
Instance reader creator for audio data. Creat a callable function to Instance reader creator for audio data. Creat a callable function to
produce instances of data. produce instances of data.
@ -256,32 +290,13 @@ class DataGenerator(object):
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text. tokenized and indexed transcription text.
:param manifest_path: Filepath of manifest for audio clip files. :param manifest: Filepath of manifest for audio clip files.
:type manifest_path: basestring :type manifest: basestring
:param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Data reader function. :return: Data reader function.
:rtype: callable :rtype: callable
""" """
if sort_by_duration and shuffle:
sort_by_duration = False
logger.warn("When shuffle set to true, "
"sort_by_duration is forced to set False.")
def reader(): def reader():
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest
if sort_by_duration:
manifest.sort(key=lambda x: x["duration"])
if shuffle:
self.__random__.shuffle(manifest)
# extract spectrogram feature # extract spectrogram feature
for instance in manifest: for instance in manifest:
spectrogram = self.__audio_featurize__( spectrogram = self.__audio_featurize__(
@ -296,8 +311,8 @@ class DataGenerator(object):
batch_size, batch_size,
padding_to=-1, padding_to=-1,
flatten=False, flatten=False,
sort_by_duration=True, sortagrad=False,
shuffle=False): batch_shuffle=False):
""" """
Batch data reader creator for audio data. Creat a callable function to Batch data reader creator for audio data. Creat a callable function to
produce batches of data. produce batches of data.
@ -317,20 +332,32 @@ class DataGenerator(object):
:param flatten: If set True, audio data will be flatten to be a 1-dim :param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False. ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool :type flatten: bool
:param sort_by_duration: Sort the audio clips by duration if set True :param sortagrad: Sort the audio clips by duration in the first epoc
(for SortaGrad). if set True.
:type sort_by_duration: bool :type sortagrad: bool
:param shuffle: Shuffle the audio clips if set True. :param batch_shuffle: Shuffle the audio clips if set True. It is
:type shuffle: bool not a thorough instance-wise shuffle, but a
specific batch-wise shuffle. For more details,
please see `__batch_shuffle__` function.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called. :return: Batch reader function, producing batches of data when called.
:rtype: callable :rtype: callable
""" """
def batch_reader(): def batch_reader():
instance_reader = self.instance_reader_creator( # read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path, manifest_path=manifest_path,
sort_by_duration=sort_by_duration, max_duration=self.__max_duration__,
shuffle=shuffle) min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest
if self.__epoc__ == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
elif batch_shuffle:
manifest = self.__batch_shuffle__(manifest, batch_size)
instance_reader = self.instance_reader_creator(manifest)
batch = [] batch = []
for instance in instance_reader(): for instance in instance_reader():
batch.append(instance) batch.append(instance)
@ -339,6 +366,7 @@ class DataGenerator(object):
batch = [] batch = []
if len(batch) > 0: if len(batch) > 0:
yield self.__padding_batch__(batch, padding_to, flatten) yield self.__padding_batch__(batch, padding_to, flatten)
self.__epoc__ += 1
return batch_reader return batch_reader

@ -1,13 +1,14 @@
""" """
Download, unpack and create manifest for Librespeech dataset. Download, unpack and create manifest json files for the Librespeech dataset.
Manifest is a json file with each line containing one audio clip filepath, A manifest is a json file summarizing filelist in a data set, with each line
its transcription text string, and its duration. It servers as a unified containing the meta data (i.e. audio filepath, transcription text, audio
interfance to organize different data sets. duration) of each audio file in the data set.
""" """
import paddle.v2 as paddle import paddle.v2 as paddle
from paddle.v2.dataset.common import md5file from paddle.v2.dataset.common import md5file
import distutils.util
import os import os
import wget import wget
import tarfile import tarfile
@ -27,7 +28,9 @@ URL_TRAIN_CLEAN_360 = URL_ROOT + "/train-clean-360.tar.gz"
URL_TRAIN_OTHER_500 = URL_ROOT + "/train-other-500.tar.gz" URL_TRAIN_OTHER_500 = URL_ROOT + "/train-other-500.tar.gz"
MD5_TEST_CLEAN = "32fa31d27d2e1cad72775fee3f4849a9" MD5_TEST_CLEAN = "32fa31d27d2e1cad72775fee3f4849a9"
MD5_TEST_OTHER = "fb5a50374b501bb3bac4815ee91d3135"
MD5_DEV_CLEAN = "42e2234ba48799c1f50f24a7926300a1" MD5_DEV_CLEAN = "42e2234ba48799c1f50f24a7926300a1"
MD5_DEV_OTHER = "c8d0bcc9cca99d4f8b62fcc847357931"
MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522" MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522"
MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa" MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa"
MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708" MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708"
@ -44,6 +47,13 @@ parser.add_argument(
default="manifest.libri", default="manifest.libri",
type=str, type=str,
help="Filepath prefix for output manifests. (default: %(default)s)") help="Filepath prefix for output manifests. (default: %(default)s)")
parser.add_argument(
"--full_download",
default="True",
type=distutils.util.strtobool,
help="Download all datasets for Librispeech."
" If False, only download a minimal requirement (test-clean, dev-clean"
" train-clean-100). (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
@ -57,7 +67,10 @@ def download(url, md5sum, target_dir):
print("Downloading %s ..." % url) print("Downloading %s ..." % url)
wget.download(url, target_dir) wget.download(url, target_dir)
print("\nMD5 Chesksum %s ..." % filepath) print("\nMD5 Chesksum %s ..." % filepath)
assert md5file(filepath) == md5sum, "MD5 checksum failed." if not md5file(filepath) == md5sum:
raise RuntimeError("MD5 checksum failed.")
else:
print("File exists, skip downloading. (%s)" % filepath)
return filepath return filepath
@ -69,21 +82,17 @@ def unpack(filepath, target_dir):
tar = tarfile.open(filepath) tar = tarfile.open(filepath)
tar.extractall(target_dir) tar.extractall(target_dir)
tar.close() tar.close()
return target_dir
def create_manifest(data_dir, manifest_path): def create_manifest(data_dir, manifest_path):
""" """
Create a manifest file summarizing the dataset (list of filepath and meta Create a manifest json file summarizing the data set, with each line
data). containing the meta data (i.e. audio filepath, transcription text, audio
duration) of each audio file within the data set.
Each line of the manifest contains one audio clip filepath, its
transcription text string, and its duration. Manifest file servers as a
unified interfance to organize data sets.
""" """
print("Creating manifest %s ..." % manifest_path) print("Creating manifest %s ..." % manifest_path)
json_lines = [] json_lines = []
for subfolder, _, filelist in os.walk(data_dir): for subfolder, _, filelist in sorted(os.walk(data_dir)):
text_filelist = [ text_filelist = [
filename for filename in filelist if filename.endswith('trans.txt') filename for filename in filelist if filename.endswith('trans.txt')
] ]
@ -111,9 +120,16 @@ def prepare_dataset(url, md5sum, target_dir, manifest_path):
""" """
Download, unpack and create summmary manifest file. Download, unpack and create summmary manifest file.
""" """
filepath = download(url, md5sum, target_dir) if not os.path.exists(os.path.join(target_dir, "LibriSpeech")):
unpacked_dir = unpack(filepath, target_dir) # download
create_manifest(unpacked_dir, manifest_path) filepath = download(url, md5sum, target_dir)
# unpack
unpack(filepath, target_dir)
else:
print("Skip downloading and unpacking. Data already exists in %s." %
target_dir)
# create manifest json file
create_manifest(target_dir, manifest_path)
def main(): def main():
@ -132,6 +148,27 @@ def main():
md5sum=MD5_TRAIN_CLEAN_100, md5sum=MD5_TRAIN_CLEAN_100,
target_dir=os.path.join(args.target_dir, "train-clean-100"), target_dir=os.path.join(args.target_dir, "train-clean-100"),
manifest_path=args.manifest_prefix + ".train-clean-100") manifest_path=args.manifest_prefix + ".train-clean-100")
if args.full_download:
prepare_dataset(
url=URL_TEST_OTHER,
md5sum=MD5_TEST_OTHER,
target_dir=os.path.join(args.target_dir, "test-other"),
manifest_path=args.manifest_prefix + ".test-other")
prepare_dataset(
url=URL_DEV_OTHER,
md5sum=MD5_DEV_OTHER,
target_dir=os.path.join(args.target_dir, "dev-other"),
manifest_path=args.manifest_prefix + ".dev-other")
prepare_dataset(
url=URL_TRAIN_CLEAN_360,
md5sum=MD5_TRAIN_CLEAN_360,
target_dir=os.path.join(args.target_dir, "train-clean-360"),
manifest_path=args.manifest_prefix + ".train-clean-360")
prepare_dataset(
url=URL_TRAIN_OTHER_500,
md5sum=MD5_TRAIN_OTHER_500,
target_dir=os.path.join(args.target_dir, "train-other-500"),
manifest_path=args.manifest_prefix + ".train-other-500")
if __name__ == '__main__': if __name__ == '__main__':

@ -11,6 +11,7 @@ import sys
from model import deep_speech2 from model import deep_speech2
from audio_data_utils import DataGenerator from audio_data_utils import DataGenerator
import numpy as np import numpy as np
import os
#TODO: add WER metric #TODO: add WER metric
@ -78,6 +79,13 @@ parser.add_argument(
default='data/eng_vocab.txt', default='data/eng_vocab.txt',
type=str, type=str,
help="Vocabulary filepath. (default: %(default)s)") help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
"--init_model_path",
default=None,
type=str,
help="If set None, the training will start from scratch. "
"Otherwise, the training will resume from "
"the existing model of this path. (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
@ -85,23 +93,27 @@ def train():
""" """
DeepSpeech2 training. DeepSpeech2 training.
""" """
# initialize data generator # initialize data generator
data_generator = DataGenerator( def data_generator():
vocab_filepath=args.vocab_filepath, return DataGenerator(
normalizer_manifest_path=args.normalizer_manifest_path, vocab_filepath=args.vocab_filepath,
normalizer_num_samples=200, normalizer_manifest_path=args.normalizer_manifest_path,
max_duration=20.0, normalizer_num_samples=200,
min_duration=0.0, max_duration=20.0,
stride_ms=10, min_duration=0.0,
window_ms=20) stride_ms=10,
window_ms=20)
train_generator = data_generator()
test_generator = data_generator()
# create network config # create network config
dict_size = data_generator.vocabulary_size() dict_size = train_generator.vocabulary_size()
# paddle.data_type.dense_array is used for variable batch input.
# the size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be set at each batch.
audio_data = paddle.layer.data( audio_data = paddle.layer.data(
name="audio_spectrogram", name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
height=161,
width=2000,
type=paddle.data_type.dense_vector(322000))
text_data = paddle.layer.data( text_data = paddle.layer.data(
name="transcript_text", name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size)) type=paddle.data_type.integer_value_sequence(dict_size))
@ -114,36 +126,30 @@ def train():
rnn_size=args.rnn_layer_size, rnn_size=args.rnn_layer_size,
is_inference=False) is_inference=False)
# create parameters and optimizer # create/load parameters and optimizer
parameters = paddle.parameters.create(cost) if args.init_model_path is None:
parameters = paddle.parameters.create(cost)
else:
if not os.path.isfile(args.init_model_path):
raise IOError("Invalid model!")
parameters = paddle.parameters.Parameters.from_tar(
gzip.open(args.init_model_path))
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400) learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
trainer = paddle.trainer.SGD( trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=optimizer) cost=cost, parameters=parameters, update_equation=optimizer)
# prepare data reader # prepare data reader
train_batch_reader_sortagrad = data_generator.batch_reader_creator( train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
padding_to=2000,
flatten=True,
sort_by_duration=True,
shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
manifest_path=args.train_manifest_path, manifest_path=args.train_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
padding_to=2000, sortagrad=True if args.init_model_path is None else False,
flatten=True, batch_shuffle=True)
sort_by_duration=False, test_batch_reader = test_generator.batch_reader_creator(
shuffle=True)
test_batch_reader = data_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path, manifest_path=args.dev_manifest_path,
batch_size=args.batch_size, batch_size=args.batch_size,
padding_to=2000, batch_shuffle=False)
flatten=True, feeding = train_generator.data_name_feeding()
sort_by_duration=False,
shuffle=False)
feeding = data_generator.data_name_feeding()
# create event handler # create event handler
def event_handler(event): def event_handler(event):
@ -169,17 +175,8 @@ def train():
time.time() - start_time, event.pass_id, result.cost) time.time() - start_time, event.pass_id, result.cost)
# run train # run train
# first pass with sortagrad
if args.use_sortagrad:
trainer.train(
reader=train_batch_reader_sortagrad,
event_handler=event_handler,
num_passes=1,
feeding=feeding)
args.num_passes -= 1
# other passes without sortagrad
trainer.train( trainer.train(
reader=train_batch_reader_nosortagrad, reader=train_batch_reader,
event_handler=event_handler, event_handler=event_handler,
num_passes=args.num_passes, num_passes=args.num_passes,
feeding=feeding) feeding=feeding)

Loading…
Cancel
Save