Refactor whole data preprocessor for DS2 (re-design classes, re-organize dir, add augmentaion interfaces etc.).

1. Refactor data preprocessor with new added class AudioSegment, SpeechSegment, TextFeaturizer, AudioFeaturizer, SpeechFeaturizer.
2. Add data augmentation interfaces and class AugmentorBase, AugmentationPipeline, VolumnPerturbAugmentor etc..
3. Seperate normalizer's mean and std computing from training, by adding FeatureNormalizer and a seperate tool compute_mean_std.py.
4. Re-organize directory.
pull/2/head
Xinghai Sun 8 years ago
parent 9c27b1d14e
commit cd3617aeb4

@ -1,411 +0,0 @@
"""
Providing basic audio data preprocessing pipeline, and offering
both instance-level and batch-level data reader interfaces.
"""
import paddle.v2 as paddle
import logging
import json
import random
import soundfile
import numpy as np
import itertools
import os
RANDOM_SEED = 0
logger = logging.getLogger(__name__)
class DataGenerator(object):
"""
DataGenerator provides basic audio data preprocessing pipeline, and offers
both instance-level and batch-level data reader interfaces.
Normalized FFT are used as audio features here.
:param vocab_filepath: Vocabulary file path for indexing tokenized
transcriptions.
:type vocab_filepath: basestring
:param normalizer_manifest_path: Manifest filepath for collecting feature
normalization statistics, e.g. mean, std.
:type normalizer_manifest_path: basestring
:param normalizer_num_samples: Number of instances sampled for collecting
feature normalization statistics.
Default is 100.
:type normalizer_num_samples: int
:param max_duration: Audio clips with duration (in seconds) greater than
this will be discarded. Default is 20.0.
:type max_duration: float
:param min_duration: Audio clips with duration (in seconds) smaller than
this will be discarded. Default is 0.0.
:type min_duration: float
:param stride_ms: Striding size (in milliseconds) for generating frames.
Default is 10.0.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for frames. Default is 20.0.
:type window_ms: float
:param max_frequency: Maximun frequency for FFT features. FFT features of
frequency larger than this will be discarded.
If set None, all features will be kept.
Default is None.
:type max_frequency: float
"""
def __init__(self,
vocab_filepath,
normalizer_manifest_path,
normalizer_num_samples=100,
max_duration=20.0,
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_frequency=None):
self.__max_duration__ = max_duration
self.__min_duration__ = min_duration
self.__stride_ms__ = stride_ms
self.__window_ms__ = window_ms
self.__max_frequency__ = max_frequency
self.__epoc__ = 0
self.__random__ = random.Random(RANDOM_SEED)
# load vocabulary (dictionary)
self.__vocab_dict__, self.__vocab_list__ = \
self.__load_vocabulary_from_file__(vocab_filepath)
# collect normalizer statistics
self.__mean__, self.__std__ = self.__collect_normalizer_statistics__(
manifest_path=normalizer_manifest_path,
num_samples=normalizer_num_samples)
def __audio_featurize__(self, audio_filename):
"""
Preprocess audio data, including feature extraction, normalization etc..
"""
features = self.__audio_basic_featurize__(audio_filename)
return self.__normalize__(features)
def __text_featurize__(self, text):
"""
Preprocess text data, including tokenizing and token indexing etc..
"""
return self.__convert_text_to_char_index__(
text=text, vocabulary=self.__vocab_dict__)
def __audio_basic_featurize__(self, audio_filename):
"""
Compute basic (without normalization etc.) features for audio data.
"""
return self.__spectrogram_from_file__(
filename=audio_filename,
stride_ms=self.__stride_ms__,
window_ms=self.__window_ms__,
max_freq=self.__max_frequency__)
def __collect_normalizer_statistics__(self, manifest_path, num_samples=100):
"""
Compute feature normalization statistics, i.e. mean and stddev.
"""
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sample for statistics
sampled_manifest = self.__random__.sample(manifest, num_samples)
# extract spectrogram feature
features = []
for instance in sampled_manifest:
spectrogram = self.__audio_basic_featurize__(
instance["audio_filepath"])
features.append(spectrogram)
features = np.hstack(features)
mean = np.mean(features, axis=1).reshape([-1, 1])
std = np.std(features, axis=1).reshape([-1, 1])
return mean, std
def __normalize__(self, features, eps=1e-14):
"""
Normalize features to be of zero mean and unit stddev.
"""
return (features - self.__mean__) / (self.__std__ + eps)
def __spectrogram_from_file__(self,
filename,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
eps=1e-14):
"""
Laod audio data and calculate the log of spectrogram by FFT.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
audio, sample_rate = soundfile.read(filename)
if audio.ndim >= 2:
audio = np.mean(audio, 1)
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
spectrogram, freqs = self.__extract_spectrogram__(
audio,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
return np.log(spectrogram[:ind, :] + eps)
def __extract_spectrogram__(self, samples, window_size, stride_size,
sample_rate):
"""
Compute the spectrogram by FFT for a discrete real signal.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
# extract strided windows
truncate_size = (len(samples) - window_size) % stride_size
samples = samples[:len(samples) - truncate_size]
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
windows = np.lib.stride_tricks.as_strided(
samples, shape=nshape, strides=nstrides)
assert np.all(
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, axis=0)
fft = np.absolute(fft)**2
scale = np.sum(weighting**2) * sample_rate
fft[1:-1, :] *= (2.0 / scale)
fft[(0, -1), :] /= scale
# prepare fft frequency list
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs
def __load_vocabulary_from_file__(self, vocabulary_path):
"""
Load vocabulary from file.
"""
if not os.path.exists(vocabulary_path):
raise ValueError("Vocabulary file %s not found.", vocabulary_path)
vocab_lines = []
with open(vocabulary_path, 'r') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_dict = dict(
[(token, id) for (id, token) in enumerate(vocab_list)])
return vocab_dict, vocab_list
def __convert_text_to_char_index__(self, text, vocabulary):
"""
Convert text string to a list of character index integers.
"""
return [vocabulary[w] for w in text]
def __read_manifest__(self, manifest_path, max_duration, min_duration):
"""
Load and parse manifest file.
"""
manifest = []
for json_line in open(manifest_path):
try:
json_data = json.loads(json_line)
except Exception as e:
raise ValueError("Error reading manifest: %s" % str(e))
if (json_data["duration"] <= max_duration and
json_data["duration"] >= min_duration):
manifest.append(json_data)
return manifest
def __padding_batch__(self, batch, padding_to=-1, flatten=False):
"""
Padding audio part of features (only in the time axis -- column axis)
with zeros, to make each instance in the batch share the same
audio feature shape.
If `padding_to` is set -1, the maximun column numbers in the batch will
be used as the target size. Otherwise, `padding_to` will be the target
size. Default is -1.
If `flatten` is set True, audio data will be flatten to be a 1-dim
ndarray. Default is False.
"""
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, text in batch])
if padding_to != -1:
if padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be greater"
" or equal to the original instance length.")
max_length = padding_to
# padding
for audio, text in batch:
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
new_batch.append((padded_audio, text))
return new_batch
def __batch_shuffle__(self, manifest, batch_size):
"""
The instances have different lengths and they cannot be
combined into a single matrix multiplication. It usually
sorts the training examples by length and combines only
similarly-sized instances into minibatches, pads with
silence when necessary so that all instances in a batch
have the same length. This batch shuffle fuction is used
to make similarly-sized instances into minibatches and
make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_size.
4. Shuffle the minibatches.
:param manifest: manifest file.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:return: batch shuffled mainifest.
:rtype: list
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self.__random__.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self.__random__.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest
def instance_reader_creator(self, manifest):
"""
Instance reader creator for audio data. Creat a callable function to
produce instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text.
:param manifest: Filepath of manifest for audio clip files.
:type manifest: basestring
:return: Data reader function.
:rtype: callable
"""
def reader():
# extract spectrogram feature
for instance in manifest:
spectrogram = self.__audio_featurize__(
instance["audio_filepath"])
transcript = self.__text_featurize__(instance["text"])
yield (spectrogram, transcript)
return reader
def batch_reader_creator(self,
manifest_path,
batch_size,
padding_to=-1,
flatten=False,
sortagrad=False,
batch_shuffle=False):
"""
Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
Audio features will be padded with zeros to make each instance in the
batch to share the same audio feature shape.
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param batch_size: Instance number in a batch.
:type batch_size: int
:param padding_to: If set -1, the maximun column numbers in the batch
will be used as the target size for padding.
Otherwise, `padding_to` will be the target size.
Default is -1.
:type padding_to: int
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:param sortagrad: Sort the audio clips by duration in the first epoc
if set True.
:type sortagrad: bool
:param batch_shuffle: Shuffle the audio clips if set True. It is
not a thorough instance-wise shuffle, but a
specific batch-wise shuffle. For more details,
please see `__batch_shuffle__` function.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
def batch_reader():
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sort (by duration) or shuffle manifest
if self.__epoc__ == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
elif batch_shuffle:
manifest = self.__batch_shuffle__(manifest, batch_size)
instance_reader = self.instance_reader_creator(manifest)
batch = []
for instance in instance_reader():
batch.append(instance)
if len(batch) == batch_size:
yield self.__padding_batch__(batch, padding_to, flatten)
batch = []
if len(batch) > 0:
yield self.__padding_batch__(batch, padding_to, flatten)
self.__epoc__ += 1
return batch_reader
def vocabulary_size(self):
"""
Get vocabulary size.
:return: Vocabulary size.
:rtype: int
"""
return len(self.__vocab_list__)
def vocabulary_dict(self):
"""
Get vocabulary in dict.
:return: Vocabulary in dict.
:rtype: dict
"""
return self.__vocab_dict__
def vocabulary_list(self):
"""
Get vocabulary in list.
:return: Vocabulary in list
:rtype: list
"""
return self.__vocab_list__
def data_name_feeding(self):
"""
Get feeddings (data field name and corresponding field id).
:return: Feeding dict.
:rtype: dict
"""
feeding = {
"audio_spectrogram": 0,
"transcript_text": 1,
}
return feeding

@ -0,0 +1,56 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from data_utils.normalizer import FeatureNormalizer
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
parser = argparse.ArgumentParser(
description='Computing mean and stddev for feature normalizer.')
parser.add_argument(
"--manifest_path",
default='datasets/manifest.train',
type=str,
help="Manifest path for computing normalizer's mean and stddev."
"(default: %(default)s)")
parser.add_argument(
"--num_samples",
default=500,
type=int,
help="Number of samples for computing mean and stddev. "
"(default: %(default)s)")
parser.add_argument(
"--augmentation_config",
default='{}',
type=str,
help="Augmentation configuration in json-format. "
"(default: %(default)s)")
parser.add_argument(
"--output_file",
default='mean_std.npz',
type=str,
help="Filepath to write mean and std to (.npz)."
"(default: %(default)s)")
args = parser.parse_args()
def main():
augmentation_pipeline = AugmentationPipeline(args.augmentation_config)
audio_featurizer = AudioFeaturizer()
def augment_and_featurize(audio_segment):
augmentation_pipeline.transform_audio(audio_segment)
return audio_featurizer.featurize(audio_segment)
normalizer = FeatureNormalizer(
mean_std_filepath=None,
manifest_path=args.manifest_path,
featurize_func=augment_and_featurize,
num_samples=args.num_samples)
normalizer.write_to_file(args.output_file)
if __name__ == '__main__':
main()

@ -0,0 +1,68 @@
import numpy as np
import io
import soundfile
class AudioSegment(object):
"""Monaural audio segment abstraction.
"""
def __init__(self, samples, sample_rate):
if not samples.dtype == np.float32:
raise ValueError("Sample data type of [%s] is not supported.")
self._samples = samples
self._sample_rate = sample_rate
if self._samples.ndim >= 2:
self._samples = np.mean(self._samples, 1)
@classmethod
def from_file(cls, filepath):
samples, sample_rate = soundfile.read(filepath, dtype='float32')
return cls(samples, sample_rate)
@classmethod
def from_bytes(cls, bytes):
samples, sample_rate = soundfile.read(
io.BytesIO(bytes), dtype='float32')
return cls(samples, sample_rate)
def apply_gain(self, gain):
self.samples *= 10.**(gain / 20.)
def resample(self, target_sample_rate):
raise NotImplementedError()
def change_speed(self, rate):
raise NotImplementedError()
@property
def samples(self):
return self._samples.copy()
@property
def sample_rate(self):
return self._sample_rate
@property
def duration(self):
return self._samples.shape[0] / float(self._sample_rate)
class SpeechSegment(AudioSegment):
def __init__(self, samples, sample_rate, transcript):
AudioSegment.__init__(self, samples, sample_rate)
self._transcript = transcript
@classmethod
def from_file(cls, filepath, transcript):
audio = AudioSegment.from_file(filepath)
return cls(audio.samples, audio.sample_rate, transcript)
@classmethod
def from_bytes(cls, bytes, transcript):
audio = AudioSegment.from_bytes(bytes)
return cls(audio.samples, audio.sample_rate, transcript)
@property
def transcript(self):
return self._transcript

@ -0,0 +1,38 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import random
from data_utils.augmentor.volumn_perturb import VolumnPerturbAugmentor
class AugmentationPipeline(object):
def __init__(self, augmentation_config, random_seed=0):
self._rng = random.Random(random_seed)
self._augmentors, self._rates = self._parse_pipeline_from(
augmentation_config)
def transform_audio(self, audio_segment):
for augmentor, rate in zip(self._augmentors, self._rates):
if self._rng.uniform(0., 1.) <= rate:
augmentor.transform_audio(audio_segment)
def _parse_pipeline_from(self, config_json):
try:
configs = json.loads(config_json)
except Exception as e:
raise ValueError("Augmentation config json format error: "
"%s" % str(e))
augmentors = [
self._get_augmentor(config["type"], config["params"])
for config in configs
]
rates = [config["rate"] for config in configs]
return augmentors, rates
def _get_augmentor(self, augmentor_type, params):
if augmentor_type == "volumn":
return VolumnPerturbAugmentor(self._rng, **params)
else:
raise ValueError("Unknown augmentor type [%s]." % augmentor_type)

@ -0,0 +1,17 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import ABCMeta, abstractmethod
class AugmentorBase(object):
__metaclass__ = ABCMeta
@abstractmethod
def __init__(self):
pass
@abstractmethod
def transform_audio(self, audio_segment):
pass

@ -0,0 +1,17 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
from data_utils.augmentor.base import AugmentorBase
class VolumnPerturbAugmentor(AugmentorBase):
def __init__(self, rng, min_gain_dBFS, max_gain_dBFS):
self._min_gain_dBFS = min_gain_dBFS
self._max_gain_dBFS = max_gain_dBFS
self._rng = rng
def transform_audio(self, audio_segment):
gain = self._rng.uniform(min_gain_dBFS, max_gain_dBFS)
audio_segment.apply_gain(gain)

@ -0,0 +1,247 @@
"""
Providing basic audio data preprocessing pipeline, and offering
both instance-level and batch-level data reader interfaces.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
import paddle.v2 as paddle
from data_utils import utils
from data_utils.augmentor.augmentation import AugmentationPipeline
from data_utils.featurizer.speech_featurizer import SpeechFeaturizer
from data_utils.audio import SpeechSegment
from data_utils.normalizer import FeatureNormalizer
class DataGenerator(object):
"""
DataGenerator provides basic audio data preprocessing pipeline, and offers
both instance-level and batch-level data reader interfaces.
Normalized FFT are used as audio features here.
:param vocab_filepath: Vocabulary file path for indexing tokenized
transcriptions.
:type vocab_filepath: basestring
:param normalizer_manifest_path: Manifest filepath for collecting feature
normalization statistics, e.g. mean, std.
:type normalizer_manifest_path: basestring
:param normalizer_num_samples: Number of instances sampled for collecting
feature normalization statistics.
Default is 100.
:type normalizer_num_samples: int
:param max_duration: Audio clips with duration (in seconds) greater than
this will be discarded. Default is 20.0.
:type max_duration: float
:param min_duration: Audio clips with duration (in seconds) smaller than
this will be discarded. Default is 0.0.
:type min_duration: float
:param stride_ms: Striding size (in milliseconds) for generating frames.
Default is 10.0.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for frames. Default is 20.0.
:type window_ms: float
:param max_frequency: Maximun frequency for FFT features. FFT features of
frequency larger than this will be discarded.
If set None, all features will be kept.
Default is None.
:type max_frequency: float
"""
def __init__(self,
vocab_filepath,
mean_std_filepath,
augmentation_config='{}',
max_duration=float('inf'),
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
random_seed=0):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=augmentation_config, random_seed=random_seed)
self._speech_featurizer = SpeechFeaturizer(
vocab_filepath=vocab_filepath,
stride_ms=stride_ms,
window_ms=window_ms,
max_freq=max_freq,
random_seed=random_seed)
self._rng = random.Random(random_seed)
self._epoch = 0
def batch_reader_creator(self,
manifest_path,
batch_size,
padding_to=-1,
flatten=False,
sortagrad=False,
batch_shuffle=False):
"""
Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
Audio features will be padded with zeros to make each instance in the
batch to share the same audio feature shape.
:param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring
:param batch_size: Instance number in a batch.
:type batch_size: int
:param padding_to: If set -1, the maximun column numbers in the batch
will be used as the target size for padding.
Otherwise, `padding_to` will be the target size.
Default is -1.
:type padding_to: int
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool
:param sortagrad: Sort the audio clips by duration in the first epoc
if set True.
:type sortagrad: bool
:param batch_shuffle: Shuffle the audio clips if set True. It is
not a thorough instance-wise shuffle, but a
specific batch-wise shuffle. For more details,
please see `_batch_shuffle` function.
:type batch_shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable
"""
def batch_reader():
# read manifest
manifest = utils.read_manifest(
manifest_path=manifest_path,
max_duration=self._max_duration,
min_duration=self._min_duration)
# sort (by duration) or batch-wise shuffle the manifest
if self._epoch == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"])
elif batch_shuffle:
manifest = self._batch_shuffle(manifest, batch_size)
# prepare batches
instance_reader = self._instance_reader_creator(manifest)
batch = []
for instance in instance_reader():
batch.append(instance)
if len(batch) == batch_size:
yield self._padding_batch(batch, padding_to, flatten)
batch = []
if len(batch) > 0:
yield self._padding_batch(batch, padding_to, flatten)
self._epoch += 1
return batch_reader
@property
def feeding(self):
"""Returns data_reader's feeding dict."""
return {"audio_spectrogram": 0, "transcript_text": 1}
@property
def vocab_size(self):
"""Returns vocabulary size."""
return self._speech_featurizer.vocab_size
@property
def vocab_list(self):
"""Returns vocabulary list."""
return self._speech_featurizer.vocab_list
def _process_utterance(self, filename, transcript):
speech_segment = SpeechSegment.from_file(filename, transcript)
self._augmentation_pipeline.transform_audio(speech_segment)
specgram, text_ids = self._speech_featurizer.featurize(speech_segment)
specgram = self._normalizer.apply(specgram)
return specgram, text_ids
def _instance_reader_creator(self, manifest):
"""
Instance reader creator for audio data. Creat a callable function to
produce instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized and indexed transcription text.
:param manifest: Filepath of manifest for audio clip files.
:type manifest: basestring
:return: Data reader function.
:rtype: callable
"""
def reader():
for instance in manifest:
yield self._process_utterance(instance["audio_filepath"],
instance["text"])
return reader
def _padding_batch(self, batch, padding_to=-1, flatten=False):
"""
Padding audio part of features (only in the time axis -- column axis)
with zeros, to make each instance in the batch share the same
audio feature shape.
If `padding_to` is set -1, the maximun column numbers in the batch will
be used as the target size. Otherwise, `padding_to` will be the target
size. Default is -1.
If `flatten` is set True, audio data will be flatten to be a 1-dim
ndarray. Default is False.
"""
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, text in batch])
if padding_to != -1:
if padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be greater"
" or equal to the original instance length.")
max_length = padding_to
# padding
for audio, text in batch:
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
new_batch.append((padded_audio, text))
return new_batch
def _batch_shuffle(self, manifest, batch_size):
"""
The instances have different lengths and they cannot be
combined into a single matrix multiplication. It usually
sorts the training examples by length and combines only
similarly-sized instances into minibatches, pads with
silence when necessary so that all instances in a batch
have the same length. This batch shuffle fuction is used
to make similarly-sized instances into minibatches and
make a batch-wise shuffle.
1. Sort the audio clips by duration.
2. Generate a random number `k`, k in [0, batch_size).
3. Randomly remove `k` instances in order to make different mini-batches,
then make minibatches and each minibatch size is batch_size.
4. Shuffle the minibatches.
:param manifest: manifest file.
:type manifest: list
:param batch_size: Batch size. This size is also used for generate
a random number for batch shuffle.
:type batch_size: int
:return: batch shuffled mainifest.
:rtype: list
"""
manifest.sort(key=lambda x: x["duration"])
shift_len = self._rng.randint(0, batch_size - 1)
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
self._rng.shuffle(batch_manifest)
batch_manifest = list(sum(batch_manifest, ()))
res_len = len(manifest) - shift_len - len(batch_manifest)
batch_manifest.extend(manifest[-res_len:])
batch_manifest.extend(manifest[0:shift_len])
return batch_manifest

@ -0,0 +1,86 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
from data_utils import utils
from data_utils.audio import AudioSegment
class AudioFeaturizer(object):
def __init__(self,
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
random_seed=0):
self._specgram_type = specgram_type
self._stride_ms = stride_ms
self._window_ms = window_ms
self._max_freq = max_freq
def featurize(self, audio_segment):
return self._compute_specgram(audio_segment.samples,
audio_segment.sample_rate)
def _compute_specgram(self, samples, sample_rate):
if self._specgram_type == 'linear':
return self._compute_linear_specgram(
samples, sample_rate, self._stride_ms, self._window_ms,
self._max_freq)
else:
raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type)
def _compute_linear_specgram(self,
samples,
sample_rate,
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
eps=1e-14):
"""Laod audio data and calculate the log of spectrogram by FFT.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
if max_freq is None:
max_freq = sample_rate / 2
if max_freq > sample_rate / 2:
raise ValueError("max_freq must be greater than half of "
"sample rate.")
if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms)
specgram, freqs = self._specgram_real(
samples,
window_size=window_size,
stride_size=stride_size,
sample_rate=sample_rate)
ind = np.where(freqs <= max_freq)[0][-1] + 1
return np.log(specgram[:ind, :] + eps)
def _specgram_real(self, samples, window_size, stride_size, sample_rate):
"""Compute the spectrogram by FFT for a discrete real signal.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
"""
# extract strided windows
truncate_size = (len(samples) - window_size) % stride_size
samples = samples[:len(samples) - truncate_size]
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
windows = np.lib.stride_tricks.as_strided(
samples, shape=nshape, strides=nstrides)
assert np.all(
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, axis=0)
fft = np.absolute(fft)**2
scale = np.sum(weighting**2) * sample_rate
fft[1:-1, :] *= (2.0 / scale)
fft[(0, -1), :] /= scale
# prepare fft frequency list
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs

@ -0,0 +1,32 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from data_utils.featurizer.audio_featurizer import AudioFeaturizer
from data_utils.featurizer.text_featurizer import TextFeaturizer
class SpeechFeaturizer(object):
def __init__(self,
vocab_filepath,
specgram_type='linear',
stride_ms=10.0,
window_ms=20.0,
max_freq=None,
random_seed=0):
self._audio_featurizer = AudioFeaturizer(
specgram_type, stride_ms, window_ms, max_freq, random_seed)
self._text_featurizer = TextFeaturizer(vocab_filepath)
def featurize(self, speech_segment):
audio_feature = self._audio_featurizer.featurize(speech_segment)
text_ids = self._text_featurizer.text2ids(speech_segment.transcript)
return audio_feature, text_ids
@property
def vocab_size(self):
return self._text_featurizer.vocab_size
@property
def vocab_list(self):
return self._text_featurizer.vocab_list

@ -0,0 +1,39 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
class TextFeaturizer(object):
def __init__(self, vocab_filepath):
self._vocab_dict, self._vocab_list = self._load_vocabulary_from_file(
vocab_filepath)
def text2ids(self, text):
tokens = self._char_tokenize(text)
return [self._vocab_dict[token] for token in tokens]
def ids2text(self, ids):
return ''.join([self._vocab_list[id] for id in ids])
@property
def vocab_size(self):
return len(self._vocab_list)
@property
def vocab_list(self):
return self._vocab_list
def _char_tokenize(self, text):
return list(text.strip())
def _load_vocabulary_from_file(self, vocab_filepath):
"""Load vocabulary from file."""
vocab_lines = []
with open(vocab_filepath, 'r') as file:
vocab_lines.extend(file.readlines())
vocab_list = [line[:-1] for line in vocab_lines]
vocab_dict = dict(
[(token, id) for (id, token) in enumerate(vocab_list)])
return vocab_dict, vocab_list

@ -0,0 +1,49 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
import data_utils.utils as utils
from data_utils.audio import AudioSegment
class FeatureNormalizer(object):
def __init__(self,
mean_std_filepath,
manifest_path=None,
featurize_func=None,
num_samples=500,
random_seed=0):
if not mean_std_filepath:
if not (manifest_path and featurize_func):
raise ValueError("If mean_std_filepath is None, meanifest_path "
"and featurize_func should not be None.")
self._rng = random.Random(random_seed)
self._compute_mean_std(manifest_path, featurize_func, num_samples)
else:
self._read_mean_std_from_file(mean_std_filepath)
def apply(self, features, eps=1e-14):
"""Normalize features to be of zero mean and unit stddev."""
return (features - self._mean) / (self._std + eps)
def write_to_file(self, filepath):
np.savez(filepath, mean=self._mean, std=self._std)
def _read_mean_std_from_file(self, filepath):
npzfile = np.load(filepath)
self._mean = npzfile["mean"]
self._std = npzfile["std"]
def _compute_mean_std(self, manifest_path, featurize_func, num_samples):
manifest = utils.read_manifest(manifest_path)
sampled_manifest = self._rng.sample(manifest, num_samples)
features = []
for instance in sampled_manifest:
features.append(
featurize_func(
AudioSegment.from_file(instance["audio_filepath"])))
features = np.hstack(features)
self._mean = np.mean(features, axis=1).reshape([-1, 1])
self._std = np.std(features, axis=1).reshape([-1, 1])

@ -0,0 +1,19 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
"""Load and parse manifest file."""
manifest = []
for json_line in open(manifest_path):
try:
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
if (json_data["duration"] <= max_duration and
json_data["duration"] >= min_duration):
manifest.append(json_data)
return manifest

@ -44,7 +44,7 @@ parser.add_argument(
help="Directory to save the dataset. (default: %(default)s)")
parser.add_argument(
"--manifest_prefix",
default="manifest.libri",
default="manifest",
type=str,
help="Filepath prefix for output manifests. (default: %(default)s)")
parser.add_argument(

@ -0,0 +1,13 @@
cd librispeech
python librispeech.py
if [ $? -ne 0 ]; then
echo "Prepare LibriSpeech failed. Terminated."
exit 1
fi
cd -
cat librispeech/manifest.train* | shuf > manifest.train
cat librispeech/manifest.dev-clean > manifest.dev
cat librispeech/manifest.test-clean > manifest.test
echo "All done."

@ -2,11 +2,15 @@
Inference for a simplifed version of Baidu DeepSpeech2 model.
"""
import paddle.v2 as paddle
import distutils.util
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import gzip
from audio_data_utils import DataGenerator
import distutils.util
import paddle.v2 as paddle
from data_utils.data import DataGenerator
from model import deep_speech2
from decoder import ctc_decode
@ -38,13 +42,13 @@ parser.add_argument(
type=distutils.util.strtobool,
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--normalizer_manifest_path",
default='data/manifest.libri.train-clean-100',
"--mean_std_filepath",
default='mean_std.npz',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--decode_manifest_path",
default='data/manifest.libri.test-clean',
default='datasets/manifest.test',
type=str,
help="Manifest path for decoding. (default: %(default)s)")
parser.add_argument(
@ -54,7 +58,7 @@ parser.add_argument(
help="Model filepath. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='data/eng_vocab.txt',
default='datasets/vocab/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
args = parser.parse_args()
@ -67,28 +71,22 @@ def infer():
# initialize data generator
data_generator = DataGenerator(
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
mean_std_filepath=args.mean_std_filepath,
augmentation_config='{}')
# create network config
dict_size = data_generator.vocabulary_size()
vocab_list = data_generator.vocabulary_list()
# paddle.data_type.dense_array is used for variable batch input.
# The size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be induced during training.
audio_data = paddle.layer.data(
name="audio_spectrogram",
height=161,
width=2000,
type=paddle.data_type.dense_vector(322000))
name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size))
type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
output_probs = deep_speech2(
audio_data=audio_data,
text_data=text_data,
dict_size=dict_size,
dict_size=data_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size,
@ -99,31 +97,30 @@ def infer():
gzip.open(args.model_filepath))
# prepare infer data
feeding = data_generator.data_name_feeding()
test_batch_reader = data_generator.batch_reader_creator(
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.decode_manifest_path,
batch_size=args.num_samples,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False)
infer_data = test_batch_reader().next()
sortagrad=False,
batch_shuffle=False)
infer_data = batch_reader().next()
# run inference
infer_results = paddle.infer(
output_layer=output_probs, parameters=parameters, input=infer_data)
num_steps = len(infer_results) / len(infer_data)
num_steps = len(infer_results) // len(infer_data)
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
for i in xrange(0, len(infer_data))
for i in xrange(len(infer_data))
]
# decode and print
for i, probs in enumerate(probs_split):
output_transcription = ctc_decode(
probs_seq=probs, vocabulary=vocab_list, method="best_path")
probs_seq=probs,
vocabulary=data_generator.vocab_list,
method="best_path")
target_transcription = ''.join(
[vocab_list[index] for index in infer_data[i][1]])
[data_generator.vocab_list[index] for index in infer_data[i][1]])
print("Target Transcription: %s \nOutput Transcription: %s \n" %
(target_transcription, output_transcription))

@ -2,21 +2,21 @@
Trainer for a simplifed version of Baidu DeepSpeech2 model.
"""
import paddle.v2 as paddle
import distutils.util
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import argparse
import gzip
import time
import sys
import distutils.util
import paddle.v2 as paddle
from model import deep_speech2
from audio_data_utils import DataGenerator
import numpy as np
import os
from data_utils.data import DataGenerator
#TODO: add WER metric
parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 trainer.')
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--batch_size", default=32, type=int, help="Minibatch size.")
parser.add_argument(
@ -51,7 +51,7 @@ parser.add_argument(
help="Use gpu or not. (default: %(default)s)")
parser.add_argument(
"--use_sortagrad",
default=False,
default=True,
type=distutils.util.strtobool,
help="Use sortagrad or not. (default: %(default)s)")
parser.add_argument(
@ -60,23 +60,23 @@ parser.add_argument(
type=int,
help="Trainer number. (default: %(default)s)")
parser.add_argument(
"--normalizer_manifest_path",
default='data/manifest.libri.train-clean-100',
"--mean_std_filepath",
default='mean_std.npz',
type=str,
help="Manifest path for normalizer. (default: %(default)s)")
parser.add_argument(
"--train_manifest_path",
default='data/manifest.libri.train-clean-100',
default='datasets/manifest.train',
type=str,
help="Manifest path for training. (default: %(default)s)")
parser.add_argument(
"--dev_manifest_path",
default='data/manifest.libri.dev-clean',
default='datasets/manifest.dev',
type=str,
help="Manifest path for validation. (default: %(default)s)")
parser.add_argument(
"--vocab_filepath",
default='data/eng_vocab.txt',
default='datasets/vocab/eng_vocab.txt',
type=str,
help="Vocabulary filepath. (default: %(default)s)")
parser.add_argument(
@ -86,6 +86,12 @@ parser.add_argument(
help="If set None, the training will start from scratch. "
"Otherwise, the training will resume from "
"the existing model of this path. (default: %(default)s)")
parser.add_argument(
"--augmentation_config",
default='{}',
type=str,
help="Augmentation configuration in json-format. "
"(default: %(default)s)")
args = parser.parse_args()
@ -98,29 +104,26 @@ def train():
def data_generator():
return DataGenerator(
vocab_filepath=args.vocab_filepath,
normalizer_manifest_path=args.normalizer_manifest_path,
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
mean_std_filepath=args.mean_std_filepath,
augmentation_config=args.augmentation_config)
train_generator = data_generator()
test_generator = data_generator()
# create network config
dict_size = train_generator.vocabulary_size()
# paddle.data_type.dense_array is used for variable batch input.
# the size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be set at each batch.
# The size 161 * 161 is only an placeholder value and the real shape
# of input batch data will be induced during training.
audio_data = paddle.layer.data(
name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size))
type=paddle.data_type.integer_value_sequence(
train_generator.vocab_size))
cost = deep_speech2(
audio_data=audio_data,
text_data=text_data,
dict_size=dict_size,
dict_size=train_generator.vocab_size,
num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers,
rnn_size=args.rnn_layer_size,
@ -143,13 +146,13 @@ def train():
train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest_path,
batch_size=args.batch_size,
sortagrad=True if args.init_model_path is None else False,
sortagrad=args.use_sortagrad if args.init_model_path is None else False,
batch_shuffle=True)
test_batch_reader = test_generator.batch_reader_creator(
manifest_path=args.dev_manifest_path,
batch_size=args.batch_size,
sortagrad=False,
batch_shuffle=False)
feeding = train_generator.data_name_feeding()
# create event handler
def event_handler(event):
@ -158,8 +161,8 @@ def train():
cost_sum += event.cost
cost_counter += 1
if event.batch_id % 50 == 0:
print "\nPass: %d, Batch: %d, TrainCost: %f" % (
event.pass_id, event.batch_id, cost_sum / cost_counter)
print("\nPass: %d, Batch: %d, TrainCost: %f" %
(event.pass_id, event.batch_id, cost_sum / cost_counter))
cost_sum, cost_counter = 0.0, 0
with gzip.open("params.tar.gz", 'w') as f:
parameters.to_tar(f)
@ -170,16 +173,17 @@ def train():
start_time = time.time()
cost_sum, cost_counter = 0.0, 0
if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_batch_reader, feeding=feeding)
print "\n------- Time: %d sec, Pass: %d, ValidationCost: %s" % (
time.time() - start_time, event.pass_id, result.cost)
result = trainer.test(
reader=test_batch_reader, feeding=test_generator.feeding)
print("\n------- Time: %d sec, Pass: %d, ValidationCost: %s" %
(time.time() - start_time, event.pass_id, result.cost))
# run train
trainer.train(
reader=train_batch_reader,
event_handler=event_handler,
num_passes=args.num_passes,
feeding=feeding)
feeding=train_generator.feeding)
def main():

Loading…
Cancel
Save