From cd3617aeb4df0dbe998060ba410c782856b2abf3 Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Mon, 12 Jun 2017 23:19:40 +0800 Subject: [PATCH 1/5] Refactor whole data preprocessor for DS2 (re-design classes, re-organize dir, add augmentaion interfaces etc.). 1. Refactor data preprocessor with new added class AudioSegment, SpeechSegment, TextFeaturizer, AudioFeaturizer, SpeechFeaturizer. 2. Add data augmentation interfaces and class AugmentorBase, AugmentationPipeline, VolumnPerturbAugmentor etc.. 3. Seperate normalizer's mean and std computing from training, by adding FeatureNormalizer and a seperate tool compute_mean_std.py. 4. Re-organize directory. --- audio_data_utils.py | 411 ------------------ compute_mean_std.py | 56 +++ data_utils/__init__.py | 0 data_utils/audio.py | 68 +++ data_utils/augmentor/__init__.py | 0 data_utils/augmentor/augmentation.py | 38 ++ data_utils/augmentor/base.py | 17 + data_utils/augmentor/volumn_perturb.py | 17 + data_utils/data.py | 247 +++++++++++ data_utils/featurizer/__init__.py | 0 data_utils/featurizer/audio_featurizer.py | 86 ++++ data_utils/featurizer/speech_featurizer.py | 32 ++ data_utils/featurizer/text_featurizer.py | 39 ++ data_utils/normalizer.py | 49 +++ data_utils/utils.py | 19 + {data => datasets/librispeech}/librispeech.py | 2 +- datasets/run_all.sh | 13 + {data => datasets/vocab}/eng_vocab.txt | 0 infer.py | 61 ++- train.py | 74 ++-- 20 files changed, 750 insertions(+), 479 deletions(-) delete mode 100644 audio_data_utils.py create mode 100755 compute_mean_std.py create mode 100755 data_utils/__init__.py create mode 100755 data_utils/audio.py create mode 100755 data_utils/augmentor/__init__.py create mode 100755 data_utils/augmentor/augmentation.py create mode 100755 data_utils/augmentor/base.py create mode 100755 data_utils/augmentor/volumn_perturb.py create mode 100644 data_utils/data.py create mode 100755 data_utils/featurizer/__init__.py create mode 100755 data_utils/featurizer/audio_featurizer.py create mode 100755 data_utils/featurizer/speech_featurizer.py create mode 100755 data_utils/featurizer/text_featurizer.py create mode 100755 data_utils/normalizer.py create mode 100755 data_utils/utils.py rename {data => datasets/librispeech}/librispeech.py (99%) create mode 100755 datasets/run_all.sh rename {data => datasets/vocab}/eng_vocab.txt (100%) diff --git a/audio_data_utils.py b/audio_data_utils.py deleted file mode 100644 index 1cd29be11..000000000 --- a/audio_data_utils.py +++ /dev/null @@ -1,411 +0,0 @@ -""" - Providing basic audio data preprocessing pipeline, and offering - both instance-level and batch-level data reader interfaces. -""" -import paddle.v2 as paddle -import logging -import json -import random -import soundfile -import numpy as np -import itertools -import os - -RANDOM_SEED = 0 -logger = logging.getLogger(__name__) - - -class DataGenerator(object): - """ - DataGenerator provides basic audio data preprocessing pipeline, and offers - both instance-level and batch-level data reader interfaces. - Normalized FFT are used as audio features here. - - :param vocab_filepath: Vocabulary file path for indexing tokenized - transcriptions. - :type vocab_filepath: basestring - :param normalizer_manifest_path: Manifest filepath for collecting feature - normalization statistics, e.g. mean, std. - :type normalizer_manifest_path: basestring - :param normalizer_num_samples: Number of instances sampled for collecting - feature normalization statistics. - Default is 100. - :type normalizer_num_samples: int - :param max_duration: Audio clips with duration (in seconds) greater than - this will be discarded. Default is 20.0. - :type max_duration: float - :param min_duration: Audio clips with duration (in seconds) smaller than - this will be discarded. Default is 0.0. - :type min_duration: float - :param stride_ms: Striding size (in milliseconds) for generating frames. - Default is 10.0. - :type stride_ms: float - :param window_ms: Window size (in milliseconds) for frames. Default is 20.0. - :type window_ms: float - :param max_frequency: Maximun frequency for FFT features. FFT features of - frequency larger than this will be discarded. - If set None, all features will be kept. - Default is None. - :type max_frequency: float - """ - - def __init__(self, - vocab_filepath, - normalizer_manifest_path, - normalizer_num_samples=100, - max_duration=20.0, - min_duration=0.0, - stride_ms=10.0, - window_ms=20.0, - max_frequency=None): - self.__max_duration__ = max_duration - self.__min_duration__ = min_duration - self.__stride_ms__ = stride_ms - self.__window_ms__ = window_ms - self.__max_frequency__ = max_frequency - self.__epoc__ = 0 - self.__random__ = random.Random(RANDOM_SEED) - # load vocabulary (dictionary) - self.__vocab_dict__, self.__vocab_list__ = \ - self.__load_vocabulary_from_file__(vocab_filepath) - # collect normalizer statistics - self.__mean__, self.__std__ = self.__collect_normalizer_statistics__( - manifest_path=normalizer_manifest_path, - num_samples=normalizer_num_samples) - - def __audio_featurize__(self, audio_filename): - """ - Preprocess audio data, including feature extraction, normalization etc.. - """ - features = self.__audio_basic_featurize__(audio_filename) - return self.__normalize__(features) - - def __text_featurize__(self, text): - """ - Preprocess text data, including tokenizing and token indexing etc.. - """ - return self.__convert_text_to_char_index__( - text=text, vocabulary=self.__vocab_dict__) - - def __audio_basic_featurize__(self, audio_filename): - """ - Compute basic (without normalization etc.) features for audio data. - """ - return self.__spectrogram_from_file__( - filename=audio_filename, - stride_ms=self.__stride_ms__, - window_ms=self.__window_ms__, - max_freq=self.__max_frequency__) - - def __collect_normalizer_statistics__(self, manifest_path, num_samples=100): - """ - Compute feature normalization statistics, i.e. mean and stddev. - """ - # read manifest - manifest = self.__read_manifest__( - manifest_path=manifest_path, - max_duration=self.__max_duration__, - min_duration=self.__min_duration__) - # sample for statistics - sampled_manifest = self.__random__.sample(manifest, num_samples) - # extract spectrogram feature - features = [] - for instance in sampled_manifest: - spectrogram = self.__audio_basic_featurize__( - instance["audio_filepath"]) - features.append(spectrogram) - features = np.hstack(features) - mean = np.mean(features, axis=1).reshape([-1, 1]) - std = np.std(features, axis=1).reshape([-1, 1]) - return mean, std - - def __normalize__(self, features, eps=1e-14): - """ - Normalize features to be of zero mean and unit stddev. - """ - return (features - self.__mean__) / (self.__std__ + eps) - - def __spectrogram_from_file__(self, - filename, - stride_ms=10.0, - window_ms=20.0, - max_freq=None, - eps=1e-14): - """ - Laod audio data and calculate the log of spectrogram by FFT. - Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech - """ - audio, sample_rate = soundfile.read(filename) - if audio.ndim >= 2: - audio = np.mean(audio, 1) - if max_freq is None: - max_freq = sample_rate / 2 - if max_freq > sample_rate / 2: - raise ValueError("max_freq must be greater than half of " - "sample rate.") - if stride_ms > window_ms: - raise ValueError("Stride size must not be greater than " - "window size.") - stride_size = int(0.001 * sample_rate * stride_ms) - window_size = int(0.001 * sample_rate * window_ms) - spectrogram, freqs = self.__extract_spectrogram__( - audio, - window_size=window_size, - stride_size=stride_size, - sample_rate=sample_rate) - ind = np.where(freqs <= max_freq)[0][-1] + 1 - return np.log(spectrogram[:ind, :] + eps) - - def __extract_spectrogram__(self, samples, window_size, stride_size, - sample_rate): - """ - Compute the spectrogram by FFT for a discrete real signal. - Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech - """ - # extract strided windows - truncate_size = (len(samples) - window_size) % stride_size - samples = samples[:len(samples) - truncate_size] - nshape = (window_size, (len(samples) - window_size) // stride_size + 1) - nstrides = (samples.strides[0], samples.strides[0] * stride_size) - windows = np.lib.stride_tricks.as_strided( - samples, shape=nshape, strides=nstrides) - assert np.all( - windows[:, 1] == samples[stride_size:(stride_size + window_size)]) - # window weighting, squared Fast Fourier Transform (fft), scaling - weighting = np.hanning(window_size)[:, None] - fft = np.fft.rfft(windows * weighting, axis=0) - fft = np.absolute(fft)**2 - scale = np.sum(weighting**2) * sample_rate - fft[1:-1, :] *= (2.0 / scale) - fft[(0, -1), :] /= scale - # prepare fft frequency list - freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) - return fft, freqs - - def __load_vocabulary_from_file__(self, vocabulary_path): - """ - Load vocabulary from file. - """ - if not os.path.exists(vocabulary_path): - raise ValueError("Vocabulary file %s not found.", vocabulary_path) - vocab_lines = [] - with open(vocabulary_path, 'r') as file: - vocab_lines.extend(file.readlines()) - vocab_list = [line[:-1] for line in vocab_lines] - vocab_dict = dict( - [(token, id) for (id, token) in enumerate(vocab_list)]) - return vocab_dict, vocab_list - - def __convert_text_to_char_index__(self, text, vocabulary): - """ - Convert text string to a list of character index integers. - """ - return [vocabulary[w] for w in text] - - def __read_manifest__(self, manifest_path, max_duration, min_duration): - """ - Load and parse manifest file. - """ - manifest = [] - for json_line in open(manifest_path): - try: - json_data = json.loads(json_line) - except Exception as e: - raise ValueError("Error reading manifest: %s" % str(e)) - if (json_data["duration"] <= max_duration and - json_data["duration"] >= min_duration): - manifest.append(json_data) - return manifest - - def __padding_batch__(self, batch, padding_to=-1, flatten=False): - """ - Padding audio part of features (only in the time axis -- column axis) - with zeros, to make each instance in the batch share the same - audio feature shape. - - If `padding_to` is set -1, the maximun column numbers in the batch will - be used as the target size. Otherwise, `padding_to` will be the target - size. Default is -1. - - If `flatten` is set True, audio data will be flatten to be a 1-dim - ndarray. Default is False. - """ - new_batch = [] - # get target shape - max_length = max([audio.shape[1] for audio, text in batch]) - if padding_to != -1: - if padding_to < max_length: - raise ValueError("If padding_to is not -1, it should be greater" - " or equal to the original instance length.") - max_length = padding_to - # padding - for audio, text in batch: - padded_audio = np.zeros([audio.shape[0], max_length]) - padded_audio[:, :audio.shape[1]] = audio - if flatten: - padded_audio = padded_audio.flatten() - new_batch.append((padded_audio, text)) - return new_batch - - def __batch_shuffle__(self, manifest, batch_size): - """ - The instances have different lengths and they cannot be - combined into a single matrix multiplication. It usually - sorts the training examples by length and combines only - similarly-sized instances into minibatches, pads with - silence when necessary so that all instances in a batch - have the same length. This batch shuffle fuction is used - to make similarly-sized instances into minibatches and - make a batch-wise shuffle. - - 1. Sort the audio clips by duration. - 2. Generate a random number `k`, k in [0, batch_size). - 3. Randomly remove `k` instances in order to make different mini-batches, - then make minibatches and each minibatch size is batch_size. - 4. Shuffle the minibatches. - - :param manifest: manifest file. - :type manifest: list - :param batch_size: Batch size. This size is also used for generate - a random number for batch shuffle. - :type batch_size: int - :return: batch shuffled mainifest. - :rtype: list - """ - manifest.sort(key=lambda x: x["duration"]) - shift_len = self.__random__.randint(0, batch_size - 1) - batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) - self.__random__.shuffle(batch_manifest) - batch_manifest = list(sum(batch_manifest, ())) - res_len = len(manifest) - shift_len - len(batch_manifest) - batch_manifest.extend(manifest[-res_len:]) - batch_manifest.extend(manifest[0:shift_len]) - return batch_manifest - - def instance_reader_creator(self, manifest): - """ - Instance reader creator for audio data. Creat a callable function to - produce instances of data. - - Instance: a tuple of a numpy ndarray of audio spectrogram and a list of - tokenized and indexed transcription text. - - :param manifest: Filepath of manifest for audio clip files. - :type manifest: basestring - :return: Data reader function. - :rtype: callable - """ - - def reader(): - # extract spectrogram feature - for instance in manifest: - spectrogram = self.__audio_featurize__( - instance["audio_filepath"]) - transcript = self.__text_featurize__(instance["text"]) - yield (spectrogram, transcript) - - return reader - - def batch_reader_creator(self, - manifest_path, - batch_size, - padding_to=-1, - flatten=False, - sortagrad=False, - batch_shuffle=False): - """ - Batch data reader creator for audio data. Creat a callable function to - produce batches of data. - - Audio features will be padded with zeros to make each instance in the - batch to share the same audio feature shape. - - :param manifest_path: Filepath of manifest for audio clip files. - :type manifest_path: basestring - :param batch_size: Instance number in a batch. - :type batch_size: int - :param padding_to: If set -1, the maximun column numbers in the batch - will be used as the target size for padding. - Otherwise, `padding_to` will be the target size. - Default is -1. - :type padding_to: int - :param flatten: If set True, audio data will be flatten to be a 1-dim - ndarray. Otherwise, 2-dim ndarray. Default is False. - :type flatten: bool - :param sortagrad: Sort the audio clips by duration in the first epoc - if set True. - :type sortagrad: bool - :param batch_shuffle: Shuffle the audio clips if set True. It is - not a thorough instance-wise shuffle, but a - specific batch-wise shuffle. For more details, - please see `__batch_shuffle__` function. - :type batch_shuffle: bool - :return: Batch reader function, producing batches of data when called. - :rtype: callable - """ - - def batch_reader(): - # read manifest - manifest = self.__read_manifest__( - manifest_path=manifest_path, - max_duration=self.__max_duration__, - min_duration=self.__min_duration__) - - # sort (by duration) or shuffle manifest - if self.__epoc__ == 0 and sortagrad: - manifest.sort(key=lambda x: x["duration"]) - elif batch_shuffle: - manifest = self.__batch_shuffle__(manifest, batch_size) - - instance_reader = self.instance_reader_creator(manifest) - batch = [] - for instance in instance_reader(): - batch.append(instance) - if len(batch) == batch_size: - yield self.__padding_batch__(batch, padding_to, flatten) - batch = [] - if len(batch) > 0: - yield self.__padding_batch__(batch, padding_to, flatten) - self.__epoc__ += 1 - - return batch_reader - - def vocabulary_size(self): - """ - Get vocabulary size. - - :return: Vocabulary size. - :rtype: int - """ - return len(self.__vocab_list__) - - def vocabulary_dict(self): - """ - Get vocabulary in dict. - - :return: Vocabulary in dict. - :rtype: dict - """ - return self.__vocab_dict__ - - def vocabulary_list(self): - """ - Get vocabulary in list. - - :return: Vocabulary in list - :rtype: list - """ - return self.__vocab_list__ - - def data_name_feeding(self): - """ - Get feeddings (data field name and corresponding field id). - - :return: Feeding dict. - :rtype: dict - """ - feeding = { - "audio_spectrogram": 0, - "transcript_text": 1, - } - return feeding diff --git a/compute_mean_std.py b/compute_mean_std.py new file mode 100755 index 000000000..b3015df73 --- /dev/null +++ b/compute_mean_std.py @@ -0,0 +1,56 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +from data_utils.normalizer import FeatureNormalizer +from data_utils.augmentor.augmentation import AugmentationPipeline +from data_utils.featurizer.audio_featurizer import AudioFeaturizer + +parser = argparse.ArgumentParser( + description='Computing mean and stddev for feature normalizer.') +parser.add_argument( + "--manifest_path", + default='datasets/manifest.train', + type=str, + help="Manifest path for computing normalizer's mean and stddev." + "(default: %(default)s)") +parser.add_argument( + "--num_samples", + default=500, + type=int, + help="Number of samples for computing mean and stddev. " + "(default: %(default)s)") +parser.add_argument( + "--augmentation_config", + default='{}', + type=str, + help="Augmentation configuration in json-format. " + "(default: %(default)s)") +parser.add_argument( + "--output_file", + default='mean_std.npz', + type=str, + help="Filepath to write mean and std to (.npz)." + "(default: %(default)s)") +args = parser.parse_args() + + +def main(): + augmentation_pipeline = AugmentationPipeline(args.augmentation_config) + audio_featurizer = AudioFeaturizer() + + def augment_and_featurize(audio_segment): + augmentation_pipeline.transform_audio(audio_segment) + return audio_featurizer.featurize(audio_segment) + + normalizer = FeatureNormalizer( + mean_std_filepath=None, + manifest_path=args.manifest_path, + featurize_func=augment_and_featurize, + num_samples=args.num_samples) + normalizer.write_to_file(args.output_file) + + +if __name__ == '__main__': + main() diff --git a/data_utils/__init__.py b/data_utils/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/data_utils/audio.py b/data_utils/audio.py new file mode 100755 index 000000000..46b241201 --- /dev/null +++ b/data_utils/audio.py @@ -0,0 +1,68 @@ +import numpy as np +import io +import soundfile + + +class AudioSegment(object): + """Monaural audio segment abstraction. + """ + + def __init__(self, samples, sample_rate): + if not samples.dtype == np.float32: + raise ValueError("Sample data type of [%s] is not supported.") + self._samples = samples + self._sample_rate = sample_rate + if self._samples.ndim >= 2: + self._samples = np.mean(self._samples, 1) + + @classmethod + def from_file(cls, filepath): + samples, sample_rate = soundfile.read(filepath, dtype='float32') + return cls(samples, sample_rate) + + @classmethod + def from_bytes(cls, bytes): + samples, sample_rate = soundfile.read( + io.BytesIO(bytes), dtype='float32') + return cls(samples, sample_rate) + + def apply_gain(self, gain): + self.samples *= 10.**(gain / 20.) + + def resample(self, target_sample_rate): + raise NotImplementedError() + + def change_speed(self, rate): + raise NotImplementedError() + + @property + def samples(self): + return self._samples.copy() + + @property + def sample_rate(self): + return self._sample_rate + + @property + def duration(self): + return self._samples.shape[0] / float(self._sample_rate) + + +class SpeechSegment(AudioSegment): + def __init__(self, samples, sample_rate, transcript): + AudioSegment.__init__(self, samples, sample_rate) + self._transcript = transcript + + @classmethod + def from_file(cls, filepath, transcript): + audio = AudioSegment.from_file(filepath) + return cls(audio.samples, audio.sample_rate, transcript) + + @classmethod + def from_bytes(cls, bytes, transcript): + audio = AudioSegment.from_bytes(bytes) + return cls(audio.samples, audio.sample_rate, transcript) + + @property + def transcript(self): + return self._transcript diff --git a/data_utils/augmentor/__init__.py b/data_utils/augmentor/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/data_utils/augmentor/augmentation.py b/data_utils/augmentor/augmentation.py new file mode 100755 index 000000000..3a1426a1f --- /dev/null +++ b/data_utils/augmentor/augmentation.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import random +from data_utils.augmentor.volumn_perturb import VolumnPerturbAugmentor + + +class AugmentationPipeline(object): + def __init__(self, augmentation_config, random_seed=0): + self._rng = random.Random(random_seed) + self._augmentors, self._rates = self._parse_pipeline_from( + augmentation_config) + + def transform_audio(self, audio_segment): + for augmentor, rate in zip(self._augmentors, self._rates): + if self._rng.uniform(0., 1.) <= rate: + augmentor.transform_audio(audio_segment) + + def _parse_pipeline_from(self, config_json): + try: + configs = json.loads(config_json) + except Exception as e: + raise ValueError("Augmentation config json format error: " + "%s" % str(e)) + augmentors = [ + self._get_augmentor(config["type"], config["params"]) + for config in configs + ] + rates = [config["rate"] for config in configs] + return augmentors, rates + + def _get_augmentor(self, augmentor_type, params): + if augmentor_type == "volumn": + return VolumnPerturbAugmentor(self._rng, **params) + else: + raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/data_utils/augmentor/base.py b/data_utils/augmentor/base.py new file mode 100755 index 000000000..e801b9b18 --- /dev/null +++ b/data_utils/augmentor/base.py @@ -0,0 +1,17 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from abc import ABCMeta, abstractmethod + + +class AugmentorBase(object): + __metaclass__ = ABCMeta + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def transform_audio(self, audio_segment): + pass diff --git a/data_utils/augmentor/volumn_perturb.py b/data_utils/augmentor/volumn_perturb.py new file mode 100755 index 000000000..dd1ba53a7 --- /dev/null +++ b/data_utils/augmentor/volumn_perturb.py @@ -0,0 +1,17 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +from data_utils.augmentor.base import AugmentorBase + + +class VolumnPerturbAugmentor(AugmentorBase): + def __init__(self, rng, min_gain_dBFS, max_gain_dBFS): + self._min_gain_dBFS = min_gain_dBFS + self._max_gain_dBFS = max_gain_dBFS + self._rng = rng + + def transform_audio(self, audio_segment): + gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) + audio_segment.apply_gain(gain) diff --git a/data_utils/data.py b/data_utils/data.py new file mode 100644 index 000000000..630007932 --- /dev/null +++ b/data_utils/data.py @@ -0,0 +1,247 @@ +""" + Providing basic audio data preprocessing pipeline, and offering + both instance-level and batch-level data reader interfaces. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import numpy as np +import paddle.v2 as paddle +from data_utils import utils +from data_utils.augmentor.augmentation import AugmentationPipeline +from data_utils.featurizer.speech_featurizer import SpeechFeaturizer +from data_utils.audio import SpeechSegment +from data_utils.normalizer import FeatureNormalizer + + +class DataGenerator(object): + """ + DataGenerator provides basic audio data preprocessing pipeline, and offers + both instance-level and batch-level data reader interfaces. + Normalized FFT are used as audio features here. + + :param vocab_filepath: Vocabulary file path for indexing tokenized + transcriptions. + :type vocab_filepath: basestring + :param normalizer_manifest_path: Manifest filepath for collecting feature + normalization statistics, e.g. mean, std. + :type normalizer_manifest_path: basestring + :param normalizer_num_samples: Number of instances sampled for collecting + feature normalization statistics. + Default is 100. + :type normalizer_num_samples: int + :param max_duration: Audio clips with duration (in seconds) greater than + this will be discarded. Default is 20.0. + :type max_duration: float + :param min_duration: Audio clips with duration (in seconds) smaller than + this will be discarded. Default is 0.0. + :type min_duration: float + :param stride_ms: Striding size (in milliseconds) for generating frames. + Default is 10.0. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for frames. Default is 20.0. + :type window_ms: float + :param max_frequency: Maximun frequency for FFT features. FFT features of + frequency larger than this will be discarded. + If set None, all features will be kept. + Default is None. + :type max_frequency: float + """ + + def __init__(self, + vocab_filepath, + mean_std_filepath, + augmentation_config='{}', + max_duration=float('inf'), + min_duration=0.0, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + random_seed=0): + self._max_duration = max_duration + self._min_duration = min_duration + self._normalizer = FeatureNormalizer(mean_std_filepath) + self._augmentation_pipeline = AugmentationPipeline( + augmentation_config=augmentation_config, random_seed=random_seed) + self._speech_featurizer = SpeechFeaturizer( + vocab_filepath=vocab_filepath, + stride_ms=stride_ms, + window_ms=window_ms, + max_freq=max_freq, + random_seed=random_seed) + self._rng = random.Random(random_seed) + self._epoch = 0 + + def batch_reader_creator(self, + manifest_path, + batch_size, + padding_to=-1, + flatten=False, + sortagrad=False, + batch_shuffle=False): + """ + Batch data reader creator for audio data. Creat a callable function to + produce batches of data. + + Audio features will be padded with zeros to make each instance in the + batch to share the same audio feature shape. + + :param manifest_path: Filepath of manifest for audio clip files. + :type manifest_path: basestring + :param batch_size: Instance number in a batch. + :type batch_size: int + :param padding_to: If set -1, the maximun column numbers in the batch + will be used as the target size for padding. + Otherwise, `padding_to` will be the target size. + Default is -1. + :type padding_to: int + :param flatten: If set True, audio data will be flatten to be a 1-dim + ndarray. Otherwise, 2-dim ndarray. Default is False. + :type flatten: bool + :param sortagrad: Sort the audio clips by duration in the first epoc + if set True. + :type sortagrad: bool + :param batch_shuffle: Shuffle the audio clips if set True. It is + not a thorough instance-wise shuffle, but a + specific batch-wise shuffle. For more details, + please see `_batch_shuffle` function. + :type batch_shuffle: bool + :return: Batch reader function, producing batches of data when called. + :rtype: callable + """ + + def batch_reader(): + # read manifest + manifest = utils.read_manifest( + manifest_path=manifest_path, + max_duration=self._max_duration, + min_duration=self._min_duration) + # sort (by duration) or batch-wise shuffle the manifest + if self._epoch == 0 and sortagrad: + manifest.sort(key=lambda x: x["duration"]) + elif batch_shuffle: + manifest = self._batch_shuffle(manifest, batch_size) + # prepare batches + instance_reader = self._instance_reader_creator(manifest) + batch = [] + for instance in instance_reader(): + batch.append(instance) + if len(batch) == batch_size: + yield self._padding_batch(batch, padding_to, flatten) + batch = [] + if len(batch) > 0: + yield self._padding_batch(batch, padding_to, flatten) + self._epoch += 1 + + return batch_reader + + @property + def feeding(self): + """Returns data_reader's feeding dict.""" + return {"audio_spectrogram": 0, "transcript_text": 1} + + @property + def vocab_size(self): + """Returns vocabulary size.""" + return self._speech_featurizer.vocab_size + + @property + def vocab_list(self): + """Returns vocabulary list.""" + return self._speech_featurizer.vocab_list + + def _process_utterance(self, filename, transcript): + speech_segment = SpeechSegment.from_file(filename, transcript) + self._augmentation_pipeline.transform_audio(speech_segment) + specgram, text_ids = self._speech_featurizer.featurize(speech_segment) + specgram = self._normalizer.apply(specgram) + return specgram, text_ids + + def _instance_reader_creator(self, manifest): + """ + Instance reader creator for audio data. Creat a callable function to + produce instances of data. + + Instance: a tuple of a numpy ndarray of audio spectrogram and a list of + tokenized and indexed transcription text. + + :param manifest: Filepath of manifest for audio clip files. + :type manifest: basestring + :return: Data reader function. + :rtype: callable + """ + + def reader(): + for instance in manifest: + yield self._process_utterance(instance["audio_filepath"], + instance["text"]) + + return reader + + def _padding_batch(self, batch, padding_to=-1, flatten=False): + """ + Padding audio part of features (only in the time axis -- column axis) + with zeros, to make each instance in the batch share the same + audio feature shape. + + If `padding_to` is set -1, the maximun column numbers in the batch will + be used as the target size. Otherwise, `padding_to` will be the target + size. Default is -1. + + If `flatten` is set True, audio data will be flatten to be a 1-dim + ndarray. Default is False. + """ + new_batch = [] + # get target shape + max_length = max([audio.shape[1] for audio, text in batch]) + if padding_to != -1: + if padding_to < max_length: + raise ValueError("If padding_to is not -1, it should be greater" + " or equal to the original instance length.") + max_length = padding_to + # padding + for audio, text in batch: + padded_audio = np.zeros([audio.shape[0], max_length]) + padded_audio[:, :audio.shape[1]] = audio + if flatten: + padded_audio = padded_audio.flatten() + new_batch.append((padded_audio, text)) + return new_batch + + def _batch_shuffle(self, manifest, batch_size): + """ + The instances have different lengths and they cannot be + combined into a single matrix multiplication. It usually + sorts the training examples by length and combines only + similarly-sized instances into minibatches, pads with + silence when necessary so that all instances in a batch + have the same length. This batch shuffle fuction is used + to make similarly-sized instances into minibatches and + make a batch-wise shuffle. + + 1. Sort the audio clips by duration. + 2. Generate a random number `k`, k in [0, batch_size). + 3. Randomly remove `k` instances in order to make different mini-batches, + then make minibatches and each minibatch size is batch_size. + 4. Shuffle the minibatches. + + :param manifest: manifest file. + :type manifest: list + :param batch_size: Batch size. This size is also used for generate + a random number for batch shuffle. + :type batch_size: int + :return: batch shuffled mainifest. + :rtype: list + """ + manifest.sort(key=lambda x: x["duration"]) + shift_len = self._rng.randint(0, batch_size - 1) + batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) + self._rng.shuffle(batch_manifest) + batch_manifest = list(sum(batch_manifest, ())) + res_len = len(manifest) - shift_len - len(batch_manifest) + batch_manifest.extend(manifest[-res_len:]) + batch_manifest.extend(manifest[0:shift_len]) + return batch_manifest diff --git a/data_utils/featurizer/__init__.py b/data_utils/featurizer/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py new file mode 100755 index 000000000..5d9c68836 --- /dev/null +++ b/data_utils/featurizer/audio_featurizer.py @@ -0,0 +1,86 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import random +from data_utils import utils +from data_utils.audio import AudioSegment + + +class AudioFeaturizer(object): + def __init__(self, + specgram_type='linear', + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + random_seed=0): + self._specgram_type = specgram_type + self._stride_ms = stride_ms + self._window_ms = window_ms + self._max_freq = max_freq + + def featurize(self, audio_segment): + return self._compute_specgram(audio_segment.samples, + audio_segment.sample_rate) + + def _compute_specgram(self, samples, sample_rate): + if self._specgram_type == 'linear': + return self._compute_linear_specgram( + samples, sample_rate, self._stride_ms, self._window_ms, + self._max_freq) + else: + raise ValueError("Unknown specgram_type %s. " + "Supported values: linear." % self._specgram_type) + + def _compute_linear_specgram(self, + samples, + sample_rate, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + eps=1e-14): + """Laod audio data and calculate the log of spectrogram by FFT. + Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech + """ + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + stride_size = int(0.001 * sample_rate * stride_ms) + window_size = int(0.001 * sample_rate * window_ms) + specgram, freqs = self._specgram_real( + samples, + window_size=window_size, + stride_size=stride_size, + sample_rate=sample_rate) + ind = np.where(freqs <= max_freq)[0][-1] + 1 + return np.log(specgram[:ind, :] + eps) + + def _specgram_real(self, samples, window_size, stride_size, sample_rate): + """Compute the spectrogram by FFT for a discrete real signal. + Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech + """ + # extract strided windows + truncate_size = (len(samples) - window_size) % stride_size + samples = samples[:len(samples) - truncate_size] + nshape = (window_size, (len(samples) - window_size) // stride_size + 1) + nstrides = (samples.strides[0], samples.strides[0] * stride_size) + windows = np.lib.stride_tricks.as_strided( + samples, shape=nshape, strides=nstrides) + assert np.all( + windows[:, 1] == samples[stride_size:(stride_size + window_size)]) + # window weighting, squared Fast Fourier Transform (fft), scaling + weighting = np.hanning(window_size)[:, None] + fft = np.fft.rfft(windows * weighting, axis=0) + fft = np.absolute(fft)**2 + scale = np.sum(weighting**2) * sample_rate + fft[1:-1, :] *= (2.0 / scale) + fft[(0, -1), :] /= scale + # prepare fft frequency list + freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) + return fft, freqs diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py new file mode 100755 index 000000000..06af7a026 --- /dev/null +++ b/data_utils/featurizer/speech_featurizer.py @@ -0,0 +1,32 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.featurizer.audio_featurizer import AudioFeaturizer +from data_utils.featurizer.text_featurizer import TextFeaturizer + + +class SpeechFeaturizer(object): + def __init__(self, + vocab_filepath, + specgram_type='linear', + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + random_seed=0): + self._audio_featurizer = AudioFeaturizer( + specgram_type, stride_ms, window_ms, max_freq, random_seed) + self._text_featurizer = TextFeaturizer(vocab_filepath) + + def featurize(self, speech_segment): + audio_feature = self._audio_featurizer.featurize(speech_segment) + text_ids = self._text_featurizer.text2ids(speech_segment.transcript) + return audio_feature, text_ids + + @property + def vocab_size(self): + return self._text_featurizer.vocab_size + + @property + def vocab_list(self): + return self._text_featurizer.vocab_list diff --git a/data_utils/featurizer/text_featurizer.py b/data_utils/featurizer/text_featurizer.py new file mode 100755 index 000000000..7e4b69d7b --- /dev/null +++ b/data_utils/featurizer/text_featurizer.py @@ -0,0 +1,39 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + + +class TextFeaturizer(object): + def __init__(self, vocab_filepath): + self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( + vocab_filepath) + + def text2ids(self, text): + tokens = self._char_tokenize(text) + return [self._vocab_dict[token] for token in tokens] + + def ids2text(self, ids): + return ''.join([self._vocab_list[id] for id in ids]) + + @property + def vocab_size(self): + return len(self._vocab_list) + + @property + def vocab_list(self): + return self._vocab_list + + def _char_tokenize(self, text): + return list(text.strip()) + + def _load_vocabulary_from_file(self, vocab_filepath): + """Load vocabulary from file.""" + vocab_lines = [] + with open(vocab_filepath, 'r') as file: + vocab_lines.extend(file.readlines()) + vocab_list = [line[:-1] for line in vocab_lines] + vocab_dict = dict( + [(token, id) for (id, token) in enumerate(vocab_list)]) + return vocab_dict, vocab_list diff --git a/data_utils/normalizer.py b/data_utils/normalizer.py new file mode 100755 index 000000000..364600af8 --- /dev/null +++ b/data_utils/normalizer.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import random +import data_utils.utils as utils +from data_utils.audio import AudioSegment + + +class FeatureNormalizer(object): + def __init__(self, + mean_std_filepath, + manifest_path=None, + featurize_func=None, + num_samples=500, + random_seed=0): + if not mean_std_filepath: + if not (manifest_path and featurize_func): + raise ValueError("If mean_std_filepath is None, meanifest_path " + "and featurize_func should not be None.") + self._rng = random.Random(random_seed) + self._compute_mean_std(manifest_path, featurize_func, num_samples) + else: + self._read_mean_std_from_file(mean_std_filepath) + + def apply(self, features, eps=1e-14): + """Normalize features to be of zero mean and unit stddev.""" + return (features - self._mean) / (self._std + eps) + + def write_to_file(self, filepath): + np.savez(filepath, mean=self._mean, std=self._std) + + def _read_mean_std_from_file(self, filepath): + npzfile = np.load(filepath) + self._mean = npzfile["mean"] + self._std = npzfile["std"] + + def _compute_mean_std(self, manifest_path, featurize_func, num_samples): + manifest = utils.read_manifest(manifest_path) + sampled_manifest = self._rng.sample(manifest, num_samples) + features = [] + for instance in sampled_manifest: + features.append( + featurize_func( + AudioSegment.from_file(instance["audio_filepath"]))) + features = np.hstack(features) + self._mean = np.mean(features, axis=1).reshape([-1, 1]) + self._std = np.std(features, axis=1).reshape([-1, 1]) diff --git a/data_utils/utils.py b/data_utils/utils.py new file mode 100755 index 000000000..2a916b54f --- /dev/null +++ b/data_utils/utils.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + + +def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): + """Load and parse manifest file.""" + manifest = [] + for json_line in open(manifest_path): + try: + json_data = json.loads(json_line) + except Exception as e: + raise IOError("Error reading manifest: %s" % str(e)) + if (json_data["duration"] <= max_duration and + json_data["duration"] >= min_duration): + manifest.append(json_data) + return manifest diff --git a/data/librispeech.py b/datasets/librispeech/librispeech.py similarity index 99% rename from data/librispeech.py rename to datasets/librispeech/librispeech.py index 653caa926..1ba2a4422 100644 --- a/data/librispeech.py +++ b/datasets/librispeech/librispeech.py @@ -44,7 +44,7 @@ parser.add_argument( help="Directory to save the dataset. (default: %(default)s)") parser.add_argument( "--manifest_prefix", - default="manifest.libri", + default="manifest", type=str, help="Filepath prefix for output manifests. (default: %(default)s)") parser.add_argument( diff --git a/datasets/run_all.sh b/datasets/run_all.sh new file mode 100755 index 000000000..ef2b721fb --- /dev/null +++ b/datasets/run_all.sh @@ -0,0 +1,13 @@ +cd librispeech +python librispeech.py +if [ $? -ne 0 ]; then + echo "Prepare LibriSpeech failed. Terminated." + exit 1 +fi +cd - + +cat librispeech/manifest.train* | shuf > manifest.train +cat librispeech/manifest.dev-clean > manifest.dev +cat librispeech/manifest.test-clean > manifest.test + +echo "All done." diff --git a/data/eng_vocab.txt b/datasets/vocab/eng_vocab.txt similarity index 100% rename from data/eng_vocab.txt rename to datasets/vocab/eng_vocab.txt diff --git a/infer.py b/infer.py index 598c348b0..eb31254ce 100644 --- a/infer.py +++ b/infer.py @@ -2,11 +2,15 @@ Inference for a simplifed version of Baidu DeepSpeech2 model. """ -import paddle.v2 as paddle -import distutils.util +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import argparse import gzip -from audio_data_utils import DataGenerator +import distutils.util +import paddle.v2 as paddle +from data_utils.data import DataGenerator from model import deep_speech2 from decoder import ctc_decode @@ -38,13 +42,13 @@ parser.add_argument( type=distutils.util.strtobool, help="Use gpu or not. (default: %(default)s)") parser.add_argument( - "--normalizer_manifest_path", - default='data/manifest.libri.train-clean-100', + "--mean_std_filepath", + default='mean_std.npz', type=str, help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", - default='data/manifest.libri.test-clean', + default='datasets/manifest.test', type=str, help="Manifest path for decoding. (default: %(default)s)") parser.add_argument( @@ -54,7 +58,7 @@ parser.add_argument( help="Model filepath. (default: %(default)s)") parser.add_argument( "--vocab_filepath", - default='data/eng_vocab.txt', + default='datasets/vocab/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") args = parser.parse_args() @@ -67,28 +71,22 @@ def infer(): # initialize data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, - normalizer_manifest_path=args.normalizer_manifest_path, - normalizer_num_samples=200, - max_duration=20.0, - min_duration=0.0, - stride_ms=10, - window_ms=20) + mean_std_filepath=args.mean_std_filepath, + augmentation_config='{}') # create network config - dict_size = data_generator.vocabulary_size() - vocab_list = data_generator.vocabulary_list() + # paddle.data_type.dense_array is used for variable batch input. + # The size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be induced during training. audio_data = paddle.layer.data( - name="audio_spectrogram", - height=161, - width=2000, - type=paddle.data_type.dense_vector(322000)) + name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) text_data = paddle.layer.data( name="transcript_text", - type=paddle.data_type.integer_value_sequence(dict_size)) + type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) output_probs = deep_speech2( audio_data=audio_data, text_data=text_data, - dict_size=dict_size, + dict_size=data_generator.vocab_size, num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, rnn_size=args.rnn_layer_size, @@ -99,31 +97,30 @@ def infer(): gzip.open(args.model_filepath)) # prepare infer data - feeding = data_generator.data_name_feeding() - test_batch_reader = data_generator.batch_reader_creator( + batch_reader = data_generator.batch_reader_creator( manifest_path=args.decode_manifest_path, batch_size=args.num_samples, - padding_to=2000, - flatten=True, - sort_by_duration=False, - shuffle=False) - infer_data = test_batch_reader().next() + sortagrad=False, + batch_shuffle=False) + infer_data = batch_reader().next() # run inference infer_results = paddle.infer( output_layer=output_probs, parameters=parameters, input=infer_data) - num_steps = len(infer_results) / len(infer_data) + num_steps = len(infer_results) // len(infer_data) probs_split = [ infer_results[i * num_steps:(i + 1) * num_steps] - for i in xrange(0, len(infer_data)) + for i in xrange(len(infer_data)) ] # decode and print for i, probs in enumerate(probs_split): output_transcription = ctc_decode( - probs_seq=probs, vocabulary=vocab_list, method="best_path") + probs_seq=probs, + vocabulary=data_generator.vocab_list, + method="best_path") target_transcription = ''.join( - [vocab_list[index] for index in infer_data[i][1]]) + [data_generator.vocab_list[index] for index in infer_data[i][1]]) print("Target Transcription: %s \nOutput Transcription: %s \n" % (target_transcription, output_transcription)) diff --git a/train.py b/train.py index 957c24267..c6aa97527 100644 --- a/train.py +++ b/train.py @@ -2,21 +2,21 @@ Trainer for a simplifed version of Baidu DeepSpeech2 model. """ -import paddle.v2 as paddle -import distutils.util +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import os import argparse import gzip import time -import sys +import distutils.util +import paddle.v2 as paddle from model import deep_speech2 -from audio_data_utils import DataGenerator -import numpy as np -import os +from data_utils.data import DataGenerator -#TODO: add WER metric - -parser = argparse.ArgumentParser( - description='Simplified version of DeepSpeech2 trainer.') +parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--batch_size", default=32, type=int, help="Minibatch size.") parser.add_argument( @@ -51,7 +51,7 @@ parser.add_argument( help="Use gpu or not. (default: %(default)s)") parser.add_argument( "--use_sortagrad", - default=False, + default=True, type=distutils.util.strtobool, help="Use sortagrad or not. (default: %(default)s)") parser.add_argument( @@ -60,23 +60,23 @@ parser.add_argument( type=int, help="Trainer number. (default: %(default)s)") parser.add_argument( - "--normalizer_manifest_path", - default='data/manifest.libri.train-clean-100', + "--mean_std_filepath", + default='mean_std.npz', type=str, help="Manifest path for normalizer. (default: %(default)s)") parser.add_argument( "--train_manifest_path", - default='data/manifest.libri.train-clean-100', + default='datasets/manifest.train', type=str, help="Manifest path for training. (default: %(default)s)") parser.add_argument( "--dev_manifest_path", - default='data/manifest.libri.dev-clean', + default='datasets/manifest.dev', type=str, help="Manifest path for validation. (default: %(default)s)") parser.add_argument( "--vocab_filepath", - default='data/eng_vocab.txt', + default='datasets/vocab/eng_vocab.txt', type=str, help="Vocabulary filepath. (default: %(default)s)") parser.add_argument( @@ -86,6 +86,12 @@ parser.add_argument( help="If set None, the training will start from scratch. " "Otherwise, the training will resume from " "the existing model of this path. (default: %(default)s)") +parser.add_argument( + "--augmentation_config", + default='{}', + type=str, + help="Augmentation configuration in json-format. " + "(default: %(default)s)") args = parser.parse_args() @@ -98,29 +104,26 @@ def train(): def data_generator(): return DataGenerator( vocab_filepath=args.vocab_filepath, - normalizer_manifest_path=args.normalizer_manifest_path, - normalizer_num_samples=200, - max_duration=20.0, - min_duration=0.0, - stride_ms=10, - window_ms=20) + mean_std_filepath=args.mean_std_filepath, + augmentation_config=args.augmentation_config) train_generator = data_generator() test_generator = data_generator() + # create network config - dict_size = train_generator.vocabulary_size() # paddle.data_type.dense_array is used for variable batch input. - # the size 161 * 161 is only an placeholder value and the real shape - # of input batch data will be set at each batch. + # The size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be induced during training. audio_data = paddle.layer.data( name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) text_data = paddle.layer.data( name="transcript_text", - type=paddle.data_type.integer_value_sequence(dict_size)) + type=paddle.data_type.integer_value_sequence( + train_generator.vocab_size)) cost = deep_speech2( audio_data=audio_data, text_data=text_data, - dict_size=dict_size, + dict_size=train_generator.vocab_size, num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, rnn_size=args.rnn_layer_size, @@ -143,13 +146,13 @@ def train(): train_batch_reader = train_generator.batch_reader_creator( manifest_path=args.train_manifest_path, batch_size=args.batch_size, - sortagrad=True if args.init_model_path is None else False, + sortagrad=args.use_sortagrad if args.init_model_path is None else False, batch_shuffle=True) test_batch_reader = test_generator.batch_reader_creator( manifest_path=args.dev_manifest_path, batch_size=args.batch_size, + sortagrad=False, batch_shuffle=False) - feeding = train_generator.data_name_feeding() # create event handler def event_handler(event): @@ -158,8 +161,8 @@ def train(): cost_sum += event.cost cost_counter += 1 if event.batch_id % 50 == 0: - print "\nPass: %d, Batch: %d, TrainCost: %f" % ( - event.pass_id, event.batch_id, cost_sum / cost_counter) + print("\nPass: %d, Batch: %d, TrainCost: %f" % + (event.pass_id, event.batch_id, cost_sum / cost_counter)) cost_sum, cost_counter = 0.0, 0 with gzip.open("params.tar.gz", 'w') as f: parameters.to_tar(f) @@ -170,16 +173,17 @@ def train(): start_time = time.time() cost_sum, cost_counter = 0.0, 0 if isinstance(event, paddle.event.EndPass): - result = trainer.test(reader=test_batch_reader, feeding=feeding) - print "\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % ( - time.time() - start_time, event.pass_id, result.cost) + result = trainer.test( + reader=test_batch_reader, feeding=test_generator.feeding) + print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % + (time.time() - start_time, event.pass_id, result.cost)) # run train trainer.train( reader=train_batch_reader, event_handler=event_handler, num_passes=args.num_passes, - feeding=feeding) + feeding=train_generator.feeding) def main(): From b07ee84a1d613511193a486363937750880ea6fa Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Tue, 13 Jun 2017 23:16:07 +0800 Subject: [PATCH 2/5] Add function, class and module docs for data parts in DS2. --- compute_mean_std.py | 3 +- data_utils/audio.py | 232 ++++++++++++++++++--- data_utils/augmentor/augmentation.py | 60 +++++- data_utils/augmentor/base.py | 16 ++ data_utils/augmentor/volume_perturb.py | 40 ++++ data_utils/augmentor/volumn_perturb.py | 17 -- data_utils/data.py | 166 +++++++-------- data_utils/featurizer/audio_featurizer.py | 38 +++- data_utils/featurizer/speech_featurizer.py | 55 ++++- data_utils/featurizer/text_featurizer.py | 36 +++- data_utils/normalizer.py | 40 +++- data_utils/speech.py | 75 +++++++ data_utils/utils.py | 17 +- datasets/librispeech/librispeech.py | 16 +- decoder.py | 9 +- infer.py | 5 +- model.py | 9 +- train.py | 7 +- 18 files changed, 662 insertions(+), 179 deletions(-) create mode 100755 data_utils/augmentor/volume_perturb.py delete mode 100755 data_utils/augmentor/volumn_perturb.py create mode 100755 data_utils/speech.py diff --git a/compute_mean_std.py b/compute_mean_std.py index b3015df73..9c301c93f 100755 --- a/compute_mean_std.py +++ b/compute_mean_std.py @@ -1,3 +1,4 @@ +"""Compute mean and std for feature normalizer, and save to file.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -17,7 +18,7 @@ parser.add_argument( "(default: %(default)s)") parser.add_argument( "--num_samples", - default=500, + default=2000, type=int, help="Number of samples for computing mean and stddev. " "(default: %(default)s)") diff --git a/data_utils/audio.py b/data_utils/audio.py index 46b241201..916c8ac1a 100755 --- a/data_utils/audio.py +++ b/data_utils/audio.py @@ -1,3 +1,8 @@ +"""Contains the audio segment class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import numpy as np import io import soundfile @@ -5,64 +10,243 @@ import soundfile class AudioSegment(object): """Monaural audio segment abstraction. + + :param samples: Audio samples [num_samples x num_channels]. + :type samples: ndarray.float32 + :param sample_rate: Audio sample rate. + :type sample_rate: int + :raises TypeError: If the sample data type is not float or int. """ def __init__(self, samples, sample_rate): - if not samples.dtype == np.float32: - raise ValueError("Sample data type of [%s] is not supported.") - self._samples = samples + """Create audio segment from samples. + + Samples are convert float32 internally, with int scaled to [-1, 1]. + """ + self._samples = self._convert_samples_to_float32(samples) self._sample_rate = sample_rate if self._samples.ndim >= 2: self._samples = np.mean(self._samples, 1) + def __eq__(self, other): + """Return whether two objects are equal.""" + if type(other) is not type(self): + return False + if self._sample_rate != other._sample_rate: + return False + if self._samples.shape != other._samples.shape: + return False + if np.any(self.samples != other._samples): + return False + return True + + def __ne__(self, other): + """Return whether two objects are unequal.""" + return not self.__eq__(other) + + def __str__(self): + """Return human-readable representation of segment.""" + return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, " + "rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate, + self.duration, self.rms_db)) + @classmethod - def from_file(cls, filepath): - samples, sample_rate = soundfile.read(filepath, dtype='float32') + def from_file(cls, file): + """Create audio segment from audio file. + + :param filepath: Filepath or file object to audio file. + :type filepath: basestring|file + :return: Audio segment instance. + :rtype: AudioSegment + """ + samples, sample_rate = soundfile.read(file, dtype='float32') return cls(samples, sample_rate) @classmethod def from_bytes(cls, bytes): + """Create audio segment from a byte string containing audio samples. + + :param bytes: Byte string containing audio samples. + :type bytes: str + :return: Audio segment instance. + :rtype: AudioSegment + """ samples, sample_rate = soundfile.read( io.BytesIO(bytes), dtype='float32') return cls(samples, sample_rate) + def to_wav_file(self, filepath, dtype='float32'): + """Save audio segment to disk as wav file. + + :param filepath: WAV filepath or file object to save the + audio segment. + :type filepath: basestring|file + :param dtype: Subtype for audio file. Options: 'int16', 'int32', + 'float32', 'float64'. Default is 'float32'. + :type dtype: str + :raises TypeError: If dtype is not supported. + """ + samples = self._convert_samples_from_float32(self._samples, dtype) + subtype_map = { + 'int16': 'PCM_16', + 'int32': 'PCM_32', + 'float32': 'FLOAT', + 'float64': 'DOUBLE' + } + soundfile.write( + filepath, + samples, + self._sample_rate, + format='WAV', + subtype=subtype_map[dtype]) + + def to_bytes(self, dtype='float32'): + """Create a byte string containing the audio content. + + :param dtype: Data type for export samples. Options: 'int16', 'int32', + 'float32', 'float64'. Default is 'float32'. + :type dtype: str + :return: Byte string containing audio content. + :rtype: str + """ + samples = self._convert_samples_from_float32(self._samples, dtype) + return samples.tostring() + def apply_gain(self, gain): - self.samples *= 10.**(gain / 20.) + """Apply gain in decibels to samples. + + Note that this is an in-place transformation. + + :param gain: Gain in decibels to apply to samples. + :type gain: float + """ + self._samples *= 10.**(gain / 20.) + + def change_speed(self, speed_rate): + """Change the audio speed by linear interpolation. + + Note that this is an in-place transformation. + + :param speed_rate: Rate of speed change: + speed_rate > 1.0, speed up the audio; + speed_rate = 1.0, unchanged; + speed_rate < 1.0, slow down the audio; + speed_rate <= 0.0, not allowed, raise ValueError. + :type speed_rate: float + :raises ValueError: If speed_rate <= 0.0. + """ + if speed_rate <= 0: + raise ValueError("speed_rate should be greater than zero.") + old_length = self._samples.shape[0] + new_length = int(old_length / speed_rate) + old_indices = np.arange(old_length) + new_indices = np.linspace(start=0, stop=old_length, num=new_length) + self._samples = np.interp(new_indices, old_indices, self._samples) + + def normalize(self, target_sample_rate): + raise NotImplementedError() def resample(self, target_sample_rate): raise NotImplementedError() - def change_speed(self, rate): + def pad_silence(self, duration, sides='both'): + raise NotImplementedError() + + def subsegment(self, start_sec=None, end_sec=None): + raise NotImplementedError() + + def convolve(self, filter, allow_resample=False): + raise NotImplementedError() + + def convolve_and_normalize(self, filter, allow_resample=False): raise NotImplementedError() @property def samples(self): + """Return audio samples. + + :return: Audio samples. + :rtype: ndarray + """ return self._samples.copy() @property def sample_rate(self): + """Return audio sample rate. + + :return: Audio sample rate. + :rtype: int + """ return self._sample_rate @property - def duration(self): - return self._samples.shape[0] / float(self._sample_rate) - + def num_samples(self): + """Return number of samples. -class SpeechSegment(AudioSegment): - def __init__(self, samples, sample_rate, transcript): - AudioSegment.__init__(self, samples, sample_rate) - self._transcript = transcript + :return: Number of samples. + :rtype: int + """ + return self._samples.shape(0) - @classmethod - def from_file(cls, filepath, transcript): - audio = AudioSegment.from_file(filepath) - return cls(audio.samples, audio.sample_rate, transcript) + @property + def duration(self): + """Return audio duration. - @classmethod - def from_bytes(cls, bytes, transcript): - audio = AudioSegment.from_bytes(bytes) - return cls(audio.samples, audio.sample_rate, transcript) + :return: Audio duration in seconds. + :rtype: float + """ + return self._samples.shape[0] / float(self._sample_rate) @property - def transcript(self): - return self._transcript + def rms_db(self): + """Return root mean square energy of the audio in decibels. + + :return: Root mean square energy in decibels. + :rtype: float + """ + # square root => multiply by 10 instead of 20 for dBs + mean_square = np.mean(self._samples**2) + return 10 * np.log10(mean_square) + + def _convert_samples_to_float32(self, samples): + """Convert sample type to float32. + + Audio sample type is usually integer or float-point. + Integers will be scaled to [-1, 1] in float32. + """ + float32_samples = samples.astype('float32') + if samples.dtype in np.sctypes['int']: + bits = np.iinfo(samples.dtype).bits + float32_samples *= (1. / 2**(bits - 1)) + elif samples.dtype in np.sctypes['float']: + pass + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return float32_samples + + def _convert_samples_from_float32(self, samples, dtype): + """Convert sample type from float32 to dtype. + + Audio sample type is usually integer or float-point. For integer + type, float32 will be rescaled from [-1, 1] to the maximum range + supported by the integer type. + + This is for writing a audio file. + """ + dtype = np.dtype(dtype) + output_samples = samples.copy() + if dtype in np.sctypes['int']: + bits = np.iinfo(dtype).bits + output_samples *= (2**(bits - 1) / 1.) + min_val = np.iinfo(dtype).min + max_val = np.iinfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + elif samples.dtype in np.sctypes['float']: + min_val = np.finfo(dtype).min + max_val = np.finfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return output_samples.astype(dtype) diff --git a/data_utils/augmentor/augmentation.py b/data_utils/augmentor/augmentation.py index 3a1426a1f..abe1a0ec8 100755 --- a/data_utils/augmentor/augmentation.py +++ b/data_utils/augmentor/augmentation.py @@ -1,38 +1,80 @@ +"""Contains the data augmentation pipeline.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import json import random -from data_utils.augmentor.volumn_perturb import VolumnPerturbAugmentor +from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor class AugmentationPipeline(object): + """Build a pre-processing pipeline with various augmentation models.Such a + data augmentation pipeline is oftern leveraged to augment the training + samples to make the model invariant to certain types of perturbations in the + real world, improving model's generalization ability. + + The pipeline is built according the the augmentation configuration in json + string, e.g. + + .. code-block:: + + '[{"type": "volume", + "params": {"min_gain_dBFS": -15, + "max_gain_dBFS": 15}, + "prob": 0.5}, + {"type": "speed", + "params": {"min_speed_rate": 0.8, + "max_speed_rate": 1.2}, + "prob": 0.5} + ]' + + This augmentation configuration inserts two augmentation models + into the pipeline, with one is VolumePerturbAugmentor and the other + SpeedPerturbAugmentor. "prob" indicates the probability of the current + augmentor to take effect. + + :param augmentation_config: Augmentation configuration in json string. + :type augmentation_config: str + :param random_seed: Random seed. + :type random_seed: int + :raises ValueError: If the augmentation json config is in incorrect format". + """ + def __init__(self, augmentation_config, random_seed=0): self._rng = random.Random(random_seed) self._augmentors, self._rates = self._parse_pipeline_from( augmentation_config) def transform_audio(self, audio_segment): + """Run the pre-processing pipeline for data augmentation. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to process. + :type audio_segment: AudioSegmenet|SpeechSegment + """ for augmentor, rate in zip(self._augmentors, self._rates): if self._rng.uniform(0., 1.) <= rate: augmentor.transform_audio(audio_segment) def _parse_pipeline_from(self, config_json): + """Parse the config json to build a augmentation pipelien.""" try: configs = json.loads(config_json) + augmentors = [ + self._get_augmentor(config["type"], config["params"]) + for config in configs + ] + rates = [config["prob"] for config in configs] except Exception as e: - raise ValueError("Augmentation config json format error: " + raise ValueError("Failed to parse the augmentation config json: " "%s" % str(e)) - augmentors = [ - self._get_augmentor(config["type"], config["params"]) - for config in configs - ] - rates = [config["rate"] for config in configs] return augmentors, rates def _get_augmentor(self, augmentor_type, params): - if augmentor_type == "volumn": - return VolumnPerturbAugmentor(self._rng, **params) + """Return an augmentation model by the type name, and pass in params.""" + if augmentor_type == "volume": + return VolumePerturbAugmentor(self._rng, **params) else: raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/data_utils/augmentor/base.py b/data_utils/augmentor/base.py index e801b9b18..a323165aa 100755 --- a/data_utils/augmentor/base.py +++ b/data_utils/augmentor/base.py @@ -1,3 +1,4 @@ +"""Contains the abstract base class for augmentation models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -6,6 +7,11 @@ from abc import ABCMeta, abstractmethod class AugmentorBase(object): + """Abstract base class for augmentation model (augmentor) class. + All augmentor classes should inherit from this class, and implement the + following abstract methods. + """ + __metaclass__ = ABCMeta @abstractmethod @@ -14,4 +20,14 @@ class AugmentorBase(object): @abstractmethod def transform_audio(self, audio_segment): + """Adds various effects to the input audio segment. Such effects + will augment the training data to make the model invariant to certain + types of perturbations in the real world, improving model's + generalization ability. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegmenet|SpeechSegment + """ pass diff --git a/data_utils/augmentor/volume_perturb.py b/data_utils/augmentor/volume_perturb.py new file mode 100755 index 000000000..a5a9f6cad --- /dev/null +++ b/data_utils/augmentor/volume_perturb.py @@ -0,0 +1,40 @@ +"""Contains the volume perturb augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class VolumePerturbAugmentor(AugmentorBase): + """Augmentation model for adding random volume perturbation. + + This is used for multi-loudness training of PCEN. See + + https://arxiv.org/pdf/1607.05666v1.pdf + + for more details. + + :param rng: Random generator object. + :type rng: random.Random + :param min_gain_dBFS: Minimal gain in dBFS. + :type min_gain_dBFS: float + :param max_gain_dBFS: Maximal gain in dBFS. + :type max_gain_dBFS: float + """ + + def __init__(self, rng, min_gain_dBFS, max_gain_dBFS): + self._min_gain_dBFS = min_gain_dBFS + self._max_gain_dBFS = max_gain_dBFS + self._rng = rng + + def transform_audio(self, audio_segment): + """Change audio loadness. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegmenet|SpeechSegment + """ + gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) + audio_segment.apply_gain(gain) diff --git a/data_utils/augmentor/volumn_perturb.py b/data_utils/augmentor/volumn_perturb.py deleted file mode 100755 index dd1ba53a7..000000000 --- a/data_utils/augmentor/volumn_perturb.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import random -from data_utils.augmentor.base import AugmentorBase - - -class VolumnPerturbAugmentor(AugmentorBase): - def __init__(self, rng, min_gain_dBFS, max_gain_dBFS): - self._min_gain_dBFS = min_gain_dBFS - self._max_gain_dBFS = max_gain_dBFS - self._rng = rng - - def transform_audio(self, audio_segment): - gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS) - audio_segment.apply_gain(gain) diff --git a/data_utils/data.py b/data_utils/data.py index 630007932..48e03fe85 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -1,8 +1,6 @@ +"""Contains data generator for orgnaizing various audio data preprocessing +pipeline and offering data reader interface of PaddlePaddle requirements. """ - Providing basic audio data preprocessing pipeline, and offering - both instance-level and batch-level data reader interfaces. -""" - from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -13,42 +11,41 @@ import paddle.v2 as paddle from data_utils import utils from data_utils.augmentor.augmentation import AugmentationPipeline from data_utils.featurizer.speech_featurizer import SpeechFeaturizer -from data_utils.audio import SpeechSegment +from data_utils.speech import SpeechSegment from data_utils.normalizer import FeatureNormalizer class DataGenerator(object): """ DataGenerator provides basic audio data preprocessing pipeline, and offers - both instance-level and batch-level data reader interfaces. - Normalized FFT are used as audio features here. + data reader interfaces of PaddlePaddle requirements. - :param vocab_filepath: Vocabulary file path for indexing tokenized - transcriptions. + :param vocab_filepath: Vocabulary filepath for indexing tokenized + transcripts. :type vocab_filepath: basestring - :param normalizer_manifest_path: Manifest filepath for collecting feature - normalization statistics, e.g. mean, std. - :type normalizer_manifest_path: basestring - :param normalizer_num_samples: Number of instances sampled for collecting - feature normalization statistics. - Default is 100. - :type normalizer_num_samples: int - :param max_duration: Audio clips with duration (in seconds) greater than - this will be discarded. Default is 20.0. + :param mean_std_filepath: File containing the pre-computed mean and stddev. + :type mean_std_filepath: None|basestring + :param augmentation_config: Augmentation configuration in json string. + Details see AugmentationPipeline.__doc__. + :type augmentation_config: str + :param max_duration: Audio with duration (in seconds) greater than + this will be discarded. :type max_duration: float - :param min_duration: Audio clips with duration (in seconds) smaller than - this will be discarded. Default is 0.0. + :param min_duration: Audio with duration (in seconds) smaller than + this will be discarded. :type min_duration: float :param stride_ms: Striding size (in milliseconds) for generating frames. - Default is 10.0. :type stride_ms: float - :param window_ms: Window size (in milliseconds) for frames. Default is 20.0. + :param window_ms: Window size (in milliseconds) for generating frames. :type window_ms: float - :param max_frequency: Maximun frequency for FFT features. FFT features of - frequency larger than this will be discarded. - If set None, all features will be kept. - Default is None. - :type max_frequency: float + :param max_freq: Used when specgram_type is 'linear', only FFT bins + corresponding to frequencies between [0, max_freq] are + returned. + :types max_freq: None|float + :param specgram_type: Specgram feature type. Options: 'linear'. + :type specgram_type: str + :param random_seed: Random seed. + :type random_seed: int """ def __init__(self, @@ -60,6 +57,7 @@ class DataGenerator(object): stride_ms=10.0, window_ms=20.0, max_freq=None, + specgram_type='linear', random_seed=0): self._max_duration = max_duration self._min_duration = min_duration @@ -68,46 +66,49 @@ class DataGenerator(object): augmentation_config=augmentation_config, random_seed=random_seed) self._speech_featurizer = SpeechFeaturizer( vocab_filepath=vocab_filepath, + specgram_type=specgram_type, stride_ms=stride_ms, window_ms=window_ms, - max_freq=max_freq, - random_seed=random_seed) + max_freq=max_freq) self._rng = random.Random(random_seed) self._epoch = 0 def batch_reader_creator(self, manifest_path, batch_size, + min_batch_size=1, padding_to=-1, flatten=False, sortagrad=False, batch_shuffle=False): """ - Batch data reader creator for audio data. Creat a callable function to - produce batches of data. + Batch data reader creator for audio data. Return a callable generator + function to produce batches of data. - Audio features will be padded with zeros to make each instance in the - batch to share the same audio feature shape. + Audio features within one batch will be padded with zeros to have the + same shape, or a user-defined shape. - :param manifest_path: Filepath of manifest for audio clip files. + :param manifest_path: Filepath of manifest for audio files. :type manifest_path: basestring - :param batch_size: Instance number in a batch. + :param batch_size: Number of instances in a batch. :type batch_size: int - :param padding_to: If set -1, the maximun column numbers in the batch - will be used as the target size for padding. - Otherwise, `padding_to` will be the target size. - Default is -1. + :param min_batch_size: Any batch with batch size smaller than this will + be discarded. (To be deprecated in the future.) + :type min_batch_size: int + :param padding_to: If set -1, the maximun shape in the batch + will be used as the target shape for padding. + Otherwise, `padding_to` will be the target shape. :type padding_to: int - :param flatten: If set True, audio data will be flatten to be a 1-dim - ndarray. Otherwise, 2-dim ndarray. Default is False. + :param flatten: If set True, audio features will be flatten to 1darray. :type flatten: bool - :param sortagrad: Sort the audio clips by duration in the first epoc - if set True. + :param sortagrad: If set True, sort the instances by audio duration + in the first epoch for speed up training. :type sortagrad: bool - :param batch_shuffle: Shuffle the audio clips if set True. It is - not a thorough instance-wise shuffle, but a - specific batch-wise shuffle. For more details, - please see `_batch_shuffle` function. + :param batch_shuffle: If set True, instances are batch-wise shuffled. + For more details, please see + ``_batch_shuffle.__doc__``. + If sortagrad is True, batch_shuffle is disabled + for the first epoch. :type batch_shuffle: bool :return: Batch reader function, producing batches of data when called. :rtype: callable @@ -132,7 +133,7 @@ class DataGenerator(object): if len(batch) == batch_size: yield self._padding_batch(batch, padding_to, flatten) batch = [] - if len(batch) > 0: + if len(batch) >= min_batch_size: yield self._padding_batch(batch, padding_to, flatten) self._epoch += 1 @@ -140,20 +141,33 @@ class DataGenerator(object): @property def feeding(self): - """Returns data_reader's feeding dict.""" + """Returns data reader's feeding dict. + + :return: Data feeding dict. + :rtype: dict + """ return {"audio_spectrogram": 0, "transcript_text": 1} @property def vocab_size(self): - """Returns vocabulary size.""" + """Return the vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ return self._speech_featurizer.vocab_size @property def vocab_list(self): - """Returns vocabulary list.""" + """Return the vocabulary in list. + + :return: Vocabulary in list. + :rtype: list + """ return self._speech_featurizer.vocab_list def _process_utterance(self, filename, transcript): + """Load, augment, featurize and normalize for speech data.""" speech_segment = SpeechSegment.from_file(filename, transcript) self._augmentation_pipeline.transform_audio(speech_segment) specgram, text_ids = self._speech_featurizer.featurize(speech_segment) @@ -162,16 +176,11 @@ class DataGenerator(object): def _instance_reader_creator(self, manifest): """ - Instance reader creator for audio data. Creat a callable function to - produce instances of data. + Instance reader creator. Create a callable function to produce + instances of data. - Instance: a tuple of a numpy ndarray of audio spectrogram and a list of - tokenized and indexed transcription text. - - :param manifest: Filepath of manifest for audio clip files. - :type manifest: basestring - :return: Data reader function. - :rtype: callable + Instance: a tuple of ndarray of audio spectrogram and a list of + token indices for transcript. """ def reader(): @@ -183,24 +192,22 @@ class DataGenerator(object): def _padding_batch(self, batch, padding_to=-1, flatten=False): """ - Padding audio part of features (only in the time axis -- column axis) - with zeros, to make each instance in the batch share the same - audio feature shape. + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one bach. - If `padding_to` is set -1, the maximun column numbers in the batch will - be used as the target size. Otherwise, `padding_to` will be the target - size. Default is -1. + If ``padding_to`` is -1, the maximun shape in the batch will be used + as the target shape for padding. Otherwise, `padding_to` will be the + target shape (only refers to the second axis). - If `flatten` is set True, audio data will be flatten to be a 1-dim - ndarray. Default is False. + If `flatten` is True, features will be flatten to 1darray. """ new_batch = [] # get target shape max_length = max([audio.shape[1] for audio, text in batch]) if padding_to != -1: if padding_to < max_length: - raise ValueError("If padding_to is not -1, it should be greater" - " or equal to the original instance length.") + raise ValueError("If padding_to is not -1, it should be larger " + "than any instance's shape in the batch") max_length = padding_to # padding for audio, text in batch: @@ -212,28 +219,21 @@ class DataGenerator(object): return new_batch def _batch_shuffle(self, manifest, batch_size): - """ - The instances have different lengths and they cannot be - combined into a single matrix multiplication. It usually - sorts the training examples by length and combines only - similarly-sized instances into minibatches, pads with - silence when necessary so that all instances in a batch - have the same length. This batch shuffle fuction is used - to make similarly-sized instances into minibatches and - make a batch-wise shuffle. + """Put similarly-sized instances into minibatches for better efficiency + and make a batch-wise shuffle. 1. Sort the audio clips by duration. 2. Generate a random number `k`, k in [0, batch_size). - 3. Randomly remove `k` instances in order to make different mini-batches, - then make minibatches and each minibatch size is batch_size. + 3. Randomly shift `k` instances in order to create different batches + for different epochs. Create minibatches. 4. Shuffle the minibatches. - :param manifest: manifest file. + :param manifest: Manifest contents. List of dict. :type manifest: list :param batch_size: Batch size. This size is also used for generate a random number for batch shuffle. :type batch_size: int - :return: batch shuffled mainifest. + :return: Batch shuffled mainifest. :rtype: list """ manifest.sort(key=lambda x: x["duration"]) diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py index 5d9c68836..9f9d4e505 100755 --- a/data_utils/featurizer/audio_featurizer.py +++ b/data_utils/featurizer/audio_featurizer.py @@ -1,30 +1,54 @@ +"""Contains the audio featurizer class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -import random from data_utils import utils from data_utils.audio import AudioSegment class AudioFeaturizer(object): + """Audio featurizer, for extracting features from audio contents of + AudioSegment or SpeechSegment. + + Currently, it only supports feature type of linear spectrogram. + + :param specgram_type: Specgram feature type. Options: 'linear'. + :type specgram_type: str + :param stride_ms: Striding size (in milliseconds) for generating frames. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for generating frames. + :type window_ms: float + :param max_freq: Used when specgram_type is 'linear', only FFT bins + corresponding to frequencies between [0, max_freq] are + returned. + :types max_freq: None|float + """ + def __init__(self, specgram_type='linear', stride_ms=10.0, window_ms=20.0, - max_freq=None, - random_seed=0): + max_freq=None): self._specgram_type = specgram_type self._stride_ms = stride_ms self._window_ms = window_ms self._max_freq = max_freq def featurize(self, audio_segment): + """Extract audio features from AudioSegment or SpeechSegment. + + :param audio_segment: Audio/speech segment to extract features from. + :type audio_segment: AudioSegment|SpeechSegment + :return: Spectrogram audio feature in 2darray. + :rtype: ndarray + """ return self._compute_specgram(audio_segment.samples, audio_segment.sample_rate) def _compute_specgram(self, samples, sample_rate): + """Extract various audio features.""" if self._specgram_type == 'linear': return self._compute_linear_specgram( samples, sample_rate, self._stride_ms, self._window_ms, @@ -40,9 +64,7 @@ class AudioFeaturizer(object): window_ms=20.0, max_freq=None, eps=1e-14): - """Laod audio data and calculate the log of spectrogram by FFT. - Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech - """ + """Compute the linear spectrogram from FFT energy.""" if max_freq is None: max_freq = sample_rate / 2 if max_freq > sample_rate / 2: @@ -62,9 +84,7 @@ class AudioFeaturizer(object): return np.log(specgram[:ind, :] + eps) def _specgram_real(self, samples, window_size, stride_size, sample_rate): - """Compute the spectrogram by FFT for a discrete real signal. - Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech - """ + """Compute the spectrogram for samples from a real signal.""" # extract strided windows truncate_size = (len(samples) - window_size) % stride_size samples = samples[:len(samples) - truncate_size] diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py index 06af7a026..770204559 100755 --- a/data_utils/featurizer/speech_featurizer.py +++ b/data_utils/featurizer/speech_featurizer.py @@ -1,3 +1,4 @@ +"""Contains the speech featurizer class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -7,26 +8,70 @@ from data_utils.featurizer.text_featurizer import TextFeaturizer class SpeechFeaturizer(object): + """Speech featurizer, for extracting features from both audio and transcript + contents of SpeechSegment. + + Currently, for audio parts, it only supports feature type of linear + spectrogram; for transcript parts, it only supports char-level tokenizing + and conversion into a list of token indices. Note that the token indexing + order follows the given vocabulary file. + + :param vocab_filepath: Filepath to load vocabulary for token indices + conversion. + :type specgram_type: basestring + :param specgram_type: Specgram feature type. Options: 'linear'. + :type specgram_type: str + :param stride_ms: Striding size (in milliseconds) for generating frames. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for generating frames. + :type window_ms: float + :param max_freq: Used when specgram_type is 'linear', only FFT bins + corresponding to frequencies between [0, max_freq] are + returned. + :types max_freq: None|float + """ + def __init__(self, vocab_filepath, specgram_type='linear', stride_ms=10.0, window_ms=20.0, - max_freq=None, - random_seed=0): - self._audio_featurizer = AudioFeaturizer( - specgram_type, stride_ms, window_ms, max_freq, random_seed) + max_freq=None): + self._audio_featurizer = AudioFeaturizer(specgram_type, stride_ms, + window_ms, max_freq) self._text_featurizer = TextFeaturizer(vocab_filepath) def featurize(self, speech_segment): + """Extract features for speech segment. + + 1. For audio parts, extract the audio features. + 2. For transcript parts, convert text string to a list of token indices + in char-level. + + :param audio_segment: Speech segment to extract features from. + :type audio_segment: SpeechSegment + :return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of + char-level token indices. + :rtype: tuple + """ audio_feature = self._audio_featurizer.featurize(speech_segment) - text_ids = self._text_featurizer.text2ids(speech_segment.transcript) + text_ids = self._text_featurizer.featurize(speech_segment.transcript) return audio_feature, text_ids @property def vocab_size(self): + """Return the vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ return self._text_featurizer.vocab_size @property def vocab_list(self): + """Return the vocabulary in list. + + :return: Vocabulary in list. + :rtype: list + """ return self._text_featurizer.vocab_list diff --git a/data_utils/featurizer/text_featurizer.py b/data_utils/featurizer/text_featurizer.py index 7e4b69d7b..4f9a49b59 100755 --- a/data_utils/featurizer/text_featurizer.py +++ b/data_utils/featurizer/text_featurizer.py @@ -1,3 +1,4 @@ +"""Contains the text featurizer class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -6,26 +7,53 @@ import os class TextFeaturizer(object): + """Text featurizer, for processing or extracting features from text. + + Currently, it only supports char-level tokenizing and conversion into + a list of token indices. Note that the token indexing order follows the + given vocabulary file. + + :param vocab_filepath: Filepath to load vocabulary for token indices + conversion. + :type specgram_type: basestring + """ + def __init__(self, vocab_filepath): self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( vocab_filepath) - def text2ids(self, text): + def featurize(self, text): + """Convert text string to a list of token indices in char-level.Note + that the token indexing order follows the given vocabulary file. + + :param text: Text to process. + :type text: basestring + :return: List of char-level token indices. + :rtype: list + """ tokens = self._char_tokenize(text) return [self._vocab_dict[token] for token in tokens] - def ids2text(self, ids): - return ''.join([self._vocab_list[id] for id in ids]) - @property def vocab_size(self): + """Return the vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ return len(self._vocab_list) @property def vocab_list(self): + """Return the vocabulary in list. + + :return: Vocabulary in list. + :rtype: list + """ return self._vocab_list def _char_tokenize(self, text): + """Character tokenizer.""" return list(text.strip()) def _load_vocabulary_from_file(self, vocab_filepath): diff --git a/data_utils/normalizer.py b/data_utils/normalizer.py index 364600af8..c123d25d2 100755 --- a/data_utils/normalizer.py +++ b/data_utils/normalizer.py @@ -1,3 +1,4 @@ +"""Contains feature normalizers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -9,6 +10,28 @@ from data_utils.audio import AudioSegment class FeatureNormalizer(object): + """Feature normalizer. Normalize features to be of zero mean and unit + stddev. + + if mean_std_filepath is provided (not None), the normalizer will directly + initilize from the file. Otherwise, both manifest_path and featurize_func + should be given for on-the-fly mean and stddev computing. + + :param mean_std_filepath: File containing the pre-computed mean and stddev. + :type mean_std_filepath: None|basestring + :param manifest_path: Manifest of instances for computing mean and stddev. + :type meanifest_path: None|basestring + :param featurize_func: Function to extract features. It should be callable + with ``featurize_func(audio_segment)``. + :type featurize_func: None|callable + :param num_samples: Number of random samples for computing mean and stddev. + :type num_samples: int + :param random_seed: Random seed for sampling instances. + :type random_seed: int + :raises ValueError: If both mean_std_filepath and manifest_path + (or both mean_std_filepath and featurize_func) are None. + """ + def __init__(self, mean_std_filepath, manifest_path=None, @@ -25,18 +48,33 @@ class FeatureNormalizer(object): self._read_mean_std_from_file(mean_std_filepath) def apply(self, features, eps=1e-14): - """Normalize features to be of zero mean and unit stddev.""" + """Normalize features to be of zero mean and unit stddev. + + :param features: Input features to be normalized. + :type features: ndarray + :param eps: added to stddev to provide numerical stablibity. + :type eps: float + :return: Normalized features. + :rtype: ndarray + """ return (features - self._mean) / (self._std + eps) def write_to_file(self, filepath): + """Write the mean and stddev to the file. + + :param filepath: File to write mean and stddev. + :type filepath: basestring + """ np.savez(filepath, mean=self._mean, std=self._std) def _read_mean_std_from_file(self, filepath): + """Load mean and std from file.""" npzfile = np.load(filepath) self._mean = npzfile["mean"] self._std = npzfile["std"] def _compute_mean_std(self, manifest_path, featurize_func, num_samples): + """Compute mean and std from randomly sampled instances.""" manifest = utils.read_manifest(manifest_path) sampled_manifest = self._rng.sample(manifest, num_samples) features = [] diff --git a/data_utils/speech.py b/data_utils/speech.py new file mode 100755 index 000000000..48db595b4 --- /dev/null +++ b/data_utils/speech.py @@ -0,0 +1,75 @@ +"""Contains the speech segment class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.audio import AudioSegment + + +class SpeechSegment(AudioSegment): + """Speech segment abstraction, a subclass of AudioSegment, + with an additional transcript. + + :param samples: Audio samples [num_samples x num_channels]. + :type samples: ndarray.float32 + :param sample_rate: Audio sample rate. + :type sample_rate: int + :param transcript: Transcript text for the speech. + :type transript: basestring + :raises TypeError: If the sample data type is not float or int. + """ + + def __init__(self, samples, sample_rate, transcript): + AudioSegment.__init__(self, samples, sample_rate) + self._transcript = transcript + + def __eq__(self, other): + """Return whether two objects are equal. + """ + if not AudioSegment.__eq__(self, other): + return False + if self._transcript != other._transcript: + return False + return True + + def __ne__(self, other): + """Return whether two objects are unequal.""" + return not self.__eq__(other) + + @classmethod + def from_file(cls, filepath, transcript): + """Create speech segment from audio file and corresponding transcript. + + :param filepath: Filepath or file object to audio file. + :type filepath: basestring|file + :param transcript: Transcript text for the speech. + :type transript: basestring + :return: Audio segment instance. + :rtype: AudioSegment + """ + audio = AudioSegment.from_file(filepath) + return cls(audio.samples, audio.sample_rate, transcript) + + @classmethod + def from_bytes(cls, bytes, transcript): + """Create speech segment from a byte string and corresponding + transcript. + + :param bytes: Byte string containing audio samples. + :type bytes: str + :param transcript: Transcript text for the speech. + :type transript: basestring + :return: Audio segment instance. + :rtype: AudioSegment + """ + audio = AudioSegment.from_bytes(bytes) + return cls(audio.samples, audio.sample_rate, transcript) + + @property + def transcript(self): + """Return the transcript text. + + :return: Transcript text for the speech. + :rtype: basestring + """ + return self._transcript diff --git a/data_utils/utils.py b/data_utils/utils.py index 2a916b54f..3f1165718 100755 --- a/data_utils/utils.py +++ b/data_utils/utils.py @@ -1,3 +1,4 @@ +"""Contains data helper functions.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -6,7 +7,21 @@ import json def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): - """Load and parse manifest file.""" + """Load and parse manifest file. + + Instances with durations outside [min_duration, max_duration] will be + filtered out. + + :param manifest_path: Manifest file to load and parse. + :type manifest_path: basestring + :param max_duration: Maximal duration in seconds for instance filter. + :type max_duration: float + :param min_duration: Minimal duration in seconds for instance filter. + :type min_duration: float + :return: Manifest parsing results. List of dict. + :rtype: list + :raises IOError: If failed to parse the manifest. + """ manifest = [] for json_line in open(manifest_path): try: diff --git a/datasets/librispeech/librispeech.py b/datasets/librispeech/librispeech.py index 1ba2a4422..faf038cc1 100644 --- a/datasets/librispeech/librispeech.py +++ b/datasets/librispeech/librispeech.py @@ -1,13 +1,14 @@ -""" - Download, unpack and create manifest json files for the Librespeech dataset. +"""Prepare Librispeech ASR datasets. - A manifest is a json file summarizing filelist in a data set, with each line - containing the meta data (i.e. audio filepath, transcription text, audio - duration) of each audio file in the data set. +Download, unpack and create manifest files. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. """ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function -import paddle.v2 as paddle -from paddle.v2.dataset.common import md5file import distutils.util import os import wget @@ -15,6 +16,7 @@ import tarfile import argparse import soundfile import json +from paddle.v2.dataset.common import md5file DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') diff --git a/decoder.py b/decoder.py index 7c4b95263..8314885ce 100755 --- a/decoder.py +++ b/decoder.py @@ -1,9 +1,10 @@ -""" - CTC-like decoder utilitis. -""" +"""Contains various CTC decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function -from itertools import groupby import numpy as np +from itertools import groupby def ctc_best_path_decode(probs_seq, vocabulary): diff --git a/infer.py b/infer.py index eb31254ce..f7c99df11 100644 --- a/infer.py +++ b/infer.py @@ -1,7 +1,4 @@ -""" - Inference for a simplifed version of Baidu DeepSpeech2 model. -""" - +"""Inferer for DeepSpeech2 model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function diff --git a/model.py b/model.py index 13ff829b9..cb0b4ecbb 100644 --- a/model.py +++ b/model.py @@ -1,11 +1,10 @@ -""" - A simplifed version of Baidu DeepSpeech2 model. -""" +"""Contains DeepSpeech2 model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import paddle.v2 as paddle -#TODO: add bidirectional rnn. - def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride, padding, act): diff --git a/train.py b/train.py index c6aa97527..7ac4626f4 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,4 @@ -""" - Trainer for a simplifed version of Baidu DeepSpeech2 model. -""" - +"""Trainer for DeepSpeech2 model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -164,7 +161,7 @@ def train(): print("\nPass: %d, Batch: %d, TrainCost: %f" % (event.pass_id, event.batch_id, cost_sum / cost_counter)) cost_sum, cost_counter = 0.0, 0 - with gzip.open("params.tar.gz", 'w') as f: + with gzip.open("params_tmp.tar.gz", 'w') as f: parameters.to_tar(f) else: sys.stdout.write('.') From 1cef98f2101b37c9ff63a02ed6955c99f5edb09e Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Tue, 13 Jun 2017 23:33:38 +0800 Subject: [PATCH 3/5] Update README.md for DS2. --- README.md | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 7a372e9be..23e0b412b 100644 --- a/README.md +++ b/README.md @@ -16,34 +16,48 @@ For some machines, we also need to install libsndfile1. Details to be added. ### Preparing Data ``` -cd data -python librispeech.py -cat manifest.libri.train-* > manifest.libri.train-all +cd datasets +sh run_all.sh cd .. ``` -After running librispeech.py, we have several "manifest" json files named with a prefix `manifest.libri.`. A manifest file summarizes a speech data set, with each line containing the meta data (i.e. audio filepath, transcription text, audio duration) of each audio file within the data set, in json format. +`sh run_all.sh` prepares all ASR datasets (currently, only LibriSpeech available). After running, we have several summarization manifest files in json-format. -By `cat manifest.libri.train-* > manifest.libri.train-all`, we simply merge the three seperate sample sets of LibriSpeech (train-clean-100, train-clean-360, train-other-500) into one training set. This is a simple way for merging different data sets. +A manifest file summarizes a speech data set, with each line containing the meta data (i.e. audio filepath, transcript text, audio duration) of each audio file within the data set, in json format. Manifest file serves as an interface informing our system of where and what to read the speech samples. + + +More help for arguments: + +``` +python datasets/librispeech/librispeech.py --help +``` + +### Preparing for Training + +``` +python compute_mean_std.py +``` + +`python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. More help for arguments: ``` -python librispeech.py --help +python compute_mean_std.py --help ``` -### Traininig +### Training For GPU Training: ``` -CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 --train_manifest_path ./data/manifest.libri.train-all +CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --trainer_count 4 ``` For CPU Training: ``` -python train.py --trainer_count 8 --use_gpu False -- train_manifest_path ./data/manifest.libri.train-all +python train.py --trainer_count 8 --use_gpu False ``` More help for arguments: @@ -55,7 +69,7 @@ python train.py --help ### Inferencing ``` -python infer.py +CUDA_VISIBLE_DEVICES=0 python infer.py ``` More help for arguments: From 04a225ae4f8f7f4af068207627bb65b93bdd5fe6 Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Wed, 14 Jun 2017 18:14:50 +0800 Subject: [PATCH 4/5] Enable min_batch_num in train.py and update train info print. --- compute_mean_std.py | 0 data_utils/__init__.py | 0 data_utils/audio.py | 0 data_utils/augmentor/__init__.py | 0 data_utils/augmentor/augmentation.py | 0 data_utils/augmentor/base.py | 0 data_utils/augmentor/volume_perturb.py | 0 data_utils/featurizer/__init__.py | 0 data_utils/featurizer/audio_featurizer.py | 0 data_utils/featurizer/speech_featurizer.py | 0 data_utils/featurizer/text_featurizer.py | 0 data_utils/normalizer.py | 0 data_utils/speech.py | 0 data_utils/utils.py | 0 datasets/run_all.sh | 0 decoder.py | 0 train.py | 10 ++++++---- 17 files changed, 6 insertions(+), 4 deletions(-) mode change 100755 => 100644 compute_mean_std.py mode change 100755 => 100644 data_utils/__init__.py mode change 100755 => 100644 data_utils/audio.py mode change 100755 => 100644 data_utils/augmentor/__init__.py mode change 100755 => 100644 data_utils/augmentor/augmentation.py mode change 100755 => 100644 data_utils/augmentor/base.py mode change 100755 => 100644 data_utils/augmentor/volume_perturb.py mode change 100755 => 100644 data_utils/featurizer/__init__.py mode change 100755 => 100644 data_utils/featurizer/audio_featurizer.py mode change 100755 => 100644 data_utils/featurizer/speech_featurizer.py mode change 100755 => 100644 data_utils/featurizer/text_featurizer.py mode change 100755 => 100644 data_utils/normalizer.py mode change 100755 => 100644 data_utils/speech.py mode change 100755 => 100644 data_utils/utils.py mode change 100755 => 100644 datasets/run_all.sh mode change 100755 => 100644 decoder.py diff --git a/compute_mean_std.py b/compute_mean_std.py old mode 100755 new mode 100644 diff --git a/data_utils/__init__.py b/data_utils/__init__.py old mode 100755 new mode 100644 diff --git a/data_utils/audio.py b/data_utils/audio.py old mode 100755 new mode 100644 diff --git a/data_utils/augmentor/__init__.py b/data_utils/augmentor/__init__.py old mode 100755 new mode 100644 diff --git a/data_utils/augmentor/augmentation.py b/data_utils/augmentor/augmentation.py old mode 100755 new mode 100644 diff --git a/data_utils/augmentor/base.py b/data_utils/augmentor/base.py old mode 100755 new mode 100644 diff --git a/data_utils/augmentor/volume_perturb.py b/data_utils/augmentor/volume_perturb.py old mode 100755 new mode 100644 diff --git a/data_utils/featurizer/__init__.py b/data_utils/featurizer/__init__.py old mode 100755 new mode 100644 diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py old mode 100755 new mode 100644 diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py old mode 100755 new mode 100644 diff --git a/data_utils/featurizer/text_featurizer.py b/data_utils/featurizer/text_featurizer.py old mode 100755 new mode 100644 diff --git a/data_utils/normalizer.py b/data_utils/normalizer.py old mode 100755 new mode 100644 diff --git a/data_utils/speech.py b/data_utils/speech.py old mode 100755 new mode 100644 diff --git a/data_utils/utils.py b/data_utils/utils.py old mode 100755 new mode 100644 diff --git a/datasets/run_all.sh b/datasets/run_all.sh old mode 100755 new mode 100644 diff --git a/decoder.py b/decoder.py old mode 100755 new mode 100644 diff --git a/train.py b/train.py index 7ac4626f4..6074aa358 100644 --- a/train.py +++ b/train.py @@ -143,11 +143,13 @@ def train(): train_batch_reader = train_generator.batch_reader_creator( manifest_path=args.train_manifest_path, batch_size=args.batch_size, + min_batch_size=args.trainer_count, sortagrad=args.use_sortagrad if args.init_model_path is None else False, batch_shuffle=True) test_batch_reader = test_generator.batch_reader_creator( manifest_path=args.dev_manifest_path, batch_size=args.batch_size, + min_batch_size=1, # must be 1, but will have errors. sortagrad=False, batch_shuffle=False) @@ -157,11 +159,11 @@ def train(): if isinstance(event, paddle.event.EndIteration): cost_sum += event.cost cost_counter += 1 - if event.batch_id % 50 == 0: - print("\nPass: %d, Batch: %d, TrainCost: %f" % - (event.pass_id, event.batch_id, cost_sum / cost_counter)) + if (event.batch_id + 1) % 100 == 0: + print("\nPass: %d, Batch: %d, TrainCost: %f" % ( + event.pass_id, event.batch_id + 1, cost_sum / cost_counter)) cost_sum, cost_counter = 0.0, 0 - with gzip.open("params_tmp.tar.gz", 'w') as f: + with gzip.open("params.tar.gz", 'w') as f: parameters.to_tar(f) else: sys.stdout.write('.') From ed5f04afb86e7285cdd2d9d36dbf4b63431b5968 Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Thu, 15 Jun 2017 17:05:00 +0800 Subject: [PATCH 5/5] Add shuffle type of instance_shuffle and batch_shuffle_clipped. --- data_utils/data.py | 50 ++++++++++++++++++++++------- datasets/librispeech/librispeech.py | 3 +- decoder.py | 6 ++-- infer.py | 11 +++---- train.py | 16 ++++++--- utils.py | 25 +++++++++++++++ 6 files changed, 82 insertions(+), 29 deletions(-) create mode 100644 utils.py diff --git a/data_utils/data.py b/data_utils/data.py index 48e03fe85..424343a48 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -80,7 +80,7 @@ class DataGenerator(object): padding_to=-1, flatten=False, sortagrad=False, - batch_shuffle=False): + shuffle_method="batch_shuffle"): """ Batch data reader creator for audio data. Return a callable generator function to produce batches of data. @@ -104,12 +104,22 @@ class DataGenerator(object): :param sortagrad: If set True, sort the instances by audio duration in the first epoch for speed up training. :type sortagrad: bool - :param batch_shuffle: If set True, instances are batch-wise shuffled. - For more details, please see - ``_batch_shuffle.__doc__``. - If sortagrad is True, batch_shuffle is disabled + :param shuffle_method: Shuffle method. Options: + '' or None: no shuffle. + 'instance_shuffle': instance-wise shuffle. + 'batch_shuffle': similarly-sized instances are + put into batches, and then + batch-wise shuffle the batches. + For more details, please see + ``_batch_shuffle.__doc__``. + 'batch_shuffle_clipped': 'batch_shuffle' with + head shift and tail + clipping. For more + details, please see + ``_batch_shuffle``. + If sortagrad is True, shuffle is disabled for the first epoch. - :type batch_shuffle: bool + :type shuffle_method: None|str :return: Batch reader function, producing batches of data when called. :rtype: callable """ @@ -123,8 +133,20 @@ class DataGenerator(object): # sort (by duration) or batch-wise shuffle the manifest if self._epoch == 0 and sortagrad: manifest.sort(key=lambda x: x["duration"]) - elif batch_shuffle: - manifest = self._batch_shuffle(manifest, batch_size) + else: + if shuffle_method == "batch_shuffle": + manifest = self._batch_shuffle( + manifest, batch_size, clipped=False) + elif shuffle_method == "batch_shuffle_clipped": + manifest = self._batch_shuffle( + manifest, batch_size, clipped=True) + elif shuffle_method == "instance_shuffle": + self._rng.shuffle(manifest) + elif not shuffle_method: + pass + else: + raise ValueError("Unknown shuffle method %s." % + shuffle_method) # prepare batches instance_reader = self._instance_reader_creator(manifest) batch = [] @@ -218,7 +240,7 @@ class DataGenerator(object): new_batch.append((padded_audio, text)) return new_batch - def _batch_shuffle(self, manifest, batch_size): + def _batch_shuffle(self, manifest, batch_size, clipped=False): """Put similarly-sized instances into minibatches for better efficiency and make a batch-wise shuffle. @@ -233,6 +255,9 @@ class DataGenerator(object): :param batch_size: Batch size. This size is also used for generate a random number for batch shuffle. :type batch_size: int + :param clipped: Whether to clip the heading (small shift) and trailing + (incomplete batch) instances. + :type clipped: bool :return: Batch shuffled mainifest. :rtype: list """ @@ -241,7 +266,8 @@ class DataGenerator(object): batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) self._rng.shuffle(batch_manifest) batch_manifest = list(sum(batch_manifest, ())) - res_len = len(manifest) - shift_len - len(batch_manifest) - batch_manifest.extend(manifest[-res_len:]) - batch_manifest.extend(manifest[0:shift_len]) + if not clipped: + res_len = len(manifest) - shift_len - len(batch_manifest) + batch_manifest.extend(manifest[-res_len:]) + batch_manifest.extend(manifest[0:shift_len]) return batch_manifest diff --git a/datasets/librispeech/librispeech.py b/datasets/librispeech/librispeech.py index faf038cc1..87e52ae4a 100644 --- a/datasets/librispeech/librispeech.py +++ b/datasets/librispeech/librispeech.py @@ -37,8 +37,7 @@ MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522" MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa" MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708" -parser = argparse.ArgumentParser( - description='Downloads and prepare LibriSpeech dataset.') +parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--target_dir", default=DATA_HOME + "/Libri", diff --git a/decoder.py b/decoder.py index 8314885ce..77d950b8d 100644 --- a/decoder.py +++ b/decoder.py @@ -8,8 +8,7 @@ from itertools import groupby def ctc_best_path_decode(probs_seq, vocabulary): - """ - Best path decoding, also called argmax decoding or greedy decoding. + """Best path decoding, also called argmax decoding or greedy decoding. Path consisting of the most probable tokens are further post-processed to remove consecutive repetitions and all blanks. @@ -38,8 +37,7 @@ def ctc_best_path_decode(probs_seq, vocabulary): def ctc_decode(probs_seq, vocabulary, method): - """ - CTC-like sequence decoding from a sequence of likelihood probablilites. + """CTC-like sequence decoding from a sequence of likelihood probablilites. :param probs_seq: 2-D list of probabilities over the vocabulary for each character. Each element is a list of float probabilities diff --git a/infer.py b/infer.py index f7c99df11..06449ab05 100644 --- a/infer.py +++ b/infer.py @@ -10,9 +10,9 @@ import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 from decoder import ctc_decode +import utils -parser = argparse.ArgumentParser( - description='Simplified version of DeepSpeech2 inference.') +parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--num_samples", default=10, @@ -62,9 +62,7 @@ args = parser.parse_args() def infer(): - """ - Max-ctc-decoding for DeepSpeech2. - """ + """Max-ctc-decoding for DeepSpeech2.""" # initialize data generator data_generator = DataGenerator( vocab_filepath=args.vocab_filepath, @@ -98,7 +96,7 @@ def infer(): manifest_path=args.decode_manifest_path, batch_size=args.num_samples, sortagrad=False, - batch_shuffle=False) + shuffle_method=None) infer_data = batch_reader().next() # run inference @@ -123,6 +121,7 @@ def infer(): def main(): + utils.print_arguments(args) paddle.init(use_gpu=args.use_gpu, trainer_count=1) infer() diff --git a/train.py b/train.py index 6074aa358..c60a039b6 100644 --- a/train.py +++ b/train.py @@ -12,6 +12,7 @@ import distutils.util import paddle.v2 as paddle from model import deep_speech2 from data_utils.data import DataGenerator +import utils parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( @@ -51,6 +52,12 @@ parser.add_argument( default=True, type=distutils.util.strtobool, help="Use sortagrad or not. (default: %(default)s)") +parser.add_argument( + "--shuffle_method", + default='instance_shuffle', + type=str, + help="Shuffle method: 'instance_shuffle', 'batch_shuffle', " + "'batch_shuffle_batch'. (default: %(default)s)") parser.add_argument( "--trainer_count", default=4, @@ -93,9 +100,7 @@ args = parser.parse_args() def train(): - """ - DeepSpeech2 training. - """ + """DeepSpeech2 training.""" # initialize data generator def data_generator(): @@ -145,13 +150,13 @@ def train(): batch_size=args.batch_size, min_batch_size=args.trainer_count, sortagrad=args.use_sortagrad if args.init_model_path is None else False, - batch_shuffle=True) + shuffle_method=args.shuffle_method) test_batch_reader = test_generator.batch_reader_creator( manifest_path=args.dev_manifest_path, batch_size=args.batch_size, min_batch_size=1, # must be 1, but will have errors. sortagrad=False, - batch_shuffle=False) + shuffle_method=None) # create event handler def event_handler(event): @@ -186,6 +191,7 @@ def train(): def main(): + utils.print_arguments(args) paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) train() diff --git a/utils.py b/utils.py new file mode 100644 index 000000000..9ca363c8f --- /dev/null +++ b/utils.py @@ -0,0 +1,25 @@ +"""Contains common utility functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("----- Configuration Arguments -----") + for arg, value in vars(args).iteritems(): + print("%s: %s" % (arg, value)) + print("------------------------------------")