From c2ccb11ba09dd2e46e061e1673fb11d8d7ed1afd Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 8 Feb 2021 09:32:07 +0000 Subject: [PATCH] export feasture size add trainer and utils add setup model and dataloader update travis using Bionic dist --- .travis.yml | 2 +- data_utils/data.py | 4 +- data_utils/dataset.py | 4 +- data_utils/featurizer/audio_featurizer.py | 17 +- data_utils/featurizer/speech_featurizer.py | 9 + infer.py | 9 +- model_utils/model.py | 354 ++++------ model_utils/network.py | 772 ++++++++++++--------- model_utils/network2.py | 555 --------------- model_utils/trainer.py | 279 ++++++++ training/__init__.py | 15 + training/trainer.py | 279 ++++++++ utils/checkpoint.py | 135 ++++ utils/error_rate.py | 3 +- {model_utils => utils}/model_check.py | 4 +- utils/mp_tools.py | 32 + 16 files changed, 1360 insertions(+), 1113 deletions(-) delete mode 100644 model_utils/network2.py create mode 100644 model_utils/trainer.py create mode 100644 training/__init__.py create mode 100644 training/trainer.py create mode 100644 utils/checkpoint.py rename {model_utils => utils}/model_check.py (94%) create mode 100644 utils/mp_tools.py diff --git a/.travis.yml b/.travis.yml index 6ca50d954..b2af6a4c4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,7 @@ language: cpp cache: ccache sudo: required -dist: xenial +dist: Bionic services: - docker os: diff --git a/data_utils/data.py b/data_utils/data.py index 245daf5c3..30819f578 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -188,8 +188,6 @@ class DataGenerator(): max_duration=self._max_duration, min_duration=self._min_duration) - - # sort (by duration) or batch-wise shuffle the manifest if self._epoch == 0 and sortagrad: manifest.sort(key=lambda x: x["duration"]) @@ -365,7 +363,7 @@ class DataGenerator(): """ manifest.sort(key=lambda x: x["duration"]) shift_len = self._rng.randint(0, batch_size - 1) - batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size)) + batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size)) self._rng.shuffle(batch_manifest) batch_manifest = [item for batch in batch_manifest for item in batch] if not clipped: diff --git a/data_utils/dataset.py b/data_utils/dataset.py index 67c1b57ee..3dcf72030 100644 --- a/data_utils/dataset.py +++ b/data_utils/dataset.py @@ -211,7 +211,7 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): rng = np.random.RandomState(self.epoch) manifest.sort(key=lambda x: x["duration"]) shift_len = rng.randint(0, batch_size - 1) - batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size)) + batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size)) rng.shuffle(batch_manifest) batch_manifest = [item for batch in batch_manifest for item in batch] if not clipped: @@ -347,7 +347,7 @@ class DeepSpeech2BatchSampler(BatchSampler): rng = np.random.RandomState(self.epoch) manifest.sort(key=lambda x: x["duration"]) shift_len = rng.randint(0, batch_size - 1) - batch_manifest = list(zip(* [iter(manifest[shift_len:])] * batch_size)) + batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size)) rng.shuffle(batch_manifest) batch_manifest = [item for batch in batch_manifest for item in batch] if not clipped: diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py index 0afd19870..7e336969d 100644 --- a/data_utils/featurizer/audio_featurizer.py +++ b/data_utils/featurizer/audio_featurizer.py @@ -63,6 +63,7 @@ class AudioFeaturizer(object): self._target_sample_rate = target_sample_rate self._use_dB_normalization = use_dB_normalization self._target_dB = target_dB + self._fft_point = None def featurize(self, audio_segment, @@ -98,6 +99,19 @@ class AudioFeaturizer(object): return self._compute_specgram(audio_segment.samples, audio_segment.sample_rate) + @property + def feature_size(self): + """audio feature size""" + if self._specgram_type == 'linear': + fft_point = self._window_ms if self._fft_point is None else self._fft_point + return fft_point * (self._target_sample_rate / 1000) / 2 + 1 + elif self._specgram_type == 'mfcc': + # mfcc,delta, delta-delta + return 13 * 3 + else: + raise ValueError("Unknown specgram_type %s. " + "Supported values: linear." % self._specgram_type) + def _compute_specgram(self, samples, sample_rate): """Extract various audio features.""" if self._specgram_type == 'linear': @@ -150,7 +164,8 @@ class AudioFeaturizer(object): windows[:, 1] == samples[stride_size:(stride_size + window_size)]) # window weighting, squared Fast Fourier Transform (fft), scaling weighting = np.hanning(window_size)[:, None] - fft = np.fft.rfft(windows * weighting, axis=0) + # https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html + fft = np.fft.rfft(windows * weighting, n=None, axis=0) fft = np.absolute(fft) fft = fft**2 scale = np.sum(weighting**2) * sample_rate diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py index 2e1424fa4..333a40cd1 100644 --- a/data_utils/featurizer/speech_featurizer.py +++ b/data_utils/featurizer/speech_featurizer.py @@ -106,3 +106,12 @@ class SpeechFeaturizer(object): :rtype: list """ return self._text_featurizer.vocab_list + + @property + def feature_size(self): + """Return the audio feature size. + + :return: audio feature size. + :rtype: int + """ + return self._audio_featurizer.feature_size \ No newline at end of file diff --git a/infer.py b/infer.py index 11a4ad7ab..ec1bb8ace 100644 --- a/infer.py +++ b/infer.py @@ -16,13 +16,12 @@ import sys import argparse import functools -import paddle.fluid as fluid +from model_utils.model_check import check_cuda, check_version +from utils.utility import add_arguments, print_arguments +from utils.error_rate import wer, cer from data_utils.data import DataGenerator from data_utils.dataset import create_dataloader from model_utils.model import DeepSpeech2Model -from model_utils.model_check import check_cuda, check_version -from utils.error_rate import wer, cer -from utils.utility import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) @@ -132,7 +131,7 @@ def infer(): dict_size=batch_reader.dataset.vocab_size, num_conv_layers=args.num_conv_layers, num_rnn_layers=args.num_rnn_layers, - #rnn_size=1024, + rnn_size=args.rnn_layer_size, use_gru=args.use_gru, share_rnn_weights=args.share_rnn_weights, ) diff --git a/model_utils/model.py b/model_utils/model.py index 16d8b0b96..28c218dd9 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -24,179 +24,177 @@ import collections import multiprocessing import numpy as np from distutils.dir_util import mkpath + import paddle.fluid as fluid -from paddle.io import DataLoader -import paddle.fluid.compiler as compiler + +from training import Trainer +from model_utils.network import DeepSpeech2 +from model_utils.network import DeepSpeech2Loss + from decoders.swig_wrapper import Scorer from decoders.swig_wrapper import ctc_greedy_decoder from decoders.swig_wrapper import ctc_beam_search_decoder_batch -from model_utils.network import deep_speech_v2_network logging.basicConfig( format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') -class DeepSpeech2Model(object): - """DeepSpeech2Model class. - - :param vocab_size: Decoding vocabulary size. - :type vocab_size: int - :param num_conv_layers: Number of stacking convolution layers. - :type num_conv_layers: int - :param num_rnn_layers: Number of stacking RNN layers. - :type num_rnn_layers: int - :param rnn_layer_size: RNN layer size (number of RNN cells). - :type rnn_layer_size: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward directional RNNs.Notice that - for GRU, weight sharing is not supported. - :type share_rnn_weights: bool - :param place: Program running place. - :type place: CPUPlace or CUDAPlace - :param init_from_pretrained_model: Pretrained model path. If None, will train - from stratch. - :type init_from_pretrained_model: string|None - :param output_model_dir: Output model directory. If None, output to current directory. - :type output_model_dir: string|None - """ - - def __init__(self, - vocab_size, - num_conv_layers, - num_rnn_layers, - rnn_layer_size, - use_gru=False, - share_rnn_weights=True, - place=fluid.CPUPlace(), - init_from_pretrained_model=None, - output_model_dir=None): - self._vocab_size = vocab_size - self._num_conv_layers = num_conv_layers - self._num_rnn_layers = num_rnn_layers - self._rnn_layer_size = rnn_layer_size - self._use_gru = use_gru - self._share_rnn_weights = share_rnn_weights - self._place = place - self._init_from_pretrained_model = init_from_pretrained_model - self._output_model_dir = output_model_dir - self._ext_scorer = None - self.logger = logging.getLogger("") - self.logger.setLevel(level=logging.INFO) - - def create_network(self, is_infer=False): - """Create data layers and model network. - :param is_training: Whether to create a network for training. - :type is_training: bool - :return reader: Reader for input. - :rtype reader: read generater - :return log_probs: An output unnormalized log probability layer. - :rtype lig_probs: Varable - :return loss: A ctc loss layer. - :rtype loss: Variable +class SpeechCollator(): + def __init__(self, padding_to=-1): """ + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one bach. - if not is_infer: - reader = DataLoader.from_generator( - feed_list=inputs, - capacity=64, - iterable=False, - use_double_buffer=True) + If ``padding_to`` is -1, the maximun shape in the batch will be used + as the target shape for padding. Otherwise, `padding_to` will be the + target shape (only refers to the second axis). + """ + self._padding_to = padding_to + + def __call__(self, batch): + new_batch = [] + # get target shape + max_length = max([audio.shape[1] for audio, _ in batch]) + if self._padding_to != -1: + if self._padding_to < max_length: + raise ValueError("If padding_to is not -1, it should be larger " + "than any instance's shape in the batch") + max_length = self._padding_to + max_text_length = max([len(text) for _, text in batch]) + # padding + padded_audios = [] + audio_lens = [] + texts, text_lens = [], [] + for audio, text in batch: + # audio + padded_audio = np.zeros([audio.shape[0], max_length]) + padded_audio[:, :audio.shape[1]] = audio + padded_audios.append(padded_audio) + audio_lens.append(audio.shape[1]) + # text + padded_text = np.zeros([max_text_length]) + padded_text[:len(text)] = text + texts.append(padded_text) + text_lens.append(len(text)) + + padded_audios = np.array(padded_audios).astype('float32') + audio_lens = np.array(audio_lens).astype('int64') + texts = np.array(texts).astype('int32') + text_lens = np.array(text_lens).astype('int64') + return padded_audios, texts, audio_lens, text_lens + + +class DeepSpeech2Trainer(Trainer): + def __init__(self): + self._ext_scorer = None - (audio_data, text_data, seq_len_data, masks) = inputs + def setup_dataloader(self): + config = self.config + + train_dataset = DeepSpeech2Dataset( + config.data.train_manifest_path, + config.data.vocab_filepath, + config.data.mean_std_filepath, + augmentation_config=config.data.augmentation_config, + max_duration=config.data.max_duration, + min_duration=config.data.min_duration, + stride_ms=config.data.stride_ms, + window_ms=config.data.window_ms, + max_freq=config.data.max_freq, + specgram_type=config.data.specgram_type, + use_dB_normalization=config.data.use_dB_normalization, + random_seed=config.data.random_seed, + keep_transcription_text=False) + + dev_dataset = DeepSpeech2Dataset( + config.data.dev_manifest_path, + config.data.vocab_filepath, + config.data.mean_std_filepath, + augmentation_config=config.data.augmentation_config, + max_duration=config.data.max_duration, + min_duration=config.data.min_duration, + stride_ms=config.data.stride_ms, + window_ms=config.data.window_ms, + max_freq=config.data.max_freq, + specgram_type=config.data.specgram_type, + use_dB_normalization=config.data.use_dB_normalization, + random_seed=config.data.random_seed, + keep_transcription_text=False) + + if self.parallel: + batch_sampler = DeepSpeech2DistributedBatchSampler( + train_dataset, + batch_size=config.data.batch_size, + num_replicas=None, + rank=None, + shuffle=True, + drop_last=True, + sortagrad=config.data.sortagrad, + shuffle_method=config.data.shuffle_method) else: - audio_data = fluid.data( - name='audio_data', - shape=[None, 161, None], - dtype='float32', - lod_level=0) - seq_len_data = fluid.data( - name='seq_len_data', - shape=[None, 1], - dtype='int64', - lod_level=0) - masks = fluid.data( - name='masks', - shape=[None, 32, 81, None], - dtype='float32', - lod_level=0) - text_data = None - reader = fluid.DataFeeder([audio_data, seq_len_data, masks], - self._place) - - log_probs, loss = deep_speech_v2_network( - audio_data=audio_data, - text_data=text_data, - seq_len_data=seq_len_data, - masks=masks, - dict_size=self._vocab_size, - num_conv_layers=self._num_conv_layers, - num_rnn_layers=self._num_rnn_layers, - rnn_size=self._rnn_layer_size, - use_gru=self._use_gru, - share_rnn_weights=self._share_rnn_weights) - return reader, log_probs, loss - - def init_from_pretrained_model(self, exe, program): - '''Init params from pretrain model. ''' - - assert isinstance(self._init_from_pretrained_model, str) - - if not os.path.exists(self._init_from_pretrained_model): - print(self._init_from_pretrained_model) - raise Warning("The pretrained params do not exist.") - return False - fluid.io.load_params( - exe, - self._init_from_pretrained_model, - main_program=program, - filename="params.pdparams") - - print("finish initing model from pretrained params from %s" % - (self._init_from_pretrained_model)) - - pre_epoch = 0 - dir_name = self._init_from_pretrained_model.split('_') - if len(dir_name) >= 2 and dir_name[-2].endswith('epoch') and dir_name[ - -1].isdigit(): - pre_epoch = int(dir_name[-1]) - - return pre_epoch + 1 - - def save_param(self, exe, program, dirname): - '''Save model params to dirname''' - - assert isinstance(self._output_model_dir, str) - - param_dir = os.path.join(self._output_model_dir) - - if not os.path.exists(param_dir): - os.mkdir(param_dir) - - fluid.io.save_params( - exe, - os.path.join(param_dir, dirname), - main_program=program, - filename="params.pdparams") - print("save parameters at %s" % (os.path.join(param_dir, dirname))) - - return True - - def test(self, exe, dev_batch_reader, test_program, test_reader, - fetch_list): + batch_sampler = DeepSpeech2BatchSampler( + train_dataset, + shuffle=True, + batch_size=config.data.batch_size, + drop_last=True, + sortagrad=config.data.sortagrad, + shuffle_method=config.data.shuffle_method) + + collate_fn = SpeechCollator() + self.train_loader = DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=config.data.num_workers, ) + self.valid_loader = DataLoader( + dev_dataset, + batch_size=config.data.batch_size, + shuffle=False, + drop_last=False, + collate_fn=collate_fn) + self.logger.info("Setup train/valid Dataloader!") + + def setup_model(self): + config = self.config + model = DeepSpeech2( + feat_size=self.train_loader.feature_size, + dict_size=self.train_loader.vocab_size, + num_conv_layers=config.model.num_conv_layers, + num_rnn_layers=config.model.num_rnn_layers, + rnn_size=config.model.rnn_layer_size, + share_rnn_weights=config.model.share_rnn_weights) + + if self.parallel: + model = paddle.DataParallel(model) + + grad_clip = paddle.nn.ClipGradByGlobalNorm(config.training.grad_clip) + + optimizer = paddle.optimizer.Adam( + learning_rate=config.training.lr, + parameters=model.parameters(), + weight_decay=paddle.regulaerizer.L2Decay( + config.training.weight_decay), + grad_clip=grad_clip, ) + + criterion = DeepSpeech2Loss(self.train_loader.vocab_size) + + self.model = model + self.optimizer = optimizer + self.criterion = criterion + self.logger.info("Setup model/optimizer/criterion!") + + def compute_losses(self, inputs, outputs): + pass + + def test(self, test_reader): '''Test the model. :param exe:The executor of program. :type exe: Executor - :param dev_batch_reader: The reader of test dataa. - :type dev_batch_reader: read generator :param test_program: The program of test. :type test_program: Program :param test_reader: Reader of test. :type test_reader: Reader - :param fetch_list: Fetch list. - :type fetch_list: list :return: An output unnormalized log probability. :rtype: array ''' @@ -254,13 +252,6 @@ class DeepSpeech2Model(object): :param test_off: Turn off testing. :type test_off: bool """ - # prepare model output directory - if not os.path.exists(self._output_model_dir): - mkpath(self._output_model_dir) - - # adapt the feeding dict according to the network - adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict) - if isinstance(self._place, fluid.CUDAPlace): dev_count = fluid.core.get_cuda_device_count() else: @@ -298,16 +289,6 @@ class DeepSpeech2Model(object): if self._init_from_pretrained_model: pre_epoch = self.init_from_pretrained_model(exe, train_program) - build_strategy = compiler.BuildStrategy() - exec_strategy = fluid.ExecutionStrategy() - - # pass the build_strategy to with_data_parallel API - compiled_prog = compiler.CompiledProgram( - train_program).with_data_parallel( - loss_name=ctc_loss.name, - build_strategy=build_strategy, - exec_strategy=exec_strategy) - train_reader.set_batch_generator(train_batch_reader) test_reader.set_batch_generator(dev_batch_reader) @@ -386,9 +367,6 @@ class DeepSpeech2Model(object): infer_program = fluid.Program() startup_prog = fluid.Program() - # adapt the feeding dict according to the network - adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict) - # prepare the network with fluid.program_guard(infer_program, startup_prog): with fluid.unique_name.guard(): @@ -523,35 +501,3 @@ class DeepSpeech2Model(object): results = [result[0][1] for result in beam_search_results] return results - - def _adapt_feeding_dict(self, feeding_dict): - """Adapt feeding dict according to network struct. - - To remove impacts from padding part, we add scale_sub_region layer and - sub_seq layer. For sub_seq layer, 'sequence_offset' and - 'sequence_length' fields are appended. For each scale_sub_region layer - 'convN_index_range' field is appended. - - :param feeding_dict: Feeding is a map of field name and tuple index - of the data that reader returns. - :type feeding_dict: dict|list - :return: Adapted feeding dict. - :rtype: dict|list - """ - adapted_feeding_dict = copy.deepcopy(feeding_dict) - if isinstance(feeding_dict, dict): - adapted_feeding_dict["sequence_offset"] = len(adapted_feeding_dict) - adapted_feeding_dict["sequence_length"] = len(adapted_feeding_dict) - for i in range(self._num_conv_layers): - adapted_feeding_dict["conv%d_index_range" %i] = \ - len(adapted_feeding_dict) - elif isinstance(feeding_dict, list): - adapted_feeding_dict.append("sequence_offset") - adapted_feeding_dict.append("sequence_length") - for i in range(self._num_conv_layers): - adapted_feeding_dict.append("conv%d_index_range" % i) - else: - raise ValueError("Type of feeding_dict is %s, not supported." % - type(feeding_dict)) - - return adapted_feeding_dict diff --git a/model_utils/network.py b/model_utils/network.py index 19f9d887c..83b91fb70 100644 --- a/model_utils/network.py +++ b/model_utils/network.py @@ -12,31 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import collections -import paddle.fluid as fluid import numpy as np +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn import initializer as I +__all__ = ['DeepSpeech2', 'DeepSpeech2Loss'] -def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride, - padding, act, masks, name): + +def brelu(x, t_min=0.0, t_max=24.0, name=None): + t_min = paddle.to_tensor(t_min) + t_max = paddle.to_tensor(t_max) + return x.maximum(t_min).minimum(t_max) + + +def sequence_mask(x_len, max_len=None, dtype='float32'): + max_len = max_len or x_len.max() + x_len = paddle.unsqueeze(x_len, -1) + row_vector = paddle.arange(max_len) + mask = row_vector < x_len + mask = paddle.cast(mask, dtype) + return mask + + +class ConvBn(nn.Layer): """Convolution layer with batch normalization. - :param input: Input layer. - :type input: Variable - :param filter_size: The x dimension of a filter kernel. Or input a tuple for + :param kernel_size: The x dimension of a filter kernel. Or input a tuple for two image dimension. - :type filter_size: int|tuple|list + :type kernel_size: int|tuple|list :param num_channels_in: Number of input channels. :type num_channels_in: int :param num_channels_out: Number of output channels. :type num_channels_out: int :param stride: The x dimension of the stride. Or input a tuple for two - image dimension. + image dimension. :type stride: int|tuple|list :param padding: The x dimension of the padding. Or input a tuple for two image dimension. :type padding: int|tuple|list - :param act: Activation type. + :param act: Activation type, relu|brelu :type act: string :param masks: Masks data layer to reset padding. :type masks: Variable @@ -44,91 +62,257 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride, :param name: string :return: Batch norm layer after convolution layer. :rtype: Variable - + + """ + + def __init__(self, num_channels_in, num_channels_out, kernel_size, stride, + padding, act): + + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + self.conv = nn.Conv2D( + num_channels_in, + num_channels_out, + kernel_size=kernel_size, + stride=stride, + padding=padding, + weight_attr=None, + bias_attr=None, + data_format='NCHW') + + self.bn = nn.BatchNorm2D( + num_channels_out, + weight_attr=None, + bias_attr=None, + data_format='NCHW') + self.act = paddle.relu if act == 'relu' else brelu + + def forward(self, x, x_len): + """ + x(Tensor): audio, shape [B, C, D, T] + """ + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + + x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] + ) // self.stride[1] + 1 + + # reset padding part to 0 + masks = sequence_mask(x_len) #[B, T] + masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] + x = x.multiply(masks) + + return x, x_len + + +class ConvStack(nn.Layer): + """Convolution group with stacked convolution layers. + + :param feat_size: audio feature dim. + :type feat_size: int + :param num_stacks: Number of stacked convolution layers. + :type num_stacks: int + """ + + def __init__(self, feat_size, num_stacks): + super().__init__() + self.feat_size = feat_size # D + self.num_stacks = num_stacks + + self.filter_size = (41, 11) # [D, T] + self.stride = (2, 3) + self.padding = (20, 5) + self.conv_in = ConvBn( + num_channels_in=1, + num_channels_out=32, + kernel_size=self.filter_size, + stride=self.stride, + padding=self.padding, + act='brelu', ) + + out_channel = 32 + self.conv_stack = nn.LayerList([ + ConvBn( + num_channels_in=32, + num_channels_out=out_channel, + kernel_size=(21, 11), + stride=(2, 1), + padding=(10, 5), + act='brelu') for i in range(num_stacks - 1) + ]) + + # conv output feat_dim + output_height = (feat_size - 1) // 2 + 1 + for i in range(self.num_stacks - 1): + output_height = (output_height - 1) // 2 + 1 + self.output_height = out_channel * output_height + + def forward(self, x, x_len): + """ + x: shape [B, C, D, T] + x_len : shape [B] + """ + x, x_len = self.conv_in(x, x_len) + for i, conv in enumerate(self.conv_stack): + x, x_len = conv(x, x_len) + return x, x_len + + +class RNNCell(nn.RNNCellBase): + r""" + Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it + computes the outputs and updates states. + The formula used is as follows: + .. math:: + h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) + y_{t} & = h_{t} + + where :math:`act` is for :attr:`activation`. """ - conv_layer = fluid.layers.conv2d( - input=input, - num_filters=num_channels_out, - filter_size=filter_size, - stride=stride, - padding=padding, - param_attr=fluid.ParamAttr(name=name + '_conv2d_weight'), - act=None, - bias_attr=False) - - batch_norm = fluid.layers.batch_norm( - input=conv_layer, - act=act, - param_attr=fluid.ParamAttr(name=name + '_batch_norm_weight'), - bias_attr=fluid.ParamAttr(name=name + '_batch_norm_bias'), - moving_mean_name=name + '_batch_norm_moving_mean', - moving_variance_name=name + '_batch_norm_moving_variance') - - # reset padding part to 0 - padding_reset = fluid.layers.elementwise_mul(batch_norm, masks) - return padding_reset - - -class RNNCell(fluid.layers.RNNCell): - """A simple rnn cell.""" def __init__(self, hidden_size, - param_attr=None, - bias_attr=None, - hidden_activation=None, - activation=None, - dtype="float32", - name="RNNCell"): - """Initialize simple rnn cell. - - :param hidden_size: Dimension of RNN cells. - :type hidden_size: int - :param param_attr: Parameter properties of hidden layer weights that - can be learned - :type param_attr: ParamAttr - :param bias_attr: Bias properties of hidden layer weights that can be learned - :type bias_attr: ParamAttr - :param hidden_activation: Activation for hidden cell - :type hidden_activation: Activation - :param activation: Activation for output - :type activation: Activation - :param name: Name of cell - :type name: string - """ + activation="tanh", + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = self.create_parameter( + (hidden_size, ), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + self.bias_hh = self.create_parameter( + (hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) self.hidden_size = hidden_size - self.param_attr = param_attr - self.bias_attr = bias_attr - self.hidden_activation = hidden_activation - self.activation = activation or fluid.layers.brelu - self.name = name - - def call(self, inputs, states): - new_hidden = fluid.layers.fc( - input=states, - size=self.hidden_size, - act=self.hidden_activation, - param_attr=self.param_attr, - bias_attr=self.bias_attr) - new_hidden = fluid.layers.elementwise_add(new_hidden, inputs) - new_hidden = self.activation(new_hidden) - - return new_hidden, new_hidden + if activation not in ["tanh", "relu", "brelu"]: + raise ValueError( + "activation for SimpleRNNCell should be tanh or relu, " + "but get {}".format(activation)) + self.activation = activation + self._activation_fn = paddle.tanh \ + if activation == "tanh" \ + else F.relu + if activation == 'brelu': + self._activation_fn = brelu + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + pre_h = states + i2h = inputs + if self.bias_ih is not None: + i2h += self.bias_ih + h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h2h += self.bias_hh + h = self._activation_fn(i2h + h2h) + return h, h @property def state_shape(self): - return [self.hidden_size] + return (self.hidden_size, ) + + +class GRUCellShare(nn.RNNCellBase): + r""" + Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, + it computes the outputs and updates states. + The formula for GRU used is as follows: + .. math:: + r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) + z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) + \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) + h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} + y_{t} & = h_{t} + + where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise + multiplication operator. + """ + + def __init__(self, + input_size, + hidden_size, + weight_ih_attr=None, + weight_hh_attr=None, + bias_ih_attr=None, + bias_hh_attr=None, + name=None): + super().__init__() + std = 1.0 / math.sqrt(hidden_size) + self.weight_hh = self.create_parameter( + (3 * hidden_size, hidden_size), + weight_hh_attr, + default_initializer=I.Uniform(-std, std)) + self.bias_ih = self.create_parameter( + (3 * hidden_size, ), + bias_ih_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + self.bias_hh = self.create_parameter( + (3 * hidden_size, ), + bias_hh_attr, + is_bias=True, + default_initializer=I.Uniform(-std, std)) + + self.hidden_size = hidden_size + self.input_size = input_size + self._gate_activation = F.sigmoid + self._activation = paddle.tanh + + def forward(self, inputs, states=None): + if states is None: + states = self.get_initial_states(inputs, self.state_shape) + + pre_hidden = states + x_gates = inputs + if self.bias_ih is not None: + x_gates = x_gates + self.bias_ih + h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) + if self.bias_hh is not None: + h_gates = h_gates + self.bias_hh + + x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) + h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) + r = self._gate_activation(x_r + h_r) + z = self._gate_activation(x_z + h_z) + c = self._activation(x_c + r * h_c) # apply reset gate after mm + h = (pre_hidden - c) * z + c -def bidirectional_simple_rnn_bn_layer(name, input, size, share_weights): + return h, h + + @property + def state_shape(self): + r""" + The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch + size would be automatically inserted into shape). The shape corresponds + to the shape of :math:`h_{t-1}`. + """ + return (self.hidden_size, ) + + +class BiRNNWithBN(nn.Layer): """Bidirectonal simple rnn layer with sequence-wise batch normalization. The batch normalization is only performed on input-state weights. :param name: Name of the layer parameters. :type name: string - :param input: Input layer. - :type input: Variable :param size: Dimension of RNN cells. :type size: int :param share_weights: Whether to share input-hidden weights between @@ -137,88 +321,44 @@ def bidirectional_simple_rnn_bn_layer(name, input, size, share_weights): :return: Bidirectional simple rnn layer. :rtype: Variable """ - forward_cell = RNNCell( - hidden_size=size, - activation=fluid.layers.brelu, - param_attr=fluid.ParamAttr(name=name + '_forward_rnn_weight'), - bias_attr=fluid.ParamAttr(name=name + '_forward_rnn_bias')) - - reverse_cell = RNNCell( - hidden_size=size, - activation=fluid.layers.brelu, - param_attr=fluid.ParamAttr(name=name + '_reverse_rnn_weight'), - bias_attr=fluid.ParamAttr(name=name + '_reverse_rnn_bias')) - - pad_value = fluid.layers.assign(input=np.array([0.0], dtype=np.float32)) - - if share_weights: - #input-hidden weights shared between bi-directional rnn. - input_proj = fluid.layers.fc( - input=input, - size=size, - act=None, - param_attr=fluid.ParamAttr(name=name + '_fc_weight'), - bias_attr=False) - - # batch norm is only performed on input-state projection - input_proj_bn_forward = fluid.layers.batch_norm( - input=input_proj, - act=None, - param_attr=fluid.ParamAttr(name=name + '_batch_norm_weight'), - bias_attr=fluid.ParamAttr(name=name + '_batch_norm_bias'), - moving_mean_name=name + '_batch_norm_moving_mean', - moving_variance_name=name + '_batch_norm_moving_variance') - input_proj_bn_reverse = input_proj_bn_forward - else: - input_proj_forward = fluid.layers.fc( - input=input, - size=size, - act=None, - param_attr=fluid.ParamAttr(name=name + '_forward_fc_weight'), - bias_attr=False) - input_proj_reverse = fluid.layers.fc( - input=input, - size=size, - act=None, - param_attr=fluid.ParamAttr(name=name + '_reverse_fc_weight'), - bias_attr=False) - #batch norm is only performed on input-state projection - input_proj_bn_forward = fluid.layers.batch_norm( - input=input_proj_forward, - act=None, - param_attr=fluid.ParamAttr( - name=name + '_forward_batch_norm_weight'), - bias_attr=fluid.ParamAttr(name=name + '_forward_batch_norm_bias'), - moving_mean_name=name + '_forward_batch_norm_moving_mean', - moving_variance_name=name + '_forward_batch_norm_moving_variance') - input_proj_bn_reverse = fluid.layers.batch_norm( - input=input_proj_reverse, - act=None, - param_attr=fluid.ParamAttr( - name=name + '_reverse_batch_norm_weight'), - bias_attr=fluid.ParamAttr(name=name + '_reverse_batch_norm_bias'), - moving_mean_name=name + '_reverse_batch_norm_moving_mean', - moving_variance_name=name + '_reverse_batch_norm_moving_variance') - # forward and backward in time - input, length = fluid.layers.sequence_pad(input_proj_bn_forward, pad_value) - forward_rnn, _ = fluid.layers.rnn( - cell=forward_cell, inputs=input, time_major=False, is_reverse=False) - forward_rnn = fluid.layers.sequence_unpad(x=forward_rnn, length=length) - - input, length = fluid.layers.sequence_pad(input_proj_bn_reverse, pad_value) - reverse_rnn, _ = fluid.layers.rnn( - cell=reverse_cell, - inputs=input, - sequence_length=length, - time_major=False, - is_reverse=True) - reverse_rnn = fluid.layers.sequence_unpad(x=reverse_rnn, length=length) - - out = fluid.layers.concat(input=[forward_rnn, reverse_rnn], axis=1) - return out - - -def bidirectional_gru_bn_layer(name, input, size, act): + + def __init__(self, i_size, h_size, share_weights): + super().__init__() + self.share_weights = share_weights + self.pad_value = paddle.to_tensor(np.array([0.0], dtype=np.float32)) + if self.share_weights: + #input-hidden weights shared between bi-directional rnn. + self.fw_fc = nn.Linear(i_size, h_size) + # batch norm is only performed on input-state projection + self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC') + self.bw_fc = self.fw_fc + self.bw_bn = self.fw_bn + else: + self.fw_fc = nn.Linear(i_size, h_size) + self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC') + self.bw_fc = nn.Linear(i_size, h_size) + self.bw_bn = nn.BatchNorm1D(h_size, data_format='NLC') + + self.fw_cell = RNNCell(hidden_size=h_size, activation='relu') + self.bw_cell = RNNCell( + hidden_size=h_size, + activation='relu', ) + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x, x_len): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class BiGRUWithBN(nn.Layer): """Bidirectonal gru layer with sequence-wise batch normalization. The batch normalization is only performed on input-state weights. @@ -233,108 +373,33 @@ def bidirectional_gru_bn_layer(name, input, size, act): :return: Bidirectional GRU layer. :rtype: Variable """ - input_proj_forward = fluid.layers.fc( - input=input, - size=size * 3, - act=None, - param_attr=fluid.ParamAttr(name=name + '_forward_fc_weight'), - bias_attr=False) - input_proj_reverse = fluid.layers.fc( - input=input, - size=size * 3, - act=None, - param_attr=fluid.ParamAttr(name=name + '_reverse_fc_weight'), - bias_attr=False) - #batch norm is only performed on input-related prohections - input_proj_bn_forward = fluid.layers.batch_norm( - input=input_proj_forward, - act=None, - param_attr=fluid.ParamAttr(name=name + '_forward_batch_norm_weight'), - bias_attr=fluid.ParamAttr(name=name + '_forward_batch_norm_bias'), - moving_mean_name=name + '_forward_batch_norm_moving_mean', - moving_variance_name=name + '_forward_batch_norm_moving_variance') - input_proj_bn_reverse = fluid.layers.batch_norm( - input=input_proj_reverse, - act=None, - param_attr=fluid.ParamAttr(name=name + '_reverse_batch_norm_weight'), - bias_attr=fluid.ParamAttr(name=name + '_reverse_batch_norm_bias'), - moving_mean_name=name + '_reverse_batch_norm_moving_mean', - moving_variance_name=name + '_reverse_batch_norm_moving_variance') - #forward and backward in time - forward_gru = fluid.layers.dynamic_gru( - input=input_proj_bn_forward, - size=size, - gate_activation='sigmoid', - candidate_activation=act, - param_attr=fluid.ParamAttr(name=name + '_forward_gru_weight'), - bias_attr=fluid.ParamAttr(name=name + '_forward_gru_bias'), - is_reverse=False) - reverse_gru = fluid.layers.dynamic_gru( - input=input_proj_bn_reverse, - size=size, - gate_activation='sigmoid', - candidate_activation=act, - param_attr=fluid.ParamAttr(name=name + '_reverse_gru_weight'), - bias_attr=fluid.ParamAttr(name=name + '_reverse_gru_bias'), - is_reverse=True) - return fluid.layers.concat(input=[forward_gru, reverse_gru], axis=1) - - -def conv_group(input, num_stacks, seq_len_data, masks): - """Convolution group with stacked convolution layers. - :param input: Input layer. - :type input: Variable - :param num_stacks: Number of stacked convolution layers. - :type num_stacks: int - :param seq_len_data:Valid sequence length data layer. - :type seq_len_data:Variable - :param masks: Masks data layer to reset padding. - :type masks: Variable - :return: Output layer of the convolution group. - :rtype: Variable - """ - filter_size = (41, 11) - stride = (2, 3) - padding = (20, 5) - conv = conv_bn_layer( - input=input, - filter_size=filter_size, - num_channels_in=1, - num_channels_out=32, - stride=stride, - padding=padding, - act="brelu", - masks=masks, - name='layer_0', ) - - seq_len_data = (np.array(seq_len_data) - filter_size[1] + 2 * padding[1] - ) // stride[1] + 1 - - output_height = (161 - 1) // 2 + 1 - - for i in range(num_stacks - 1): - #reshape masks - output_height = (output_height - 1) // 2 + 1 - masks = fluid.layers.slice( - masks, axes=[2], starts=[0], ends=[output_height]) - conv = conv_bn_layer( - input=conv, - filter_size=(21, 11), - num_channels_in=32, - num_channels_out=32, - stride=(2, 1), - padding=(10, 5), - act="brelu", - masks=masks, - name='layer_{}'.format(i + 1), ) - - output_num_channels = 32 - return conv, output_num_channels, output_height, seq_len_data - - -def rnn_group(input, size, num_stacks, num_conv_layers, use_gru, - share_rnn_weights): + def __init__(self, i_size, h_size, act): + super().__init__() + hidden_size = h_size * 3 + self.fw_fc = nn.Linear(i_size, hidden_size) + self.fw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC') + self.bw_fc = nn.Linear(i_size, hidden_size) + self.bw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC') + + self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size) + self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size) + self.fw_rnn = nn.RNN( + self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] + self.bw_rnn = nn.RNN( + self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] + + def forward(self, x, x_len): + # x, shape [B, T, D] + fw_x = self.fw_bn(self.fw_fc(x)) + bw_x = self.bw_bn(self.bw_fc(x)) + fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) + bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) + x = paddle.concat([fw_x, bw_x], axis=-1) + return x, x_len + + +class RNNStack(nn.Layer): """RNN group with stacked bidirectional simple RNN or GRU layers. :param input: Input layer. @@ -352,42 +417,42 @@ def rnn_group(input, size, num_stacks, num_conv_layers, use_gru, :return: Output layer of the RNN group. :rtype: Variable """ - output = input - for i in range(num_stacks): - if use_gru: - output = bidirectional_gru_bn_layer( - name='layer_{}'.format(i + num_conv_layers), - input=output, - size=size, - act="relu") - else: - name = 'layer_{}'.format(i + num_conv_layers) - output = bidirectional_simple_rnn_bn_layer( - name=name, - input=output, - size=size, - share_weights=share_rnn_weights) - return output - - -def deep_speech_v2_network(audio_data, - text_data, - seq_len_data, - masks, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=256, - use_gru=False, - share_rnn_weights=True): + + def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights): + super().__init__() + self.rnn_stacks = nn.LayerList() + for i in range(num_stacks): + if use_gru: + #default:GRU using tanh + self.rnn_stacks.append( + BiGRUWithBN(i_size=i_size, h_size=h_size, act="relu")) + else: + self.rnn_stacks.append( + BiRNNWithBN( + i_size=i_size, + h_size=h_size, + share_weights=share_rnn_weights)) + i_size = h_size * 2 + + def forward(self, x, x_len): + """ + x: shape [B, T, D] + x_len: shpae [B] + """ + for i, rnn in enumerate(self.rnn_stacks): + x, x_len = rnn(x, x_len) + return x, x_len + + +class DeepSpeech2(nn.Layer): """The DeepSpeech2 network structure. :param audio_data: Audio spectrogram data layer. :type audio_data: Variable :param text_data: Transcription text data layer. :type text_data: Variable - :param seq_len_data: Valid sequence length data layer. - :type seq_len_data: Variable + :param audio_len: Valid sequence length data layer. + :type audio_len: Variable :param masks: Masks data layer to reset padding. :type masks: Variable :param dict_size: Dictionary size for tokenized transcription. @@ -408,51 +473,80 @@ def deep_speech_v2_network(audio_data, before softmax) and a ctc cost layer. :rtype: tuple of LayerOutput """ - audio_data = fluid.layers.unsqueeze(audio_data, axes=[1]) - - # convolution group - conv_group_output, conv_group_num_channels, conv_group_height, seq_len_data = conv_group( - input=audio_data, - num_stacks=num_conv_layers, - seq_len_data=seq_len_data, - masks=masks) - - # convert data form convolution feature map to sequence of vectors - transpose = fluid.layers.transpose(conv_group_output, perm=[0, 3, 1, 2]) - reshape_conv_output = fluid.layers.reshape( - x=transpose, - shape=[0, -1, conv_group_height * conv_group_num_channels], - inplace=False) - # remove padding part - seq_len_data = fluid.layers.reshape(seq_len_data, [-1]) - sequence = fluid.layers.sequence_unpad( - x=reshape_conv_output, length=seq_len_data) - #rnn group - rnn_group_output = rnn_group( - input=sequence, - size=rnn_size, - num_stacks=num_rnn_layers, - num_conv_layers=num_conv_layers, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) - fc = fluid.layers.fc( - input=rnn_group_output, - size=dict_size + 1, - act=None, - param_attr=fluid.ParamAttr( - name='layer_{}'.format(num_conv_layers + num_rnn_layers) + - '_fc_weight'), - bias_attr=fluid.ParamAttr( - name='layer_{}'.format(num_conv_layers + num_rnn_layers) + - '_fc_bias')) - # pribability distribution with softmax - log_probs = fluid.layers.softmax(fc) - log_probs.persistable = True - if not text_data: - return log_probs, None - else: - #ctc cost - ctc_loss = fluid.layers.warpctc( - input=fc, label=text_data, blank=dict_size, norm_by_times=True) - ctc_loss = fluid.layers.reduce_sum(ctc_loss) - return log_probs, ctc_loss + + def __init__(self, + feat_size, + dict_size, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=256, + use_gru=False, + share_rnn_weights=True): + super().__init__() + self.feat_size = feat_size # 161 for linear + self.dict_size = dict_size + + self.conv = ConvStack(feat_size, num_conv_layers) + + i_size = self.conv.output_height # H after conv stack + self.rnn = RNNStack( + i_size=i_size, + h_size=rnn_size, + num_stacks=num_rnn_layers, + use_gru=use_gru, + share_rnn_weights=share_rnn_weights) + self.fc = nn.Linear(rnn_size * 2, dict_size + 1) + + def predict(self, audio, audio_len): + # [B, D, T] -> [B, C=1, D, T] + audio = audio.unsqueeze(1) + + # convolution group + x, audio_len = self.conv(audio, audio_len) + + # convert data from convolution feature map to sequence of vectors + B, C, D, T = paddle.shape(x) + x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] + x = x.reshape([B, T, C * D]) #[B, T, C*D] + + # remove padding part + x, audio_len = self.rnn(x, audio_len) #[B, T, D] + + logits = self.fc(x) #[B, T, V + 1] + + #ctcdecoder need probs, not log_probs + probs = F.softmax(logits) + + return logits, probs + + @paddle.no_grad() + def infer(self, audio, audio_len): + _, probs = self.predict(audio, audio_len) + return probs + + def forward(self, audio, text, audio_len, text_len): + """ + audio: shape [B, D, T] + text: shape [B, T] + audio_len: shape [B] + text_len: shape [B] + """ + logits, _ = self.predict(audio, audio_len) + return logits + + +class DeepSpeech2Loss(nn.Layer): + def __init__(self, vocab_size): + super().__init__() + # last token id as blank id + self.loss = nn.CTCLoss(blank=vocab_size, reduction='none') + + def forward(self, logits, text, audio_len, text_len): + # warp-ctc do softmax on activations + # warp-ctc need activation with shape [T, B, V + 1] + logits = logits.transpose([1, 0, 2]) + + ctc_loss = self.loss(logits, text, audio_len, text_len) + ctc_loss /= text_len # norm_by_times + ctc_loss = ctc_loss.sum() + return ctc_loss diff --git a/model_utils/network2.py b/model_utils/network2.py deleted file mode 100644 index bab97a3cc..000000000 --- a/model_utils/network2.py +++ /dev/null @@ -1,555 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import collections -import numpy as np -import paddle -from paddle import nn -from paddle.nn import functional as F -from paddle.nn import initializer as I - -__all__ = ['DeepSpeech2'] - - -def brelu(x, t_min=0.0, t_max=24.0, name=None): - t_min = paddle.to_tensor(t_min) - t_max = paddle.to_tensor(t_max) - return x.maximum(t_min).minimum(t_max) - - -def sequence_mask(x_len, max_len=None, dtype='float32'): - max_len = max_len or x_len.max() - x_len = paddle.unsqueeze(x_len, -1) - row_vector = paddle.arange(max_len) - mask = row_vector < x_len - mask = paddle.cast(mask, dtype) - return mask - - -class ConvBn(nn.Layer): - """Convolution layer with batch normalization. - - :param kernel_size: The x dimension of a filter kernel. Or input a tuple for - two image dimension. - :type kernel_size: int|tuple|list - :param num_channels_in: Number of input channels. - :type num_channels_in: int - :param num_channels_out: Number of output channels. - :type num_channels_out: int - :param stride: The x dimension of the stride. Or input a tuple for two - image dimension. - :type stride: int|tuple|list - :param padding: The x dimension of the padding. Or input a tuple for two - image dimension. - :type padding: int|tuple|list - :param act: Activation type, relu|brelu - :type act: string - :param masks: Masks data layer to reset padding. - :type masks: Variable - :param name: Name of the layer. - :param name: string - :return: Batch norm layer after convolution layer. - :rtype: Variable - - """ - - def __init__(self, num_channels_in, num_channels_out, kernel_size, stride, - padding, act): - - super().__init__() - self.kernel_size = kernel_size - self.stride = stride - self.padding = padding - - self.conv = nn.Conv2D( - num_channels_in, - num_channels_out, - kernel_size=kernel_size, - stride=stride, - padding=padding, - weight_attr=None, - bias_attr=None, - data_format='NCHW') - - self.bn = nn.BatchNorm2D( - num_channels_out, - weight_attr=None, - bias_attr=None, - data_format='NCHW') - self.act = paddle.relu if act == 'relu' else brelu - - def forward(self, x, x_len): - """ - x(Tensor): audio, shape [B, C, D, T] - """ - x = self.conv(x) - x = self.bn(x) - x = self.act(x) - - x_len = (x_len - self.kernel_size[1] + 2 * self.padding[1] - ) // self.stride[1] + 1 - - # reset padding part to 0 - masks = sequence_mask(x_len) #[B, T] - masks = masks.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T] - x = x.multiply(masks) - - return x, x_len - - -class ConvStack(nn.Layer): - """Convolution group with stacked convolution layers. - - :param feat_size: audio feature dim. - :type feat_size: int - :param num_stacks: Number of stacked convolution layers. - :type num_stacks: int - """ - - def __init__(self, feat_size, num_stacks): - super().__init__() - self.feat_size = feat_size # D - self.num_stacks = num_stacks - - self.filter_size = (41, 11) # [D, T] - self.stride = (2, 3) - self.padding = (20, 5) - self.conv_in = ConvBn( - num_channels_in=1, - num_channels_out=32, - kernel_size=self.filter_size, - stride=self.stride, - padding=self.padding, - act='brelu', ) - - out_channel = 32 - self.conv_stack = nn.LayerList([ - ConvBn( - num_channels_in=32, - num_channels_out=out_channel, - kernel_size=(21, 11), - stride=(2, 1), - padding=(10, 5), - act='brelu') for i in range(num_stacks - 1) - ]) - - # conv output feat_dim - output_height = (feat_size - 1) // 2 + 1 - for i in range(self.num_stacks - 1): - output_height = (output_height - 1) // 2 + 1 - self.output_height = out_channel * output_height - - def forward(self, x, x_len): - """ - x: shape [B, C, D, T] - x_len : shape [B] - """ - x, x_len = self.conv_in(x, x_len) - for i, conv in enumerate(self.conv_stack): - x, x_len = conv(x, x_len) - return x, x_len - - -class RNNCell(nn.RNNCellBase): - r""" - Elman RNN (SimpleRNN) cell. Given the inputs and previous states, it - computes the outputs and updates states. - The formula used is as follows: - .. math:: - h_{t} & = act(x_{t} + b_{ih} + W_{hh}h_{t-1} + b_{hh}) - y_{t} & = h_{t} - - where :math:`act` is for :attr:`activation`. - """ - - def __init__(self, - hidden_size, - activation="tanh", - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - self.bias_ih = self.create_parameter( - (hidden_size, ), - bias_ih_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - self.bias_hh = self.create_parameter( - (hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - if activation not in ["tanh", "relu", "brelu"]: - raise ValueError( - "activation for SimpleRNNCell should be tanh or relu, " - "but get {}".format(activation)) - self.activation = activation - self._activation_fn = paddle.tanh \ - if activation == "tanh" \ - else F.relu - if activation == 'brelu': - self._activation_fn = brelu - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - pre_h = states - i2h = inputs - if self.bias_ih is not None: - i2h += self.bias_ih - h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h2h += self.bias_hh - h = self._activation_fn(i2h + h2h) - return h, h - - @property - def state_shape(self): - return (self.hidden_size, ) - - -class GRUCellShare(nn.RNNCellBase): - r""" - Gated Recurrent Unit (GRU) RNN cell. Given the inputs and previous states, - it computes the outputs and updates states. - The formula for GRU used is as follows: - .. math:: - r_{t} & = \sigma(W_{ir}x_{t} + b_{ir} + W_{hr}h_{t-1} + b_{hr}) - z_{t} & = \sigma(W_{iz}x_{t} + b_{iz} + W_{hz}h_{t-1} + b_{hz}) - \widetilde{h}_{t} & = \tanh(W_{ic}x_{t} + b_{ic} + r_{t} * (W_{hc}h_{t-1} + b_{hc})) - h_{t} & = z_{t} * h_{t-1} + (1 - z_{t}) * \widetilde{h}_{t} - y_{t} & = h_{t} - - where :math:`\sigma` is the sigmoid fucntion, and * is the elemetwise - multiplication operator. - """ - - def __init__(self, - input_size, - hidden_size, - weight_ih_attr=None, - weight_hh_attr=None, - bias_ih_attr=None, - bias_hh_attr=None, - name=None): - super().__init__() - std = 1.0 / math.sqrt(hidden_size) - self.weight_hh = self.create_parameter( - (3 * hidden_size, hidden_size), - weight_hh_attr, - default_initializer=I.Uniform(-std, std)) - self.bias_ih = self.create_parameter( - (3 * hidden_size, ), - bias_ih_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - self.bias_hh = self.create_parameter( - (3 * hidden_size, ), - bias_hh_attr, - is_bias=True, - default_initializer=I.Uniform(-std, std)) - - self.hidden_size = hidden_size - self.input_size = input_size - self._gate_activation = F.sigmoid - self._activation = paddle.tanh - - def forward(self, inputs, states=None): - if states is None: - states = self.get_initial_states(inputs, self.state_shape) - - pre_hidden = states - x_gates = inputs - if self.bias_ih is not None: - x_gates = x_gates + self.bias_ih - h_gates = paddle.matmul(pre_hidden, self.weight_hh, transpose_y=True) - if self.bias_hh is not None: - h_gates = h_gates + self.bias_hh - - x_r, x_z, x_c = paddle.split(x_gates, num_or_sections=3, axis=1) - h_r, h_z, h_c = paddle.split(h_gates, num_or_sections=3, axis=1) - - r = self._gate_activation(x_r + h_r) - z = self._gate_activation(x_z + h_z) - c = self._activation(x_c + r * h_c) # apply reset gate after mm - h = (pre_hidden - c) * z + c - - return h, h - - @property - def state_shape(self): - r""" - The `state_shape` of GRUCell is a shape `[hidden_size]` (-1 for batch - size would be automatically inserted into shape). The shape corresponds - to the shape of :math:`h_{t-1}`. - """ - return (self.hidden_size, ) - - -class BiRNNWithBN(nn.Layer): - """Bidirectonal simple rnn layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param name: Name of the layer parameters. - :type name: string - :param size: Dimension of RNN cells. - :type size: int - :param share_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - :type share_weights: bool - :return: Bidirectional simple rnn layer. - :rtype: Variable - """ - - def __init__(self, i_size, h_size, share_weights): - super().__init__() - self.share_weights = share_weights - self.pad_value = paddle.to_tensor(np.array([0.0], dtype=np.float32)) - if self.share_weights: - #input-hidden weights shared between bi-directional rnn. - self.fw_fc = nn.Linear(i_size, h_size) - # batch norm is only performed on input-state projection - self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC') - self.bw_fc = self.fw_fc - self.bw_bn = self.fw_bn - else: - self.fw_fc = nn.Linear(i_size, h_size) - self.fw_bn = nn.BatchNorm1D(h_size, data_format='NLC') - self.bw_fc = nn.Linear(i_size, h_size) - self.bw_bn = nn.BatchNorm1D(h_size, data_format='NLC') - - self.fw_cell = RNNCell(hidden_size=h_size, activation='relu') - self.bw_cell = RNNCell( - hidden_size=h_size, - activation='relu', ) - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x, x_len): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class BiGRUWithBN(nn.Layer): - """Bidirectonal gru layer with sequence-wise batch normalization. - The batch normalization is only performed on input-state weights. - - :param name: Name of the layer. - :type name: string - :param input: Input layer. - :type input: Variable - :param size: Dimension of GRU cells. - :type size: int - :param act: Activation type. - :type act: string - :return: Bidirectional GRU layer. - :rtype: Variable - """ - - def __init__(self, i_size, h_size, act): - super().__init__() - hidden_size = h_size * 3 - self.fw_fc = nn.Linear(i_size, hidden_size) - self.fw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC') - self.bw_fc = nn.Linear(i_size, hidden_size) - self.bw_bn = nn.BatchNorm1D(hidden_size, data_format='NLC') - - self.fw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size) - self.bw_cell = GRUCellShare(input_size=hidden_size, hidden_size=h_size) - self.fw_rnn = nn.RNN( - self.fw_cell, is_reverse=False, time_major=False) #[B, T, D] - self.bw_rnn = nn.RNN( - self.fw_cell, is_reverse=True, time_major=False) #[B, T, D] - - def forward(self, x, x_len): - # x, shape [B, T, D] - fw_x = self.fw_bn(self.fw_fc(x)) - bw_x = self.bw_bn(self.bw_fc(x)) - fw_x, _ = self.fw_rnn(inputs=fw_x, sequence_length=x_len) - bw_x, _ = self.bw_rnn(inputs=bw_x, sequence_length=x_len) - x = paddle.concat([fw_x, bw_x], axis=-1) - return x, x_len - - -class RNNStack(nn.Layer): - """RNN group with stacked bidirectional simple RNN or GRU layers. - - :param input: Input layer. - :type input: Variable - :param size: Dimension of RNN cells in each layer. - :type size: int - :param num_stacks: Number of stacked rnn layers. - :type num_stacks: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward directional RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: Output layer of the RNN group. - :rtype: Variable - """ - - def __init__(self, i_size, h_size, num_stacks, use_gru, share_rnn_weights): - super().__init__() - self.rnn_stacks = nn.LayerList() - for i in range(num_stacks): - if use_gru: - #default:GRU using tanh - self.rnn_stacks.append( - BiGRUWithBN(i_size=i_size, h_size=h_size, act="relu")) - else: - self.rnn_stacks.append( - BiRNNWithBN( - i_size=i_size, - h_size=h_size, - share_weights=share_rnn_weights)) - i_size = h_size * 2 - - def forward(self, x, x_len): - """ - x: shape [B, T, D] - x_len: shpae [B] - """ - for i, rnn in enumerate(self.rnn_stacks): - x, x_len = rnn(x, x_len) - return x, x_len - - -class DeepSpeech2(nn.Layer): - """The DeepSpeech2 network structure. - - :param audio_data: Audio spectrogram data layer. - :type audio_data: Variable - :param text_data: Transcription text data layer. - :type text_data: Variable - :param audio_len: Valid sequence length data layer. - :type audio_len: Variable - :param masks: Masks data layer to reset padding. - :type masks: Variable - :param dict_size: Dictionary size for tokenized transcription. - :type dict_size: int - :param num_conv_layers: Number of stacking convolution layers. - :type num_conv_layers: int - :param num_rnn_layers: Number of stacking RNN layers. - :type num_rnn_layers: int - :param rnn_size: RNN layer size (dimension of RNN cells). - :type rnn_size: int - :param use_gru: Use gru if set True. Use simple rnn if set False. - :type use_gru: bool - :param share_rnn_weights: Whether to share input-hidden weights between - forward and backward direction RNNs. - It is only available when use_gru=False. - :type share_weights: bool - :return: A tuple of an output unnormalized log probability layer ( - before softmax) and a ctc cost layer. - :rtype: tuple of LayerOutput - """ - - def __init__(self, - feat_size, - dict_size, - num_conv_layers=2, - num_rnn_layers=3, - rnn_size=256, - use_gru=False, - share_rnn_weights=True): - super().__init__() - self.feat_size = feat_size # 161 for linear - self.dict_size = dict_size - - self.conv = ConvStack(feat_size, num_conv_layers) - - i_size = self.conv.output_height # H after conv stack - self.rnn = RNNStack( - i_size=i_size, - h_size=rnn_size, - num_stacks=num_rnn_layers, - use_gru=use_gru, - share_rnn_weights=share_rnn_weights) - self.fc = nn.Linear(rnn_size * 2, dict_size + 1) - - def predict(self, audio, audio_len): - # [B, D, T] -> [B, C=1, D, T] - audio = audio.unsqueeze(1) - - # convolution group - x, audio_len = self.conv(audio, audio_len) - - # convert data from convolution feature map to sequence of vectors - B, C, D, T = paddle.shape(x) - x = x.transpose([0, 3, 1, 2]) #[B, T, C, D] - x = x.reshape([B, T, C * D]) #[B, T, C*D] - - # remove padding part - x, audio_len = self.rnn(x, audio_len) #[B, T, D] - - logits = self.fc(x) #[B, T, V + 1] - - #ctcdecoder need probs, not log_probs - probs = F.softmax(logits) - - return logits, probs - - @paddle.no_grad() - def infer(self, audio, audio_len): - _, probs = self.predict(audio, audio_len) - return probs - - def forward(self, audio, text, audio_len, text_len): - """ - audio: shape [B, D, T] - text: shape [B, T] - audio_len: shape [B] - text_len: shape [B] - """ - logits, probs = self.predict(audio, audio_len) - print(logits.shape) - print(text.shape) - print(audio_len.shape) - print(text_len.shape) - return logits - - -class DeepSpeechLoss(nn.Layer): - def __init__(self, vocab_size): - super().__init__() - self.loss = nn.CTCLoss(blank=vocab_size, reduction='none') - - def forward(self, logits, text, audio_len, text_len): - # warp-ctc do softmax on activations - # warp-ctc need activation with shape [T, B, V + 1] - logits = logits.transpose([1, 0, 2]) - - ctc_loss = self.loss(logits, text, audio_len, text_len) - ctc_loss /= text_len # norm_by_times - ctc_loss = ctc_loss.sum() - return ctc_loss diff --git a/model_utils/trainer.py b/model_utils/trainer.py new file mode 100644 index 000000000..90a2bfb85 --- /dev/null +++ b/model_utils/trainer.py @@ -0,0 +1,279 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import logging +from pathlib import Path +import numpy as np +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader, DistributedBatchSampler +from tensorboardX import SummaryWriter +from collections import defaultdict + +import parakeet +from parakeet.utils import checkpoint, mp_tools + +__all__ = ["ExperimentBase"] + + +class ExperimentBase(object): + """ + An experiment template in order to structure the training code and take + care of saving, loading, logging, visualization stuffs. It's intended to + be flexible and simple. + + So it only handles output directory (create directory for the output, + create a checkpoint directory, dump the config in use and create + visualizer and logger) in a standard way without enforcing any + input-output protocols to the model and dataloader. It leaves the main + part for the user to implement their own (setup the model, criterion, + optimizer, define a training step, define a validation function and + customize all the text and visual logs). + It does not save too much boilerplate code. The users still have to write + the forward/backward/update mannually, but they are free to add + non-standard behaviors if needed. + We have some conventions to follow. + 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and + ``valid_loader``, ``config`` and ``args`` attributes. + 2. The config should have a ``training`` field, which has + ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is + used as the trigger to invoke validation, checkpointing and stop of the + experiment. + 3. There are four methods, namely ``train_batch``, ``valid``, + ``setup_model`` and ``setup_dataloader`` that should be implemented. + Feel free to add/overwrite other methods and standalone functions if you + need. + + Parameters + ---------- + config: yacs.config.CfgNode + The configuration used for the experiment. + + args: argparse.Namespace + The parsed command line arguments. + Examples + -------- + >>> def main_sp(config, args): + >>> exp = Experiment(config, args) + >>> exp.setup() + >>> exp.run() + >>> + >>> config = get_cfg_defaults() + >>> parser = default_argument_parser() + >>> args = parser.parse_args() + >>> if args.config: + >>> config.merge_from_file(args.config) + >>> if args.opts: + >>> config.merge_from_list(args.opts) + >>> config.freeze() + >>> + >>> if args.nprocs > 1 and args.device == "gpu": + >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + >>> else: + >>> main_sp(config, args) + """ + + def __init__(self, config, args): + self.config = config + self.args = args + + def setup(self): + """Setup the experiment. + """ + paddle.set_device(self.args.device) + if self.parallel: + self.init_parallel() + + self.setup_output_dir() + self.dump_config() + self.setup_visualizer() + self.setup_logger() + self.setup_checkpointer() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + @property + def parallel(self): + """A flag indicating whether the experiment should run with + multiprocessing. + """ + return self.args.device == "gpu" and self.args.nprocs > 1 + + def init_parallel(self): + """Init environment for multiprocess training. + """ + dist.init_parallel_env() + + def save(self): + """Save checkpoint (model parameters and optimizer states). + """ + checkpoint.save_parameters(self.checkpoint_dir, self.iteration, + self.model, self.optimizer) + + def load_or_resume(self): + """Resume from latest checkpoint at checkpoints in the output + directory or load a specified checkpoint. + + If ``args.checkpoint_path`` is not None, load the checkpoint, else + resume training. + """ + iteration = checkpoint.load_parameters( + self.model, + self.optimizer, + checkpoint_dir=self.checkpoint_dir, + checkpoint_path=self.args.checkpoint_path) + self.iteration = iteration + + def read_batch(self): + """Read a batch from the train_loader. + Returns + ------- + List[Tensor] + A batch. + """ + try: + batch = next(self.iterator) + except StopIteration: + self.new_epoch() + batch = next(self.iterator) + return batch + + def new_epoch(self): + """Reset the train loader and increment ``epoch``. + """ + self.epoch += 1 + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + self.iterator = iter(self.train_loader) + + def train(self): + """The training process. + + It includes forward/backward/update and periodical validation and + saving. + """ + self.new_epoch() + while self.iteration < self.config.training.max_iteration: + self.iteration += 1 + self.train_batch() + + if self.iteration % self.config.training.valid_interval == 0: + self.valid() + + if self.iteration % self.config.training.save_interval == 0: + self.save() + + def run(self): + """The routine of the experiment after setup. This method is intended + to be used by the user. + """ + self.load_or_resume() + try: + self.train() + except KeyboardInterrupt: + self.save() + exit(-1) + + @mp_tools.rank_zero_only + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir + + @mp_tools.rank_zero_only + def setup_checkpointer(self): + """Create a directory used to save checkpoints into. + + It is "checkpoints" inside the output directory. + """ + # checkpoint dir + checkpoint_dir = self.output_dir / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True) + + self.checkpoint_dir = checkpoint_dir + + @mp_tools.rank_zero_only + def setup_visualizer(self): + """Initialize a visualizer to log the experiment. + + The visual log is saved in the output directory. + + Notes + ------ + Only the main process has a visualizer with it. Use multiple + visualizers in multiprocess to write to a same log file may cause + unexpected behaviors. + """ + # visualizer + visualizer = SummaryWriter(logdir=str(self.output_dir)) + + self.visualizer = visualizer + + def setup_logger(self): + """Initialize a text logger to log the experiment. + + Each process has its own text logger. The logging message is write to + the standard output and a text file named ``worker_n.log`` in the + output directory, where ``n`` means the rank of the process. + """ + logger = logging.getLogger(__name__) + logger.setLevel("INFO") + logger.addHandler(logging.StreamHandler()) + log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank()) + logger.addHandler(logging.FileHandler(str(log_file))) + + self.logger = logger + + @mp_tools.rank_zero_only + def dump_config(self): + """Save the configuration used for this experiment. + + It is saved in to ``config.yaml`` in the output directory at the + beginning of the experiment. + """ + with open(self.output_dir / "config.yaml", 'wt') as f: + print(self.config, file=f) + + def train_batch(self): + """The training loop. A subclass should implement this method. + """ + raise NotImplementedError("train_batch should be implemented.") + + @mp_tools.rank_zero_only + @paddle.no_grad() + def valid(self): + """The validation. A subclass should implement this method. + """ + raise NotImplementedError("valid should be implemented.") + + def setup_model(self): + """Setup model, criterion and optimizer, etc. A subclass should + implement this method. + """ + raise NotImplementedError("setup_model should be implemented.") + + def setup_dataloader(self): + """Setup training dataloader and validation dataloader. A subclass + should implement this method. + """ + raise NotImplementedError("setup_dataloader should be implemented.") \ No newline at end of file diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 000000000..932432db1 --- /dev/null +++ b/training/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from training.trainer import * \ No newline at end of file diff --git a/training/trainer.py b/training/trainer.py new file mode 100644 index 000000000..d4173d5ec --- /dev/null +++ b/training/trainer.py @@ -0,0 +1,279 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import logging +from pathlib import Path +import numpy as np +from collections import defaultdict + +import paddle +from paddle import distributed as dist +from tensorboardX import SummaryWriter + +from utils import checkpoint +from utils import mp_tools + +__all__ = ["Trainer"] + + +class Trainer(): + """ + An experiment template in order to structure the training code and take + care of saving, loading, logging, visualization stuffs. It's intended to + be flexible and simple. + + So it only handles output directory (create directory for the output, + create a checkpoint directory, dump the config in use and create + visualizer and logger) in a standard way without enforcing any + input-output protocols to the model and dataloader. It leaves the main + part for the user to implement their own (setup the model, criterion, + optimizer, define a training step, define a validation function and + customize all the text and visual logs). + It does not save too much boilerplate code. The users still have to write + the forward/backward/update mannually, but they are free to add + non-standard behaviors if needed. + We have some conventions to follow. + 1. Experiment should have ``model``, ``optimizer``, ``train_loader`` and + ``valid_loader``, ``config`` and ``args`` attributes. + 2. The config should have a ``training`` field, which has + ``valid_interval``, ``save_interval`` and ``max_iteration`` keys. It is + used as the trigger to invoke validation, checkpointing and stop of the + experiment. + 3. There are four methods, namely ``train_batch``, ``valid``, + ``setup_model`` and ``setup_dataloader`` that should be implemented. + Feel free to add/overwrite other methods and standalone functions if you + need. + + Parameters + ---------- + config: yacs.config.CfgNode + The configuration used for the experiment. + + args: argparse.Namespace + The parsed command line arguments. + Examples + -------- + >>> def main_sp(config, args): + >>> exp = Trainer(config, args) + >>> exp.setup() + >>> exp.run() + >>> + >>> config = get_cfg_defaults() + >>> parser = default_argument_parser() + >>> args = parser.parse_args() + >>> if args.config: + >>> config.merge_from_file(args.config) + >>> if args.opts: + >>> config.merge_from_list(args.opts) + >>> config.freeze() + >>> + >>> if args.nprocs > 1 and args.device == "gpu": + >>> dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) + >>> else: + >>> main_sp(config, args) + """ + + def __init__(self, config, args): + self.config = config + self.args = args + + def setup(self): + """Setup the experiment. + """ + paddle.set_device(self.args.device) + if self.parallel: + self.init_parallel() + + self.setup_output_dir() + self.dump_config() + self.setup_visualizer() + self.setup_logger() + self.setup_checkpointer() + + self.setup_dataloader() + self.setup_model() + + self.iteration = 0 + self.epoch = 0 + + @property + def parallel(self): + """A flag indicating whether the experiment should run with + multiprocessing. + """ + return self.args.device == "gpu" and self.args.nprocs > 1 + + def init_parallel(self): + """Init environment for multiprocess training. + """ + dist.init_parallel_env() + + def save(self): + """Save checkpoint (model parameters and optimizer states). + """ + checkpoint.save_parameters(self.checkpoint_dir, self.iteration, + self.model, self.optimizer) + + def resume_or_load(self): + """Resume from latest checkpoint at checkpoints in the output + directory or load a specified checkpoint. + + If ``args.checkpoint_path`` is not None, load the checkpoint, else + resume training. + """ + iteration = checkpoint.load_parameters( + self.model, + self.optimizer, + checkpoint_dir=self.checkpoint_dir, + checkpoint_path=self.args.checkpoint_path) + self.iteration = iteration + + def read_batch(self): + """Read a batch from the train_loader. + Returns + ------- + List[Tensor] + A batch. + """ + try: + batch = next(self.iterator) + except StopIteration: + self.new_epoch() + batch = next(self.iterator) + return batch + + def new_epoch(self): + """Reset the train loader and increment ``epoch``. + """ + self.epoch += 1 + if self.parallel: + self.train_loader.batch_sampler.set_epoch(self.epoch) + self.iterator = iter(self.train_loader) + + def train(self): + """The training process. + + It includes forward/backward/update and periodical validation and + saving. + """ + self.new_epoch() + while self.iteration < self.config.training.max_iteration: + self.iteration += 1 + self.train_batch() + + if self.iteration % self.config.training.valid_interval == 0: + self.valid() + + if self.iteration % self.config.training.save_interval == 0: + self.save() + + def run(self): + """The routine of the experiment after setup. This method is intended + to be used by the user. + """ + self.resume_or_load() + try: + self.train() + except KeyboardInterrupt: + self.save() + exit(-1) + + @mp_tools.rank_zero_only + def setup_output_dir(self): + """Create a directory used for output. + """ + # output dir + output_dir = Path(self.args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + + self.output_dir = output_dir + + @mp_tools.rank_zero_only + def setup_checkpointer(self): + """Create a directory used to save checkpoints into. + + It is "checkpoints" inside the output directory. + """ + # checkpoint dir + checkpoint_dir = self.output_dir / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True) + + self.checkpoint_dir = checkpoint_dir + + @mp_tools.rank_zero_only + def setup_visualizer(self): + """Initialize a visualizer to log the experiment. + + The visual log is saved in the output directory. + + Notes + ------ + Only the main process has a visualizer with it. Use multiple + visualizers in multiprocess to write to a same log file may cause + unexpected behaviors. + """ + # visualizer + visualizer = SummaryWriter(logdir=str(self.output_dir)) + + self.visualizer = visualizer + + def setup_logger(self): + """Initialize a text logger to log the experiment. + + Each process has its own text logger. The logging message is write to + the standard output and a text file named ``worker_n.log`` in the + output directory, where ``n`` means the rank of the process. + """ + logger = logging.getLogger(__name__) + logger.setLevel("INFO") + logger.addHandler(logging.StreamHandler()) + log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank()) + logger.addHandler(logging.FileHandler(str(log_file))) + + self.logger = logger + + @mp_tools.rank_zero_only + def dump_config(self): + """Save the configuration used for this experiment. + + It is saved in to ``config.yaml`` in the output directory at the + beginning of the experiment. + """ + with open(self.output_dir / "config.yaml", 'wt') as f: + print(self.config, file=f) + + def train_batch(self): + """The training loop. A subclass should implement this method. + """ + raise NotImplementedError("train_batch should be implemented.") + + @mp_tools.rank_zero_only + @paddle.no_grad() + def valid(self): + """The validation. A subclass should implement this method. + """ + raise NotImplementedError("valid should be implemented.") + + def setup_model(self): + """Setup model, criterion and optimizer, etc. A subclass should + implement this method. + """ + raise NotImplementedError("setup_model should be implemented.") + + def setup_dataloader(self): + """Setup training dataloader and validation dataloader. A subclass + should implement this method. + """ + raise NotImplementedError("setup_dataloader should be implemented.") diff --git a/utils/checkpoint.py b/utils/checkpoint.py new file mode 100644 index 000000000..5a09f20a1 --- /dev/null +++ b/utils/checkpoint.py @@ -0,0 +1,135 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import numpy as np +import paddle +from paddle import distributed as dist +from paddle.nn import Layer +from paddle.optimizer import Optimizer + +from utils import mp_tools + +__all__ = ["load_parameters", "save_parameters"] + + +def _load_latest_checkpoint(checkpoint_dir: str) -> int: + """Get the iteration number corresponding to the latest saved checkpoint. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + Returns: + int: the latest iteration number. + """ + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") + if (not os.path.isfile(checkpoint_record)): + return 0 + + # Fetch the latest checkpoint index. + with open(checkpoint_record, "rt") as handle: + latest_checkpoint = handle.readlines()[-1].strip() + step = latest_checkpoint.split(":")[-1] + iteration = int(step.split("-")[-1]) + + return iteration + + +def _save_checkpoint(checkpoint_dir: str, iteration: int): + """Save the iteration number of the latest model to be checkpointed. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + Returns: + None + """ + checkpoint_record = os.path.join(checkpoint_dir, "checkpoint") + # Update the latest checkpoint index. + with open(checkpoint_record, "a+") as handle: + handle.write("model_checkpoint_path:step-{}\n".format(iteration)) + + +def load_parameters(model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None): + """Load a specific model checkpoint from disk. + Args: + model (Layer): model to load parameters. + optimizer (Optimizer, optional): optimizer to load states if needed. + Defaults to None. + checkpoint_dir (str, optional): the directory where checkpoint is saved. + checkpoint_path (str, optional): if specified, load the checkpoint + stored in the checkpoint_path and the argument 'checkpoint_dir' will + be ignored. Defaults to None. + Returns: + iteration (int): number of iterations that the loaded checkpoint has + been trained. + """ + if checkpoint_path is not None: + iteration = int(os.path.basename(checkpoint_path).split("-")[-1]) + elif checkpoint_dir is not None: + iteration = _load_latest_checkpoint(checkpoint_dir) + if iteration == 0: + return iteration + checkpoint_path = os.path.join(checkpoint_dir, + "step-{}".format(iteration)) + else: + raise ValueError( + "At least one of 'checkpoint_dir' and 'checkpoint_path' should be specified!" + ) + + rank = dist.get_rank() + + params_path = checkpoint_path + ".pdparams" + model_dict = paddle.load(params_path) + model.set_state_dict(model_dict) + print( + "[checkpoint] Rank {}: loaded model from {}".format(rank, params_path)) + + optimizer_path = checkpoint_path + ".pdopt" + if optimizer and os.path.isfile(optimizer_path): + optimizer_dict = paddle.load(optimizer_path) + optimizer.set_state_dict(optimizer_dict) + print("[checkpoint] Rank {}: loaded optimizer state from {}".format( + rank, optimizer_path)) + + return iteration + + +@mp_tools.rank_zero_only +def save_parameters(checkpoint_dir, iteration, model, optimizer=None): + """Checkpoint the latest trained model parameters. + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + model (Layer): model to be checkpointed. + optimizer (Optimizer, optional): optimizer to be checkpointed. + Defaults to None. + Returns: + None + """ + checkpoint_path = os.path.join(checkpoint_dir, "step-{}".format(iteration)) + + model_dict = model.state_dict() + params_path = checkpoint_path + ".pdparams" + paddle.save(model_dict, params_path) + print("[checkpoint] Saved model to {}".format(params_path)) + + if optimizer: + opt_dict = optimizer.state_dict() + optimizer_path = checkpoint_path + ".pdopt" + paddle.save(opt_dict, optimizer_path) + print("[checkpoint] Saved optimzier state to {}".format(optimizer_path)) + + _save_checkpoint(checkpoint_dir, iteration) diff --git a/utils/error_rate.py b/utils/error_rate.py index d80546ee2..3fb6b769c 100644 --- a/utils/error_rate.py +++ b/utils/error_rate.py @@ -14,9 +14,10 @@ """This module provides functions to calculate error rate in different level. e.g. wer for word-level, cer for char-level. """ - import numpy as np +__all__ = ['word_errors', 'char_errors', 'wer', 'cer'] + def _levenshtein_distance(ref, hyp): """Levenshtein distance is a string metric for measuring the difference diff --git a/model_utils/model_check.py b/utils/model_check.py similarity index 94% rename from model_utils/model_check.py rename to utils/model_check.py index bf2c424fd..e69c02ba3 100644 --- a/model_utils/model_check.py +++ b/utils/model_check.py @@ -38,12 +38,12 @@ def check_version(): Log error and exit when the installed version of paddlepaddle is not satisfied. """ - err = "PaddlePaddle version 1.6 or higher is required, " \ + err = "PaddlePaddle version 2.0.0 or higher is required, " \ "or a suitable develop version is satisfied as well. \n" \ "Please make sure the version is good with your code." \ try: - fluid.require_version('1.6.0') + fluid.require_version('2.0.0') except Exception as e: print(err) sys.exit(1) diff --git a/utils/mp_tools.py b/utils/mp_tools.py new file mode 100644 index 000000000..0daa62af2 --- /dev/null +++ b/utils/mp_tools.py @@ -0,0 +1,32 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import distributed as dist +from functools import wraps + +__all__ = ["rank_zero_only"] + + +def rank_zero_only(func): + rank = dist.get_rank() + + @wraps(func) + def wrapper(*args, **kwargs): + if rank != 0: + return + result = func(*args, **kwargs) + return result + + return wrapper \ No newline at end of file