Merge pull request #444 from pkuyym/fix-442

Support padding removing.
pull/2/head
Yang yaming 7 years ago committed by GitHub
commit 4913cba53c

@ -60,6 +60,9 @@ class DataGenerator(object):
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
:param num_conv_layers: The number of convolution layer, used to compute
the sequence length.
:type num_conv_layers: int
"""
def __init__(self,
@ -75,7 +78,8 @@ class DataGenerator(object):
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count() // 2,
random_seed=0,
keep_transcription_text=False):
keep_transcription_text=False,
num_conv_layers=2):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
@ -96,6 +100,7 @@ class DataGenerator(object):
self._local_data = local()
self._local_data.tar2info = {}
self._local_data.tar2object = {}
self._num_conv_layers = num_conv_layers
def process_utterance(self, filename, transcript):
"""Load, augment, featurize and normalize for speech data.
@ -214,7 +219,15 @@ class DataGenerator(object):
:return: Data feeding dict.
:rtype: dict
"""
return {"audio_spectrogram": 0, "transcript_text": 1}
feeding_dict = {
"audio_spectrogram": 0,
"transcript_text": 1,
"sequence_offset": 2,
"sequence_length": 3
}
for i in xrange(self._num_conv_layers):
feeding_dict["conv%d_index_range" % i] = len(feeding_dict)
return feeding_dict
@property
def vocab_size(self):
@ -312,7 +325,30 @@ class DataGenerator(object):
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()
new_batch.append((padded_audio, text))
# Stride size for conv0 is (3, 2)
# Stride size for conv1 to convN is (1, 2)
# Same as the network, hard-coded here
padded_instance = [padded_audio, text]
padded_conv0_h = (padded_audio.shape[0] - 1) // 2 + 1
padded_conv0_w = (padded_audio.shape[1] - 1) // 3 + 1
valid_w = (audio.shape[1] - 1) // 3 + 1
padded_instance += [
[0], # sequence offset, always 0
[valid_w], # valid sequence length
# Index ranges for channel, height and width
# Please refer scale_sub_region layer to see details
[1, 32, 1, padded_conv0_h, valid_w + 1, padded_conv0_w]
]
pre_padded_h = padded_conv0_h
for i in xrange(self._num_conv_layers - 1):
padded_h = (pre_padded_h - 1) // 2 + 1
pre_padded_h = padded_h
padded_instance += [
[1, 32, 1, padded_h, valid_w + 1, padded_conv0_w]
]
new_batch.append(padded_instance)
return new_batch
def _batch_shuffle(self, manifest, batch_size, clipped=False):

@ -69,7 +69,8 @@ def infer():
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1,
keep_transcription_text=True)
keep_transcription_text=True,
num_conv_layers=args.num_conv_layers)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.infer_manifest,
batch_size=args.num_samples,
@ -100,10 +101,11 @@ def infer():
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
num_processes=args.num_proc_bsearch,
feeding_dict=data_generator.feeding)
error_rate_func = cer if args.error_rate_type == 'cer' else wer
target_transcripts = [transcript for _, transcript in infer_data]
target_transcripts = [data[1] for data in infer_data]
for target, result in zip(target_transcripts, result_transcripts):
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
(target, result))

@ -165,7 +165,7 @@ class DeepSpeech2Model(object):
def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
beam_size, cutoff_prob, cutoff_top_n, vocab_list,
language_model_path, num_processes):
language_model_path, num_processes, feeding_dict):
"""Model inference. Infer the transcription for a batch of speech
utterances.
@ -195,6 +195,9 @@ class DeepSpeech2Model(object):
:type language_model_path: basestring|None
:param num_processes: Number of processes (CPU) for decoder.
:type num_processes: int
:param feeding_dict: Feeding is a map of field name and tuple index
of the data that reader returns.
:type feeding_dict: dict|list
:return: List of transcription texts.
:rtype: List of basestring
"""
@ -203,10 +206,13 @@ class DeepSpeech2Model(object):
self._inferer = paddle.inference.Inference(
output_layer=self._log_probs, parameters=self._parameters)
# run inference
infer_results = self._inferer.infer(input=infer_data)
num_steps = len(infer_results) // len(infer_data)
infer_results = self._inferer.infer(
input=infer_data, feeding=feeding_dict)
start_pos = [0] * (len(infer_data) + 1)
for i in xrange(len(infer_data)):
start_pos[i + 1] = start_pos[i] + infer_data[i][3][0]
probs_split = [
infer_results[i * num_steps:(i + 1) * num_steps]
infer_results[start_pos[i]:start_pos[i + 1]]
for i in xrange(0, len(infer_data))
]
# run decoder
@ -274,9 +280,25 @@ class DeepSpeech2Model(object):
text_data = paddle.layer.data(
name="transcript_text",
type=paddle.data_type.integer_value_sequence(vocab_size))
seq_offset_data = paddle.layer.data(
name='sequence_offset',
type=paddle.data_type.integer_value_sequence(1))
seq_len_data = paddle.layer.data(
name='sequence_length',
type=paddle.data_type.integer_value_sequence(1))
index_range_datas = []
for i in xrange(num_rnn_layers):
index_range_datas.append(
paddle.layer.data(
name='conv%d_index_range' % i,
type=paddle.data_type.dense_vector(6)))
self._log_probs, self._loss = deep_speech_v2_network(
audio_data=audio_data,
text_data=text_data,
seq_offset_data=seq_offset_data,
seq_len_data=seq_len_data,
index_range_datas=index_range_datas,
dict_size=vocab_size,
num_conv_layers=num_conv_layers,
num_rnn_layers=num_rnn_layers,

@ -7,7 +7,7 @@ import paddle.v2 as paddle
def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
padding, act):
padding, act, index_range_data):
"""Convolution layer with batch normalization.
:param input: Input layer.
@ -24,6 +24,8 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
:type padding: int|tuple|list
:param act: Activation type.
:type act: BaseActivation
:param index_range_data: Index range to indicate sub region.
:type index_range_data: LayerOutput
:return: Batch norm layer after convolution layer.
:rtype: LayerOutput
"""
@ -36,7 +38,11 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
padding=padding,
act=paddle.activation.Linear(),
bias_attr=False)
return paddle.layer.batch_norm(input=conv_layer, act=act)
batch_norm = paddle.layer.batch_norm(input=conv_layer, act=act)
# reset padding part to 0
scale_sub_region = paddle.layer.scale_sub_region(
batch_norm, index_range_data, value=0.0)
return scale_sub_region
def bidirectional_simple_rnn_bn_layer(name, input, size, act, share_weights):
@ -136,13 +142,15 @@ def bidirectional_gru_bn_layer(name, input, size, act):
return paddle.layer.concat(input=[forward_gru, backward_gru])
def conv_group(input, num_stacks):
def conv_group(input, num_stacks, index_range_datas):
"""Convolution group with stacked convolution layers.
:param input: Input layer.
:type input: LayerOutput
:param num_stacks: Number of stacked convolution layers.
:type num_stacks: int
:param index_range_datas: Index ranges for each convolution layer.
:type index_range_datas: tuple|list
:return: Output layer of the convolution group.
:rtype: LayerOutput
"""
@ -153,7 +161,8 @@ def conv_group(input, num_stacks):
num_channels_out=32,
stride=(3, 2),
padding=(5, 20),
act=paddle.activation.BRelu())
act=paddle.activation.BRelu(),
index_range_data=index_range_datas[0])
for i in xrange(num_stacks - 1):
conv = conv_bn_layer(
input=conv,
@ -162,7 +171,8 @@ def conv_group(input, num_stacks):
num_channels_out=32,
stride=(1, 2),
padding=(5, 10),
act=paddle.activation.BRelu())
act=paddle.activation.BRelu(),
index_range_data=index_range_datas[i + 1])
output_num_channels = 32
output_height = 160 // pow(2, num_stacks) + 1
return conv, output_num_channels, output_height
@ -207,6 +217,9 @@ def rnn_group(input, size, num_stacks, use_gru, share_rnn_weights):
def deep_speech_v2_network(audio_data,
text_data,
seq_offset_data,
seq_len_data,
index_range_datas,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
@ -219,6 +232,12 @@ def deep_speech_v2_network(audio_data,
:type audio_data: LayerOutput
:param text_data: Transcription text data layer.
:type text_data: LayerOutput
:param seq_offset_data: Sequence offset data layer.
:type seq_offset_data: LayerOutput
:param seq_len_data: Valid sequence length data layer.
:type seq_len_data: LayerOutput
:param index_range_datas: Index ranges data layers.
:type index_range_datas: tuple|list
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
@ -239,7 +258,9 @@ def deep_speech_v2_network(audio_data,
"""
# convolution group
conv_group_output, conv_group_num_channels, conv_group_height = conv_group(
input=audio_data, num_stacks=num_conv_layers)
input=audio_data,
num_stacks=num_conv_layers,
index_range_datas=index_range_datas)
# convert data form convolution feature map to sequence of vectors
conv2seq = paddle.layer.block_expand(
input=conv_group_output,
@ -248,9 +269,16 @@ def deep_speech_v2_network(audio_data,
stride_y=1,
block_x=1,
block_y=conv_group_height)
# remove padding part
remove_padding_data = paddle.layer.sub_seq(
input=conv2seq,
offsets=seq_offset_data,
sizes=seq_len_data,
act=paddle.activation.Linear(),
bias_attr=False)
# rnn group
rnn_group_output = rnn_group(
input=conv2seq,
input=remove_padding_data,
size=rnn_size,
num_stacks=num_rnn_layers,
use_gru=use_gru,

@ -70,7 +70,8 @@ def evaluate():
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=args.num_proc_data,
keep_transcription_text=True)
keep_transcription_text=True,
num_conv_layers=args.num_conv_layers)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.test_manifest,
batch_size=args.batch_size,
@ -103,8 +104,9 @@ def evaluate():
cutoff_top_n=args.cutoff_top_n,
vocab_list=vocab_list,
language_model_path=args.lang_model_path,
num_processes=args.num_proc_bsearch)
target_transcripts = [transcript for _, transcript in infer_data]
num_processes=args.num_proc_bsearch,
feeding_dict=data_generator.feeding)
target_transcripts = [data[1] for data in infer_data]
for target, result in zip(target_transcripts, result_transcripts):
error_sum += error_rate_func(target, result)
num_ins += 1

@ -75,13 +75,15 @@ def train():
max_duration=args.max_duration,
min_duration=args.min_duration,
specgram_type=args.specgram_type,
num_threads=args.num_proc_data)
num_threads=args.num_proc_data,
num_conv_layers=args.num_conv_layers)
dev_generator = DataGenerator(
vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
augmentation_config="{}",
specgram_type=args.specgram_type,
num_threads=args.num_proc_data)
num_threads=args.num_proc_data,
num_conv_layers=args.num_conv_layers)
train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest,
batch_size=args.batch_size,

Loading…
Cancel
Save