Refactor data utils into a class and add feature normalization.

pull/2/head
Xinghai Sun 8 years ago
parent 9c3cd3c704
commit e6a349992b

@ -1,5 +1,6 @@
""" """
Audio data preprocessing tools and reader creators. Providing basic audio data preprocessing pipeline, and offering
both instance-level and batch-level data reader interfaces.
""" """
import paddle.v2 as paddle import paddle.v2 as paddle
import logging import logging
@ -9,20 +10,127 @@ import soundfile
import numpy as np import numpy as np
import os import os
# TODO: add z-score normalization. RANDOM_SEED = 0
logger = logging.getLogger(__name__)
ENGLISH_CHAR_VOCAB_FILEPATH = "eng_vocab.txt"
logger = logging.getLogger(__name__) class DataGenerator(object):
"""
DataGenerator provides basic audio data preprocessing pipeline, and offer
both instance-level and batch-level data reader interfaces.
Normalized FFT are used as audio features here.
:param vocab_filepath: Vocabulary file path for indexing tokenized
transcriptions.
:type vocab_filepath: basestring
:param normalizer_manifest_path: Manifest filepath for collecting feature
normalization statistics, e.g. mean, std.
:type normalizer_manifest_path: basestring
:param normalizer_num_samples: Number of instances sampled for collecting
feature normalization statistics.
Default is 100.
:type normalizer_num_samples: int
:param max_duration: Audio clips with duration (in seconds) greater than
this will be discarded. Default is 20.0.
:type max_duration: float
:param min_duration: Audio clips with duration (in seconds) smaller than
this will be discarded. Default is 0.0.
:type min_duration: float
:param stride_ms: Striding size (in milliseconds) for generating frames.
Default is 10.0.
:type stride_ms: float
:param window_ms: Window size (in milliseconds) for frames. Default is 20.0.
:type window_ms: float
:param max_frequency: Maximun frequency for FFT features. FFT features of
frequency larger than this will be discarded.
If set None, all features will be kept.
Default is None.
:type max_frequency: float
"""
def __init__(self,
vocab_filepath,
normalizer_manifest_path,
normalizer_num_samples=100,
max_duration=20.0,
min_duration=0.0,
stride_ms=10.0,
window_ms=20.0,
max_frequency=None):
self.__max_duration__ = max_duration
self.__min_duration__ = min_duration
self.__stride_ms__ = stride_ms
self.__window_ms__ = window_ms
self.__max_frequency__ = max_frequency
self.__random__ = random.Random(RANDOM_SEED)
# load vocabulary (dictionary)
self.__vocab_dict__, self.__vocab_list__ = \
self.__load_vocabulary_from_file__(vocab_filepath)
# collect normalizer statistics
self.__mean__, self.__std__ = self.__collect_normalizer_statistics__(
manifest_path=normalizer_manifest_path,
num_samples=normalizer_num_samples)
def __audio_featurize__(self, audio_filename):
"""
Preprocess audio data, including feature extraction, normalization etc..
"""
features = self.__audio_basic_featurize__(audio_filename)
return self.__normalize__(features)
def __text_featurize__(self, text):
"""
Preprocess text data, including tokenizing and token indexing etc..
"""
return self.__convert_text_to_char_index__(
text=text, vocabulary=self.__vocab_dict__)
def __audio_basic_featurize__(self, audio_filename):
"""
Compute basic (without normalization etc.) features for audio data.
"""
return self.__spectrogram_from_file__(
filename=audio_filename,
stride_ms=self.__stride_ms__,
window_ms=self.__window_ms__,
max_freq=self.__max_frequency__)
def __collect_normalizer_statistics__(self, manifest_path, num_samples=100):
"""
Compute feature normalization statistics, i.e. mean and stddev.
"""
# read manifest
manifest = self.__read_manifest__(
manifest_path=manifest_path,
max_duration=self.__max_duration__,
min_duration=self.__min_duration__)
# sample for statistics
sampled_manifest = self.__random__.sample(manifest, num_samples)
# extract spectrogram feature
features = []
for instance in sampled_manifest:
spectrogram = self.__audio_basic_featurize__(
instance["audio_filepath"])
features.append(spectrogram)
features = np.hstack(features)
mean = np.mean(features, axis=1).reshape([-1, 1])
std = np.std(features, axis=1).reshape([-1, 1])
return mean, std
def spectrogram_from_file(filename, def __normalize__(self, features, eps=1e-14):
stride_ms=10, """
window_ms=20, Normalize features to be of zero mean and unit stddev.
"""
return (features - self.__mean__) / (self.__std__ + eps)
def __spectrogram_from_file__(self,
filename,
stride_ms=10.0,
window_ms=20.0,
max_freq=None, max_freq=None,
eps=1e-14): eps=1e-14):
""" """
Calculate the log of linear spectrogram from FFT energy Laod audio data and calculate the log of spectrogram by FFT.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
""" """
audio, sample_rate = soundfile.read(filename) audio, sample_rate = soundfile.read(filename)
@ -34,10 +142,11 @@ def spectrogram_from_file(filename,
raise ValueError("max_freq must be greater than half of " raise ValueError("max_freq must be greater than half of "
"sample rate.") "sample rate.")
if stride_ms > window_ms: if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than window size.") raise ValueError("Stride size must not be greater than "
"window size.")
stride_size = int(0.001 * sample_rate * stride_ms) stride_size = int(0.001 * sample_rate * stride_ms)
window_size = int(0.001 * sample_rate * window_ms) window_size = int(0.001 * sample_rate * window_ms)
spectrogram, freqs = extract_spectrogram( spectrogram, freqs = self.__extract_spectrogram__(
audio, audio,
window_size=window_size, window_size=window_size,
stride_size=stride_size, stride_size=stride_size,
@ -45,10 +154,10 @@ def spectrogram_from_file(filename,
ind = np.where(freqs <= max_freq)[0][-1] + 1 ind = np.where(freqs <= max_freq)[0][-1] + 1
return np.log(spectrogram[:ind, :] + eps) return np.log(spectrogram[:ind, :] + eps)
def __extract_spectrogram__(self, samples, window_size, stride_size,
def extract_spectrogram(samples, window_size, stride_size, sample_rate): sample_rate):
""" """
Compute the spectrogram for a real discrete signal. Compute the spectrogram by FFT for a discrete real signal.
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
""" """
# extract strided windows # extract strided windows
@ -60,7 +169,7 @@ def extract_spectrogram(samples, window_size, stride_size, sample_rate):
samples, shape=nshape, strides=nstrides) samples, shape=nshape, strides=nstrides)
assert np.all( assert np.all(
windows[:, 1] == samples[stride_size:(stride_size + window_size)]) windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# window weighting, compute squared Fast Fourier Transform (fft), scaling # window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None] weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, axis=0) fft = np.fft.rfft(windows * weighting, axis=0)
fft = np.absolute(fft)**2 fft = np.absolute(fft)**2
@ -71,12 +180,12 @@ def extract_spectrogram(samples, window_size, stride_size, sample_rate):
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
return fft, freqs return fft, freqs
def __load_vocabulary_from_file__(self, vocabulary_path):
def vocabulary_from_file(vocabulary_path):
""" """
Load vocabulary from file. Load vocabulary from file.
""" """
if os.path.exists(vocabulary_path): if not os.path.exists(vocabulary_path):
raise ValueError("Vocabulary file %s not found.", vocabulary_path)
vocab_lines = [] vocab_lines = []
with open(vocabulary_path, 'r') as file: with open(vocabulary_path, 'r') as file:
vocab_lines.extend(file.readlines()) vocab_lines.extend(file.readlines())
@ -84,56 +193,76 @@ def vocabulary_from_file(vocabulary_path):
vocab_dict = dict( vocab_dict = dict(
[(token, id) for (id, token) in enumerate(vocab_list)]) [(token, id) for (id, token) in enumerate(vocab_list)])
return vocab_dict, vocab_list return vocab_dict, vocab_list
else:
raise ValueError("Vocabulary file %s not found.", vocabulary_path)
def __convert_text_to_char_index__(self, text, vocabulary):
def get_vocabulary_size():
""" """
Get vocabulary size. Convert text string to a list of character index integers.
""" """
vocab_dict, _ = vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH) return [vocabulary[w] for w in text]
return len(vocab_dict)
def get_vocabulary(): def __read_manifest__(self, manifest_path, max_duration, min_duration):
""" """
Get vocabulary. Load and parse manifest file.
""" """
return vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH) manifest = []
for json_line in open(manifest_path):
try:
json_data = json.loads(json_line)
except Exception as e:
raise ValueError("Error reading manifest: %s" % str(e))
if (json_data["duration"] <= max_duration and
json_data["duration"] >= min_duration):
manifest.append(json_data)
return manifest
def parse_transcript(text, vocabulary): def __padding_batch__(self, batch, padding_to=-1, flatten=False):
""" """
Convert the transcript text string to list of token index integers. Padding audio part of features (only in the time axis -- column axis)
""" with zeros, to make each instance in the batch share the same
return [vocabulary[w] for w in text] audio feature shape.
If `padding_to` is set -1, the maximun column numbers in the batch will
be used as the target size. Otherwise, `padding_to` will be the target
size. Default is -1.
If `flatten` is set True, audio data will be flatten to be a 1-dim
ndarray. Default is False.
"""
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, text in batch])
if padding_to != -1:
if padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be greater"
" or equal to the original instance length.")
max_length = padding_to
# padding
for audio, text in batch:
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
new_batch.append((padded_audio, text))
return new_batch
def reader_creator(manifest_path, def instance_reader_creator(self,
manifest_path,
sort_by_duration=True, sort_by_duration=True,
shuffle=False, shuffle=False):
max_duration=10.0,
min_duration=0.0):
""" """
Audio data reader creator. Instance reader creator for audio data. Creat a callable function to
produce instances of data.
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
tokenized transcription text. tokenized and indexed transcription text.
:param manifest_path: Filepath for Manifest of audio clip files. :param manifest_path: Filepath of manifest for audio clip files.
:type manifest_path: basestring :type manifest_path: basestring
:param sort_by_duration: Sort the audio clips by duration if set True. :param sort_by_duration: Sort the audio clips by duration if set True
For SortaGrad. (for SortaGrad).
:type sort_by_duration: bool :type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True. :param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool :type shuffle: bool
:param max_duration: Audio clips with duration (in seconds) greater than
this will be discarded.
:type max_duration: float
:param min_duration: Audio clips with duration (in seconds) smaller than
this will be discarded.
:type min_duration: float
:return: Data reader function. :return: Data reader function.
:rtype: callable :rtype: callable
""" """
@ -141,75 +270,114 @@ def reader_creator(manifest_path,
sort_by_duration = False sort_by_duration = False
logger.warn("When shuffle set to true, " logger.warn("When shuffle set to true, "
"sort_by_duration is forced to set False.") "sort_by_duration is forced to set False.")
vocab_dict, _ = vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH)
def reader(): def reader():
# read manifest # read manifest
manifest_data = [] manifest = self.__read_manifest__(
for json_line in open(manifest_path): manifest_path=manifest_path,
try: max_duration=self.__max_duration__,
json_data = json.loads(json_line) min_duration=self.__min_duration__)
except Exception as e:
raise ValueError("Error reading manifest: %s" % str(e))
if (json_data["duration"] <= max_duration and
json_data["duration"] >= min_duration):
manifest_data.append(json_data)
# sort (by duration) or shuffle manifest # sort (by duration) or shuffle manifest
if sort_by_duration: if sort_by_duration:
manifest_data.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
if shuffle: if shuffle:
random.shuffle(manifest_data) self.__random__.shuffle(manifest)
# extract spectrogram feature # extract spectrogram feature
for instance in manifest_data: for instance in manifest:
spectrogram = spectrogram_from_file(instance["audio_filepath"]) spectrogram = self.__audio_featurize__(
text = parse_transcript(instance["text"], vocab_dict) instance["audio_filepath"])
yield (spectrogram, text) transcript = self.__text_featurize__(instance["text"])
yield (spectrogram, transcript)
return reader return reader
def batch_reader_creator(self,
def padding_batch_reader(batch_reader, padding=[-1, -1], flatten=True): manifest_path,
batch_size,
padding_to=-1,
flatten=False,
sort_by_duration=True,
shuffle=False):
""" """
Padding for batches. Return a batch reader. Batch data reader creator for audio data. Creat a callable function to
produce batches of data.
Each instance in a batch will be padded to be of a same target shape. Audio features will be padded with zeros to make each instance in the
The target shape is the largest shape among all the batch instances and batch to share the same audio feature shape.
'padding' argument. Therefore, if padding is set [-1, -1], instance will be
padded to have the same shape just within each batch and the shape will
be different across batches; if padding is set
[VERY_LARGE_NUM, VERY_LARGE_NUM], instances in all batches will be padded to
have the same shape of [VERY_LARGE_NUM, VERY_LARGE_NUM].
:param batch_reader: Input batch reader. :param manifest_path: Filepath of manifest for audio clip files.
:type batch_reader: callable :type manifest_path: basestring
:param padding: Padding pattern. Details please refer to the above. :param batch_size: Instance number in a batch.
:type padding: list :type batch_size: int
:param flatten: Flatten the tensor to be one dimension. :param padding_to: If set -1, the maximun column numbers in the batch
will be used as the target size for padding.
Otherwise, `padding_to` will be the target size.
Default is -1.
:type padding_to: int
:param flatten: If set True, audio data will be flatten to be a 1-dim
ndarray. Otherwise, 2-dim ndarray. Default is False.
:type flatten: bool :type flatten: bool
:return: Batch reader function. :param sort_by_duration: Sort the audio clips by duration if set True
(for SortaGrad).
:type sort_by_duration: bool
:param shuffle: Shuffle the audio clips if set True.
:type shuffle: bool
:return: Batch reader function, producing batches of data when called.
:rtype: callable :rtype: callable
""" """
def padding_batch(batch): def batch_reader():
new_batch = [] instance_reader = self.instance_reader_creator(
# get target shape within batch manifest_path=manifest_path,
nshape_list = [padding] sort_by_duration=sort_by_duration,
for audio, text in batch: shuffle=shuffle)
nshape_list.append(audio.shape) batch = []
target_shape = np.array(nshape_list).max(axis=0) for instance in instance_reader():
# padding batch.append(instance)
for audio, text in batch: if len(batch) == batch_size:
pad_shape = target_shape - audio.shape yield self.__padding_batch__(batch, padding_to, flatten)
assert np.all(pad_shape >= 0) batch = []
padded_audio = np.pad( if len(batch) > 0:
audio, [(0, pad_shape[0]), (0, pad_shape[1])], mode="constant") yield self.__padding_batch__(batch, padding_to, flatten)
if flatten:
padded_audio = padded_audio.flatten() return batch_reader
new_batch.append((padded_audio, text))
return new_batch def vocabulary_size(self):
"""
Get vocabulary size.
def new_batch_reader(): :return: Vocabulary size.
for batch in batch_reader(): :rtype: int
yield padding_batch(batch) """
return len(self.__vocab_list__)
return new_batch_reader def vocabulary_dict(self):
"""
Get vocabulary in dict.
:return: Vocabulary in dict.
:rtype: dict
"""
return self.__vocab_dict__
def vocabulary_list(self):
"""
Get vocabulary in list.
:return: Vocabulary in list
:rtype: list
"""
return self.__vocab_list__
def data_name_feeding(self):
"""
Get feeddings (data field name and corresponding field id).
:return: Feeding dict.
:rtype: dict
"""
feeding = {
"audio_spectrogram": 0,
"transcript_text": 1,
}
return feeding

@ -5,16 +5,18 @@
import paddle.v2 as paddle import paddle.v2 as paddle
import argparse import argparse
import gzip import gzip
import time
import sys import sys
from model import deep_speech2 from model import deep_speech2
import audio_data_utils from audio_data_utils import DataGenerator
import numpy as np
#TODO: add WER metric #TODO: add WER metric
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Simplified version of DeepSpeech2 trainer.') description='Simplified version of DeepSpeech2 trainer.')
parser.add_argument( parser.add_argument(
"--batch_size", default=512, type=int, help="Minibatch size.") "--batch_size", default=32, type=int, help="Minibatch size.")
parser.add_argument("--trainer", default=1, type=int, help="Trainer number.") parser.add_argument("--trainer", default=1, type=int, help="Trainer number.")
parser.add_argument( parser.add_argument(
"--num_passes", default=20, type=int, help="Training pass number.") "--num_passes", default=20, type=int, help="Training pass number.")
@ -23,7 +25,7 @@ parser.add_argument(
parser.add_argument( parser.add_argument(
"--num_rnn_layers", default=5, type=int, help="RNN layer number.") "--num_rnn_layers", default=5, type=int, help="RNN layer number.")
parser.add_argument( parser.add_argument(
"--rnn_layer_size", default=256, type=int, help="RNN layer cell number.") "--rnn_layer_size", default=512, type=int, help="RNN layer cell number.")
parser.add_argument( parser.add_argument(
"--use_gpu", default=True, type=bool, help="Use gpu or not.") "--use_gpu", default=True, type=bool, help="Use gpu or not.")
parser.add_argument( parser.add_argument(
@ -37,13 +39,45 @@ def train():
""" """
DeepSpeech2 training. DeepSpeech2 training.
""" """
# create data readers
data_generator = DataGenerator(
vocab_filepath='eng_vocab.txt',
normalizer_manifest_path='./libri.manifest.train',
normalizer_num_samples=200,
max_duration=20.0,
min_duration=0.0,
stride_ms=10,
window_ms=20)
train_batch_reader_sortagrad = data_generator.batch_reader_creator(
manifest_path='./libri.manifest.dev.small',
batch_size=args.batch_size // args.trainer,
padding_to=2000,
flatten=True,
sort_by_duration=True,
shuffle=False)
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
manifest_path='./libri.manifest.dev.small',
batch_size=args.batch_size // args.trainer,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=True)
test_batch_reader = data_generator.batch_reader_creator(
manifest_path='./libri.manifest.test',
batch_size=args.batch_size // args.trainer,
padding_to=2000,
flatten=True,
sort_by_duration=False,
shuffle=False)
feeding = data_generator.data_name_feeding()
# create network config # create network config
dict_size = audio_data_utils.get_vocabulary_size() dict_size = data_generator.vocabulary_size()
audio_data = paddle.layer.data( audio_data = paddle.layer.data(
name="audio_spectrogram", name="audio_spectrogram",
height=161, height=161,
width=1000, width=2000,
type=paddle.data_type.dense_vector(161000)) type=paddle.data_type.dense_vector(322000))
text_data = paddle.layer.data( text_data = paddle.layer.data(
name="transcript_text", name="transcript_text",
type=paddle.data_type.integer_value_sequence(dict_size)) type=paddle.data_type.integer_value_sequence(dict_size))
@ -58,47 +92,26 @@ def train():
# create parameters and optimizer # create parameters and optimizer
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=5e-4, gradient_clipping_threshold=400) learning_rate=5e-5, gradient_clipping_threshold=400)
trainer = paddle.trainer.SGD( trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=optimizer) cost=cost, parameters=parameters, update_equation=optimizer)
# create data readers
feeding = {
"audio_spectrogram": 0,
"transcript_text": 1,
}
train_batch_reader_with_sortagrad = audio_data_utils.padding_batch_reader(
paddle.batch(
audio_data_utils.reader_creator(
manifest_path="./libri.manifest.train", sort_by_duration=True),
batch_size=args.batch_size // args.trainer),
padding=[-1, 1000])
train_batch_reader_without_sortagrad = audio_data_utils.padding_batch_reader(
paddle.batch(
audio_data_utils.reader_creator(
manifest_path="./libri.manifest.train",
sort_by_duration=False,
shuffle=True),
batch_size=args.batch_size // args.trainer),
padding=[-1, 1000])
test_batch_reader = audio_data_utils.padding_batch_reader(
paddle.batch(
audio_data_utils.reader_creator(
manifest_path="./libri.manifest.dev", sort_by_duration=False),
batch_size=args.batch_size // args.trainer),
padding=[-1, 1000])
# create event handler # create event handler
def event_handler(event): def event_handler(event):
global start_time
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0: if event.batch_id % 10 == 0:
print "/nPass: %d, Batch: %d, TrainCost: %f" % ( print "\nPass: %d, Batch: %d, TrainCost: %f" % (
event.pass_id, event.batch_id, event.cost) event.pass_id, event.batch_id, event.cost)
else: else:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
if isinstance(event, paddle.event.BeginPass):
start_time = time.time()
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_batch_reader, feeding=feeding) result = trainer.test(reader=test_batch_reader, feeding=feeding)
print "Pass: %d, TestCost: %s" % (event.pass_id, result.cost) print "\n------- Time: %d, Pass: %d, TestCost: %s" % (
time.time() - start_time, event.pass_id, result.cost)
with gzip.open("params.tar.gz", 'w') as f: with gzip.open("params.tar.gz", 'w') as f:
parameters.to_tar(f) parameters.to_tar(f)
@ -106,14 +119,14 @@ def train():
# first pass with sortagrad # first pass with sortagrad
if args.use_sortagrad: if args.use_sortagrad:
trainer.train( trainer.train(
reader=train_batch_reader_with_sortagrad, reader=train_batch_reader_sortagrad,
event_handler=event_handler, event_handler=event_handler,
num_passes=1, num_passes=1,
feeding=feeding) feeding=feeding)
args.num_passes -= 1 args.num_passes -= 1
# other passes without sortagrad # other passes without sortagrad
trainer.train( trainer.train(
reader=train_batch_reader_without_sortagrad, reader=train_batch_reader_nosortagrad,
event_handler=event_handler, event_handler=event_handler,
num_passes=args.num_passes, num_passes=args.num_passes,
feeding=feeding) feeding=feeding)

Loading…
Cancel
Save