export feasture size

add trainer and utils
add setup model and dataloader
update travis using Bionic dist
pull/522/head
Hui Zhang 5 years ago
parent 508182752e
commit c2ccb11ba0

@ -1,7 +1,7 @@
language: cpp language: cpp
cache: ccache cache: ccache
sudo: required sudo: required
dist: xenial dist: Bionic
services: services:
- docker - docker
os: os:

@ -188,8 +188,6 @@ class DataGenerator():
max_duration=self._max_duration, max_duration=self._max_duration,
min_duration=self._min_duration) min_duration=self._min_duration)
# sort (by duration) or batch-wise shuffle the manifest # sort (by duration) or batch-wise shuffle the manifest
if self._epoch == 0 and sortagrad: if self._epoch == 0 and sortagrad:
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
@ -365,7 +363,7 @@ class DataGenerator():
""" """
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
shift_len = self._rng.randint(0, batch_size - 1) shift_len = self._rng.randint(0, batch_size - 1)
batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size)) batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size))
self._rng.shuffle(batch_manifest) self._rng.shuffle(batch_manifest)
batch_manifest = [item for batch in batch_manifest for item in batch] batch_manifest = [item for batch in batch_manifest for item in batch]
if not clipped: if not clipped:

@ -211,7 +211,7 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler):
rng = np.random.RandomState(self.epoch) rng = np.random.RandomState(self.epoch)
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size)) batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size))
rng.shuffle(batch_manifest) rng.shuffle(batch_manifest)
batch_manifest = [item for batch in batch_manifest for item in batch] batch_manifest = [item for batch in batch_manifest for item in batch]
if not clipped: if not clipped:
@ -347,7 +347,7 @@ class DeepSpeech2BatchSampler(BatchSampler):
rng = np.random.RandomState(self.epoch) rng = np.random.RandomState(self.epoch)
manifest.sort(key=lambda x: x["duration"]) manifest.sort(key=lambda x: x["duration"])
shift_len = rng.randint(0, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size)) batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size))
rng.shuffle(batch_manifest) rng.shuffle(batch_manifest)
batch_manifest = [item for batch in batch_manifest for item in batch] batch_manifest = [item for batch in batch_manifest for item in batch]
if not clipped: if not clipped:

@ -63,6 +63,7 @@ class AudioFeaturizer(object):
self._target_sample_rate = target_sample_rate self._target_sample_rate = target_sample_rate
self._use_dB_normalization = use_dB_normalization self._use_dB_normalization = use_dB_normalization
self._target_dB = target_dB self._target_dB = target_dB
self._fft_point = None
def featurize(self, def featurize(self,
audio_segment, audio_segment,
@ -98,6 +99,19 @@ class AudioFeaturizer(object):
return self._compute_specgram(audio_segment.samples, return self._compute_specgram(audio_segment.samples,
audio_segment.sample_rate) audio_segment.sample_rate)
@property
def feature_size(self):
"""audio feature size"""
if self._specgram_type == 'linear':
fft_point = self._window_ms if self._fft_point is None else self._fft_point
return fft_point * (self._target_sample_rate / 1000) / 2 + 1
elif self._specgram_type == 'mfcc':
# mfcc,delta, delta-delta
return 13 * 3
else:
raise ValueError("Unknown specgram_type %s. "
"Supported values: linear." % self._specgram_type)
def _compute_specgram(self, samples, sample_rate): def _compute_specgram(self, samples, sample_rate):
"""Extract various audio features.""" """Extract various audio features."""
if self._specgram_type == 'linear': if self._specgram_type == 'linear':
@ -150,7 +164,8 @@ class AudioFeaturizer(object):
windows[:, 1] == samples[stride_size:(stride_size + window_size)]) windows[:, 1] == samples[stride_size:(stride_size + window_size)])
# window weighting, squared Fast Fourier Transform (fft), scaling # window weighting, squared Fast Fourier Transform (fft), scaling
weighting = np.hanning(window_size)[:, None] weighting = np.hanning(window_size)[:, None]
fft = np.fft.rfft(windows * weighting, axis=0) # https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html
fft = np.fft.rfft(windows * weighting, n=None, axis=0)
fft = np.absolute(fft) fft = np.absolute(fft)
fft = fft**2 fft = fft**2
scale = np.sum(weighting**2) * sample_rate scale = np.sum(weighting**2) * sample_rate

@ -106,3 +106,12 @@ class SpeechFeaturizer(object):
:rtype: list :rtype: list
""" """
return self._text_featurizer.vocab_list return self._text_featurizer.vocab_list
@property
def feature_size(self):
"""Return the audio feature size.
:return: audio feature size.
:rtype: int
"""
return self._audio_featurizer.feature_size

@ -16,13 +16,12 @@
import sys import sys
import argparse import argparse
import functools import functools
import paddle.fluid as fluid from model_utils.model_check import check_cuda, check_version
from utils.utility import add_arguments, print_arguments
from utils.error_rate import wer, cer
from data_utils.data import DataGenerator from data_utils.data import DataGenerator
from data_utils.dataset import create_dataloader from data_utils.dataset import create_dataloader
from model_utils.model import DeepSpeech2Model from model_utils.model import DeepSpeech2Model
from model_utils.model_check import check_cuda, check_version
from utils.error_rate import wer, cer
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
@ -132,7 +131,7 @@ def infer():
dict_size=batch_reader.dataset.vocab_size, dict_size=batch_reader.dataset.vocab_size,
num_conv_layers=args.num_conv_layers, num_conv_layers=args.num_conv_layers,
num_rnn_layers=args.num_rnn_layers, num_rnn_layers=args.num_rnn_layers,
#rnn_size=1024, rnn_size=args.rnn_layer_size,
use_gru=args.use_gru, use_gru=args.use_gru,
share_rnn_weights=args.share_rnn_weights, share_rnn_weights=args.share_rnn_weights,
) )

@ -24,179 +24,177 @@ import collections
import multiprocessing import multiprocessing
import numpy as np import numpy as np
from distutils.dir_util import mkpath from distutils.dir_util import mkpath
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.io import DataLoader
import paddle.fluid.compiler as compiler from training import Trainer
from model_utils.network import DeepSpeech2
from model_utils.network import DeepSpeech2Loss
from decoders.swig_wrapper import Scorer from decoders.swig_wrapper import Scorer
from decoders.swig_wrapper import ctc_greedy_decoder from decoders.swig_wrapper import ctc_greedy_decoder
from decoders.swig_wrapper import ctc_beam_search_decoder_batch from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from model_utils.network import deep_speech_v2_network
logging.basicConfig( logging.basicConfig(
format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
class DeepSpeech2Model(object): class SpeechCollator():
"""DeepSpeech2Model class. def __init__(self, padding_to=-1):
:param vocab_size: Decoding vocabulary size.
:type vocab_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_layer_size: RNN layer size (number of RNN cells).
:type rnn_layer_size: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.Notice that
for GRU, weight sharing is not supported.
:type share_rnn_weights: bool
:param place: Program running place.
:type place: CPUPlace or CUDAPlace
:param init_from_pretrained_model: Pretrained model path. If None, will train
from stratch.
:type init_from_pretrained_model: string|None
:param output_model_dir: Output model directory. If None, output to current directory.
:type output_model_dir: string|None
"""
def __init__(self,
vocab_size,
num_conv_layers,
num_rnn_layers,
rnn_layer_size,
use_gru=False,
share_rnn_weights=True,
place=fluid.CPUPlace(),
init_from_pretrained_model=None,
output_model_dir=None):
self._vocab_size = vocab_size
self._num_conv_layers = num_conv_layers
self._num_rnn_layers = num_rnn_layers
self._rnn_layer_size = rnn_layer_size
self._use_gru = use_gru
self._share_rnn_weights = share_rnn_weights
self._place = place
self._init_from_pretrained_model = init_from_pretrained_model
self._output_model_dir = output_model_dir
self._ext_scorer = None
self.logger = logging.getLogger("")
self.logger.setLevel(level=logging.INFO)
def create_network(self, is_infer=False):
"""Create data layers and model network.
:param is_training: Whether to create a network for training.
:type is_training: bool
:return reader: Reader for input.
:rtype reader: read generater
:return log_probs: An output unnormalized log probability layer.
:rtype lig_probs: Varable
:return loss: A ctc loss layer.
:rtype loss: Variable
""" """
Padding audio features with zeros to make them have the same shape (or
a user-defined shape) within one bach.
if not is_infer: If ``padding_to`` is -1, the maximun shape in the batch will be used
reader = DataLoader.from_generator( as the target shape for padding. Otherwise, `padding_to` will be the
feed_list=inputs, target shape (only refers to the second axis).
capacity=64, """
iterable=False, self._padding_to = padding_to
use_double_buffer=True)
def __call__(self, batch):
new_batch = []
# get target shape
max_length = max([audio.shape[1] for audio, _ in batch])
if self._padding_to != -1:
if self._padding_to < max_length:
raise ValueError("If padding_to is not -1, it should be larger "
"than any instance's shape in the batch")
max_length = self._padding_to
max_text_length = max([len(text) for _, text in batch])
# padding
padded_audios = []
audio_lens = []
texts, text_lens = [], []
for audio, text in batch:
# audio
padded_audio = np.zeros([audio.shape[0], max_length])
padded_audio[:, :audio.shape[1]] = audio
padded_audios.append(padded_audio)
audio_lens.append(audio.shape[1])
# text
padded_text = np.zeros([max_text_length])
padded_text[:len(text)] = text
texts.append(padded_text)
text_lens.append(len(text))
padded_audios = np.array(padded_audios).astype('float32')
audio_lens = np.array(audio_lens).astype('int64')
texts = np.array(texts).astype('int32')
text_lens = np.array(text_lens).astype('int64')
return padded_audios, texts, audio_lens, text_lens
class DeepSpeech2Trainer(Trainer):
def __init__(self):
self._ext_scorer = None
(audio_data, text_data, seq_len_data, masks) = inputs def setup_dataloader(self):
config = self.config
train_dataset = DeepSpeech2Dataset(
config.data.train_manifest_path,
config.data.vocab_filepath,
config.data.mean_std_filepath,
augmentation_config=config.data.augmentation_config,
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
max_freq=config.data.max_freq,
specgram_type=config.data.specgram_type,
use_dB_normalization=config.data.use_dB_normalization,
random_seed=config.data.random_seed,
keep_transcription_text=False)
dev_dataset = DeepSpeech2Dataset(
config.data.dev_manifest_path,
config.data.vocab_filepath,
config.data.mean_std_filepath,
augmentation_config=config.data.augmentation_config,
max_duration=config.data.max_duration,
min_duration=config.data.min_duration,
stride_ms=config.data.stride_ms,
window_ms=config.data.window_ms,
max_freq=config.data.max_freq,
specgram_type=config.data.specgram_type,
use_dB_normalization=config.data.use_dB_normalization,
random_seed=config.data.random_seed,
keep_transcription_text=False)
if self.parallel:
batch_sampler = DeepSpeech2DistributedBatchSampler(
train_dataset,
batch_size=config.data.batch_size,
num_replicas=None,
rank=None,
shuffle=True,
drop_last=True,
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
else: else:
audio_data = fluid.data( batch_sampler = DeepSpeech2BatchSampler(
name='audio_data', train_dataset,
shape=[None, 161, None], shuffle=True,
dtype='float32', batch_size=config.data.batch_size,
lod_level=0) drop_last=True,
seq_len_data = fluid.data( sortagrad=config.data.sortagrad,
name='seq_len_data', shuffle_method=config.data.shuffle_method)
shape=[None, 1],
dtype='int64', collate_fn = SpeechCollator()
lod_level=0) self.train_loader = DataLoader(
masks = fluid.data( train_dataset,
name='masks', batch_sampler=batch_sampler,
shape=[None, 32, 81, None], collate_fn=collate_fn,
dtype='float32', num_workers=config.data.num_workers, )
lod_level=0) self.valid_loader = DataLoader(
text_data = None dev_dataset,
reader = fluid.DataFeeder([audio_data, seq_len_data, masks], batch_size=config.data.batch_size,
self._place) shuffle=False,
drop_last=False,
log_probs, loss = deep_speech_v2_network( collate_fn=collate_fn)
audio_data=audio_data, self.logger.info("Setup train/valid Dataloader!")
text_data=text_data,
seq_len_data=seq_len_data, def setup_model(self):
masks=masks, config = self.config
dict_size=self._vocab_size, model = DeepSpeech2(
num_conv_layers=self._num_conv_layers, feat_size=self.train_loader.feature_size,
num_rnn_layers=self._num_rnn_layers, dict_size=self.train_loader.vocab_size,
rnn_size=self._rnn_layer_size, num_conv_layers=config.model.num_conv_layers,
use_gru=self._use_gru, num_rnn_layers=config.model.num_rnn_layers,
share_rnn_weights=self._share_rnn_weights) rnn_size=config.model.rnn_layer_size,
return reader, log_probs, loss share_rnn_weights=config.model.share_rnn_weights)
def init_from_pretrained_model(self, exe, program): if self.parallel:
'''Init params from pretrain model. ''' model = paddle.DataParallel(model)
assert isinstance(self._init_from_pretrained_model, str) grad_clip = paddle.nn.ClipGradByGlobalNorm(config.training.grad_clip)
if not os.path.exists(self._init_from_pretrained_model): optimizer = paddle.optimizer.Adam(
print(self._init_from_pretrained_model) learning_rate=config.training.lr,
raise Warning("The pretrained params do not exist.") parameters=model.parameters(),
return False weight_decay=paddle.regulaerizer.L2Decay(
fluid.io.load_params( config.training.weight_decay),
exe, grad_clip=grad_clip, )
self._init_from_pretrained_model,
main_program=program, criterion = DeepSpeech2Loss(self.train_loader.vocab_size)
filename="params.pdparams")
self.model = model
print("finish initing model from pretrained params from %s" % self.optimizer = optimizer
(self._init_from_pretrained_model)) self.criterion = criterion
self.logger.info("Setup model/optimizer/criterion!")
pre_epoch = 0
dir_name = self._init_from_pretrained_model.split('_') def compute_losses(self, inputs, outputs):
if len(dir_name) >= 2 and dir_name[-2].endswith('epoch') and dir_name[ pass
-1].isdigit():
pre_epoch = int(dir_name[-1]) def test(self, test_reader):
return pre_epoch + 1
def save_param(self, exe, program, dirname):
'''Save model params to dirname'''
assert isinstance(self._output_model_dir, str)
param_dir = os.path.join(self._output_model_dir)
if not os.path.exists(param_dir):
os.mkdir(param_dir)
fluid.io.save_params(
exe,
os.path.join(param_dir, dirname),
main_program=program,
filename="params.pdparams")
print("save parameters at %s" % (os.path.join(param_dir, dirname)))
return True
def test(self, exe, dev_batch_reader, test_program, test_reader,
fetch_list):
'''Test the model. '''Test the model.
:param exe:The executor of program. :param exe:The executor of program.
:type exe: Executor :type exe: Executor
:param dev_batch_reader: The reader of test dataa.
:type dev_batch_reader: read generator
:param test_program: The program of test. :param test_program: The program of test.
:type test_program: Program :type test_program: Program
:param test_reader: Reader of test. :param test_reader: Reader of test.
:type test_reader: Reader :type test_reader: Reader
:param fetch_list: Fetch list.
:type fetch_list: list
:return: An output unnormalized log probability. :return: An output unnormalized log probability.
:rtype: array :rtype: array
''' '''
@ -254,13 +252,6 @@ class DeepSpeech2Model(object):
:param test_off: Turn off testing. :param test_off: Turn off testing.
:type test_off: bool :type test_off: bool
""" """
# prepare model output directory
if not os.path.exists(self._output_model_dir):
mkpath(self._output_model_dir)
# adapt the feeding dict according to the network
adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict)
if isinstance(self._place, fluid.CUDAPlace): if isinstance(self._place, fluid.CUDAPlace):
dev_count = fluid.core.get_cuda_device_count() dev_count = fluid.core.get_cuda_device_count()
else: else:
@ -298,16 +289,6 @@ class DeepSpeech2Model(object):
if self._init_from_pretrained_model: if self._init_from_pretrained_model:
pre_epoch = self.init_from_pretrained_model(exe, train_program) pre_epoch = self.init_from_pretrained_model(exe, train_program)
build_strategy = compiler.BuildStrategy()
exec_strategy = fluid.ExecutionStrategy()
# pass the build_strategy to with_data_parallel API
compiled_prog = compiler.CompiledProgram(
train_program).with_data_parallel(
loss_name=ctc_loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
train_reader.set_batch_generator(train_batch_reader) train_reader.set_batch_generator(train_batch_reader)
test_reader.set_batch_generator(dev_batch_reader) test_reader.set_batch_generator(dev_batch_reader)
@ -386,9 +367,6 @@ class DeepSpeech2Model(object):
infer_program = fluid.Program() infer_program = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
# adapt the feeding dict according to the network
adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict)
# prepare the network # prepare the network
with fluid.program_guard(infer_program, startup_prog): with fluid.program_guard(infer_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
@ -523,35 +501,3 @@ class DeepSpeech2Model(object):
results = [result[0][1] for result in beam_search_results] results = [result[0][1] for result in beam_search_results]
return results return results
def _adapt_feeding_dict(self, feeding_dict):
"""Adapt feeding dict according to network struct.
To remove impacts from padding part, we add scale_sub_region layer and
sub_seq layer. For sub_seq layer, 'sequence_offset' and
'sequence_length' fields are appended. For each scale_sub_region layer
'convN_index_range' field is appended.
:param feeding_dict: Feeding is a map of field name and tuple index
of the data that reader returns.
:type feeding_dict: dict|list
:return: Adapted feeding dict.
:rtype: dict|list
"""
adapted_feeding_dict = copy.deepcopy(feeding_dict)
if isinstance(feeding_dict, dict):
adapted_feeding_dict["sequence_offset"] = len(adapted_feeding_dict)
adapted_feeding_dict["sequence_length"] = len(adapted_feeding_dict)
for i in range(self._num_conv_layers):
adapted_feeding_dict["conv%d_index_range" %i] = \
len(adapted_feeding_dict)
elif isinstance(feeding_dict, list):
adapted_feeding_dict.append("sequence_offset")
adapted_feeding_dict.append("sequence_length")
for i in range(self._num_conv_layers):
adapted_feeding_dict.append("conv%d_index_range" % i)
else:
raise ValueError("Type of feeding_dict is %s, not supported." %
type(feeding_dict))
return adapted_feeding_dict

@ -12,31 +12,49 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math
import collections import collections
import paddle.fluid as fluid
import numpy as np import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
__all__ = ['DeepSpeech2', 'DeepSpeech2Loss']
def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
padding, act, masks, name): def brelu(x, t_min=0.0, t_max=24.0, name=None):
t_min = paddle.to_tensor(t_min)
t_max = paddle.to_tensor(t_max)
return x.maximum(t_min).minimum(t_max)
def sequence_mask(x_len, max_len=None, dtype='float32'):
max_len = max_len or x_len.max()
x_len = paddle.unsqueeze(x_len, -1)
row_vector = paddle.arange(max_len)
mask = row_vector < x_len
mask = paddle.cast(mask, dtype)
return mask
class ConvBn(nn.Layer):
"""Convolution layer with batch normalization. """Convolution layer with batch normalization.
:param input: Input layer. :param kernel_size: The x dimension of a filter kernel. Or input a tuple for
:type input: Variable
:param filter_size: The x dimension of a filter kernel. Or input a tuple for
two image dimension. two image dimension.
:type filter_size: int|tuple|list :type kernel_size: int|tuple|list
:param num_channels_in: Number of input channels. :param num_channels_in: Number of input channels.
:type num_channels_in: int :type num_channels_in: int
:param num_channels_out: Number of output channels. :param num_channels_out: Number of output channels.
:type num_channels_out: int :type num_channels_out: int
:param stride: The x dimension of the stride. Or input a tuple for two :param stride: The x dimension of the stride. Or input a tuple for two
image dimension. image dimension.
:type stride: int|tuple|list :type stride: int|tuple|list
:param padding: The x dimension of the padding. Or input a tuple for two :param padding: The x dimension of the padding. Or input a tuple for two
image dimension. image dimension.
:type padding: int|tuple|list :type padding: int|tuple|list
:param act: Activation type. :param act: Activation type, relu|brelu
:type act: string :type act: string
:param masks: Masks data layer to reset padding. :param masks: Masks data layer to reset padding.
:type masks: Variable :type masks: Variable
@ -46,89 +64,255 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
:rtype: Variable :rtype: Variable
""" """
conv_layer = fluid.layers.conv2d(
input=input, def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,
num_filters=num_channels_out, padding, act):
filter_size=filter_size,
stride=stride, super().__init__()
padding=padding, self.kernel_size = kernel_size
param_attr=fluid.ParamAttr(name=name + '_conv2d_weight'), self.stride = stride
act=None, self.padding = padding
bias_attr=False)
self.conv = nn.Conv2D(
batch_norm = fluid.layers.batch_norm( num_channels_in,
input=conv_layer, num_channels_out,
act=act, kernel_size=kernel_size,
param_attr=fluid.ParamAttr(name=name + '_batch_norm_weight'), stride=stride,
bias_attr=fluid.ParamAttr(name=name + '_batch_norm_bias'), padding=padding,
moving_mean_name=name + '_batch_norm_moving_mean', weight_attr=None,
moving_variance_name=name + '_batch_norm_moving_variance') bias_attr=None,
data_format='NCHW')
# reset padding part to 0
padding_reset = fluid.layers.elementwise_mul(batch_norm, masks) self.bn = nn.BatchNorm2D(
return padding_reset num_channels_out,
weight_attr=None,
bias_attr=None,
class RNNCell(fluid.layers.RNNCell): data_format='NCHW')
"""A simple rnn cell.""" self.act = paddle.relu if act == 'relu' else brelu
def forward(self, x, x_len):
"""
x(Tensor): audio, shape [B, C, D, T]
"""
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
) // self.stride[1] + 1
# reset padding part to 0
masks = sequence_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
x = x.multiply(masks)
return x, x_len
class ConvStack(nn.Layer):
"""Convolution group with stacked convolution layers.
:param feat_size: audio feature dim.
:type feat_size: int
:param num_stacks: Number of stacked convolution layers.
:type num_stacks: int
"""
def __init__(self, feat_size, num_stacks):
super().__init__()
self.feat_size = feat_size # D
self.num_stacks = num_stacks
self.filter_size = (41, 11) # [D, T]
self.stride = (2, 3)
self.padding = (20, 5)
self.conv_in = ConvBn(
num_channels_in=1,
num_channels_out=32,
kernel_size=self.filter_size,
stride=self.stride,
padding=self.padding,
act='brelu', )
out_channel = 32
self.conv_stack = nn.LayerList([
ConvBn(
num_channels_in=32,
num_channels_out=out_channel,
kernel_size=(21, 11),
stride=(2, 1),
padding=(10, 5),
act='brelu') for i in range(num_stacks - 1)
])
# conv output feat_dim
output_height = (feat_size - 1) // 2 + 1
for i in range(self.num_stacks - 1):
output_height = (output_height - 1) // 2 + 1
self.output_height = out_channel * output_height
def forward(self, x, x_len):
"""
x: shape [B, C, D, T]
x_len : shape [B]
"""
x, x_len = self.conv_in(x, x_len)
for i, conv in enumerate(self.conv_stack):
x, x_len = conv(x, x_len)
return x, x_len
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
def __init__(self, def __init__(self,
hidden_size, hidden_size,
param_attr=None, activation="tanh",
bias_attr=None, weight_ih_attr=None,
hidden_activation=None, weight_hh_attr=None,
activation=None, bias_ih_attr=None,
dtype="float32", bias_hh_attr=None,
name="RNNCell"): name=None):
"""Initialize simple rnn cell. super().__init__()
std = 1.0 / math.sqrt(hidden_size)
:param hidden_size: Dimension of RNN cells. self.weight_hh = self.create_parameter(
:type hidden_size: int (hidden_size, hidden_size),
:param param_attr: Parameter properties of hidden layer weights that weight_hh_attr,
can be learned default_initializer=I.Uniform(-std, std))
:type param_attr: ParamAttr self.bias_ih = self.create_parameter(
:param bias_attr: Bias properties of hidden layer weights that can be learned (hidden_size, ),
:type bias_attr: ParamAttr bias_ih_attr,
:param hidden_activation: Activation for hidden cell is_bias=True,
:type hidden_activation: Activation default_initializer=I.Uniform(-std, std))
:param activation: Activation for output self.bias_hh = self.create_parameter(
:type activation: Activation (hidden_size, ),
:param name: Name of cell bias_hh_attr,
:type name: string is_bias=True,
""" default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
if activation not in ["tanh", "relu", "brelu"]:
raise ValueError(
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu
if activation == 'brelu':
self._activation_fn = brelu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_h = states
i2h = inputs
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._activation_fn(i2h + h2h)
return h, h
@property
def state_shape(self):
return (self.hidden_size, )
class GRUCellShare(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
def __init__(self,
input_size,
hidden_size,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = self.create_parameter(
(3 * hidden_size, ),
bias_ih_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.param_attr = param_attr self.input_size = input_size
self.bias_attr = bias_attr self._gate_activation = F.sigmoid
self.hidden_activation = hidden_activation self._activation = paddle.tanh
self.activation = activation or fluid.layers.brelu
self.name = name def forward(self, inputs, states=None):
if states is None:
def call(self, inputs, states): states = self.get_initial_states(inputs, self.state_shape)
new_hidden = fluid.layers.fc(
input=states, pre_hidden = states
size=self.hidden_size, x_gates = inputs
act=self.hidden_activation, if self.bias_ih is not None:
param_attr=self.param_attr, x_gates = x_gates + self.bias_ih
bias_attr=self.bias_attr) h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
new_hidden = fluid.layers.elementwise_add(new_hidden, inputs) if self.bias_hh is not None:
new_hidden = self.activation(new_hidden) h_gates = h_gates + self.bias_hh
return new_hidden, new_hidden x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)
h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)
r = self._gate_activation(x_r + h_r)
z = self._gate_activation(x_z + h_z)
c = self._activation(x_c + r * h_c) # apply reset gate after mm
h = (pre_hidden - c) * z + c
return h, h
@property @property
def state_shape(self): def state_shape(self):
return [self.hidden_size] r"""
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
size would be automatically inserted into shape). The shape corresponds
to the shape of :math:`h_{t-1}`.
"""
return (self.hidden_size, )
def bidirectional_simple_rnn_bn_layer(name, input, size, share_weights): class BiRNNWithBN(nn.Layer):
"""Bidirectonal simple rnn layer with sequence-wise batch normalization. """Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights. The batch normalization is only performed on input-state weights.
:param name: Name of the layer parameters. :param name: Name of the layer parameters.
:type name: string :type name: string
:param input: Input layer.
:type input: Variable
:param size: Dimension of RNN cells. :param size: Dimension of RNN cells.
:type size: int :type size: int
:param share_weights: Whether to share input-hidden weights between :param share_weights: Whether to share input-hidden weights between
@ -137,88 +321,44 @@ def bidirectional_simple_rnn_bn_layer(name, input, size, share_weights):
:return: Bidirectional simple rnn layer. :return: Bidirectional simple rnn layer.
:rtype: Variable :rtype: Variable
""" """
forward_cell = RNNCell(
hidden_size=size, def __init__(self, i_size, h_size, share_weights):
activation=fluid.layers.brelu, super().__init__()
param_attr=fluid.ParamAttr(name=name + '_forward_rnn_weight'), self.share_weights = share_weights
bias_attr=fluid.ParamAttr(name=name + '_forward_rnn_bias')) self.pad_value = paddle.to_tensor(np.array([0.0], dtype=np.float32))
if self.share_weights:
reverse_cell = RNNCell( #input-hidden weights shared between bi-directional rnn.
hidden_size=size, self.fw_fc = nn.Linear(i_size, h_size)
activation=fluid.layers.brelu, # batch norm is only performed on input-state projection
param_attr=fluid.ParamAttr(name=name + '_reverse_rnn_weight'), self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
bias_attr=fluid.ParamAttr(name=name + '_reverse_rnn_bias')) self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
pad_value = fluid.layers.assign(input=np.array([0.0], dtype=np.float32)) else:
self.fw_fc = nn.Linear(i_size, h_size)
if share_weights: self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
#input-hidden weights shared between bi-directional rnn. self.bw_fc = nn.Linear(i_size, h_size)
input_proj = fluid.layers.fc( self.bw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
input=input,
size=size, self.fw_cell = RNNCell(hidden_size=h_size, activation='relu')
act=None, self.bw_cell = RNNCell(
param_attr=fluid.ParamAttr(name=name + '_fc_weight'), hidden_size=h_size,
bias_attr=False) activation='relu', )
self.fw_rnn = nn.RNN(
# batch norm is only performed on input-state projection self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
input_proj_bn_forward = fluid.layers.batch_norm( self.bw_rnn = nn.RNN(
input=input_proj, self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
act=None,
param_attr=fluid.ParamAttr(name=name + '_batch_norm_weight'), def forward(self, x, x_len):
bias_attr=fluid.ParamAttr(name=name + '_batch_norm_bias'), # x, shape [B, T, D]
moving_mean_name=name + '_batch_norm_moving_mean', fw_x = self.fw_bn(self.fw_fc(x))
moving_variance_name=name + '_batch_norm_moving_variance') bw_x = self.bw_bn(self.bw_fc(x))
input_proj_bn_reverse = input_proj_bn_forward fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
else: bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
input_proj_forward = fluid.layers.fc( x = paddle.concat([fw_x, bw_x], axis=-1)
input=input, return x, x_len
size=size,
act=None,
param_attr=fluid.ParamAttr(name=name + '_forward_fc_weight'), class BiGRUWithBN(nn.Layer):
bias_attr=False)
input_proj_reverse = fluid.layers.fc(
input=input,
size=size,
act=None,
param_attr=fluid.ParamAttr(name=name + '_reverse_fc_weight'),
bias_attr=False)
#batch norm is only performed on input-state projection
input_proj_bn_forward = fluid.layers.batch_norm(
input=input_proj_forward,
act=None,
param_attr=fluid.ParamAttr(
name=name + '_forward_batch_norm_weight'),
bias_attr=fluid.ParamAttr(name=name + '_forward_batch_norm_bias'),
moving_mean_name=name + '_forward_batch_norm_moving_mean',
moving_variance_name=name + '_forward_batch_norm_moving_variance')
input_proj_bn_reverse = fluid.layers.batch_norm(
input=input_proj_reverse,
act=None,
param_attr=fluid.ParamAttr(
name=name + '_reverse_batch_norm_weight'),
bias_attr=fluid.ParamAttr(name=name + '_reverse_batch_norm_bias'),
moving_mean_name=name + '_reverse_batch_norm_moving_mean',
moving_variance_name=name + '_reverse_batch_norm_moving_variance')
# forward and backward in time
input, length = fluid.layers.sequence_pad(input_proj_bn_forward, pad_value)
forward_rnn, _ = fluid.layers.rnn(
cell=forward_cell, inputs=input, time_major=False, is_reverse=False)
forward_rnn = fluid.layers.sequence_unpad(x=forward_rnn, length=length)
input, length = fluid.layers.sequence_pad(input_proj_bn_reverse, pad_value)
reverse_rnn, _ = fluid.layers.rnn(
cell=reverse_cell,
inputs=input,
sequence_length=length,
time_major=False,
is_reverse=True)
reverse_rnn = fluid.layers.sequence_unpad(x=reverse_rnn, length=length)
out = fluid.layers.concat(input=[forward_rnn, reverse_rnn], axis=1)
return out
def bidirectional_gru_bn_layer(name, input, size, act):
"""Bidirectonal gru layer with sequence-wise batch normalization. """Bidirectonal gru layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights. The batch normalization is only performed on input-state weights.
@ -233,108 +373,33 @@ def bidirectional_gru_bn_layer(name, input, size, act):
:return: Bidirectional GRU layer. :return: Bidirectional GRU layer.
:rtype: Variable :rtype: Variable
""" """
input_proj_forward = fluid.layers.fc(
input=input,
size=size * 3,
act=None,
param_attr=fluid.ParamAttr(name=name + '_forward_fc_weight'),
bias_attr=False)
input_proj_reverse = fluid.layers.fc(
input=input,
size=size * 3,
act=None,
param_attr=fluid.ParamAttr(name=name + '_reverse_fc_weight'),
bias_attr=False)
#batch norm is only performed on input-related prohections
input_proj_bn_forward = fluid.layers.batch_norm(
input=input_proj_forward,
act=None,
param_attr=fluid.ParamAttr(name=name + '_forward_batch_norm_weight'),
bias_attr=fluid.ParamAttr(name=name + '_forward_batch_norm_bias'),
moving_mean_name=name + '_forward_batch_norm_moving_mean',
moving_variance_name=name + '_forward_batch_norm_moving_variance')
input_proj_bn_reverse = fluid.layers.batch_norm(
input=input_proj_reverse,
act=None,
param_attr=fluid.ParamAttr(name=name + '_reverse_batch_norm_weight'),
bias_attr=fluid.ParamAttr(name=name + '_reverse_batch_norm_bias'),
moving_mean_name=name + '_reverse_batch_norm_moving_mean',
moving_variance_name=name + '_reverse_batch_norm_moving_variance')
#forward and backward in time
forward_gru = fluid.layers.dynamic_gru(
input=input_proj_bn_forward,
size=size,
gate_activation='sigmoid',
candidate_activation=act,
param_attr=fluid.ParamAttr(name=name + '_forward_gru_weight'),
bias_attr=fluid.ParamAttr(name=name + '_forward_gru_bias'),
is_reverse=False)
reverse_gru = fluid.layers.dynamic_gru(
input=input_proj_bn_reverse,
size=size,
gate_activation='sigmoid',
candidate_activation=act,
param_attr=fluid.ParamAttr(name=name + '_reverse_gru_weight'),
bias_attr=fluid.ParamAttr(name=name + '_reverse_gru_bias'),
is_reverse=True)
return fluid.layers.concat(input=[forward_gru, reverse_gru], axis=1)
def conv_group(input, num_stacks, seq_len_data, masks):
"""Convolution group with stacked convolution layers.
:param input: Input layer. def __init__(self, i_size, h_size, act):
:type input: Variable super().__init__()
:param num_stacks: Number of stacked convolution layers. hidden_size = h_size * 3
:type num_stacks: int self.fw_fc = nn.Linear(i_size, hidden_size)
:param seq_len_data:Valid sequence length data layer. self.fw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
:type seq_len_data:Variable self.bw_fc = nn.Linear(i_size, hidden_size)
:param masks: Masks data layer to reset padding. self.bw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
:type masks: Variable
:return: Output layer of the convolution group. self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)
:rtype: Variable self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)
""" self.fw_rnn = nn.RNN(
filter_size = (41, 11) self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
stride = (2, 3) self.bw_rnn = nn.RNN(
padding = (20, 5) self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
conv = conv_bn_layer(
input=input, def forward(self, x, x_len):
filter_size=filter_size, # x, shape [B, T, D]
num_channels_in=1, fw_x = self.fw_bn(self.fw_fc(x))
num_channels_out=32, bw_x = self.bw_bn(self.bw_fc(x))
stride=stride, fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
padding=padding, bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
act="brelu", x = paddle.concat([fw_x, bw_x], axis=-1)
masks=masks, return x, x_len
name='layer_0', )
seq_len_data = (np.array(seq_len_data) - filter_size[1] + 2 * padding[1] class RNNStack(nn.Layer):
) // stride[1] + 1
output_height = (161 - 1) // 2 + 1
for i in range(num_stacks - 1):
#reshape masks
output_height = (output_height - 1) // 2 + 1
masks = fluid.layers.slice(
masks, axes=[2], starts=[0], ends=[output_height])
conv = conv_bn_layer(
input=conv,
filter_size=(21, 11),
num_channels_in=32,
num_channels_out=32,
stride=(2, 1),
padding=(10, 5),
act="brelu",
masks=masks,
name='layer_{}'.format(i + 1), )
output_num_channels = 32
return conv, output_num_channels, output_height, seq_len_data
def rnn_group(input, size, num_stacks, num_conv_layers, use_gru,
share_rnn_weights):
"""RNN group with stacked bidirectional simple RNN or GRU layers. """RNN group with stacked bidirectional simple RNN or GRU layers.
:param input: Input layer. :param input: Input layer.
@ -352,42 +417,42 @@ def rnn_group(input, size, num_stacks, num_conv_layers, use_gru,
:return: Output layer of the RNN group. :return: Output layer of the RNN group.
:rtype: Variable :rtype: Variable
""" """
output = input
for i in range(num_stacks): def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):
if use_gru: super().__init__()
output = bidirectional_gru_bn_layer( self.rnn_stacks = nn.LayerList()
name='layer_{}'.format(i + num_conv_layers), for i in range(num_stacks):
input=output, if use_gru:
size=size, #default:GRU using tanh
act="relu") self.rnn_stacks.append(
else: BiGRUWithBN(i_size=i_size, h_size=h_size, act="relu"))
name = 'layer_{}'.format(i + num_conv_layers) else:
output = bidirectional_simple_rnn_bn_layer( self.rnn_stacks.append(
name=name, BiRNNWithBN(
input=output, i_size=i_size,
size=size, h_size=h_size,
share_weights=share_rnn_weights) share_weights=share_rnn_weights))
return output i_size = h_size * 2
def forward(self, x, x_len):
def deep_speech_v2_network(audio_data, """
text_data, x: shape [B, T, D]
seq_len_data, x_len: shpae [B]
masks, """
dict_size, for i, rnn in enumerate(self.rnn_stacks):
num_conv_layers=2, x, x_len = rnn(x, x_len)
num_rnn_layers=3, return x, x_len
rnn_size=256,
use_gru=False,
share_rnn_weights=True): class DeepSpeech2(nn.Layer):
"""The DeepSpeech2 network structure. """The DeepSpeech2 network structure.
:param audio_data: Audio spectrogram data layer. :param audio_data: Audio spectrogram data layer.
:type audio_data: Variable :type audio_data: Variable
:param text_data: Transcription text data layer. :param text_data: Transcription text data layer.
:type text_data: Variable :type text_data: Variable
:param seq_len_data: Valid sequence length data layer. :param audio_len: Valid sequence length data layer.
:type seq_len_data: Variable :type audio_len: Variable
:param masks: Masks data layer to reset padding. :param masks: Masks data layer to reset padding.
:type masks: Variable :type masks: Variable
:param dict_size: Dictionary size for tokenized transcription. :param dict_size: Dictionary size for tokenized transcription.
@ -408,51 +473,80 @@ def deep_speech_v2_network(audio_data,
before softmax) and a ctc cost layer. before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput :rtype: tuple of LayerOutput
""" """
audio_data = fluid.layers.unsqueeze(audio_data, axes=[1])
def __init__(self,
# convolution group feat_size,
conv_group_output, conv_group_num_channels, conv_group_height, seq_len_data = conv_group( dict_size,
input=audio_data, num_conv_layers=2,
num_stacks=num_conv_layers, num_rnn_layers=3,
seq_len_data=seq_len_data, rnn_size=256,
masks=masks) use_gru=False,
share_rnn_weights=True):
# convert data form convolution feature map to sequence of vectors super().__init__()
transpose = fluid.layers.transpose(conv_group_output, perm=[0, 3, 1, 2]) self.feat_size = feat_size # 161 for linear
reshape_conv_output = fluid.layers.reshape( self.dict_size = dict_size
x=transpose,
shape=[0, -1, conv_group_height * conv_group_num_channels], self.conv = ConvStack(feat_size, num_conv_layers)
inplace=False)
# remove padding part i_size = self.conv.output_height # H after conv stack
seq_len_data = fluid.layers.reshape(seq_len_data, [-1]) self.rnn = RNNStack(
sequence = fluid.layers.sequence_unpad( i_size=i_size,
x=reshape_conv_output, length=seq_len_data) h_size=rnn_size,
#rnn group num_stacks=num_rnn_layers,
rnn_group_output = rnn_group( use_gru=use_gru,
input=sequence, share_rnn_weights=share_rnn_weights)
size=rnn_size, self.fc = nn.Linear(rnn_size * 2, dict_size + 1)
num_stacks=num_rnn_layers,
num_conv_layers=num_conv_layers, def predict(self, audio, audio_len):
use_gru=use_gru, # [B, D, T] -> [B, C=1, D, T]
share_rnn_weights=share_rnn_weights) audio = audio.unsqueeze(1)
fc = fluid.layers.fc(
input=rnn_group_output, # convolution group
size=dict_size + 1, x, audio_len = self.conv(audio, audio_len)
act=None,
param_attr=fluid.ParamAttr( # convert data from convolution feature map to sequence of vectors
name='layer_{}'.format(num_conv_layers + num_rnn_layers) + B, C, D, T = paddle.shape(x)
'_fc_weight'), x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
bias_attr=fluid.ParamAttr( x = x.reshape([B, T, C * D]) #[B, T, C*D]
name='layer_{}'.format(num_conv_layers + num_rnn_layers) +
'_fc_bias')) # remove padding part
# pribability distribution with softmax x, audio_len = self.rnn(x, audio_len) #[B, T, D]
log_probs = fluid.layers.softmax(fc)
log_probs.persistable = True logits = self.fc(x) #[B, T, V + 1]
if not text_data:
return log_probs, None #ctcdecoder need probs, not log_probs
else: probs = F.softmax(logits)
#ctc cost
ctc_loss = fluid.layers.warpctc( return logits, probs
input=fc, label=text_data, blank=dict_size, norm_by_times=True)
ctc_loss = fluid.layers.reduce_sum(ctc_loss) @paddle.no_grad()
return log_probs, ctc_loss def infer(self, audio, audio_len):
_, probs = self.predict(audio, audio_len)
return probs
def forward(self, audio, text, audio_len, text_len):
"""
audio: shape [B, D, T]
text: shape [B, T]
audio_len: shape [B]
text_len: shape [B]
"""
logits, _ = self.predict(audio, audio_len)
return logits
class DeepSpeech2Loss(nn.Layer):
def __init__(self, vocab_size):
super().__init__()
# last token id as blank id
self.loss = nn.CTCLoss(blank=vocab_size, reduction='none')
def forward(self, logits, text, audio_len, text_len):
# warp-ctc do softmax on activations
# warp-ctc need activation with shape [T, B, V + 1]
logits = logits.transpose([1, 0, 2])
ctc_loss = self.loss(logits, text, audio_len, text_len)
ctc_loss /= text_len # norm_by_times
ctc_loss = ctc_loss.sum()
return ctc_loss

@ -1,555 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import collections
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
__all__ = ['DeepSpeech2']
def brelu(x, t_min=0.0, t_max=24.0, name=None):
t_min = paddle.to_tensor(t_min)
t_max = paddle.to_tensor(t_max)
return x.maximum(t_min).minimum(t_max)
def sequence_mask(x_len, max_len=None, dtype='float32'):
max_len = max_len or x_len.max()
x_len = paddle.unsqueeze(x_len, -1)
row_vector = paddle.arange(max_len)
mask = row_vector < x_len
mask = paddle.cast(mask, dtype)
return mask
class ConvBn(nn.Layer):
"""Convolution layer with batch normalization.
:param kernel_size: The x dimension of a filter kernel. Or input a tuple for
two image dimension.
:type kernel_size: int|tuple|list
:param num_channels_in: Number of input channels.
:type num_channels_in: int
:param num_channels_out: Number of output channels.
:type num_channels_out: int
:param stride: The x dimension of the stride. Or input a tuple for two
image dimension.
:type stride: int|tuple|list
:param padding: The x dimension of the padding. Or input a tuple for two
image dimension.
:type padding: int|tuple|list
:param act: Activation type, relu|brelu
:type act: string
:param masks: Masks data layer to reset padding.
:type masks: Variable
:param name: Name of the layer.
:param name: string
:return: Batch norm layer after convolution layer.
:rtype: Variable
"""
def __init__(self, num_channels_in, num_channels_out, kernel_size, stride,
padding, act):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.conv = nn.Conv2D(
num_channels_in,
num_channels_out,
kernel_size=kernel_size,
stride=stride,
padding=padding,
weight_attr=None,
bias_attr=None,
data_format='NCHW')
self.bn = nn.BatchNorm2D(
num_channels_out,
weight_attr=None,
bias_attr=None,
data_format='NCHW')
self.act = paddle.relu if act == 'relu' else brelu
def forward(self, x, x_len):
"""
x(Tensor): audio, shape [B, C, D, T]
"""
x = self.conv(x)
x = self.bn(x)
x = self.act(x)
x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1]
) // self.stride[1] + 1
# reset padding part to 0
masks = sequence_mask(x_len) #[B, T]
masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T]
x = x.multiply(masks)
return x, x_len
class ConvStack(nn.Layer):
"""Convolution group with stacked convolution layers.
:param feat_size: audio feature dim.
:type feat_size: int
:param num_stacks: Number of stacked convolution layers.
:type num_stacks: int
"""
def __init__(self, feat_size, num_stacks):
super().__init__()
self.feat_size = feat_size # D
self.num_stacks = num_stacks
self.filter_size = (41, 11) # [D, T]
self.stride = (2, 3)
self.padding = (20, 5)
self.conv_in = ConvBn(
num_channels_in=1,
num_channels_out=32,
kernel_size=self.filter_size,
stride=self.stride,
padding=self.padding,
act='brelu', )
out_channel = 32
self.conv_stack = nn.LayerList([
ConvBn(
num_channels_in=32,
num_channels_out=out_channel,
kernel_size=(21, 11),
stride=(2, 1),
padding=(10, 5),
act='brelu') for i in range(num_stacks - 1)
])
# conv output feat_dim
output_height = (feat_size - 1) // 2 + 1
for i in range(self.num_stacks - 1):
output_height = (output_height - 1) // 2 + 1
self.output_height = out_channel * output_height
def forward(self, x, x_len):
"""
x: shape [B, C, D, T]
x_len : shape [B]
"""
x, x_len = self.conv_in(x, x_len)
for i, conv in enumerate(self.conv_stack):
x, x_len = conv(x, x_len)
return x, x_len
class RNNCell(nn.RNNCellBase):
r"""
Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it
computes the outputs and updates states.
The formula used is as follows:
.. math::
h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh})
y_{t} & = h_{t}
where :math:`act` is for :attr:`activation`.
"""
def __init__(self,
hidden_size,
activation="tanh",
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = self.create_parameter(
(hidden_size, ),
bias_ih_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.bias_hh = self.create_parameter(
(hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
if activation not in ["tanh", "relu", "brelu"]:
raise ValueError(
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu
if activation == 'brelu':
self._activation_fn = brelu
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_h = states
i2h = inputs
if self.bias_ih is not None:
i2h += self.bias_ih
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._activation_fn(i2h + h2h)
return h, h
@property
def state_shape(self):
return (self.hidden_size, )
class GRUCellShare(nn.RNNCellBase):
r"""
Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states,
it computes the outputs and updates states.
The formula for GRU used is as follows:
.. math::
r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr})
z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz})
\widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc}))
h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t}
y_{t} & = h_{t}
where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise
multiplication operator.
"""
def __init__(self,
input_size,
hidden_size,
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
bias_hh_attr=None,
name=None):
super().__init__()
std = 1.0 / math.sqrt(hidden_size)
self.weight_hh = self.create_parameter(
(3 * hidden_size, hidden_size),
weight_hh_attr,
default_initializer=I.Uniform(-std, std))
self.bias_ih = self.create_parameter(
(3 * hidden_size, ),
bias_ih_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.bias_hh = self.create_parameter(
(3 * hidden_size, ),
bias_hh_attr,
is_bias=True,
default_initializer=I.Uniform(-std, std))
self.hidden_size = hidden_size
self.input_size = input_size
self._gate_activation = F.sigmoid
self._activation = paddle.tanh
def forward(self, inputs, states=None):
if states is None:
states = self.get_initial_states(inputs, self.state_shape)
pre_hidden = states
x_gates = inputs
if self.bias_ih is not None:
x_gates = x_gates + self.bias_ih
h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h_gates = h_gates + self.bias_hh
x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1)
h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1)
r = self._gate_activation(x_r + h_r)
z = self._gate_activation(x_z + h_z)
c = self._activation(x_c + r * h_c) # apply reset gate after mm
h = (pre_hidden - c) * z + c
return h, h
@property
def state_shape(self):
r"""
The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch
size would be automatically inserted into shape). The shape corresponds
to the shape of :math:`h_{t-1}`.
"""
return (self.hidden_size, )
class BiRNNWithBN(nn.Layer):
"""Bidirectonal simple rnn layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer parameters.
:type name: string
:param size: Dimension of RNN cells.
:type size: int
:param share_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
:type share_weights: bool
:return: Bidirectional simple rnn layer.
:rtype: Variable
"""
def __init__(self, i_size, h_size, share_weights):
super().__init__()
self.share_weights = share_weights
self.pad_value = paddle.to_tensor(np.array([0.0], dtype=np.float32))
if self.share_weights:
#input-hidden weights shared between bi-directional rnn.
self.fw_fc = nn.Linear(i_size, h_size)
# batch norm is only performed on input-state projection
self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
self.bw_fc = self.fw_fc
self.bw_bn = self.fw_bn
else:
self.fw_fc = nn.Linear(i_size, h_size)
self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
self.bw_fc = nn.Linear(i_size, h_size)
self.bw_bn = nn.BatchNorm1D(h_size, data_format='NLC')
self.fw_cell = RNNCell(hidden_size=h_size, activation='relu')
self.bw_cell = RNNCell(
hidden_size=h_size,
activation='relu', )
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class BiGRUWithBN(nn.Layer):
"""Bidirectonal gru layer with sequence-wise batch normalization.
The batch normalization is only performed on input-state weights.
:param name: Name of the layer.
:type name: string
:param input: Input layer.
:type input: Variable
:param size: Dimension of GRU cells.
:type size: int
:param act: Activation type.
:type act: string
:return: Bidirectional GRU layer.
:rtype: Variable
"""
def __init__(self, i_size, h_size, act):
super().__init__()
hidden_size = h_size * 3
self.fw_fc = nn.Linear(i_size, hidden_size)
self.fw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
self.bw_fc = nn.Linear(i_size, hidden_size)
self.bw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC')
self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)
self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size)
self.fw_rnn = nn.RNN(
self.fw_cell, is_reverse=False, time_major=False) #[B, T, D]
self.bw_rnn = nn.RNN(
self.fw_cell, is_reverse=True, time_major=False) #[B, T, D]
def forward(self, x, x_len):
# x, shape [B, T, D]
fw_x = self.fw_bn(self.fw_fc(x))
bw_x = self.bw_bn(self.bw_fc(x))
fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len)
bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len)
x = paddle.concat([fw_x, bw_x], axis=-1)
return x, x_len
class RNNStack(nn.Layer):
"""RNN group with stacked bidirectional simple RNN or GRU layers.
:param input: Input layer.
:type input: Variable
:param size: Dimension of RNN cells in each layer.
:type size: int
:param num_stacks: Number of stacked rnn layers.
:type num_stacks: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward directional RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: Output layer of the RNN group.
:rtype: Variable
"""
def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights):
super().__init__()
self.rnn_stacks = nn.LayerList()
for i in range(num_stacks):
if use_gru:
#default:GRU using tanh
self.rnn_stacks.append(
BiGRUWithBN(i_size=i_size, h_size=h_size, act="relu"))
else:
self.rnn_stacks.append(
BiRNNWithBN(
i_size=i_size,
h_size=h_size,
share_weights=share_rnn_weights))
i_size = h_size * 2
def forward(self, x, x_len):
"""
x: shape [B, T, D]
x_len: shpae [B]
"""
for i, rnn in enumerate(self.rnn_stacks):
x, x_len = rnn(x, x_len)
return x, x_len
class DeepSpeech2(nn.Layer):
"""The DeepSpeech2 network structure.
:param audio_data: Audio spectrogram data layer.
:type audio_data: Variable
:param text_data: Transcription text data layer.
:type text_data: Variable
:param audio_len: Valid sequence length data layer.
:type audio_len: Variable
:param masks: Masks data layer to reset padding.
:type masks: Variable
:param dict_size: Dictionary size for tokenized transcription.
:type dict_size: int
:param num_conv_layers: Number of stacking convolution layers.
:type num_conv_layers: int
:param num_rnn_layers: Number of stacking RNN layers.
:type num_rnn_layers: int
:param rnn_size: RNN layer size (dimension of RNN cells).
:type rnn_size: int
:param use_gru: Use gru if set True. Use simple rnn if set False.
:type use_gru: bool
:param share_rnn_weights: Whether to share input-hidden weights between
forward and backward direction RNNs.
It is only available when use_gru=False.
:type share_weights: bool
:return: A tuple of an output unnormalized log probability layer (
before softmax) and a ctc cost layer.
:rtype: tuple of LayerOutput
"""
def __init__(self,
feat_size,
dict_size,
num_conv_layers=2,
num_rnn_layers=3,
rnn_size=256,
use_gru=False,
share_rnn_weights=True):
super().__init__()
self.feat_size = feat_size # 161 for linear
self.dict_size = dict_size
self.conv = ConvStack(feat_size, num_conv_layers)
i_size = self.conv.output_height # H after conv stack
self.rnn = RNNStack(
i_size=i_size,
h_size=rnn_size,
num_stacks=num_rnn_layers,
use_gru=use_gru,
share_rnn_weights=share_rnn_weights)
self.fc = nn.Linear(rnn_size * 2, dict_size + 1)
def predict(self, audio, audio_len):
# [B, D, T] -> [B, C=1, D, T]
audio = audio.unsqueeze(1)
# convolution group
x, audio_len = self.conv(audio, audio_len)
# convert data from convolution feature map to sequence of vectors
B, C, D, T = paddle.shape(x)
x = x.transpose([0, 3, 1, 2]) #[B, T, C, D]
x = x.reshape([B, T, C * D]) #[B, T, C*D]
# remove padding part
x, audio_len = self.rnn(x, audio_len) #[B, T, D]
logits = self.fc(x) #[B, T, V + 1]
#ctcdecoder need probs, not log_probs
probs = F.softmax(logits)
return logits, probs
@paddle.no_grad()
def infer(self, audio, audio_len):
_, probs = self.predict(audio, audio_len)
return probs
def forward(self, audio, text, audio_len, text_len):
"""
audio: shape [B, D, T]
text: shape [B, T]
audio_len: shape [B]
text_len: shape [B]
"""
logits, probs = self.predict(audio, audio_len)
print(logits.shape)
print(text.shape)
print(audio_len.shape)
print(text_len.shape)
return logits
class DeepSpeechLoss(nn.Layer):
def __init__(self, vocab_size):
super().__init__()
self.loss = nn.CTCLoss(blank=vocab_size, reduction='none')
def forward(self, logits, text, audio_len, text_len):
# warp-ctc do softmax on activations
# warp-ctc need activation with shape [T, B, V + 1]
logits = logits.transpose([1, 0, 2])
ctc_loss = self.loss(logits, text, audio_len, text_len)
ctc_loss /= text_len # norm_by_times
ctc_loss = ctc_loss.sum()
return ctc_loss

@ -0,0 +1,279 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import logging
from pathlib import Path
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from tensorboardX import SummaryWriter
from collections import defaultdict
import parakeet
from parakeet.utils import checkpoint, mp_tools
__all__ = ["ExperimentBase"]
class ExperimentBase(object):
"""
An experiment template in order to structure the training code and take
care of saving, loading, logging, visualization stuffs. It's intended to
be flexible and simple.
So it only handles output directory (create directory for the output,
create a checkpoint directory, dump the config in use and create
visualizer and logger) in a standard way without enforcing any
input-output protocols to the model and dataloader. It leaves the main
part for the user to implement their own (setup the model, criterion,
optimizer, define a training step, define a validation function and
customize all the text and visual logs).
It does not save too much boilerplate code. The users still have to write
the forward/backward/update mannually, but they are free to add
non-standard behaviors if needed.
We have some conventions to follow.
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
``valid_loader``, ``config`` and ``args`` attributes.
2. The config should have a ``training`` field, which has
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
used as the trigger to invoke validation, checkpointing and stop of the
experiment.
3. There are four methods, namely ``train_batch``, ``valid``,
``setup_model`` and ``setup_dataloader`` that should be implemented.
Feel free to add/overwrite other methods and standalone functions if you
need.
Parameters
----------
config: yacs.config.CfgNode
The configuration used for the experiment.
args: argparse.Namespace
The parsed command line arguments.
Examples
--------
>>> def main_sp(config, args):
>>> exp = Experiment(config, args)
>>> exp.setup()
>>> exp.run()
>>>
>>> config = get_cfg_defaults()
>>> parser = default_argument_parser()
>>> args = parser.parse_args()
>>> if args.config:
>>> config.merge_from_file(args.config)
>>> if args.opts:
>>> config.merge_from_list(args.opts)
>>> config.freeze()
>>>
>>> if args.nprocs > 1 and args.device == "gpu":
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
>>> else:
>>> main_sp(config, args)
"""
def __init__(self, config, args):
self.config = config
self.args = args
def setup(self):
"""Setup the experiment.
"""
paddle.set_device(self.args.device)
if self.parallel:
self.init_parallel()
self.setup_output_dir()
self.dump_config()
self.setup_visualizer()
self.setup_logger()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
@property
def parallel(self):
"""A flag indicating whether the experiment should run with
multiprocessing.
"""
return self.args.device == "gpu" and self.args.nprocs > 1
def init_parallel(self):
"""Init environment for multiprocess training.
"""
dist.init_parallel_env()
def save(self):
"""Save checkpoint (model parameters and optimizer states).
"""
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
self.model, self.optimizer)
def load_or_resume(self):
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
iteration = checkpoint.load_parameters(
self.model,
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path)
self.iteration = iteration
def read_batch(self):
"""Read a batch from the train_loader.
Returns
-------
List[Tensor]
A batch.
"""
try:
batch = next(self.iterator)
except StopIteration:
self.new_epoch()
batch = next(self.iterator)
return batch
def new_epoch(self):
"""Reset the train loader and increment ``epoch``.
"""
self.epoch += 1
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.iterator = iter(self.train_loader)
def train(self):
"""The training process.
It includes forward/backward/update and periodical validation and
saving.
"""
self.new_epoch()
while self.iteration < self.config.training.max_iteration:
self.iteration += 1
self.train_batch()
if self.iteration % self.config.training.valid_interval == 0:
self.valid()
if self.iteration % self.config.training.save_interval == 0:
self.save()
def run(self):
"""The routine of the experiment after setup. This method is intended
to be used by the user.
"""
self.load_or_resume()
try:
self.train()
except KeyboardInterrupt:
self.save()
exit(-1)
@mp_tools.rank_zero_only
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
@mp_tools.rank_zero_only
def setup_checkpointer(self):
"""Create a directory used to save checkpoints into.
It is "checkpoints" inside the output directory.
"""
# checkpoint dir
checkpoint_dir = self.output_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
self.checkpoint_dir = checkpoint_dir
@mp_tools.rank_zero_only
def setup_visualizer(self):
"""Initialize a visualizer to log the experiment.
The visual log is saved in the output directory.
Notes
------
Only the main process has a visualizer with it. Use multiple
visualizers in multiprocess to write to a same log file may cause
unexpected behaviors.
"""
# visualizer
visualizer = SummaryWriter(logdir=str(self.output_dir))
self.visualizer = visualizer
def setup_logger(self):
"""Initialize a text logger to log the experiment.
Each process has its own text logger. The logging message is write to
the standard output and a text file named ``worker_n.log`` in the
output directory, where ``n`` means the rank of the process.
"""
logger = logging.getLogger(__name__)
logger.setLevel("INFO")
logger.addHandler(logging.StreamHandler())
log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank())
logger.addHandler(logging.FileHandler(str(log_file)))
self.logger = logger
@mp_tools.rank_zero_only
def dump_config(self):
"""Save the configuration used for this experiment.
It is saved in to ``config.yaml`` in the output directory at the
beginning of the experiment.
"""
with open(self.output_dir / "config.yaml", 'wt') as f:
print(self.config, file=f)
def train_batch(self):
"""The training loop. A subclass should implement this method.
"""
raise NotImplementedError("train_batch should be implemented.")
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
"""The validation. A subclass should implement this method.
"""
raise NotImplementedError("valid should be implemented.")
def setup_model(self):
"""Setup model, criterion and optimizer, etc. A subclass should
implement this method.
"""
raise NotImplementedError("setup_model should be implemented.")
def setup_dataloader(self):
"""Setup training dataloader and validation dataloader. A subclass
should implement this method.
"""
raise NotImplementedError("setup_dataloader should be implemented.")

@ -0,0 +1,15 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from training.trainer import *

@ -0,0 +1,279 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import logging
from pathlib import Path
import numpy as np
from collections import defaultdict
import paddle
from paddle import distributed as dist
from tensorboardX import SummaryWriter
from utils import checkpoint
from utils import mp_tools
__all__ = ["Trainer"]
class Trainer():
"""
An experiment template in order to structure the training code and take
care of saving, loading, logging, visualization stuffs. It's intended to
be flexible and simple.
So it only handles output directory (create directory for the output,
create a checkpoint directory, dump the config in use and create
visualizer and logger) in a standard way without enforcing any
input-output protocols to the model and dataloader. It leaves the main
part for the user to implement their own (setup the model, criterion,
optimizer, define a training step, define a validation function and
customize all the text and visual logs).
It does not save too much boilerplate code. The users still have to write
the forward/backward/update mannually, but they are free to add
non-standard behaviors if needed.
We have some conventions to follow.
1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and
``valid_loader``, ``config`` and ``args`` attributes.
2. The config should have a ``training`` field, which has
``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is
used as the trigger to invoke validation, checkpointing and stop of the
experiment.
3. There are four methods, namely ``train_batch``, ``valid``,
``setup_model`` and ``setup_dataloader`` that should be implemented.
Feel free to add/overwrite other methods and standalone functions if you
need.
Parameters
----------
config: yacs.config.CfgNode
The configuration used for the experiment.
args: argparse.Namespace
The parsed command line arguments.
Examples
--------
>>> def main_sp(config, args):
>>> exp = Trainer(config, args)
>>> exp.setup()
>>> exp.run()
>>>
>>> config = get_cfg_defaults()
>>> parser = default_argument_parser()
>>> args = parser.parse_args()
>>> if args.config:
>>> config.merge_from_file(args.config)
>>> if args.opts:
>>> config.merge_from_list(args.opts)
>>> config.freeze()
>>>
>>> if args.nprocs > 1 and args.device == "gpu":
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
>>> else:
>>> main_sp(config, args)
"""
def __init__(self, config, args):
self.config = config
self.args = args
def setup(self):
"""Setup the experiment.
"""
paddle.set_device(self.args.device)
if self.parallel:
self.init_parallel()
self.setup_output_dir()
self.dump_config()
self.setup_visualizer()
self.setup_logger()
self.setup_checkpointer()
self.setup_dataloader()
self.setup_model()
self.iteration = 0
self.epoch = 0
@property
def parallel(self):
"""A flag indicating whether the experiment should run with
multiprocessing.
"""
return self.args.device == "gpu" and self.args.nprocs > 1
def init_parallel(self):
"""Init environment for multiprocess training.
"""
dist.init_parallel_env()
def save(self):
"""Save checkpoint (model parameters and optimizer states).
"""
checkpoint.save_parameters(self.checkpoint_dir, self.iteration,
self.model, self.optimizer)
def resume_or_load(self):
"""Resume from latest checkpoint at checkpoints in the output
directory or load a specified checkpoint.
If ``args.checkpoint_path`` is not None, load the checkpoint, else
resume training.
"""
iteration = checkpoint.load_parameters(
self.model,
self.optimizer,
checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path)
self.iteration = iteration
def read_batch(self):
"""Read a batch from the train_loader.
Returns
-------
List[Tensor]
A batch.
"""
try:
batch = next(self.iterator)
except StopIteration:
self.new_epoch()
batch = next(self.iterator)
return batch
def new_epoch(self):
"""Reset the train loader and increment ``epoch``.
"""
self.epoch += 1
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
self.iterator = iter(self.train_loader)
def train(self):
"""The training process.
It includes forward/backward/update and periodical validation and
saving.
"""
self.new_epoch()
while self.iteration < self.config.training.max_iteration:
self.iteration += 1
self.train_batch()
if self.iteration % self.config.training.valid_interval == 0:
self.valid()
if self.iteration % self.config.training.save_interval == 0:
self.save()
def run(self):
"""The routine of the experiment after setup. This method is intended
to be used by the user.
"""
self.resume_or_load()
try:
self.train()
except KeyboardInterrupt:
self.save()
exit(-1)
@mp_tools.rank_zero_only
def setup_output_dir(self):
"""Create a directory used for output.
"""
# output dir
output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir
@mp_tools.rank_zero_only
def setup_checkpointer(self):
"""Create a directory used to save checkpoints into.
It is "checkpoints" inside the output directory.
"""
# checkpoint dir
checkpoint_dir = self.output_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
self.checkpoint_dir = checkpoint_dir
@mp_tools.rank_zero_only
def setup_visualizer(self):
"""Initialize a visualizer to log the experiment.
The visual log is saved in the output directory.
Notes
------
Only the main process has a visualizer with it. Use multiple
visualizers in multiprocess to write to a same log file may cause
unexpected behaviors.
"""
# visualizer
visualizer = SummaryWriter(logdir=str(self.output_dir))
self.visualizer = visualizer
def setup_logger(self):
"""Initialize a text logger to log the experiment.
Each process has its own text logger. The logging message is write to
the standard output and a text file named ``worker_n.log`` in the
output directory, where ``n`` means the rank of the process.
"""
logger = logging.getLogger(__name__)
logger.setLevel("INFO")
logger.addHandler(logging.StreamHandler())
log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank())
logger.addHandler(logging.FileHandler(str(log_file)))
self.logger = logger
@mp_tools.rank_zero_only
def dump_config(self):
"""Save the configuration used for this experiment.
It is saved in to ``config.yaml`` in the output directory at the
beginning of the experiment.
"""
with open(self.output_dir / "config.yaml", 'wt') as f:
print(self.config, file=f)
def train_batch(self):
"""The training loop. A subclass should implement this method.
"""
raise NotImplementedError("train_batch should be implemented.")
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
"""The validation. A subclass should implement this method.
"""
raise NotImplementedError("valid should be implemented.")
def setup_model(self):
"""Setup model, criterion and optimizer, etc. A subclass should
implement this method.
"""
raise NotImplementedError("setup_model should be implemented.")
def setup_dataloader(self):
"""Setup training dataloader and validation dataloader. A subclass
should implement this method.
"""
raise NotImplementedError("setup_dataloader should be implemented.")

@ -0,0 +1,135 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import numpy as np
import paddle
from paddle import distributed as dist
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from utils import mp_tools
__all__ = ["load_parameters", "save_parameters"]
def _load_latest_checkpoint(checkpoint_dir: str) -> int:
"""Get the iteration number corresponding to the latest saved checkpoint.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
Returns:
int: the latest iteration number.
"""
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
if (not os.path.isfile(checkpoint_record)):
return 0
# Fetch the latest checkpoint index.
with open(checkpoint_record, "rt") as handle:
latest_checkpoint = handle.readlines()[-1].strip()
step = latest_checkpoint.split(":")[-1]
iteration = int(step.split("-")[-1])
return iteration
def _save_checkpoint(checkpoint_dir: str, iteration: int):
"""Save the iteration number of the latest model to be checkpointed.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_record = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_record, "a+") as handle:
handle.write("model_checkpoint_path:step-{}\n".format(iteration))
def load_parameters(model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None):
"""Load a specific model checkpoint from disk.
Args:
model (Layer): model to load parameters.
optimizer (Optimizer, optional): optimizer to load states if needed.
Defaults to None.
checkpoint_dir (str, optional): the directory where checkpoint is saved.
checkpoint_path (str, optional): if specified, load the checkpoint
stored in the checkpoint_path and the argument 'checkpoint_dir' will
be ignored. Defaults to None.
Returns:
iteration (int): number of iterations that the loaded checkpoint has
been trained.
"""
if checkpoint_path is not None:
iteration = int(os.path.basename(checkpoint_path).split("-")[-1])
elif checkpoint_dir is not None:
iteration = _load_latest_checkpoint(checkpoint_dir)
if iteration == 0:
return iteration
checkpoint_path = os.path.join(checkpoint_dir,
"step-{}".format(iteration))
else:
raise ValueError(
"At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!"
)
rank = dist.get_rank()
params_path = checkpoint_path + ".pdparams"
model_dict = paddle.load(params_path)
model.set_state_dict(model_dict)
print(
"[checkpoint] Rank {}: loaded model from {}".format(rank, params_path))
optimizer_path = checkpoint_path + ".pdopt"
if optimizer and os.path.isfile(optimizer_path):
optimizer_dict = paddle.load(optimizer_path)
optimizer.set_state_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
rank, optimizer_path))
return iteration
@mp_tools.rank_zero_only
def save_parameters(checkpoint_dir, iteration, model, optimizer=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
model (Layer): model to be checkpointed.
optimizer (Optimizer, optional): optimizer to be checkpointed.
Defaults to None.
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration))
model_dict = model.state_dict()
params_path = checkpoint_path + ".pdparams"
paddle.save(model_dict, params_path)
print("[checkpoint] Saved model to {}".format(params_path))
if optimizer:
opt_dict = optimizer.state_dict()
optimizer_path = checkpoint_path + ".pdopt"
paddle.save(opt_dict, optimizer_path)
print("[checkpoint] Saved optimzier state to {}".format(optimizer_path))
_save_checkpoint(checkpoint_dir, iteration)

@ -14,9 +14,10 @@
"""This module provides functions to calculate error rate in different level. """This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level. e.g. wer for word-level, cer for char-level.
""" """
import numpy as np import numpy as np
__all__ = ['word_errors', 'char_errors', 'wer', 'cer']
def _levenshtein_distance(ref, hyp): def _levenshtein_distance(ref, hyp):
"""Levenshtein distance is a string metric for measuring the difference """Levenshtein distance is a string metric for measuring the difference

@ -38,12 +38,12 @@ def check_version():
Log error and exit when the installed version of paddlepaddle is Log error and exit when the installed version of paddlepaddle is
not satisfied. not satisfied.
""" """
err = "PaddlePaddle version 1.6 or higher is required, " \ err = "PaddlePaddle version 2.0.0 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \ "or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \ "Please make sure the version is good with your code." \
try: try:
fluid.require_version('1.6.0') fluid.require_version('2.0.0')
except Exception as e: except Exception as e:
print(err) print(err)
sys.exit(1) sys.exit(1)

@ -0,0 +1,32 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import distributed as dist
from functools import wraps
__all__ = ["rank_zero_only"]
def rank_zero_only(func):
rank = dist.get_rank()
@wraps(func)
def wrapper(*args, **kwargs):
if rank != 0:
return
result = func(*args, **kwargs)
return result
return wrapper
Loading…
Cancel
Save