diff --git a/README.md b/README.md new file mode 100644 index 00000000..2912ff31 --- /dev/null +++ b/README.md @@ -0,0 +1,79 @@ +# Deep Speech 2 on PaddlePaddle + +## Installation + +Please replace `$PADDLE_INSTALL_DIR` with your own paddle installation directory. + +``` +sh setup.sh +export LD_LIBRARY_PATH=$PADDLE_INSTALL_DIR/Paddle/third_party/install/warpctc/lib:$LD_LIBRARY_PATH +``` + +For some machines, we also need to install libsndfile1. Details to be added. + +## Usage + +### Preparing Data + +``` +cd datasets +sh run_all.sh +cd .. +``` + +`sh run_all.sh` prepares all ASR datasets (currently, only LibriSpeech available). After running, we have several summarization manifest files in json-format. + +A manifest file summarizes a speech data set, with each line containing the meta data (i.e. audio filepath, transcript text, audio duration) of each audio file within the data set, in json format. Manifest file serves as an interface informing our system of where and what to read the speech samples. + + +More help for arguments: + +``` +python datasets/librispeech/librispeech.py --help +``` + +### Preparing for Training + +``` +python compute_mean_std.py +``` + +`python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. + +More help for arguments: + +``` +python compute_mean_std.py --help +``` + +### Training + +For GPU Training: + +``` +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py +``` + +For CPU Training: + +``` +python train.py --use_gpu False +``` + +More help for arguments: + +``` +python train.py --help +``` + +### Inferencing + +``` +CUDA_VISIBLE_DEVICES=0 python infer.py +``` + +More help for arguments: + +``` +python infer.py --help +``` diff --git a/compute_mean_std.py b/compute_mean_std.py new file mode 100644 index 00000000..9c301c93 --- /dev/null +++ b/compute_mean_std.py @@ -0,0 +1,57 @@ +"""Compute mean and std for feature normalizer, and save to file.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +from data_utils.normalizer import FeatureNormalizer +from data_utils.augmentor.augmentation import AugmentationPipeline +from data_utils.featurizer.audio_featurizer import AudioFeaturizer + +parser = argparse.ArgumentParser( + description='Computing mean and stddev for feature normalizer.') +parser.add_argument( + "--manifest_path", + default='datasets/manifest.train', + type=str, + help="Manifest path for computing normalizer's mean and stddev." + "(default: %(default)s)") +parser.add_argument( + "--num_samples", + default=2000, + type=int, + help="Number of samples for computing mean and stddev. " + "(default: %(default)s)") +parser.add_argument( + "--augmentation_config", + default='{}', + type=str, + help="Augmentation configuration in json-format. " + "(default: %(default)s)") +parser.add_argument( + "--output_file", + default='mean_std.npz', + type=str, + help="Filepath to write mean and std to (.npz)." + "(default: %(default)s)") +args = parser.parse_args() + + +def main(): + augmentation_pipeline = AugmentationPipeline(args.augmentation_config) + audio_featurizer = AudioFeaturizer() + + def augment_and_featurize(audio_segment): + augmentation_pipeline.transform_audio(audio_segment) + return audio_featurizer.featurize(audio_segment) + + normalizer = FeatureNormalizer( + mean_std_filepath=None, + manifest_path=args.manifest_path, + featurize_func=augment_and_featurize, + num_samples=args.num_samples) + normalizer.write_to_file(args.output_file) + + +if __name__ == '__main__': + main() diff --git a/data_utils/__init__.py b/data_utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data_utils/audio.py b/data_utils/audio.py new file mode 100644 index 00000000..3891f5b9 --- /dev/null +++ b/data_utils/audio.py @@ -0,0 +1,622 @@ +"""Contains the audio segment class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import io +import soundfile +import resampy +from scipy import signal +import random +import copy + + +class AudioSegment(object): + """Monaural audio segment abstraction. + + :param samples: Audio samples [num_samples x num_channels]. + :type samples: ndarray.float32 + :param sample_rate: Audio sample rate. + :type sample_rate: int + :raises TypeError: If the sample data type is not float or int. + """ + + def __init__(self, samples, sample_rate): + """Create audio segment from samples. + + Samples are convert float32 internally, with int scaled to [-1, 1]. + """ + self._samples = self._convert_samples_to_float32(samples) + self._sample_rate = sample_rate + if self._samples.ndim >= 2: + self._samples = np.mean(self._samples, 1) + + def __eq__(self, other): + """Return whether two objects are equal.""" + if type(other) is not type(self): + return False + if self._sample_rate != other._sample_rate: + return False + if self._samples.shape != other._samples.shape: + return False + if np.any(self.samples != other._samples): + return False + return True + + def __ne__(self, other): + """Return whether two objects are unequal.""" + return not self.__eq__(other) + + def __str__(self): + """Return human-readable representation of segment.""" + return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, " + "rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate, + self.duration, self.rms_db)) + + @classmethod + def from_file(cls, file): + """Create audio segment from audio file. + + :param filepath: Filepath or file object to audio file. + :type filepath: basestring|file + :return: Audio segment instance. + :rtype: AudioSegment + """ + samples, sample_rate = soundfile.read(file, dtype='float32') + return cls(samples, sample_rate) + + @classmethod + def slice_from_file(cls, file, start=None, end=None): + """Loads a small section of an audio without having to load + the entire file into the memory which can be incredibly wasteful. + + :param file: Input audio filepath or file object. + :type file: basestring|file + :param start: Start time in seconds. If start is negative, it wraps + around from the end. If not provided, this function + reads from the very beginning. + :type start: float + :param end: End time in seconds. If end is negative, it wraps around + from the end. If not provided, the default behvaior is + to read to the end of the file. + :type end: float + :return: AudioSegment instance of the specified slice of the input + audio file. + :rtype: AudioSegment + :raise ValueError: If start or end is incorrectly set, e.g. out of + bounds in time. + """ + sndfile = soundfile.SoundFile(file) + sample_rate = sndfile.samplerate + duration = float(len(sndfile)) / sample_rate + start = 0. if start is None else start + end = 0. if end is None else end + if start < 0.0: + start += duration + if end < 0.0: + end += duration + if start < 0.0: + raise ValueError("The slice start position (%f s) is out of " + "bounds." % start) + if end < 0.0: + raise ValueError("The slice end position (%f s) is out of bounds." % + end) + if start > end: + raise ValueError("The slice start position (%f s) is later than " + "the slice end position (%f s)." % (start, end)) + if end > duration: + raise ValueError("The slice end position (%f s) is out of bounds " + "(> %f s)" % (end, duration)) + start_frame = int(start * sample_rate) + end_frame = int(end * sample_rate) + sndfile.seek(start_frame) + data = sndfile.read(frames=end_frame - start_frame, dtype='float32') + return cls(data, sample_rate) + + @classmethod + def from_bytes(cls, bytes): + """Create audio segment from a byte string containing audio samples. + + :param bytes: Byte string containing audio samples. + :type bytes: str + :return: Audio segment instance. + :rtype: AudioSegment + """ + samples, sample_rate = soundfile.read( + io.BytesIO(bytes), dtype='float32') + return cls(samples, sample_rate) + + @classmethod + def concatenate(cls, *segments): + """Concatenate an arbitrary number of audio segments together. + + :param *segments: Input audio segments to be concatenated. + :type *segments: tuple of AudioSegment + :return: Audio segment instance as concatenating results. + :rtype: AudioSegment + :raises ValueError: If the number of segments is zero, or if the + sample_rate of any segments does not match. + :raises TypeError: If any segment is not AudioSegment instance. + """ + # Perform basic sanity-checks. + if len(segments) == 0: + raise ValueError("No audio segments are given to concatenate.") + sample_rate = segments[0]._sample_rate + for seg in segments: + if sample_rate != seg._sample_rate: + raise ValueError("Can't concatenate segments with " + "different sample rates") + if type(seg) is not cls: + raise TypeError("Only audio segments of the same type " + "can be concatenated.") + samples = np.concatenate([seg.samples for seg in segments]) + return cls(samples, sample_rate) + + @classmethod + def make_silence(cls, duration, sample_rate): + """Creates a silent audio segment of the given duration and sample rate. + + :param duration: Length of silence in seconds. + :type duration: float + :param sample_rate: Sample rate. + :type sample_rate: float + :return: Silent AudioSegment instance of the given duration. + :rtype: AudioSegment + """ + samples = np.zeros(int(duration * sample_rate)) + return cls(samples, sample_rate) + + def to_wav_file(self, filepath, dtype='float32'): + """Save audio segment to disk as wav file. + + :param filepath: WAV filepath or file object to save the + audio segment. + :type filepath: basestring|file + :param dtype: Subtype for audio file. Options: 'int16', 'int32', + 'float32', 'float64'. Default is 'float32'. + :type dtype: str + :raises TypeError: If dtype is not supported. + """ + samples = self._convert_samples_from_float32(self._samples, dtype) + subtype_map = { + 'int16': 'PCM_16', + 'int32': 'PCM_32', + 'float32': 'FLOAT', + 'float64': 'DOUBLE' + } + soundfile.write( + filepath, + samples, + self._sample_rate, + format='WAV', + subtype=subtype_map[dtype]) + + def superimpose(self, other): + """Add samples from another segment to those of this segment + (sample-wise addition, not segment concatenation). + + Note that this is an in-place transformation. + + :param other: Segment containing samples to be added in. + :type other: AudioSegments + :raise TypeError: If type of two segments don't match. + :raise ValueError: If the sample rates of the two segments are not + equal, or if the lengths of segments don't match. + """ + if type(self) != type(other): + raise TypeError("Cannot add segments of different types: %s " + "and %s." % (type(self), type(other))) + if self._sample_rate != other._sample_rate: + raise ValueError("Sample rates must match to add segments.") + if len(self._samples) != len(other._samples): + raise ValueError("Segment lengths must match to add segments.") + self._samples += other._samples + + def to_bytes(self, dtype='float32'): + """Create a byte string containing the audio content. + + :param dtype: Data type for export samples. Options: 'int16', 'int32', + 'float32', 'float64'. Default is 'float32'. + :type dtype: str + :return: Byte string containing audio content. + :rtype: str + """ + samples = self._convert_samples_from_float32(self._samples, dtype) + return samples.tostring() + + def gain_db(self, gain): + """Apply gain in decibels to samples. + + Note that this is an in-place transformation. + + :param gain: Gain in decibels to apply to samples. + :type gain: float + """ + self._samples *= 10.**(gain / 20.) + + def change_speed(self, speed_rate): + """Change the audio speed by linear interpolation. + + Note that this is an in-place transformation. + + :param speed_rate: Rate of speed change: + speed_rate > 1.0, speed up the audio; + speed_rate = 1.0, unchanged; + speed_rate < 1.0, slow down the audio; + speed_rate <= 0.0, not allowed, raise ValueError. + :type speed_rate: float + :raises ValueError: If speed_rate <= 0.0. + """ + if speed_rate <= 0: + raise ValueError("speed_rate should be greater than zero.") + old_length = self._samples.shape[0] + new_length = int(old_length / speed_rate) + old_indices = np.arange(old_length) + new_indices = np.linspace(start=0, stop=old_length, num=new_length) + self._samples = np.interp(new_indices, old_indices, self._samples) + + def normalize(self, target_db=-20, max_gain_db=300.0): + """Normalize audio to be of the desired RMS value in decibels. + + Note that this is an in-place transformation. + + :param target_db: Target RMS value in decibels. This value should be + less than 0.0 as 0.0 is full-scale audio. + :type target_db: float + :param max_gain_db: Max amount of gain in dB that can be applied for + normalization. This is to prevent nans when + attempting to normalize a signal consisting of + all zeros. + :type max_gain_db: float + :raises ValueError: If the required gain to normalize the segment to + the target_db value exceeds max_gain_db. + """ + gain = target_db - self.rms_db + if gain > max_gain_db: + raise ValueError( + "Unable to normalize segment to %f dB because the " + "the probable gain have exceeds max_gain_db (%f dB)" % + (target_db, max_gain_db)) + self.gain_db(min(max_gain_db, target_db - self.rms_db)) + + def normalize_online_bayesian(self, + target_db, + prior_db, + prior_samples, + startup_delay=0.0): + """Normalize audio using a production-compatible online/causal + algorithm. This uses an exponential likelihood and gamma prior to + make online estimates of the RMS even when there are very few samples. + + Note that this is an in-place transformation. + + :param target_db: Target RMS value in decibels. + :type target_bd: float + :param prior_db: Prior RMS estimate in decibels. + :type prior_db: float + :param prior_samples: Prior strength in number of samples. + :type prior_samples: float + :param startup_delay: Default 0.0s. If provided, this function will + accrue statistics for the first startup_delay + seconds before applying online normalization. + :type startup_delay: float + """ + # Estimate total RMS online. + startup_sample_idx = min(self.num_samples - 1, + int(self.sample_rate * startup_delay)) + prior_mean_squared = 10.**(prior_db / 10.) + prior_sum_of_squares = prior_mean_squared * prior_samples + cumsum_of_squares = np.cumsum(self.samples**2) + sample_count = np.arange(self.num_samples) + 1 + if startup_sample_idx > 0: + cumsum_of_squares[:startup_sample_idx] = \ + cumsum_of_squares[startup_sample_idx] + sample_count[:startup_sample_idx] = \ + sample_count[startup_sample_idx] + mean_squared_estimate = ((cumsum_of_squares + prior_sum_of_squares) / + (sample_count + prior_samples)) + rms_estimate_db = 10 * np.log10(mean_squared_estimate) + # Compute required time-varying gain. + gain_db = target_db - rms_estimate_db + self.gain_db(gain_db) + + def resample(self, target_sample_rate, filter='kaiser_best'): + """Resample the audio to a target sample rate. + + Note that this is an in-place transformation. + + :param target_sample_rate: Target sample rate. + :type target_sample_rate: int + :param filter: The resampling filter to use one of {'kaiser_best', + 'kaiser_fast'}. + :type filter: str + """ + self._samples = resampy.resample( + self.samples, self.sample_rate, target_sample_rate, filter=filter) + self._sample_rate = target_sample_rate + + def pad_silence(self, duration, sides='both'): + """Pad this audio sample with a period of silence. + + Note that this is an in-place transformation. + + :param duration: Length of silence in seconds to pad. + :type duration: float + :param sides: Position for padding: + 'beginning' - adds silence in the beginning; + 'end' - adds silence in the end; + 'both' - adds silence in both the beginning and the end. + :type sides: str + :raises ValueError: If sides is not supported. + """ + if duration == 0.0: + return self + cls = type(self) + silence = self.make_silence(duration, self._sample_rate) + if sides == "beginning": + padded = cls.concatenate(silence, self) + elif sides == "end": + padded = cls.concatenate(self, silence) + elif sides == "both": + padded = cls.concatenate(silence, self, silence) + else: + raise ValueError("Unknown value for the sides %s" % sides) + self._samples = padded._samples + + def shift(self, shift_ms): + """Shift the audio in time. If `shift_ms` is positive, shift with time + advance; if negative, shift with time delay. Silence are padded to + keep the duration unchanged. + + Note that this is an in-place transformation. + + :param shift_ms: Shift time in millseconds. If positive, shift with + time advance; if negative; shift with time delay. + :type shift_ms: float + :raises ValueError: If shift_ms is longer than audio duration. + """ + if abs(shift_ms) / 1000.0 > self.duration: + raise ValueError("Absolute value of shift_ms should be smaller " + "than audio duration.") + shift_samples = int(shift_ms * self._sample_rate / 1000) + if shift_samples > 0: + # time advance + self._samples[:-shift_samples] = self._samples[shift_samples:] + self._samples[-shift_samples:] = 0 + elif shift_samples < 0: + # time delay + self._samples[-shift_samples:] = self._samples[:shift_samples] + self._samples[:-shift_samples] = 0 + + def subsegment(self, start_sec=None, end_sec=None): + """Cut the AudioSegment between given boundaries. + + Note that this is an in-place transformation. + + :param start_sec: Beginning of subsegment in seconds. + :type start_sec: float + :param end_sec: End of subsegment in seconds. + :type end_sec: float + :raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out + of bounds in time. + """ + start_sec = 0.0 if start_sec is None else start_sec + end_sec = self.duration if end_sec is None else end_sec + if start_sec < 0.0: + start_sec = self.duration + start_sec + if end_sec < 0.0: + end_sec = self.duration + end_sec + if start_sec < 0.0: + raise ValueError("The slice start position (%f s) is out of " + "bounds." % start_sec) + if end_sec < 0.0: + raise ValueError("The slice end position (%f s) is out of bounds." % + end_sec) + if start_sec > end_sec: + raise ValueError("The slice start position (%f s) is later than " + "the end position (%f s)." % (start_sec, end_sec)) + if end_sec > self.duration: + raise ValueError("The slice end position (%f s) is out of bounds " + "(> %f s)" % (end_sec, self.duration)) + start_sample = int(round(start_sec * self._sample_rate)) + end_sample = int(round(end_sec * self._sample_rate)) + self._samples = self._samples[start_sample:end_sample] + + def random_subsegment(self, subsegment_length, rng=None): + """Cut the specified length of the audiosegment randomly. + + Note that this is an in-place transformation. + + :param subsegment_length: Subsegment length in seconds. + :type subsegment_length: float + :param rng: Random number generator state. + :type rng: random.Random + :raises ValueError: If the length of subsegment is greater than + the origineal segemnt. + """ + rng = random.Random() if rng is None else rng + if subsegment_length > self.duration: + raise ValueError("Length of subsegment must not be greater " + "than original segment.") + start_time = rng.uniform(0.0, self.duration - subsegment_length) + self.subsegment(start_time, start_time + subsegment_length) + + def convolve(self, impulse_segment, allow_resample=False): + """Convolve this audio segment with the given impulse segment. + + Note that this is an in-place transformation. + + :param impulse_segment: Impulse response segments. + :type impulse_segment: AudioSegment + :param allow_resample: Indicates whether resampling is allowed when + the impulse_segment has a different sample + rate from this signal. + :type allow_resample: bool + :raises ValueError: If the sample rate is not match between two + audio segments when resample is not allowed. + """ + if allow_resample and self.sample_rate != impulse_segment.sample_rate: + impulse_segment = impulse_segment.resample(self.sample_rate) + if self.sample_rate != impulse_segment.sample_rate: + raise ValueError("Impulse segment's sample rate (%d Hz) is not" + "equal to base signal sample rate (%d Hz)." % + (impulse_segment.sample_rate, self.sample_rate)) + samples = signal.fftconvolve(self.samples, impulse_segment.samples, + "full") + self._samples = samples + + def convolve_and_normalize(self, impulse_segment, allow_resample=False): + """Convolve and normalize the resulting audio segment so that it + has the same average power as the input signal. + + Note that this is an in-place transformation. + + :param impulse_segment: Impulse response segments. + :type impulse_segment: AudioSegment + :param allow_resample: Indicates whether resampling is allowed when + the impulse_segment has a different sample + rate from this signal. + :type allow_resample: bool + """ + target_db = self.rms_db + self.convolve(impulse_segment, allow_resample=allow_resample) + self.normalize(target_db) + + def add_noise(self, + noise, + snr_dB, + allow_downsampling=False, + max_gain_db=300.0, + rng=None): + """Add the given noise segment at a specific signal-to-noise ratio. + If the noise segment is longer than this segment, a random subsegment + of matching length is sampled from it and used instead. + + Note that this is an in-place transformation. + + :param noise: Noise signal to add. + :type noise: AudioSegment + :param snr_dB: Signal-to-Noise Ratio, in decibels. + :type snr_dB: float + :param allow_downsampling: Whether to allow the noise signal to be + downsampled to match the base signal sample + rate. + :type allow_downsampling: bool + :param max_gain_db: Maximum amount of gain to apply to noise signal + before adding it in. This is to prevent attempting + to apply infinite gain to a zero signal. + :type max_gain_db: float + :param rng: Random number generator state. + :type rng: None|random.Random + :raises ValueError: If the sample rate does not match between the two + audio segments when downsampling is not allowed, or + if the duration of noise segments is shorter than + original audio segments. + """ + rng = random.Random() if rng is None else rng + if allow_downsampling and noise.sample_rate > self.sample_rate: + noise = noise.resample(self.sample_rate) + if noise.sample_rate != self.sample_rate: + raise ValueError("Noise sample rate (%d Hz) is not equal to base " + "signal sample rate (%d Hz)." % (noise.sample_rate, + self.sample_rate)) + if noise.duration < self.duration: + raise ValueError("Noise signal (%f sec) must be at least as long as" + " base signal (%f sec)." % + (noise.duration, self.duration)) + noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db) + noise_new = copy.deepcopy(noise) + noise_new.random_subsegment(self.duration, rng=rng) + noise_new.gain_db(noise_gain_db) + self.superimpose(noise_new) + + @property + def samples(self): + """Return audio samples. + + :return: Audio samples. + :rtype: ndarray + """ + return self._samples.copy() + + @property + def sample_rate(self): + """Return audio sample rate. + + :return: Audio sample rate. + :rtype: int + """ + return self._sample_rate + + @property + def num_samples(self): + """Return number of samples. + + :return: Number of samples. + :rtype: int + """ + return self._samples.shape[0] + + @property + def duration(self): + """Return audio duration. + + :return: Audio duration in seconds. + :rtype: float + """ + return self._samples.shape[0] / float(self._sample_rate) + + @property + def rms_db(self): + """Return root mean square energy of the audio in decibels. + + :return: Root mean square energy in decibels. + :rtype: float + """ + # square root => multiply by 10 instead of 20 for dBs + mean_square = np.mean(self._samples**2) + return 10 * np.log10(mean_square) + + def _convert_samples_to_float32(self, samples): + """Convert sample type to float32. + + Audio sample type is usually integer or float-point. + Integers will be scaled to [-1, 1] in float32. + """ + float32_samples = samples.astype('float32') + if samples.dtype in np.sctypes['int']: + bits = np.iinfo(samples.dtype).bits + float32_samples *= (1. / 2**(bits - 1)) + elif samples.dtype in np.sctypes['float']: + pass + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return float32_samples + + def _convert_samples_from_float32(self, samples, dtype): + """Convert sample type from float32 to dtype. + + Audio sample type is usually integer or float-point. For integer + type, float32 will be rescaled from [-1, 1] to the maximum range + supported by the integer type. + + This is for writing a audio file. + """ + dtype = np.dtype(dtype) + output_samples = samples.copy() + if dtype in np.sctypes['int']: + bits = np.iinfo(dtype).bits + output_samples *= (2**(bits - 1) / 1.) + min_val = np.iinfo(dtype).min + max_val = np.iinfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + elif samples.dtype in np.sctypes['float']: + min_val = np.finfo(dtype).min + max_val = np.finfo(dtype).max + output_samples[output_samples > max_val] = max_val + output_samples[output_samples < min_val] = min_val + else: + raise TypeError("Unsupported sample type: %s." % samples.dtype) + return output_samples.astype(dtype) diff --git a/data_utils/augmentor/__init__.py b/data_utils/augmentor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data_utils/augmentor/augmentation.py b/data_utils/augmentor/augmentation.py new file mode 100644 index 00000000..9dced473 --- /dev/null +++ b/data_utils/augmentor/augmentation.py @@ -0,0 +1,93 @@ +"""Contains the data augmentation pipeline.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import random +from data_utils.augmentor.volume_perturb import VolumePerturbAugmentor +from data_utils.augmentor.shift_perturb import ShiftPerturbAugmentor +from data_utils.augmentor.speed_perturb import SpeedPerturbAugmentor +from data_utils.augmentor.resample import ResampleAugmentor +from data_utils.augmentor.online_bayesian_normalization import \ + OnlineBayesianNormalizationAugmentor + + +class AugmentationPipeline(object): + """Build a pre-processing pipeline with various augmentation models.Such a + data augmentation pipeline is oftern leveraged to augment the training + samples to make the model invariant to certain types of perturbations in the + real world, improving model's generalization ability. + + The pipeline is built according the the augmentation configuration in json + string, e.g. + + .. code-block:: + + '[{"type": "volume", + "params": {"min_gain_dBFS": -15, + "max_gain_dBFS": 15}, + "prob": 0.5}, + {"type": "speed", + "params": {"min_speed_rate": 0.8, + "max_speed_rate": 1.2}, + "prob": 0.5} + ]' + + This augmentation configuration inserts two augmentation models + into the pipeline, with one is VolumePerturbAugmentor and the other + SpeedPerturbAugmentor. "prob" indicates the probability of the current + augmentor to take effect. + + :param augmentation_config: Augmentation configuration in json string. + :type augmentation_config: str + :param random_seed: Random seed. + :type random_seed: int + :raises ValueError: If the augmentation json config is in incorrect format". + """ + + def __init__(self, augmentation_config, random_seed=0): + self._rng = random.Random(random_seed) + self._augmentors, self._rates = self._parse_pipeline_from( + augmentation_config) + + def transform_audio(self, audio_segment): + """Run the pre-processing pipeline for data augmentation. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to process. + :type audio_segment: AudioSegmenet|SpeechSegment + """ + for augmentor, rate in zip(self._augmentors, self._rates): + if self._rng.uniform(0., 1.) <= rate: + augmentor.transform_audio(audio_segment) + + def _parse_pipeline_from(self, config_json): + """Parse the config json to build a augmentation pipelien.""" + try: + configs = json.loads(config_json) + augmentors = [ + self._get_augmentor(config["type"], config["params"]) + for config in configs + ] + rates = [config["prob"] for config in configs] + except Exception as e: + raise ValueError("Failed to parse the augmentation config json: " + "%s" % str(e)) + return augmentors, rates + + def _get_augmentor(self, augmentor_type, params): + """Return an augmentation model by the type name, and pass in params.""" + if augmentor_type == "volume": + return VolumePerturbAugmentor(self._rng, **params) + elif augmentor_type == "shift": + return ShiftPerturbAugmentor(self._rng, **params) + elif augmentor_type == "speed": + return SpeedPerturbAugmentor(self._rng, **params) + elif augmentor_type == "resample": + return ResampleAugmentor(self._rng, **params) + elif augmentor_type == "bayesian_normal": + return OnlineBayesianNormalizationAugmentor(self._rng, **params) + else: + raise ValueError("Unknown augmentor type [%s]." % augmentor_type) diff --git a/data_utils/augmentor/base.py b/data_utils/augmentor/base.py new file mode 100644 index 00000000..a323165a --- /dev/null +++ b/data_utils/augmentor/base.py @@ -0,0 +1,33 @@ +"""Contains the abstract base class for augmentation models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from abc import ABCMeta, abstractmethod + + +class AugmentorBase(object): + """Abstract base class for augmentation model (augmentor) class. + All augmentor classes should inherit from this class, and implement the + following abstract methods. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def transform_audio(self, audio_segment): + """Adds various effects to the input audio segment. Such effects + will augment the training data to make the model invariant to certain + types of perturbations in the real world, improving model's + generalization ability. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegmenet|SpeechSegment + """ + pass diff --git a/data_utils/augmentor/online_bayesian_normalization.py b/data_utils/augmentor/online_bayesian_normalization.py new file mode 100755 index 00000000..e488ac7d --- /dev/null +++ b/data_utils/augmentor/online_bayesian_normalization.py @@ -0,0 +1,48 @@ +"""Contain the online bayesian normalization augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class OnlineBayesianNormalizationAugmentor(AugmentorBase): + """Augmentation model for adding online bayesian normalization. + + :param rng: Random generator object. + :type rng: random.Random + :param target_db: Target RMS value in decibels. + :type target_db: float + :param prior_db: Prior RMS estimate in decibels. + :type prior_db: float + :param prior_samples: Prior strength in number of samples. + :type prior_samples: int + :param startup_delay: Default 0.0s. If provided, this function will + accrue statistics for the first startup_delay + seconds before applying online normalization. + :type starup_delay: float. + """ + + def __init__(self, + rng, + target_db, + prior_db, + prior_samples, + startup_delay=0.0): + self._target_db = target_db + self._prior_db = prior_db + self._prior_samples = prior_samples + self._rng = rng + self._startup_delay = startup_delay + + def transform_audio(self, audio_segment): + """Normalizes the input audio using the online Bayesian approach. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegment|SpeechSegment + """ + audio_segment.normalize_online_bayesian(self._target_db, self._prior_db, + self._prior_samples, + self._startup_delay) diff --git a/data_utils/augmentor/resample.py b/data_utils/augmentor/resample.py new file mode 100755 index 00000000..8df17f3a --- /dev/null +++ b/data_utils/augmentor/resample.py @@ -0,0 +1,33 @@ +"""Contain the resample augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class ResampleAugmentor(AugmentorBase): + """Augmentation model for resampling. + + See more info here: + https://ccrma.stanford.edu/~jos/resample/index.html + + :param rng: Random generator object. + :type rng: random.Random + :param new_sample_rate: New sample rate in Hz. + :type new_sample_rate: int + """ + + def __init__(self, rng, new_sample_rate): + self._new_sample_rate = new_sample_rate + self._rng = rng + + def transform_audio(self, audio_segment): + """Resamples the input audio to a target sample rate. + + Note that this is an in-place transformation. + + :param audio: Audio segment to add effects to. + :type audio: AudioSegment|SpeechSegment + """ + audio_segment.resample(self._new_sample_rate) diff --git a/data_utils/augmentor/shift_perturb.py b/data_utils/augmentor/shift_perturb.py new file mode 100644 index 00000000..c4cbe3e1 --- /dev/null +++ b/data_utils/augmentor/shift_perturb.py @@ -0,0 +1,34 @@ +"""Contains the volume perturb augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class ShiftPerturbAugmentor(AugmentorBase): + """Augmentation model for adding random shift perturbation. + + :param rng: Random generator object. + :type rng: random.Random + :param min_shift_ms: Minimal shift in milliseconds. + :type min_shift_ms: float + :param max_shift_ms: Maximal shift in milliseconds. + :type max_shift_ms: float + """ + + def __init__(self, rng, min_shift_ms, max_shift_ms): + self._min_shift_ms = min_shift_ms + self._max_shift_ms = max_shift_ms + self._rng = rng + + def transform_audio(self, audio_segment): + """Shift audio. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegmenet|SpeechSegment + """ + shift_ms = self._rng.uniform(self._min_shift_ms, self._max_shift_ms) + audio_segment.shift(shift_ms) diff --git a/data_utils/augmentor/speed_perturb.py b/data_utils/augmentor/speed_perturb.py new file mode 100644 index 00000000..cc5738bd --- /dev/null +++ b/data_utils/augmentor/speed_perturb.py @@ -0,0 +1,47 @@ +"""Contain the speech perturbation augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class SpeedPerturbAugmentor(AugmentorBase): + """Augmentation model for adding speed perturbation. + + See reference paper here: + http://www.danielpovey.com/files/2015_interspeech_augmentation.pdf + + :param rng: Random generator object. + :type rng: random.Random + :param min_speed_rate: Lower bound of new speed rate to sample and should + not be smaller than 0.9. + :type min_speed_rate: float + :param max_speed_rate: Upper bound of new speed rate to sample and should + not be larger than 1.1. + :type max_speed_rate: float + """ + + def __init__(self, rng, min_speed_rate, max_speed_rate): + if min_speed_rate < 0.9: + raise ValueError( + "Sampling speed below 0.9 can cause unnatural effects") + if max_speed_rate > 1.1: + raise ValueError( + "Sampling speed above 1.1 can cause unnatural effects") + self._min_speed_rate = min_speed_rate + self._max_speed_rate = max_speed_rate + self._rng = rng + + def transform_audio(self, audio_segment): + """Sample a new speed rate from the given range and + changes the speed of the given audio clip. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegment|SpeechSegment + """ + sampled_speed = self._rng.uniform(self._min_speed_rate, + self._max_speed_rate) + audio_segment.change_speed(sampled_speed) diff --git a/data_utils/augmentor/volume_perturb.py b/data_utils/augmentor/volume_perturb.py new file mode 100644 index 00000000..758676d5 --- /dev/null +++ b/data_utils/augmentor/volume_perturb.py @@ -0,0 +1,40 @@ +"""Contains the volume perturb augmentation model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.augmentor.base import AugmentorBase + + +class VolumePerturbAugmentor(AugmentorBase): + """Augmentation model for adding random volume perturbation. + + This is used for multi-loudness training of PCEN. See + + https://arxiv.org/pdf/1607.05666v1.pdf + + for more details. + + :param rng: Random generator object. + :type rng: random.Random + :param min_gain_dBFS: Minimal gain in dBFS. + :type min_gain_dBFS: float + :param max_gain_dBFS: Maximal gain in dBFS. + :type max_gain_dBFS: float + """ + + def __init__(self, rng, min_gain_dBFS, max_gain_dBFS): + self._min_gain_dBFS = min_gain_dBFS + self._max_gain_dBFS = max_gain_dBFS + self._rng = rng + + def transform_audio(self, audio_segment): + """Change audio loadness. + + Note that this is an in-place transformation. + + :param audio_segment: Audio segment to add effects to. + :type audio_segment: AudioSegmenet|SpeechSegment + """ + gain = self._rng.uniform(self._min_gain_dBFS, self._max_gain_dBFS) + audio_segment.gain_db(gain) diff --git a/data_utils/data.py b/data_utils/data.py new file mode 100644 index 00000000..d01ca8cc --- /dev/null +++ b/data_utils/data.py @@ -0,0 +1,287 @@ +"""Contains data generator for orgnaizing various audio data preprocessing +pipeline and offering data reader interface of PaddlePaddle requirements. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import numpy as np +import multiprocessing +import paddle.v2 as paddle +from data_utils import utils +from data_utils.augmentor.augmentation import AugmentationPipeline +from data_utils.featurizer.speech_featurizer import SpeechFeaturizer +from data_utils.speech import SpeechSegment +from data_utils.normalizer import FeatureNormalizer + + +class DataGenerator(object): + """ + DataGenerator provides basic audio data preprocessing pipeline, and offers + data reader interfaces of PaddlePaddle requirements. + + :param vocab_filepath: Vocabulary filepath for indexing tokenized + transcripts. + :type vocab_filepath: basestring + :param mean_std_filepath: File containing the pre-computed mean and stddev. + :type mean_std_filepath: None|basestring + :param augmentation_config: Augmentation configuration in json string. + Details see AugmentationPipeline.__doc__. + :type augmentation_config: str + :param max_duration: Audio with duration (in seconds) greater than + this will be discarded. + :type max_duration: float + :param min_duration: Audio with duration (in seconds) smaller than + this will be discarded. + :type min_duration: float + :param stride_ms: Striding size (in milliseconds) for generating frames. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for generating frames. + :type window_ms: float + :param max_freq: Used when specgram_type is 'linear', only FFT bins + corresponding to frequencies between [0, max_freq] are + returned. + :types max_freq: None|float + :param specgram_type: Specgram feature type. Options: 'linear'. + :type specgram_type: str + :param use_dB_normalization: Whether to normalize the audio to -20 dB + before extracting the features. + :type use_dB_normalization: bool + :param num_threads: Number of CPU threads for processing data. + :type num_threads: int + :param random_seed: Random seed. + :type random_seed: int + """ + + def __init__(self, + vocab_filepath, + mean_std_filepath, + augmentation_config='{}', + max_duration=float('inf'), + min_duration=0.0, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + specgram_type='linear', + use_dB_normalization=True, + num_threads=multiprocessing.cpu_count(), + random_seed=0): + self._max_duration = max_duration + self._min_duration = min_duration + self._normalizer = FeatureNormalizer(mean_std_filepath) + self._augmentation_pipeline = AugmentationPipeline( + augmentation_config=augmentation_config, random_seed=random_seed) + self._speech_featurizer = SpeechFeaturizer( + vocab_filepath=vocab_filepath, + specgram_type=specgram_type, + stride_ms=stride_ms, + window_ms=window_ms, + max_freq=max_freq, + use_dB_normalization=use_dB_normalization) + self._num_threads = num_threads + self._rng = random.Random(random_seed) + self._epoch = 0 + + def batch_reader_creator(self, + manifest_path, + batch_size, + min_batch_size=1, + padding_to=-1, + flatten=False, + sortagrad=False, + shuffle_method="batch_shuffle"): + """ + Batch data reader creator for audio data. Return a callable generator + function to produce batches of data. + + Audio features within one batch will be padded with zeros to have the + same shape, or a user-defined shape. + + :param manifest_path: Filepath of manifest for audio files. + :type manifest_path: basestring + :param batch_size: Number of instances in a batch. + :type batch_size: int + :param min_batch_size: Any batch with batch size smaller than this will + be discarded. (To be deprecated in the future.) + :type min_batch_size: int + :param padding_to: If set -1, the maximun shape in the batch + will be used as the target shape for padding. + Otherwise, `padding_to` will be the target shape. + :type padding_to: int + :param flatten: If set True, audio features will be flatten to 1darray. + :type flatten: bool + :param sortagrad: If set True, sort the instances by audio duration + in the first epoch for speed up training. + :type sortagrad: bool + :param shuffle_method: Shuffle method. Options: + '' or None: no shuffle. + 'instance_shuffle': instance-wise shuffle. + 'batch_shuffle': similarly-sized instances are + put into batches, and then + batch-wise shuffle the batches. + For more details, please see + ``_batch_shuffle.__doc__``. + 'batch_shuffle_clipped': 'batch_shuffle' with + head shift and tail + clipping. For more + details, please see + ``_batch_shuffle``. + If sortagrad is True, shuffle is disabled + for the first epoch. + :type shuffle_method: None|str + :return: Batch reader function, producing batches of data when called. + :rtype: callable + """ + + def batch_reader(): + # read manifest + manifest = utils.read_manifest( + manifest_path=manifest_path, + max_duration=self._max_duration, + min_duration=self._min_duration) + # sort (by duration) or batch-wise shuffle the manifest + if self._epoch == 0 and sortagrad: + manifest.sort(key=lambda x: x["duration"]) + else: + if shuffle_method == "batch_shuffle": + manifest = self._batch_shuffle( + manifest, batch_size, clipped=False) + elif shuffle_method == "batch_shuffle_clipped": + manifest = self._batch_shuffle( + manifest, batch_size, clipped=True) + elif shuffle_method == "instance_shuffle": + self._rng.shuffle(manifest) + elif not shuffle_method: + pass + else: + raise ValueError("Unknown shuffle method %s." % + shuffle_method) + # prepare batches + instance_reader = self._instance_reader_creator(manifest) + batch = [] + for instance in instance_reader(): + batch.append(instance) + if len(batch) == batch_size: + yield self._padding_batch(batch, padding_to, flatten) + batch = [] + if len(batch) >= min_batch_size: + yield self._padding_batch(batch, padding_to, flatten) + self._epoch += 1 + + return batch_reader + + @property + def feeding(self): + """Returns data reader's feeding dict. + + :return: Data feeding dict. + :rtype: dict + """ + return {"audio_spectrogram": 0, "transcript_text": 1} + + @property + def vocab_size(self): + """Return the vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ + return self._speech_featurizer.vocab_size + + @property + def vocab_list(self): + """Return the vocabulary in list. + + :return: Vocabulary in list. + :rtype: list + """ + return self._speech_featurizer.vocab_list + + def _process_utterance(self, filename, transcript): + """Load, augment, featurize and normalize for speech data.""" + speech_segment = SpeechSegment.from_file(filename, transcript) + self._augmentation_pipeline.transform_audio(speech_segment) + specgram, text_ids = self._speech_featurizer.featurize(speech_segment) + specgram = self._normalizer.apply(specgram) + return specgram, text_ids + + def _instance_reader_creator(self, manifest): + """ + Instance reader creator. Create a callable function to produce + instances of data. + + Instance: a tuple of ndarray of audio spectrogram and a list of + token indices for transcript. + """ + + def reader(): + for instance in manifest: + yield instance + + def mapper(instance): + return self._process_utterance(instance["audio_filepath"], + instance["text"]) + + return paddle.reader.xmap_readers( + mapper, reader, self._num_threads, 1024, order=True) + + def _padding_batch(self, batch, padding_to=-1, flatten=False): + """ + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one bach. + + If ``padding_to`` is -1, the maximun shape in the batch will be used + as the target shape for padding. Otherwise, `padding_to` will be the + target shape (only refers to the second axis). + + If `flatten` is True, features will be flatten to 1darray. + """ + new_batch = [] + # get target shape + max_length = max([audio.shape[1] for audio, text in batch]) + if padding_to != -1: + if padding_to < max_length: + raise ValueError("If padding_to is not -1, it should be larger " + "than any instance's shape in the batch") + max_length = padding_to + # padding + for audio, text in batch: + padded_audio = np.zeros([audio.shape[0], max_length]) + padded_audio[:, :audio.shape[1]] = audio + if flatten: + padded_audio = padded_audio.flatten() + new_batch.append((padded_audio, text)) + return new_batch + + def _batch_shuffle(self, manifest, batch_size, clipped=False): + """Put similarly-sized instances into minibatches for better efficiency + and make a batch-wise shuffle. + + 1. Sort the audio clips by duration. + 2. Generate a random number `k`, k in [0, batch_size). + 3. Randomly shift `k` instances in order to create different batches + for different epochs. Create minibatches. + 4. Shuffle the minibatches. + + :param manifest: Manifest contents. List of dict. + :type manifest: list + :param batch_size: Batch size. This size is also used for generate + a random number for batch shuffle. + :type batch_size: int + :param clipped: Whether to clip the heading (small shift) and trailing + (incomplete batch) instances. + :type clipped: bool + :return: Batch shuffled mainifest. + :rtype: list + """ + manifest.sort(key=lambda x: x["duration"]) + shift_len = self._rng.randint(0, batch_size - 1) + batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size) + self._rng.shuffle(batch_manifest) + batch_manifest = list(sum(batch_manifest, ())) + if not clipped: + res_len = len(manifest) - shift_len - len(batch_manifest) + batch_manifest.extend(manifest[-res_len:]) + batch_manifest.extend(manifest[0:shift_len]) + return batch_manifest diff --git a/data_utils/featurizer/__init__.py b/data_utils/featurizer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py new file mode 100644 index 00000000..4b4d02c6 --- /dev/null +++ b/data_utils/featurizer/audio_featurizer.py @@ -0,0 +1,144 @@ +"""Contains the audio featurizer class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from data_utils import utils +from data_utils.audio import AudioSegment + + +class AudioFeaturizer(object): + """Audio featurizer, for extracting features from audio contents of + AudioSegment or SpeechSegment. + + Currently, it only supports feature type of linear spectrogram. + + :param specgram_type: Specgram feature type. Options: 'linear'. + :type specgram_type: str + :param stride_ms: Striding size (in milliseconds) for generating frames. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for generating frames. + :type window_ms: float + :param max_freq: Used when specgram_type is 'linear', only FFT bins + corresponding to frequencies between [0, max_freq] are + returned. + :types max_freq: None|float + :param target_sample_rate: Audio are resampled (if upsampling or + downsampling is allowed) to this before + extracting spectrogram features. + :type target_sample_rate: float + :param use_dB_normalization: Whether to normalize the audio to a certain + decibels before extracting the features. + :type use_dB_normalization: bool + :param target_dB: Target audio decibels for normalization. + :type target_dB: float + """ + + def __init__(self, + specgram_type='linear', + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + target_sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): + self._specgram_type = specgram_type + self._stride_ms = stride_ms + self._window_ms = window_ms + self._max_freq = max_freq + self._target_sample_rate = target_sample_rate + self._use_dB_normalization = use_dB_normalization + self._target_dB = target_dB + + def featurize(self, + audio_segment, + allow_downsampling=True, + allow_upsamplling=True): + """Extract audio features from AudioSegment or SpeechSegment. + + :param audio_segment: Audio/speech segment to extract features from. + :type audio_segment: AudioSegment|SpeechSegment + :param allow_downsampling: Whether to allow audio downsampling before + featurizing. + :type allow_downsampling: bool + :param allow_upsampling: Whether to allow audio upsampling before + featurizing. + :type allow_upsampling: bool + :return: Spectrogram audio feature in 2darray. + :rtype: ndarray + :raises ValueError: If audio sample rate is not supported. + """ + # upsampling or downsampling + if ((audio_segment.sample_rate > self._target_sample_rate and + allow_downsampling) or + (audio_segment.sample_rate < self._target_sample_rate and + allow_upsampling)): + audio_segment.resample(self._target_sample_rate) + if audio_segment.sample_rate != self._target_sample_rate: + raise ValueError("Audio sample rate is not supported. " + "Turn allow_downsampling or allow up_sampling on.") + # decibel normalization + if self._use_dB_normalization: + audio_segment.normalize(target_db=self._target_dB) + # extract spectrogram + return self._compute_specgram(audio_segment.samples, + audio_segment.sample_rate) + + def _compute_specgram(self, samples, sample_rate): + """Extract various audio features.""" + if self._specgram_type == 'linear': + return self._compute_linear_specgram( + samples, sample_rate, self._stride_ms, self._window_ms, + self._max_freq) + else: + raise ValueError("Unknown specgram_type %s. " + "Supported values: linear." % self._specgram_type) + + def _compute_linear_specgram(self, + samples, + sample_rate, + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + eps=1e-14): + """Compute the linear spectrogram from FFT energy.""" + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + stride_size = int(0.001 * sample_rate * stride_ms) + window_size = int(0.001 * sample_rate * window_ms) + specgram, freqs = self._specgram_real( + samples, + window_size=window_size, + stride_size=stride_size, + sample_rate=sample_rate) + ind = np.where(freqs <= max_freq)[0][-1] + 1 + return np.log(specgram[:ind, :] + eps) + + def _specgram_real(self, samples, window_size, stride_size, sample_rate): + """Compute the spectrogram for samples from a real signal.""" + # extract strided windows + truncate_size = (len(samples) - window_size) % stride_size + samples = samples[:len(samples) - truncate_size] + nshape = (window_size, (len(samples) - window_size) // stride_size + 1) + nstrides = (samples.strides[0], samples.strides[0] * stride_size) + windows = np.lib.stride_tricks.as_strided( + samples, shape=nshape, strides=nstrides) + assert np.all( + windows[:, 1] == samples[stride_size:(stride_size + window_size)]) + # window weighting, squared Fast Fourier Transform (fft), scaling + weighting = np.hanning(window_size)[:, None] + fft = np.fft.rfft(windows * weighting, axis=0) + fft = np.absolute(fft)**2 + scale = np.sum(weighting**2) * sample_rate + fft[1:-1, :] *= (2.0 / scale) + fft[(0, -1), :] /= scale + # prepare fft frequency list + freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) + return fft, freqs diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py new file mode 100644 index 00000000..26283892 --- /dev/null +++ b/data_utils/featurizer/speech_featurizer.py @@ -0,0 +1,95 @@ +"""Contains the speech featurizer class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.featurizer.audio_featurizer import AudioFeaturizer +from data_utils.featurizer.text_featurizer import TextFeaturizer + + +class SpeechFeaturizer(object): + """Speech featurizer, for extracting features from both audio and transcript + contents of SpeechSegment. + + Currently, for audio parts, it only supports feature type of linear + spectrogram; for transcript parts, it only supports char-level tokenizing + and conversion into a list of token indices. Note that the token indexing + order follows the given vocabulary file. + + :param vocab_filepath: Filepath to load vocabulary for token indices + conversion. + :type specgram_type: basestring + :param specgram_type: Specgram feature type. Options: 'linear'. + :type specgram_type: str + :param stride_ms: Striding size (in milliseconds) for generating frames. + :type stride_ms: float + :param window_ms: Window size (in milliseconds) for generating frames. + :type window_ms: float + :param max_freq: Used when specgram_type is 'linear', only FFT bins + corresponding to frequencies between [0, max_freq] are + returned. + :types max_freq: None|float + :param target_sample_rate: Speech are resampled (if upsampling or + downsampling is allowed) to this before + extracting spectrogram features. + :type target_sample_rate: float + :param use_dB_normalization: Whether to normalize the audio to a certain + decibels before extracting the features. + :type use_dB_normalization: bool + :param target_dB: Target audio decibels for normalization. + :type target_dB: float + """ + + def __init__(self, + vocab_filepath, + specgram_type='linear', + stride_ms=10.0, + window_ms=20.0, + max_freq=None, + target_sample_rate=16000, + use_dB_normalization=True, + target_dB=-20): + self._audio_featurizer = AudioFeaturizer( + specgram_type=specgram_type, + stride_ms=stride_ms, + window_ms=window_ms, + max_freq=max_freq, + target_sample_rate=target_sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB) + self._text_featurizer = TextFeaturizer(vocab_filepath) + + def featurize(self, speech_segment): + """Extract features for speech segment. + + 1. For audio parts, extract the audio features. + 2. For transcript parts, convert text string to a list of token indices + in char-level. + + :param audio_segment: Speech segment to extract features from. + :type audio_segment: SpeechSegment + :return: A tuple of 1) spectrogram audio feature in 2darray, 2) list of + char-level token indices. + :rtype: tuple + """ + audio_feature = self._audio_featurizer.featurize(speech_segment) + text_ids = self._text_featurizer.featurize(speech_segment.transcript) + return audio_feature, text_ids + + @property + def vocab_size(self): + """Return the vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ + return self._text_featurizer.vocab_size + + @property + def vocab_list(self): + """Return the vocabulary in list. + + :return: Vocabulary in list. + :rtype: list + """ + return self._text_featurizer.vocab_list diff --git a/data_utils/featurizer/text_featurizer.py b/data_utils/featurizer/text_featurizer.py new file mode 100644 index 00000000..4f9a49b5 --- /dev/null +++ b/data_utils/featurizer/text_featurizer.py @@ -0,0 +1,67 @@ +"""Contains the text featurizer class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + + +class TextFeaturizer(object): + """Text featurizer, for processing or extracting features from text. + + Currently, it only supports char-level tokenizing and conversion into + a list of token indices. Note that the token indexing order follows the + given vocabulary file. + + :param vocab_filepath: Filepath to load vocabulary for token indices + conversion. + :type specgram_type: basestring + """ + + def __init__(self, vocab_filepath): + self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file( + vocab_filepath) + + def featurize(self, text): + """Convert text string to a list of token indices in char-level.Note + that the token indexing order follows the given vocabulary file. + + :param text: Text to process. + :type text: basestring + :return: List of char-level token indices. + :rtype: list + """ + tokens = self._char_tokenize(text) + return [self._vocab_dict[token] for token in tokens] + + @property + def vocab_size(self): + """Return the vocabulary size. + + :return: Vocabulary size. + :rtype: int + """ + return len(self._vocab_list) + + @property + def vocab_list(self): + """Return the vocabulary in list. + + :return: Vocabulary in list. + :rtype: list + """ + return self._vocab_list + + def _char_tokenize(self, text): + """Character tokenizer.""" + return list(text.strip()) + + def _load_vocabulary_from_file(self, vocab_filepath): + """Load vocabulary from file.""" + vocab_lines = [] + with open(vocab_filepath, 'r') as file: + vocab_lines.extend(file.readlines()) + vocab_list = [line[:-1] for line in vocab_lines] + vocab_dict = dict( + [(token, id) for (id, token) in enumerate(vocab_list)]) + return vocab_dict, vocab_list diff --git a/data_utils/normalizer.py b/data_utils/normalizer.py new file mode 100644 index 00000000..c123d25d --- /dev/null +++ b/data_utils/normalizer.py @@ -0,0 +1,87 @@ +"""Contains feature normalizers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import random +import data_utils.utils as utils +from data_utils.audio import AudioSegment + + +class FeatureNormalizer(object): + """Feature normalizer. Normalize features to be of zero mean and unit + stddev. + + if mean_std_filepath is provided (not None), the normalizer will directly + initilize from the file. Otherwise, both manifest_path and featurize_func + should be given for on-the-fly mean and stddev computing. + + :param mean_std_filepath: File containing the pre-computed mean and stddev. + :type mean_std_filepath: None|basestring + :param manifest_path: Manifest of instances for computing mean and stddev. + :type meanifest_path: None|basestring + :param featurize_func: Function to extract features. It should be callable + with ``featurize_func(audio_segment)``. + :type featurize_func: None|callable + :param num_samples: Number of random samples for computing mean and stddev. + :type num_samples: int + :param random_seed: Random seed for sampling instances. + :type random_seed: int + :raises ValueError: If both mean_std_filepath and manifest_path + (or both mean_std_filepath and featurize_func) are None. + """ + + def __init__(self, + mean_std_filepath, + manifest_path=None, + featurize_func=None, + num_samples=500, + random_seed=0): + if not mean_std_filepath: + if not (manifest_path and featurize_func): + raise ValueError("If mean_std_filepath is None, meanifest_path " + "and featurize_func should not be None.") + self._rng = random.Random(random_seed) + self._compute_mean_std(manifest_path, featurize_func, num_samples) + else: + self._read_mean_std_from_file(mean_std_filepath) + + def apply(self, features, eps=1e-14): + """Normalize features to be of zero mean and unit stddev. + + :param features: Input features to be normalized. + :type features: ndarray + :param eps: added to stddev to provide numerical stablibity. + :type eps: float + :return: Normalized features. + :rtype: ndarray + """ + return (features - self._mean) / (self._std + eps) + + def write_to_file(self, filepath): + """Write the mean and stddev to the file. + + :param filepath: File to write mean and stddev. + :type filepath: basestring + """ + np.savez(filepath, mean=self._mean, std=self._std) + + def _read_mean_std_from_file(self, filepath): + """Load mean and std from file.""" + npzfile = np.load(filepath) + self._mean = npzfile["mean"] + self._std = npzfile["std"] + + def _compute_mean_std(self, manifest_path, featurize_func, num_samples): + """Compute mean and std from randomly sampled instances.""" + manifest = utils.read_manifest(manifest_path) + sampled_manifest = self._rng.sample(manifest, num_samples) + features = [] + for instance in sampled_manifest: + features.append( + featurize_func( + AudioSegment.from_file(instance["audio_filepath"]))) + features = np.hstack(features) + self._mean = np.mean(features, axis=1).reshape([-1, 1]) + self._std = np.std(features, axis=1).reshape([-1, 1]) diff --git a/data_utils/speech.py b/data_utils/speech.py new file mode 100644 index 00000000..568e4443 --- /dev/null +++ b/data_utils/speech.py @@ -0,0 +1,143 @@ +"""Contains the speech segment class.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from data_utils.audio import AudioSegment + + +class SpeechSegment(AudioSegment): + """Speech segment abstraction, a subclass of AudioSegment, + with an additional transcript. + + :param samples: Audio samples [num_samples x num_channels]. + :type samples: ndarray.float32 + :param sample_rate: Audio sample rate. + :type sample_rate: int + :param transcript: Transcript text for the speech. + :type transript: basestring + :raises TypeError: If the sample data type is not float or int. + """ + + def __init__(self, samples, sample_rate, transcript): + AudioSegment.__init__(self, samples, sample_rate) + self._transcript = transcript + + def __eq__(self, other): + """Return whether two objects are equal. + """ + if not AudioSegment.__eq__(self, other): + return False + if self._transcript != other._transcript: + return False + return True + + def __ne__(self, other): + """Return whether two objects are unequal.""" + return not self.__eq__(other) + + @classmethod + def from_file(cls, filepath, transcript): + """Create speech segment from audio file and corresponding transcript. + + :param filepath: Filepath or file object to audio file. + :type filepath: basestring|file + :param transcript: Transcript text for the speech. + :type transript: basestring + :return: Audio segment instance. + :rtype: AudioSegment + """ + audio = AudioSegment.from_file(filepath) + return cls(audio.samples, audio.sample_rate, transcript) + + @classmethod + def from_bytes(cls, bytes, transcript): + """Create speech segment from a byte string and corresponding + transcript. + + :param bytes: Byte string containing audio samples. + :type bytes: str + :param transcript: Transcript text for the speech. + :type transript: basestring + :return: Audio segment instance. + :rtype: AudioSegment + """ + audio = AudioSegment.from_bytes(bytes) + return cls(audio.samples, audio.sample_rate, transcript) + + @classmethod + def concatenate(cls, *segments): + """Concatenate an arbitrary number of speech segments together, both + audio and transcript will be concatenated. + + :param *segments: Input speech segments to be concatenated. + :type *segments: tuple of SpeechSegment + :return: Speech segment instance. + :rtype: SpeechSegment + :raises ValueError: If the number of segments is zero, or if the + sample_rate of any two segments does not match. + :raises TypeError: If any segment is not SpeechSegment instance. + """ + if len(segments) == 0: + raise ValueError("No speech segments are given to concatenate.") + sample_rate = segments[0]._sample_rate + transcripts = "" + for seg in segments: + if sample_rate != seg._sample_rate: + raise ValueError("Can't concatenate segments with " + "different sample rates") + if type(seg) is not cls: + raise TypeError("Only speech segments of the same type " + "instance can be concatenated.") + transcripts += seg._transcript + samples = np.concatenate([seg.samples for seg in segments]) + return cls(samples, sample_rate, transcripts) + + @classmethod + def slice_from_file(cls, filepath, transcript, start=None, end=None): + """Loads a small section of an speech without having to load + the entire file into the memory which can be incredibly wasteful. + + :param filepath: Filepath or file object to audio file. + :type filepath: basestring|file + :param start: Start time in seconds. If start is negative, it wraps + around from the end. If not provided, this function + reads from the very beginning. + :type start: float + :param end: End time in seconds. If end is negative, it wraps around + from the end. If not provided, the default behvaior is + to read to the end of the file. + :type end: float + :param transcript: Transcript text for the speech. if not provided, + the defaults is an empty string. + :type transript: basestring + :return: SpeechSegment instance of the specified slice of the input + speech file. + :rtype: SpeechSegment + """ + audio = Audiosegment.slice_from_file(filepath, start, end) + return cls(audio.samples, audio.sample_rate, transcript) + + @classmethod + def make_silence(cls, duration, sample_rate): + """Creates a silent speech segment of the given duration and + sample rate, transcript will be an empty string. + + :param duration: Length of silence in seconds. + :type duration: float + :param sample_rate: Sample rate. + :type sample_rate: float + :return: Silence of the given duration. + :rtype: SpeechSegment + """ + audio = AudioSegment.make_silence(duration, sample_rate) + return cls(audio.samples, audio.sample_rate, "") + + @property + def transcript(self): + """Return the transcript text. + + :return: Transcript text for the speech. + :rtype: basestring + """ + return self._transcript diff --git a/data_utils/utils.py b/data_utils/utils.py new file mode 100644 index 00000000..3f116571 --- /dev/null +++ b/data_utils/utils.py @@ -0,0 +1,34 @@ +"""Contains data helper functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + + +def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): + """Load and parse manifest file. + + Instances with durations outside [min_duration, max_duration] will be + filtered out. + + :param manifest_path: Manifest file to load and parse. + :type manifest_path: basestring + :param max_duration: Maximal duration in seconds for instance filter. + :type max_duration: float + :param min_duration: Minimal duration in seconds for instance filter. + :type min_duration: float + :return: Manifest parsing results. List of dict. + :rtype: list + :raises IOError: If failed to parse the manifest. + """ + manifest = [] + for json_line in open(manifest_path): + try: + json_data = json.loads(json_line) + except Exception as e: + raise IOError("Error reading manifest: %s" % str(e)) + if (json_data["duration"] <= max_duration and + json_data["duration"] >= min_duration): + manifest.append(json_data) + return manifest diff --git a/datasets/librispeech/librispeech.py b/datasets/librispeech/librispeech.py new file mode 100644 index 00000000..87e52ae4 --- /dev/null +++ b/datasets/librispeech/librispeech.py @@ -0,0 +1,176 @@ +"""Prepare Librispeech ASR datasets. + +Download, unpack and create manifest files. +Manifest file is a json-format file with each line containing the +meta data (i.e. audio filepath, transcript and audio duration) +of each audio file in the data set. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import distutils.util +import os +import wget +import tarfile +import argparse +import soundfile +import json +from paddle.v2.dataset.common import md5file + +DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') + +URL_ROOT = "http://www.openslr.org/resources/12" +URL_TEST_CLEAN = URL_ROOT + "/test-clean.tar.gz" +URL_TEST_OTHER = URL_ROOT + "/test-other.tar.gz" +URL_DEV_CLEAN = URL_ROOT + "/dev-clean.tar.gz" +URL_DEV_OTHER = URL_ROOT + "/dev-other.tar.gz" +URL_TRAIN_CLEAN_100 = URL_ROOT + "/train-clean-100.tar.gz" +URL_TRAIN_CLEAN_360 = URL_ROOT + "/train-clean-360.tar.gz" +URL_TRAIN_OTHER_500 = URL_ROOT + "/train-other-500.tar.gz" + +MD5_TEST_CLEAN = "32fa31d27d2e1cad72775fee3f4849a9" +MD5_TEST_OTHER = "fb5a50374b501bb3bac4815ee91d3135" +MD5_DEV_CLEAN = "42e2234ba48799c1f50f24a7926300a1" +MD5_DEV_OTHER = "c8d0bcc9cca99d4f8b62fcc847357931" +MD5_TRAIN_CLEAN_100 = "2a93770f6d5c6c964bc36631d331a522" +MD5_TRAIN_CLEAN_360 = "c0e676e450a7ff2f54aeade5171606fa" +MD5_TRAIN_OTHER_500 = "d1a0fd59409feb2c614ce4d30c387708" + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--target_dir", + default=DATA_HOME + "/Libri", + type=str, + help="Directory to save the dataset. (default: %(default)s)") +parser.add_argument( + "--manifest_prefix", + default="manifest", + type=str, + help="Filepath prefix for output manifests. (default: %(default)s)") +parser.add_argument( + "--full_download", + default="True", + type=distutils.util.strtobool, + help="Download all datasets for Librispeech." + " If False, only download a minimal requirement (test-clean, dev-clean" + " train-clean-100). (default: %(default)s)") +args = parser.parse_args() + + +def download(url, md5sum, target_dir): + """ + Download file from url to target_dir, and check md5sum. + """ + if not os.path.exists(target_dir): os.makedirs(target_dir) + filepath = os.path.join(target_dir, url.split("/")[-1]) + if not (os.path.exists(filepath) and md5file(filepath) == md5sum): + print("Downloading %s ..." % url) + wget.download(url, target_dir) + print("\nMD5 Chesksum %s ..." % filepath) + if not md5file(filepath) == md5sum: + raise RuntimeError("MD5 checksum failed.") + else: + print("File exists, skip downloading. (%s)" % filepath) + return filepath + + +def unpack(filepath, target_dir): + """ + Unpack the file to the target_dir. + """ + print("Unpacking %s ..." % filepath) + tar = tarfile.open(filepath) + tar.extractall(target_dir) + tar.close() + + +def create_manifest(data_dir, manifest_path): + """ + Create a manifest json file summarizing the data set, with each line + containing the meta data (i.e. audio filepath, transcription text, audio + duration) of each audio file within the data set. + """ + print("Creating manifest %s ..." % manifest_path) + json_lines = [] + for subfolder, _, filelist in sorted(os.walk(data_dir)): + text_filelist = [ + filename for filename in filelist if filename.endswith('trans.txt') + ] + if len(text_filelist) > 0: + text_filepath = os.path.join(data_dir, subfolder, text_filelist[0]) + for line in open(text_filepath): + segments = line.strip().split() + text = ' '.join(segments[1:]).lower() + audio_filepath = os.path.join(data_dir, subfolder, + segments[0] + '.flac') + audio_data, samplerate = soundfile.read(audio_filepath) + duration = float(len(audio_data)) / samplerate + json_lines.append( + json.dumps({ + 'audio_filepath': audio_filepath, + 'duration': duration, + 'text': text + })) + with open(manifest_path, 'w') as out_file: + for line in json_lines: + out_file.write(line + '\n') + + +def prepare_dataset(url, md5sum, target_dir, manifest_path): + """ + Download, unpack and create summmary manifest file. + """ + if not os.path.exists(os.path.join(target_dir, "LibriSpeech")): + # download + filepath = download(url, md5sum, target_dir) + # unpack + unpack(filepath, target_dir) + else: + print("Skip downloading and unpacking. Data already exists in %s." % + target_dir) + # create manifest json file + create_manifest(target_dir, manifest_path) + + +def main(): + prepare_dataset( + url=URL_TEST_CLEAN, + md5sum=MD5_TEST_CLEAN, + target_dir=os.path.join(args.target_dir, "test-clean"), + manifest_path=args.manifest_prefix + ".test-clean") + prepare_dataset( + url=URL_DEV_CLEAN, + md5sum=MD5_DEV_CLEAN, + target_dir=os.path.join(args.target_dir, "dev-clean"), + manifest_path=args.manifest_prefix + ".dev-clean") + prepare_dataset( + url=URL_TRAIN_CLEAN_100, + md5sum=MD5_TRAIN_CLEAN_100, + target_dir=os.path.join(args.target_dir, "train-clean-100"), + manifest_path=args.manifest_prefix + ".train-clean-100") + if args.full_download: + prepare_dataset( + url=URL_TEST_OTHER, + md5sum=MD5_TEST_OTHER, + target_dir=os.path.join(args.target_dir, "test-other"), + manifest_path=args.manifest_prefix + ".test-other") + prepare_dataset( + url=URL_DEV_OTHER, + md5sum=MD5_DEV_OTHER, + target_dir=os.path.join(args.target_dir, "dev-other"), + manifest_path=args.manifest_prefix + ".dev-other") + prepare_dataset( + url=URL_TRAIN_CLEAN_360, + md5sum=MD5_TRAIN_CLEAN_360, + target_dir=os.path.join(args.target_dir, "train-clean-360"), + manifest_path=args.manifest_prefix + ".train-clean-360") + prepare_dataset( + url=URL_TRAIN_OTHER_500, + md5sum=MD5_TRAIN_OTHER_500, + target_dir=os.path.join(args.target_dir, "train-other-500"), + manifest_path=args.manifest_prefix + ".train-other-500") + + +if __name__ == '__main__': + main() diff --git a/datasets/run_all.sh b/datasets/run_all.sh new file mode 100644 index 00000000..ef2b721f --- /dev/null +++ b/datasets/run_all.sh @@ -0,0 +1,13 @@ +cd librispeech +python librispeech.py +if [ $? -ne 0 ]; then + echo "Prepare LibriSpeech failed. Terminated." + exit 1 +fi +cd - + +cat librispeech/manifest.train* | shuf > manifest.train +cat librispeech/manifest.dev-clean > manifest.dev +cat librispeech/manifest.test-clean > manifest.test + +echo "All done." diff --git a/datasets/vocab/eng_vocab.txt b/datasets/vocab/eng_vocab.txt new file mode 100644 index 00000000..8268f3f3 --- /dev/null +++ b/datasets/vocab/eng_vocab.txt @@ -0,0 +1,28 @@ +' + +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z diff --git a/decoder.py b/decoder.py new file mode 100644 index 00000000..77d950b8 --- /dev/null +++ b/decoder.py @@ -0,0 +1,59 @@ +"""Contains various CTC decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from itertools import groupby + + +def ctc_best_path_decode(probs_seq, vocabulary): + """Best path decoding, also called argmax decoding or greedy decoding. + Path consisting of the most probable tokens are further post-processed to + remove consecutive repetitions and all blanks. + + :param probs_seq: 2-D list of probabilities over the vocabulary for each + character. Each element is a list of float probabilities + for one character. + :type probs_seq: list + :param vocabulary: Vocabulary list. + :type vocabulary: list + :return: Decoding result string. + :rtype: baseline + """ + # dimension verification + for probs in probs_seq: + if not len(probs) == len(vocabulary) + 1: + raise ValueError("probs_seq dimension mismatchedd with vocabulary") + # argmax to get the best index for each time step + max_index_list = list(np.array(probs_seq).argmax(axis=1)) + # remove consecutive duplicate indexes + index_list = [index_group[0] for index_group in groupby(max_index_list)] + # remove blank indexes + blank_index = len(vocabulary) + index_list = [index for index in index_list if index != blank_index] + # convert index list to string + return ''.join([vocabulary[index] for index in index_list]) + + +def ctc_decode(probs_seq, vocabulary, method): + """CTC-like sequence decoding from a sequence of likelihood probablilites. + + :param probs_seq: 2-D list of probabilities over the vocabulary for each + character. Each element is a list of float probabilities + for one character. + :type probs_seq: list + :param vocabulary: Vocabulary list. + :type vocabulary: list + :param method: Decoding method name, with options: "best_path". + :type method: basestring + :return: Decoding result string. + :rtype: baseline + """ + for prob_list in probs_seq: + if not len(prob_list) == len(vocabulary) + 1: + raise ValueError("probs dimension mismatchedd with vocabulary") + if method == "best_path": + return ctc_best_path_decode(probs_seq, vocabulary) + else: + raise ValueError("Decoding method [%s] is not supported.") diff --git a/error_rate.py b/error_rate.py new file mode 100644 index 00000000..0cf17921 --- /dev/null +++ b/error_rate.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +"""This module provides functions to calculate error rate in different level. +e.g. wer for word-level, cer for char-level. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + + +def _levenshtein_distance(ref, hyp): + """Levenshtein distance is a string metric for measuring the difference between + two sequences. Informally, the levenshtein disctance is defined as the minimum + number of single-character edits (substitutions, insertions or deletions) + required to change one word into the other. We can naturally extend the edits to + word level when calculate levenshtein disctance for two sentences. + """ + ref_len = len(ref) + hyp_len = len(hyp) + + # special case + if ref == hyp: + return 0 + if ref_len == 0: + return hyp_len + if hyp_len == 0: + return ref_len + + distance = np.zeros((ref_len + 1, hyp_len + 1), dtype=np.int32) + + # initialize distance matrix + for j in xrange(hyp_len + 1): + distance[0][j] = j + for i in xrange(ref_len + 1): + distance[i][0] = i + + # calculate levenshtein distance + for i in xrange(1, ref_len + 1): + for j in xrange(1, hyp_len + 1): + if ref[i - 1] == hyp[j - 1]: + distance[i][j] = distance[i - 1][j - 1] + else: + s_num = distance[i - 1][j - 1] + 1 + i_num = distance[i][j - 1] + 1 + d_num = distance[i - 1][j] + 1 + distance[i][j] = min(s_num, i_num, d_num) + + return distance[ref_len][hyp_len] + + +def wer(reference, hypothesis, ignore_case=False, delimiter=' '): + """Calculate word error rate (WER). WER compares reference text and + hypothesis text in word-level. WER is defined as: + + .. math:: + WER = (Sw + Dw + Iw) / Nw + + where + + .. code-block:: text + + Sw is the number of words subsituted, + Dw is the number of words deleted, + Iw is the number of words inserted, + Nw is the number of words in the reference + + We can use levenshtein distance to calculate WER. Please draw an attention that + empty items will be removed when splitting sentences by delimiter. + + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Word error rate. + :rtype: float + :raises ValueError: If the reference length is zero. + """ + if ignore_case == True: + reference = reference.lower() + hypothesis = hypothesis.lower() + + ref_words = filter(None, reference.split(delimiter)) + hyp_words = filter(None, hypothesis.split(delimiter)) + + if len(ref_words) == 0: + raise ValueError("Reference's word number should be greater than 0.") + + edit_distance = _levenshtein_distance(ref_words, hyp_words) + wer = float(edit_distance) / len(ref_words) + return wer + + +def cer(reference, hypothesis, ignore_case=False): + """Calculate charactor error rate (CER). CER compares reference text and + hypothesis text in char-level. CER is defined as: + + .. math:: + CER = (Sc + Dc + Ic) / Nc + + where + + .. code-block:: text + + Sc is the number of characters substituted, + Dc is the number of characters deleted, + Ic is the number of characters inserted + Nc is the number of characters in the reference + + We can use levenshtein distance to calculate CER. Chinese input should be + encoded to unicode. Please draw an attention that the leading and tailing + white space characters will be truncated and multiple consecutive white + space characters in a sentence will be replaced by one white space character. + + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :return: Character error rate. + :rtype: float + :raises ValueError: If the reference length is zero. + """ + if ignore_case == True: + reference = reference.lower() + hypothesis = hypothesis.lower() + + reference = ' '.join(filter(None, reference.split(' '))) + hypothesis = ' '.join(filter(None, hypothesis.split(' '))) + + if len(reference) == 0: + raise ValueError("Length of reference should be greater than 0.") + + edit_distance = _levenshtein_distance(reference, hypothesis) + cer = float(edit_distance) / len(reference) + return cer diff --git a/infer.py b/infer.py new file mode 100644 index 00000000..9037a108 --- /dev/null +++ b/infer.py @@ -0,0 +1,137 @@ +"""Inferer for DeepSpeech2 model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import gzip +import distutils.util +import multiprocessing +import paddle.v2 as paddle +from data_utils.data import DataGenerator +from model import deep_speech2 +from decoder import ctc_decode +import utils + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--num_samples", + default=10, + type=int, + help="Number of samples for inference. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--mean_std_filepath", + default='mean_std.npz', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--decode_manifest_path", + default='datasets/manifest.test', + type=str, + help="Manifest path for decoding. (default: %(default)s)") +parser.add_argument( + "--model_filepath", + default='checkpoints/params.latest.tar.gz', + type=str, + help="Model filepath. (default: %(default)s)") +parser.add_argument( + "--vocab_filepath", + default='datasets/vocab/eng_vocab.txt', + type=str, + help="Vocabulary filepath. (default: %(default)s)") +args = parser.parse_args() + + +def infer(): + """Max-ctc-decoding for DeepSpeech2.""" + # initialize data generator + data_generator = DataGenerator( + vocab_filepath=args.vocab_filepath, + mean_std_filepath=args.mean_std_filepath, + augmentation_config='{}', + num_threads=args.num_threads_data) + + # create network config + # paddle.data_type.dense_array is used for variable batch input. + # The size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be induced during training. + audio_data = paddle.layer.data( + name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) + text_data = paddle.layer.data( + name="transcript_text", + type=paddle.data_type.integer_value_sequence(data_generator.vocab_size)) + output_probs = deep_speech2( + audio_data=audio_data, + text_data=text_data, + dict_size=data_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_size=args.rnn_layer_size, + is_inference=True) + + # load parameters + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.model_filepath)) + + # prepare infer data + batch_reader = data_generator.batch_reader_creator( + manifest_path=args.decode_manifest_path, + batch_size=args.num_samples, + sortagrad=False, + shuffle_method=None) + infer_data = batch_reader().next() + + # run inference + infer_results = paddle.infer( + output_layer=output_probs, parameters=parameters, input=infer_data) + num_steps = len(infer_results) // len(infer_data) + probs_split = [ + infer_results[i * num_steps:(i + 1) * num_steps] + for i in xrange(len(infer_data)) + ] + + # decode and print + for i, probs in enumerate(probs_split): + output_transcription = ctc_decode( + probs_seq=probs, + vocabulary=data_generator.vocab_list, + method="best_path") + target_transcription = ''.join( + [data_generator.vocab_list[index] for index in infer_data[i][1]]) + print("Target Transcription: %s \nOutput Transcription: %s \n" % + (target_transcription, output_transcription)) + + +def main(): + utils.print_arguments(args) + paddle.init(use_gpu=args.use_gpu, trainer_count=1) + infer() + + +if __name__ == '__main__': + main() diff --git a/model.py b/model.py new file mode 100644 index 00000000..cb0b4ecb --- /dev/null +++ b/model.py @@ -0,0 +1,143 @@ +"""Contains DeepSpeech2 model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.v2 as paddle + + +def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride, + padding, act): + """ + Convolution layer with batch normalization. + """ + conv_layer = paddle.layer.img_conv( + input=input, + filter_size=filter_size, + num_channels=num_channels_in, + num_filters=num_channels_out, + stride=stride, + padding=padding, + act=paddle.activation.Linear(), + bias_attr=False) + return paddle.layer.batch_norm(input=conv_layer, act=act) + + +def bidirectional_simple_rnn_bn_layer(name, input, size, act): + """ + Bidirectonal simple rnn layer with sequence-wise batch normalization. + The batch normalization is only performed on input-state weights. + """ + # input-hidden weights shared across bi-direcitonal rnn. + input_proj = paddle.layer.fc( + input=input, size=size, act=paddle.activation.Linear(), bias_attr=False) + # batch norm is only performed on input-state projection + input_proj_bn = paddle.layer.batch_norm( + input=input_proj, act=paddle.activation.Linear()) + # forward and backward in time + forward_simple_rnn = paddle.layer.recurrent( + input=input_proj_bn, act=act, reverse=False) + backward_simple_rnn = paddle.layer.recurrent( + input=input_proj_bn, act=act, reverse=True) + return paddle.layer.concat(input=[forward_simple_rnn, backward_simple_rnn]) + + +def conv_group(input, num_stacks): + """ + Convolution group with several stacking convolution layers. + """ + conv = conv_bn_layer( + input=input, + filter_size=(11, 41), + num_channels_in=1, + num_channels_out=32, + stride=(3, 2), + padding=(5, 20), + act=paddle.activation.BRelu()) + for i in xrange(num_stacks - 1): + conv = conv_bn_layer( + input=conv, + filter_size=(11, 21), + num_channels_in=32, + num_channels_out=32, + stride=(1, 2), + padding=(5, 10), + act=paddle.activation.BRelu()) + output_num_channels = 32 + output_height = 160 // pow(2, num_stacks) + 1 + return conv, output_num_channels, output_height + + +def rnn_group(input, size, num_stacks): + """ + RNN group with several stacking RNN layers. + """ + output = input + for i in xrange(num_stacks): + output = bidirectional_simple_rnn_bn_layer( + name=str(i), input=output, size=size, act=paddle.activation.BRelu()) + return output + + +def deep_speech2(audio_data, + text_data, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=256, + is_inference=False): + """ + The whole DeepSpeech2 model structure (a simplified version). + + :param audio_data: Audio spectrogram data layer. + :type audio_data: LayerOutput + :param text_data: Transcription text data layer. + :type text_data: LayerOutput + :param dict_size: Dictionary size for tokenized transcription. + :type dict_size: int + :param num_conv_layers: Number of stacking convolution layers. + :type num_conv_layers: int + :param num_rnn_layers: Number of stacking RNN layers. + :type num_rnn_layers: int + :param rnn_size: RNN layer size (number of RNN cells). + :type rnn_size: int + :param is_inference: False in the training mode, and True in the + inferene mode. + :type is_inference: bool + :return: If is_inference set False, return a ctc cost layer; + if is_inference set True, return a sequence layer of output + probability distribution. + :rtype: tuple of LayerOutput + """ + # convolution group + conv_group_output, conv_group_num_channels, conv_group_height = conv_group( + input=audio_data, num_stacks=num_conv_layers) + # convert data form convolution feature map to sequence of vectors + conv2seq = paddle.layer.block_expand( + input=conv_group_output, + num_channels=conv_group_num_channels, + stride_x=1, + stride_y=1, + block_x=1, + block_y=conv_group_height) + # rnn group + rnn_group_output = rnn_group( + input=conv2seq, size=rnn_size, num_stacks=num_rnn_layers) + fc = paddle.layer.fc( + input=rnn_group_output, + size=dict_size + 1, + act=paddle.activation.Linear(), + bias_attr=True) + if is_inference: + # probability distribution with softmax + return paddle.layer.mixed( + input=paddle.layer.identity_projection(input=fc), + act=paddle.activation.Softmax()) + else: + # ctc cost + return paddle.layer.warp_ctc( + input=fc, + label=text_data, + size=dict_size + 1, + blank=dict_size, + norm_by_times=True) diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 00000000..967b4f8c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +wget==3.2 +scipy==0.13.1 +resampy==0.1.5 \ No newline at end of file diff --git a/setup.sh b/setup.sh new file mode 100644 index 00000000..8cba91ec --- /dev/null +++ b/setup.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# install python dependencies +if [ -f "requirements.txt" ]; then + pip install -r requirements.txt +fi +if [ $? != 0 ]; then + echo "Install python dependencies failed !!!" + exit 1 +fi + +# install package Soundfile +curl -O "http://www.mega-nerd.com/libsndfile/files/libsndfile-1.0.28.tar.gz" +if [ $? != 0 ]; then + echo "Download libsndfile-1.0.28.tar.gz failed !!!" + exit 1 +fi +tar -zxvf libsndfile-1.0.28.tar.gz +cd libsndfile-1.0.28 +./configure && make && make install +cd - +rm -rf libsndfile-1.0.28 +rm libsndfile-1.0.28.tar.gz +pip install SoundFile==0.9.0.post1 +if [ $? != 0 ]; then + echo "Install SoundFile failed !!!" + exit 1 +fi + +# prepare ./checkpoints +mkdir checkpoints + +echo "Install all dependencies successfully." diff --git a/tests/test_error_rate.py b/tests/test_error_rate.py new file mode 100644 index 00000000..be7313f3 --- /dev/null +++ b/tests/test_error_rate.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +"""Test error rate.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +import error_rate + + +class TestParse(unittest.TestCase): + def test_wer_1(self): + ref = 'i UM the PHONE IS i LEFT THE portable PHONE UPSTAIRS last night' + hyp = 'i GOT IT TO the FULLEST i LOVE TO portable FROM OF STORES last night' + word_error_rate = error_rate.wer(ref, hyp) + self.assertTrue(abs(word_error_rate - 0.769230769231) < 1e-6) + + def test_wer_2(self): + ref = 'i UM the PHONE IS i LEFT THE portable PHONE UPSTAIRS last night' + word_error_rate = error_rate.wer(ref, ref) + self.assertEqual(word_error_rate, 0.0) + + def test_wer_3(self): + ref = ' ' + hyp = 'Hypothesis sentence' + with self.assertRaises(ValueError): + word_error_rate = error_rate.wer(ref, hyp) + + def test_cer_1(self): + ref = 'werewolf' + hyp = 'weae wolf' + char_error_rate = error_rate.cer(ref, hyp) + self.assertTrue(abs(char_error_rate - 0.25) < 1e-6) + + def test_cer_2(self): + ref = 'werewolf' + char_error_rate = error_rate.cer(ref, ref) + self.assertEqual(char_error_rate, 0.0) + + def test_cer_3(self): + ref = u'我是中国人' + hyp = u'我是 美洲人' + char_error_rate = error_rate.cer(ref, hyp) + self.assertTrue(abs(char_error_rate - 0.6) < 1e-6) + + def test_cer_4(self): + ref = u'我是中国人' + char_error_rate = error_rate.cer(ref, ref) + self.assertFalse(char_error_rate, 0.0) + + def test_cer_5(self): + ref = '' + hyp = 'Hypothesis' + with self.assertRaises(ValueError): + char_error_rate = error_rate.cer(ref, hyp) + + +if __name__ == '__main__': + unittest.main() diff --git a/train.py b/train.py new file mode 100644 index 00000000..3a2d0cad --- /dev/null +++ b/train.py @@ -0,0 +1,226 @@ +"""Trainer for DeepSpeech2 model.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import os +import argparse +import gzip +import time +import distutils.util +import multiprocessing +import paddle.v2 as paddle +from model import deep_speech2 +from data_utils.data import DataGenerator +import utils + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + "--batch_size", default=256, type=int, help="Minibatch size.") +parser.add_argument( + "--num_passes", + default=200, + type=int, + help="Training pass number. (default: %(default)s)") +parser.add_argument( + "--num_conv_layers", + default=2, + type=int, + help="Convolution layer number. (default: %(default)s)") +parser.add_argument( + "--num_rnn_layers", + default=3, + type=int, + help="RNN layer number. (default: %(default)s)") +parser.add_argument( + "--rnn_layer_size", + default=512, + type=int, + help="RNN layer cell number. (default: %(default)s)") +parser.add_argument( + "--adam_learning_rate", + default=5e-4, + type=float, + help="Learning rate for ADAM Optimizer. (default: %(default)s)") +parser.add_argument( + "--use_gpu", + default=True, + type=distutils.util.strtobool, + help="Use gpu or not. (default: %(default)s)") +parser.add_argument( + "--use_sortagrad", + default=True, + type=distutils.util.strtobool, + help="Use sortagrad or not. (default: %(default)s)") +parser.add_argument( + "--max_duration", + default=27.0, + type=float, + help="Audios with duration larger than this will be discarded. " + "(default: %(default)s)") +parser.add_argument( + "--min_duration", + default=0.0, + type=float, + help="Audios with duration smaller than this will be discarded. " + "(default: %(default)s)") +parser.add_argument( + "--shuffle_method", + default='batch_shuffle_clipped', + type=str, + help="Shuffle method: 'instance_shuffle', 'batch_shuffle', " + "'batch_shuffle_batch'. (default: %(default)s)") +parser.add_argument( + "--trainer_count", + default=8, + type=int, + help="Trainer number. (default: %(default)s)") +parser.add_argument( + "--num_threads_data", + default=multiprocessing.cpu_count(), + type=int, + help="Number of cpu threads for preprocessing data. (default: %(default)s)") +parser.add_argument( + "--mean_std_filepath", + default='mean_std.npz', + type=str, + help="Manifest path for normalizer. (default: %(default)s)") +parser.add_argument( + "--train_manifest_path", + default='datasets/manifest.train', + type=str, + help="Manifest path for training. (default: %(default)s)") +parser.add_argument( + "--dev_manifest_path", + default='datasets/manifest.dev', + type=str, + help="Manifest path for validation. (default: %(default)s)") +parser.add_argument( + "--vocab_filepath", + default='datasets/vocab/eng_vocab.txt', + type=str, + help="Vocabulary filepath. (default: %(default)s)") +parser.add_argument( + "--init_model_path", + default=None, + type=str, + help="If set None, the training will start from scratch. " + "Otherwise, the training will resume from " + "the existing model of this path. (default: %(default)s)") +parser.add_argument( + "--augmentation_config", + default='[{"type": "shift", ' + '"params": {"min_shift_ms": -5, "max_shift_ms": 5},' + '"prob": 1.0}]', + type=str, + help="Augmentation configuration in json-format. " + "(default: %(default)s)") +args = parser.parse_args() + + +def train(): + """DeepSpeech2 training.""" + + # initialize data generator + def data_generator(): + return DataGenerator( + vocab_filepath=args.vocab_filepath, + mean_std_filepath=args.mean_std_filepath, + augmentation_config=args.augmentation_config, + max_duration=args.max_duration, + min_duration=args.min_duration, + num_threads=args.num_threads_data) + + train_generator = data_generator() + test_generator = data_generator() + + # create network config + # paddle.data_type.dense_array is used for variable batch input. + # The size 161 * 161 is only an placeholder value and the real shape + # of input batch data will be induced during training. + audio_data = paddle.layer.data( + name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161)) + text_data = paddle.layer.data( + name="transcript_text", + type=paddle.data_type.integer_value_sequence( + train_generator.vocab_size)) + cost = deep_speech2( + audio_data=audio_data, + text_data=text_data, + dict_size=train_generator.vocab_size, + num_conv_layers=args.num_conv_layers, + num_rnn_layers=args.num_rnn_layers, + rnn_size=args.rnn_layer_size, + is_inference=False) + + # create/load parameters and optimizer + if args.init_model_path is None: + parameters = paddle.parameters.create(cost) + else: + if not os.path.isfile(args.init_model_path): + raise IOError("Invalid model!") + parameters = paddle.parameters.Parameters.from_tar( + gzip.open(args.init_model_path)) + optimizer = paddle.optimizer.Adam( + learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400) + trainer = paddle.trainer.SGD( + cost=cost, parameters=parameters, update_equation=optimizer) + + # prepare data reader + train_batch_reader = train_generator.batch_reader_creator( + manifest_path=args.train_manifest_path, + batch_size=args.batch_size, + min_batch_size=args.trainer_count, + sortagrad=args.use_sortagrad if args.init_model_path is None else False, + shuffle_method=args.shuffle_method) + test_batch_reader = test_generator.batch_reader_creator( + manifest_path=args.dev_manifest_path, + batch_size=args.batch_size, + min_batch_size=1, # must be 1, but will have errors. + sortagrad=False, + shuffle_method=None) + + # create event handler + def event_handler(event): + global start_time, cost_sum, cost_counter + if isinstance(event, paddle.event.EndIteration): + cost_sum += event.cost + cost_counter += 1 + if (event.batch_id + 1) % 100 == 0: + print("\nPass: %d, Batch: %d, TrainCost: %f" % ( + event.pass_id, event.batch_id + 1, cost_sum / cost_counter)) + cost_sum, cost_counter = 0.0, 0 + with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f: + parameters.to_tar(f) + else: + sys.stdout.write('.') + sys.stdout.flush() + if isinstance(event, paddle.event.BeginPass): + start_time = time.time() + cost_sum, cost_counter = 0.0, 0 + if isinstance(event, paddle.event.EndPass): + result = trainer.test( + reader=test_batch_reader, feeding=test_generator.feeding) + print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % + (time.time() - start_time, event.pass_id, result.cost)) + with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id, + 'w') as f: + parameters.to_tar(f) + + # run train + trainer.train( + reader=train_batch_reader, + event_handler=event_handler, + num_passes=args.num_passes, + feeding=train_generator.feeding) + + +def main(): + utils.print_arguments(args) + paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) + train() + + +if __name__ == '__main__': + main() diff --git a/utils.py b/utils.py new file mode 100644 index 00000000..9ca363c8 --- /dev/null +++ b/utils.py @@ -0,0 +1,25 @@ +"""Contains common utility functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("----- Configuration Arguments -----") + for arg, value in vars(args).iteritems(): + print("%s: %s" % (arg, value)) + print("------------------------------------")