Bug exists when run training.pull/2/head
parent
d59b8ca97e
commit
3fc94427db
@ -1 +1,7 @@
|
|||||||
TBD
|
# Deep Speech 2 on PaddlePaddle
|
||||||
|
|
||||||
|
```
|
||||||
|
sh requirements.sh
|
||||||
|
python librispeech.py
|
||||||
|
python train.py
|
||||||
|
```
|
||||||
|
@ -0,0 +1,159 @@
|
|||||||
|
import paddle.v2 as paddle
|
||||||
|
import logging
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import soundfile
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
# TODO: add z-score normalization.
|
||||||
|
|
||||||
|
ENGLISH_CHAR_VOCAB_FILEPATH = "eng_vocab.txt"
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram_from_file(filename,
|
||||||
|
stride_ms=10,
|
||||||
|
window_ms=20,
|
||||||
|
max_freq=None,
|
||||||
|
eps=1e-14):
|
||||||
|
"""
|
||||||
|
Calculate the log of linear spectrogram from FFT energy
|
||||||
|
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
|
||||||
|
"""
|
||||||
|
audio, sample_rate = soundfile.read(filename)
|
||||||
|
if audio.ndim >= 2:
|
||||||
|
audio = np.mean(audio, 1)
|
||||||
|
if max_freq is None:
|
||||||
|
max_freq = sample_rate / 2
|
||||||
|
if max_freq > sample_rate / 2:
|
||||||
|
raise ValueError("max_freq must be greater than half of "
|
||||||
|
"sample rate.")
|
||||||
|
if stride_ms > window_ms:
|
||||||
|
raise ValueError("Stride size must not be greater than window size.")
|
||||||
|
stride_size = int(0.001 * sample_rate * stride_ms)
|
||||||
|
window_size = int(0.001 * sample_rate * window_ms)
|
||||||
|
spectrogram, freqs = extract_spectrogram(
|
||||||
|
audio,
|
||||||
|
window_size=window_size,
|
||||||
|
stride_size=stride_size,
|
||||||
|
sample_rate=sample_rate)
|
||||||
|
ind = np.where(freqs <= max_freq)[0][-1] + 1
|
||||||
|
return np.log(spectrogram[:ind, :] + eps)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_spectrogram(samples, window_size, stride_size, sample_rate):
|
||||||
|
"""
|
||||||
|
Compute the spectrogram for a real discrete signal.
|
||||||
|
Refer to utils.py in https://github.com/baidu-research/ba-dls-deepspeech
|
||||||
|
"""
|
||||||
|
# extract strided windows
|
||||||
|
truncate_size = (len(samples) - window_size) % stride_size
|
||||||
|
samples = samples[:len(samples) - truncate_size]
|
||||||
|
nshape = (window_size, (len(samples) - window_size) // stride_size + 1)
|
||||||
|
nstrides = (samples.strides[0], samples.strides[0] * stride_size)
|
||||||
|
windows = np.lib.stride_tricks.as_strided(
|
||||||
|
samples, shape=nshape, strides=nstrides)
|
||||||
|
assert np.all(
|
||||||
|
windows[:, 1] == samples[stride_size:(stride_size + window_size)])
|
||||||
|
# window weighting, compute squared Fast Fourier Transform (fft), scaling
|
||||||
|
weighting = np.hanning(window_size)[:, None]
|
||||||
|
fft = np.fft.rfft(windows * weighting, axis=0)
|
||||||
|
fft = np.absolute(fft)**2
|
||||||
|
scale = np.sum(weighting**2) * sample_rate
|
||||||
|
fft[1:-1, :] *= (2.0 / scale)
|
||||||
|
fft[(0, -1), :] /= scale
|
||||||
|
# prepare fft frequency list
|
||||||
|
freqs = float(sample_rate) / window_size * np.arange(fft.shape[0])
|
||||||
|
return fft, freqs
|
||||||
|
|
||||||
|
|
||||||
|
def vocabulary_from_file(vocabulary_path):
|
||||||
|
"""
|
||||||
|
Load vocabulary from file.
|
||||||
|
"""
|
||||||
|
if os.path.exists(vocabulary_path):
|
||||||
|
vocab_lines = []
|
||||||
|
with open(vocabulary_path, 'r') as file:
|
||||||
|
vocab_lines.extend(file.readlines())
|
||||||
|
vocab_list = [line[:-1] for line in vocab_lines]
|
||||||
|
vocab_dict = dict(
|
||||||
|
[(token, id) for (id, token) in enumerate(vocab_list)])
|
||||||
|
return vocab_dict, vocab_list
|
||||||
|
else:
|
||||||
|
raise ValueError("Vocabulary file %s not found.", vocabulary_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_vocabulary_size():
|
||||||
|
vocab_dict, _ = vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH)
|
||||||
|
return len(vocab_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_transcript(text, vocabulary):
|
||||||
|
"""
|
||||||
|
Convert the transcript text string to list of token index integers..
|
||||||
|
"""
|
||||||
|
return [vocabulary[w] for w in text]
|
||||||
|
|
||||||
|
|
||||||
|
def reader_creator(manifest_path,
|
||||||
|
sort_by_duration=True,
|
||||||
|
shuffle=False,
|
||||||
|
max_duration=10.0,
|
||||||
|
min_duration=0.0):
|
||||||
|
if sort_by_duration and shuffle:
|
||||||
|
sort_by_duration = False
|
||||||
|
logger.warn("When shuffle set to true, "
|
||||||
|
"sort_by_duration is forced to set False.")
|
||||||
|
vocab_dict, _ = vocabulary_from_file(ENGLISH_CHAR_VOCAB_FILEPATH)
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
# read manifest
|
||||||
|
manifest_data = []
|
||||||
|
for json_line in open(manifest_path):
|
||||||
|
try:
|
||||||
|
json_data = json.loads(json_line)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("Error reading manifest: %s" % str(e))
|
||||||
|
if (json_data["duration"] <= max_duration and
|
||||||
|
json_data["duration"] >= min_duration):
|
||||||
|
manifest_data.append(json_data)
|
||||||
|
# sort (by duration) or shuffle manifest
|
||||||
|
if sort_by_duration:
|
||||||
|
manifest_data.sort(key=lambda x: x["duration"])
|
||||||
|
if shuffle:
|
||||||
|
random.shuffle(manifest_data)
|
||||||
|
# extract spectrogram feature
|
||||||
|
for instance in manifest_data:
|
||||||
|
spectrogram = spectrogram_from_file(instance["audio_filepath"])
|
||||||
|
text = parse_transcript(instance["text"], vocab_dict)
|
||||||
|
yield (spectrogram, text)
|
||||||
|
|
||||||
|
return reader
|
||||||
|
|
||||||
|
|
||||||
|
def padding_batch_reader(batch_reader, padding=[-1, -1], flatten=True):
|
||||||
|
def padding_batch(batch):
|
||||||
|
new_batch = []
|
||||||
|
# get target shape within batch
|
||||||
|
nshape_list = [padding]
|
||||||
|
for audio, text in batch:
|
||||||
|
nshape_list.append(audio.shape)
|
||||||
|
target_shape = np.array(nshape_list).max(axis=0)
|
||||||
|
# padding
|
||||||
|
for audio, text in batch:
|
||||||
|
pad_shape = target_shape - audio.shape
|
||||||
|
assert np.all(pad_shape >= 0)
|
||||||
|
padded_audio = np.pad(
|
||||||
|
audio, [(0, pad_shape[0]), (0, pad_shape[1])], mode="constant")
|
||||||
|
if flatten:
|
||||||
|
padded_audio = padded_audio.flatten()
|
||||||
|
new_batch.append((padded_audio, text))
|
||||||
|
return new_batch
|
||||||
|
|
||||||
|
def new_batch_reader():
|
||||||
|
for batch in batch_reader():
|
||||||
|
yield padding_batch(batch)
|
||||||
|
|
||||||
|
return new_batch_reader
|
@ -0,0 +1,28 @@
|
|||||||
|
'
|
||||||
|
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
d
|
||||||
|
e
|
||||||
|
f
|
||||||
|
g
|
||||||
|
h
|
||||||
|
i
|
||||||
|
j
|
||||||
|
k
|
||||||
|
l
|
||||||
|
m
|
||||||
|
n
|
||||||
|
o
|
||||||
|
p
|
||||||
|
q
|
||||||
|
r
|
||||||
|
s
|
||||||
|
t
|
||||||
|
u
|
||||||
|
v
|
||||||
|
w
|
||||||
|
x
|
||||||
|
y
|
||||||
|
z
|
@ -0,0 +1,97 @@
|
|||||||
|
import paddle.v2 as paddle
|
||||||
|
import os
|
||||||
|
import wget
|
||||||
|
import tarfile
|
||||||
|
import argparse
|
||||||
|
import soundfile
|
||||||
|
import json
|
||||||
|
|
||||||
|
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
|
||||||
|
|
||||||
|
URL_TEST = "http://www.openslr.org/resources/12/test-clean.tar.gz"
|
||||||
|
URL_DEV = "http://www.openslr.org/resources/12/dev-clean.tar.gz"
|
||||||
|
URL_TRAIN = "http://www.openslr.org/resources/12/train-clean-100.tar.gz"
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Downloads and prepare LibriSpeech dataset.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_dir",
|
||||||
|
default=DATA_HOME + "/Libri",
|
||||||
|
type=str,
|
||||||
|
help="Directory to save the dataset.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest",
|
||||||
|
default="./libri.manifest",
|
||||||
|
type=str,
|
||||||
|
help="Filepath prefix of output manifests.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def download(url, target_dir):
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
os.makedirs(target_dir)
|
||||||
|
filepath = os.path.join(target_dir, url.split("/")[-1])
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
print("Downloading %s ..." % url)
|
||||||
|
wget.download(url, target_dir)
|
||||||
|
print("")
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
|
def unpack(filepath, target_dir):
|
||||||
|
print("Unpacking %s ..." % filepath)
|
||||||
|
tar = tarfile.open(filepath)
|
||||||
|
tar.extractall(target_dir)
|
||||||
|
tar.close()
|
||||||
|
return target_dir
|
||||||
|
|
||||||
|
|
||||||
|
def create_manifest(data_dir, manifest_path):
|
||||||
|
print("Creating manifest %s ..." % manifest_path)
|
||||||
|
json_lines = []
|
||||||
|
for subfolder, _, filelist in os.walk(data_dir):
|
||||||
|
text_filelist = [
|
||||||
|
filename for filename in filelist if filename.endswith('trans.txt')
|
||||||
|
]
|
||||||
|
if len(text_filelist) > 0:
|
||||||
|
text_filepath = os.path.join(data_dir, subfolder, text_filelist[0])
|
||||||
|
for line in open(text_filepath):
|
||||||
|
segments = line.strip().split()
|
||||||
|
text = ' '.join(segments[1:]).lower()
|
||||||
|
audio_filepath = os.path.join(data_dir, subfolder,
|
||||||
|
segments[0] + '.flac')
|
||||||
|
audio_data, samplerate = soundfile.read(audio_filepath)
|
||||||
|
duration = float(len(audio_data)) / samplerate
|
||||||
|
json_lines.append(
|
||||||
|
json.dumps({
|
||||||
|
'audio_filepath': audio_filepath,
|
||||||
|
'duration': duration,
|
||||||
|
'text': text
|
||||||
|
}))
|
||||||
|
with open(manifest_path, 'w') as out_file:
|
||||||
|
for line in json_lines:
|
||||||
|
out_file.write(line + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(url, target_dir, manifest_path):
|
||||||
|
filepath = download(url, target_dir)
|
||||||
|
unpacked_dir = unpack(filepath, target_dir)
|
||||||
|
create_manifest(unpacked_dir, manifest_path)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
prepare_dataset(
|
||||||
|
url=URL_TEST,
|
||||||
|
target_dir=os.path.join(args.target_dir),
|
||||||
|
manifest_path=args.manifest + ".test")
|
||||||
|
prepare_dataset(
|
||||||
|
url=URL_DEV,
|
||||||
|
target_dir=os.path.join(args.target_dir),
|
||||||
|
manifest_path=args.manifest + ".dev")
|
||||||
|
#prepare_dataset(url=URL_TRAIN,
|
||||||
|
#target_dir=os.path.join(args.target_dir),
|
||||||
|
#manifest_path=args.manifest + ".train")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,5 @@
|
|||||||
|
pip install wget
|
||||||
|
pip install soundfile
|
||||||
|
|
||||||
|
# For Linux only
|
||||||
|
apt-get install libsndfile1
|
@ -0,0 +1,188 @@
|
|||||||
|
import paddle.v2 as paddle
|
||||||
|
import audio_data_utils
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Simpled version of DeepSpeech2 trainer.')
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch_size", default=512, type=int, help="Minibatch size.")
|
||||||
|
parser.add_argument("--trainer", default=1, type=int, help="Trainer number.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_passes", default=20, type=int, help="Training pass number.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
|
||||||
|
padding, act):
|
||||||
|
conv_layer = paddle.layer.img_conv(
|
||||||
|
input=input,
|
||||||
|
filter_size=filter_size,
|
||||||
|
num_channels=num_channels_in,
|
||||||
|
num_filters=num_channels_out,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=False)
|
||||||
|
return paddle.layer.batch_norm(input=conv_layer, act=act)
|
||||||
|
|
||||||
|
|
||||||
|
def bidirectonal_simple_rnn_bn_layer(name, input, size, act):
|
||||||
|
def __simple_rnn_step__(input):
|
||||||
|
last_state = paddle.layer.memory(name=name + "_state", size=size)
|
||||||
|
input_fc = paddle.layer.fc(
|
||||||
|
input=input,
|
||||||
|
size=size,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=False)
|
||||||
|
input_fc_bn = paddle.layer.batch_norm(
|
||||||
|
input=input_fc, act=paddle.activation.Linear())
|
||||||
|
state_fc = paddle.layer.fc(
|
||||||
|
input=last_state,
|
||||||
|
size=size,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=False)
|
||||||
|
return paddle.layer.addto(
|
||||||
|
name=name + "_state", input=[input_fc_bn, state_fc], act=act)
|
||||||
|
|
||||||
|
forward = paddle.layer.recurrent_group(
|
||||||
|
step=__simple_rnn_step__, input=input)
|
||||||
|
return forward
|
||||||
|
# argument reverse is not exposed in V2 recurrent_group
|
||||||
|
#backward = paddle.layer.recurrent_group(
|
||||||
|
|
||||||
|
|
||||||
|
#step=__simple_rnn_step__,
|
||||||
|
#input=input,
|
||||||
|
#reverse=True)
|
||||||
|
#return paddle.layer.concat(input=[forward, backward])
|
||||||
|
|
||||||
|
|
||||||
|
def conv_group(input):
|
||||||
|
conv1 = conv_bn_layer(
|
||||||
|
input=input,
|
||||||
|
filter_size=(11, 41),
|
||||||
|
num_channels_in=1,
|
||||||
|
num_channels_out=32,
|
||||||
|
stride=(3, 2),
|
||||||
|
padding=(5, 20),
|
||||||
|
act=paddle.activation.BRelu())
|
||||||
|
conv2 = conv_bn_layer(
|
||||||
|
input=conv1,
|
||||||
|
filter_size=(11, 21),
|
||||||
|
num_channels_in=32,
|
||||||
|
num_channels_out=32,
|
||||||
|
stride=(1, 2),
|
||||||
|
padding=(5, 10),
|
||||||
|
act=paddle.activation.BRelu())
|
||||||
|
conv3 = conv_bn_layer(
|
||||||
|
input=conv2,
|
||||||
|
filter_size=(11, 21),
|
||||||
|
num_channels_in=32,
|
||||||
|
num_channels_out=32,
|
||||||
|
stride=(1, 2),
|
||||||
|
padding=(5, 10),
|
||||||
|
act=paddle.activation.BRelu())
|
||||||
|
return conv3
|
||||||
|
|
||||||
|
|
||||||
|
def rnn_group(input, size, num_stacks):
|
||||||
|
output = input
|
||||||
|
for i in xrange(num_stacks):
|
||||||
|
output = bidirectonal_simple_rnn_bn_layer(
|
||||||
|
name=str(i), input=output, size=size, act=paddle.activation.BRelu())
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def deep_speech2(audio_data, text_data, dict_size):
|
||||||
|
conv_group_output = conv_group(input=audio_data)
|
||||||
|
conv2seq = paddle.layer.block_expand(
|
||||||
|
input=conv_group_output,
|
||||||
|
num_channels=32,
|
||||||
|
stride_x=1,
|
||||||
|
stride_y=1,
|
||||||
|
block_x=1,
|
||||||
|
block_y=21)
|
||||||
|
rnn_group_output = rnn_group(input=conv2seq, size=256, num_stacks=5)
|
||||||
|
fc = paddle.layer.fc(
|
||||||
|
input=rnn_group_output,
|
||||||
|
size=dict_size + 1,
|
||||||
|
act=paddle.activation.Linear(),
|
||||||
|
bias_attr=True)
|
||||||
|
cost = paddle.layer.warp_ctc(
|
||||||
|
input=fc,
|
||||||
|
label=text_data,
|
||||||
|
size=dict_size + 1,
|
||||||
|
blank=dict_size,
|
||||||
|
norm_by_times=True)
|
||||||
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
def train():
|
||||||
|
# create network config
|
||||||
|
dict_size = audio_data_utils.get_vocabulary_size()
|
||||||
|
audio_data = paddle.layer.data(
|
||||||
|
name="audio_spectrogram",
|
||||||
|
height=161,
|
||||||
|
width=1000,
|
||||||
|
type=paddle.data_type.dense_vector(161000))
|
||||||
|
text_data = paddle.layer.data(
|
||||||
|
name="transcript_text",
|
||||||
|
type=paddle.data_type.integer_value_sequence(dict_size))
|
||||||
|
cost = deep_speech2(audio_data, text_data, dict_size)
|
||||||
|
|
||||||
|
# create parameters and optimizer
|
||||||
|
parameters = paddle.parameters.create(cost)
|
||||||
|
optimizer = paddle.optimizer.Adam(
|
||||||
|
learning_rate=5e-5,
|
||||||
|
gradient_clipping_threshold=5,
|
||||||
|
regularization=paddle.optimizer.L2Regularization(rate=8e-4))
|
||||||
|
trainer = paddle.trainer.SGD(
|
||||||
|
cost=cost, parameters=parameters, update_equation=optimizer)
|
||||||
|
return
|
||||||
|
|
||||||
|
# create data readers
|
||||||
|
feeding = {
|
||||||
|
"audio_spectrogram": 0,
|
||||||
|
"transcript_text": 1,
|
||||||
|
}
|
||||||
|
train_batch_reader = audio_data_utils.padding_batch_reader(
|
||||||
|
paddle.batch(
|
||||||
|
audio_data_utils.reader_creator("./libri.manifest.dev"),
|
||||||
|
batch_size=args.batch_size // args.trainer),
|
||||||
|
padding=[-1, 1000])
|
||||||
|
test_batch_reader = audio_data_utils.padding_batch_reader(
|
||||||
|
paddle.batch(
|
||||||
|
audio_data_utils.reader_creator("./libri.manifest.test"),
|
||||||
|
batch_size=args.batch_size // args.trainer),
|
||||||
|
padding=[-1, 1000])
|
||||||
|
|
||||||
|
# create event handler
|
||||||
|
def event_handler(event):
|
||||||
|
if isinstance(event, paddle.event.EndIteration):
|
||||||
|
if event.batch_id % 10 == 0:
|
||||||
|
print "Pass: %d, Batch: %d, TrainCost: %f, %s" % (
|
||||||
|
event.pass_id, event.batch_id, event.cost, event.metrics)
|
||||||
|
else:
|
||||||
|
sys.stdout.write('.')
|
||||||
|
sys.stdout.flush()
|
||||||
|
if isinstance(event, paddle.event.EndPass):
|
||||||
|
result = trainer.test(reader=test_batch_reader, feeding=feeding)
|
||||||
|
print "Pass: %d, TestCost: %f, %s" % (event.pass_id, event.cost,
|
||||||
|
result.metrics)
|
||||||
|
with gzip.open("params.tar.gz", 'w') as f:
|
||||||
|
parameters.to_tar(f)
|
||||||
|
|
||||||
|
# run train
|
||||||
|
trainer.train(
|
||||||
|
reader=train_batch_reader,
|
||||||
|
event_handler=event_handler,
|
||||||
|
num_passes=10,
|
||||||
|
feeding=feeding)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
train()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in new issue