diff --git a/.gitignore b/.gitignore index dde3895fc..0bd3d362c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ .DS_Store *.pyc +tools/venv +dataset diff --git a/data_utils/augmentor/augmentation.py b/data_utils/augmentor/augmentation.py index 349cdc564..f36d993e1 100644 --- a/data_utils/augmentor/augmentation.py +++ b/data_utils/augmentor/augmentation.py @@ -25,7 +25,7 @@ from data_utils.augmentor.online_bayesian_normalization import \ OnlineBayesianNormalizationAugmentor -class AugmentationPipeline(object): +class AugmentationPipeline(): """Build a pre-processing pipeline with various augmentation models.Such a data augmentation pipeline is oftern leveraged to augment the training samples to make the model invariant to certain types of perturbations in the diff --git a/data_utils/augmentor/base.py b/data_utils/augmentor/base.py index 5b80be2fe..0f7826cdf 100644 --- a/data_utils/augmentor/base.py +++ b/data_utils/augmentor/base.py @@ -16,7 +16,7 @@ from abc import ABCMeta, abstractmethod -class AugmentorBase(object): +class AugmentorBase(): """Abstract base class for augmentation model (augmentor) class. All augmentor classes should inherit from this class, and implement the following abstract methods. diff --git a/data_utils/dataset.py b/data_utils/dataset.py index 71d3a61e0..151e6ef1c 100644 --- a/data_utils/dataset.py +++ b/data_utils/dataset.py @@ -31,8 +31,10 @@ from data_utils.speech import SpeechSegment from data_utils.normalizer import FeatureNormalizer __all__ = [ - "DeepSpeech2Dataset", "DeepSpeech2DistributedBatchSampler", - "DeepSpeech2BatchSampler" + "DeepSpeech2Dataset", + "DeepSpeech2DistributedBatchSampler", + "DeepSpeech2BatchSampler", + "SpeechCollator", ] @@ -46,9 +48,12 @@ class DeepSpeech2Dataset(Dataset): min_duration=0.0, stride_ms=10.0, window_ms=20.0, + n_fft=None, max_freq=None, + target_sample_rate=16000, specgram_type='linear', use_dB_normalization=True, + target_dB=-20, random_seed=0, keep_transcription_text=False): super().__init__() @@ -63,8 +68,11 @@ class DeepSpeech2Dataset(Dataset): specgram_type=specgram_type, stride_ms=stride_ms, window_ms=window_ms, + n_fft=n_fft, max_freq=max_freq, - use_dB_normalization=use_dB_normalization) + target_sample_rate=target_sample_rate, + use_dB_normalization=use_dB_normalization, + target_dB=target_dB) self._rng = random.Random(random_seed) self._keep_transcription_text = keep_transcription_text # for caching tar files info @@ -459,6 +467,51 @@ class DeepSpeech2BatchSampler(BatchSampler): self.epoch = epoch +class SpeechCollator(): + def __init__(self, padding_to=-1): + """ + Padding audio features with zeros to make them have the same shape (or + a user-defined shape) within one bach. + + If ``padding_to`` is -1, the maximun shape in the batch will be used + as the target shape for padding. Otherwise, `padding_to` will be the + target shape (only refers to the second axis). + """ + self._padding_to = padding_to + + def __call__(self, batch): + new_batch = [] + # get target shape + max_length = max([audio.shape[1] for audio, _ in batch]) + if self._padding_to != -1: + if self._padding_to < max_length: + raise ValueError("If padding_to is not -1, it should be larger " + "than any instance's shape in the batch") + max_length = self._padding_to + max_text_length = max([len(text) for _, text in batch]) + # padding + padded_audios = [] + audio_lens = [] + texts, text_lens = [], [] + for audio, text in batch: + # audio + padded_audio = np.zeros([audio.shape[0], max_length]) + padded_audio[:, :audio.shape[1]] = audio + padded_audios.append(padded_audio) + audio_lens.append(audio.shape[1]) + # text + padded_text = np.zeros([max_text_length]) + padded_text[:len(text)] = text + texts.append(padded_text) + text_lens.append(len(text)) + + padded_audios = np.array(padded_audios).astype('float32') + audio_lens = np.array(audio_lens).astype('int64') + texts = np.array(texts).astype('int32') + text_lens = np.array(text_lens).astype('int64') + return padded_audios, texts, audio_lens, text_lens + + def create_dataloader(manifest_path, vocab_filepath, mean_std_filepath, diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py index 7e336969d..ad9901d4a 100644 --- a/data_utils/featurizer/audio_featurizer.py +++ b/data_utils/featurizer/audio_featurizer.py @@ -52,6 +52,7 @@ class AudioFeaturizer(object): specgram_type='linear', stride_ms=10.0, window_ms=20.0, + n_fft=None, max_freq=None, target_sample_rate=16000, use_dB_normalization=True, @@ -63,7 +64,7 @@ class AudioFeaturizer(object): self._target_sample_rate = target_sample_rate self._use_dB_normalization = use_dB_normalization self._target_dB = target_dB - self._fft_point = None + self._fft_point = n_fft def featurize(self, audio_segment, diff --git a/data_utils/featurizer/speech_featurizer.py b/data_utils/featurizer/speech_featurizer.py index 333a40cd1..1bbf2bf58 100644 --- a/data_utils/featurizer/speech_featurizer.py +++ b/data_utils/featurizer/speech_featurizer.py @@ -56,6 +56,7 @@ class SpeechFeaturizer(object): specgram_type='linear', stride_ms=10.0, window_ms=20.0, + n_fft=None, max_freq=None, target_sample_rate=16000, use_dB_normalization=True, @@ -64,6 +65,7 @@ class SpeechFeaturizer(object): specgram_type=specgram_type, stride_ms=stride_ms, window_ms=window_ms, + n_fft=n_fft, max_freq=max_freq, target_sample_rate=target_sample_rate, use_dB_normalization=use_dB_normalization, diff --git a/model_utils/config.py b/model_utils/config.py new file mode 100644 index 000000000..fffec423e --- /dev/null +++ b/model_utils/config.py @@ -0,0 +1,70 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from yacs.config import CfgNode as CN + +_C = CN() +_C.data = CN( + dict( + train_manifest="", + dev_manifest="", + test_manifest="", + vocab_filepath="", + mean_std_filepath="", + augmentation_config='{}', + max_duration=float('inf'), + min_duration=0.0, + stride_ms=10.0, # ms + window_ms=20.0, # ms + n_fft=None, # fft points + max_freq=None, # None for samplerate/2 + specgram_type='linear', # 'linear', 'mfcc' + target_sample_rate=16000, # sample rate + use_dB_normalization=True, + target_dB=-20, + random_seed=0, + keep_transcription_text=False, + batch_size=32, # batch size + num_workers=0, # data loader workers + sortagrad=False, # sorted in first epoch when True + shuffle_method="batch_shuffle", # 'batch_shuffle', 'instance_shuffle' + )) + +_C.model = CN( + dict( + num_conv_layers=2, #Number of stacking convolution layers. + num_rnn_layers=3, #Number of stacking RNN layers. + rnn_layer_size=1024, #RNN layer size (number of RNN cells). + use_gru=True, #Use gru if set True. Use simple rnn if set False. + share_rnn_weights=False #Whether to share input-hidden weights between forward and backward directional RNNs.Notice that for GRU, weight sharing is not supported. + )) + +_C.training = CN( + dict( + lr=5e-4, # learning rate + weight_decay=1e-6, # the coeff of weight decay + global_grad_clip=400.0, # the global norm clip + plot_interval=1000, # plot attention and spectrogram by step + valid_interval=1000, # validation by step + save_interval=1000, # checkpoint by step + max_iteration=500000, # max iteration to train by step + n_epoch=50, # train epochs + )) + + +def get_cfg_defaults(): + """Get a yacs CfgNode object with default values for my_project.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + return _C.clone() diff --git a/model_utils/model.py b/model_utils/model.py index dd621b053..ecd8a3d5e 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -28,61 +28,15 @@ from distutils.dir_util import mkpath import paddle.fluid as fluid from training import Trainer + from model_utils.network import DeepSpeech2 from model_utils.network import DeepSpeech2Loss +from model_utils.network import SpeechCollator from decoders.swig_wrapper import Scorer from decoders.swig_wrapper import ctc_greedy_decoder from decoders.swig_wrapper import ctc_beam_search_decoder_batch -logging.basicConfig( - format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') - - -class SpeechCollator(): - def __init__(self, padding_to=-1): - """ - Padding audio features with zeros to make them have the same shape (or - a user-defined shape) within one bach. - - If ``padding_to`` is -1, the maximun shape in the batch will be used - as the target shape for padding. Otherwise, `padding_to` will be the - target shape (only refers to the second axis). - """ - self._padding_to = padding_to - - def __call__(self, batch): - new_batch = [] - # get target shape - max_length = max([audio.shape[1] for audio, _ in batch]) - if self._padding_to != -1: - if self._padding_to < max_length: - raise ValueError("If padding_to is not -1, it should be larger " - "than any instance's shape in the batch") - max_length = self._padding_to - max_text_length = max([len(text) for _, text in batch]) - # padding - padded_audios = [] - audio_lens = [] - texts, text_lens = [], [] - for audio, text in batch: - # audio - padded_audio = np.zeros([audio.shape[0], max_length]) - padded_audio[:, :audio.shape[1]] = audio - padded_audios.append(padded_audio) - audio_lens.append(audio.shape[1]) - # text - padded_text = np.zeros([max_text_length]) - padded_text[:len(text)] = text - texts.append(padded_text) - text_lens.append(len(text)) - - padded_audios = np.array(padded_audios).astype('float32') - audio_lens = np.array(audio_lens).astype('int64') - texts = np.array(texts).astype('int32') - text_lens = np.array(text_lens).astype('int64') - return padded_audios, texts, audio_lens, text_lens - class DeepSpeech2Trainer(Trainer): def __init__(self): @@ -92,7 +46,7 @@ class DeepSpeech2Trainer(Trainer): config = self.config train_dataset = DeepSpeech2Dataset( - config.data.train_manifest_path, + config.data.train_manifest, config.data.vocab_filepath, config.data.mean_std_filepath, augmentation_config=config.data.augmentation_config, @@ -100,14 +54,17 @@ class DeepSpeech2Trainer(Trainer): min_duration=config.data.min_duration, stride_ms=config.data.stride_ms, window_ms=config.data.window_ms, + n_fft=config.data.n_fft, max_freq=config.data.max_freq, + target_sample_rate=config.data.target_sample_rate, specgram_type=config.data.specgram_type, use_dB_normalization=config.data.use_dB_normalization, + target_dB=config.data.target_dB, random_seed=config.data.random_seed, keep_transcription_text=False) dev_dataset = DeepSpeech2Dataset( - config.data.dev_manifest_path, + config.data.dev_manifest, config.data.vocab_filepath, config.data.mean_std_filepath, augmentation_config=config.data.augmentation_config, @@ -115,9 +72,12 @@ class DeepSpeech2Trainer(Trainer): min_duration=config.data.min_duration, stride_ms=config.data.stride_ms, window_ms=config.data.window_ms, + n_fft=config.data.n_fft, max_freq=config.data.max_freq, + target_sample_rate=config.data.target_sample_rate, specgram_type=config.data.specgram_type, use_dB_normalization=config.data.use_dB_normalization, + target_dB=config.data.target_dB, random_seed=config.data.random_seed, keep_transcription_text=False) @@ -167,14 +127,15 @@ class DeepSpeech2Trainer(Trainer): if self.parallel: model = paddle.DataParallel(model) - grad_clip = paddle.nn.ClipGradByGlobalNorm(config.training.grad_clip) + grad_clip = paddle.nn.ClipGradByGlobalNorm( + config.training.global_grad_clip) optimizer = paddle.optimizer.Adam( learning_rate=config.training.lr, parameters=model.parameters(), weight_decay=paddle.regulaerizer.L2Decay( config.training.weight_decay), - grad_clip=grad_clip, ) + grad_clip=grad_clip) criterion = DeepSpeech2Loss(self.train_loader.vocab_size) @@ -255,7 +216,7 @@ class DeepSpeech2Trainer(Trainer): """ self.model.eval() audio, text, audio_len, text_len = infer_data - logits, probs = self.model.predict(audio, audio_len) + _, probs = self.model.predict(audio, audio_len) return probs def decode_batch_greedy(self, probs_split, vocab_list): diff --git a/train.py b/train.py index 067f6d786..f06c7627d 100644 --- a/train.py +++ b/train.py @@ -16,137 +16,45 @@ import argparse import functools import io -from model_utils.model import DeepSpeech2Model -from model_utils.model_check import check_cuda, check_version -from data_utils.data import DataGenerator -from utils.utility import add_arguments, print_arguments -import paddle.fluid as fluid +from utils.model_check import check_cuda, check_version +from utils.utility import print_arguments +from training.cli import default_argument_parser -parser = argparse.ArgumentParser(description=__doc__) -add_arg = functools.partial(add_arguments, argparser=parser) -# yapf: disable -add_arg('batch_size', int, 256, "Minibatch size.") -add_arg('num_epoch', int, 200, "# of training epochs.") -add_arg('num_conv_layers', int, 2, "# of convolution layers.") -add_arg('num_rnn_layers', int, 3, "# of recurrent layers.") -add_arg('rnn_layer_size', int, 2048, "# of recurrent cells per layer.") -add_arg('num_iter_print', int, 100, "Every # batch for printing " - "train cost.") -add_arg('save_epoch', int, 10, "# Every # batch for save checkpoint and modle params ") -add_arg('num_samples', int, 10000, "The num of train samples.") -add_arg('learning_rate', float, 5e-4, "Learning rate.") -add_arg('max_duration', float, 27.0, "Longest audio duration allowed.") -add_arg('min_duration', float, 0.0, "Shortest audio duration allowed.") -add_arg('test_off', bool, False, "Turn off testing.") -add_arg('use_sortagrad', bool, True, "Use SortaGrad or not.") -add_arg('use_gpu', bool, True, "Use GPU or not.") -add_arg('use_gru', bool, False, "Use GRUs instead of simple RNNs.") -add_arg('is_local', bool, True, "Use pserver or not.") -add_arg('share_rnn_weights',bool, True, "Share input-hidden weights across " - "bi-directional RNNs. Not for GRU.") -add_arg('init_from_pretrained_model',str, - None, - "If None, the training starts from scratch, " - "otherwise, it resumes from the pre-trained model.") +from model_utils.config import get_cfg_defaults +from model_utils.model import DeepSpeech2Trainer as Trainer -add_arg('train_manifest', str, - 'data/librispeech/manifest.train', - "Filepath of train manifest.") -add_arg('dev_manifest', str, - 'data/librispeech/manifest.dev-clean', - "Filepath of validation manifest.") -add_arg('mean_std_path', str, - 'data/librispeech/mean_std.npz', - "Filepath of normalizer's mean & std.") -add_arg('vocab_path', str, - 'data/librispeech/vocab.txt', - "Filepath of vocabulary.") -add_arg('output_model_dir', str, - "./checkpoints/libri", - "Directory for saving checkpoints.") -add_arg('augment_conf_path',str, - 'conf/augmentation.config', - "Filepath of augmentation configuration file (json-format).") -add_arg('specgram_type', str, - 'linear', - "Audio feature type. Options: linear, mfcc.", - choices=['linear', 'mfcc']) -add_arg('shuffle_method', str, - 'batch_shuffle_clipped', - "Shuffle method.", - choices=['instance_shuffle', 'batch_shuffle', 'batch_shuffle_clipped']) -# yapf: disable -args = parser.parse_args() +logging.basicConfig( + format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') -def train(): - """DeepSpeech2 training.""" +def main_sp(config, args): + exp = Trainer(config, args) + exp.setup() + exp.run() + +def main(config, args): # check if set use_gpu=True in paddlepaddle cpu version - check_cuda(args.use_gpu) + check_cuda(args.device == 'gpu') # check if paddlepaddle version is satisfied check_version() - - if args.use_gpu: - place = fluid.CUDAPlace(0) + if args.nprocs > 1 and args.device == "gpu": + dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: - place = fluid.CPUPlace() - - train_generator = DataGenerator( - vocab_filepath=args.vocab_path, - mean_std_filepath=args.mean_std_path, - augmentation_config=io.open(args.augment_conf_path, mode='r', encoding='utf8').read(), - max_duration=args.max_duration, - min_duration=args.min_duration, - specgram_type=args.specgram_type, - place=place) - dev_generator = DataGenerator( - vocab_filepath=args.vocab_path, - mean_std_filepath=args.mean_std_path, - augmentation_config="{}", - specgram_type=args.specgram_type, - place = place) - train_batch_reader = train_generator.batch_reader_creator( - manifest_path=args.train_manifest, - batch_size=args.batch_size, - sortagrad=args.use_sortagrad if args.init_from_pretrained_model is None else False, - shuffle_method=args.shuffle_method) - dev_batch_reader = dev_generator.batch_reader_creator( - manifest_path=args.dev_manifest, - batch_size=args.batch_size, - sortagrad=False, - shuffle_method=None) - - ds2_model = DeepSpeech2Model( - vocab_size=train_generator.vocab_size, - num_conv_layers=args.num_conv_layers, - num_rnn_layers=args.num_rnn_layers, - rnn_layer_size=args.rnn_layer_size, - use_gru=args.use_gru, - share_rnn_weights=args.share_rnn_weights, - place=place, - init_from_pretrained_model=args.init_from_pretrained_model, - output_model_dir=args.output_model_dir) - - ds2_model.train( - train_batch_reader=train_batch_reader, - dev_batch_reader=dev_batch_reader, - feeding_dict=train_generator.feeding, - learning_rate=args.learning_rate, - gradient_clipping=400, - batch_size=args.batch_size, - num_samples=args.num_samples, - num_epoch=args.num_epoch, - save_epoch=args.save_epoch, - num_iterations_print=args.num_iter_print, - test_off=args.test_off) - - -def main(): + main_sp(config, args) + + +if __name__ == "__main__": + config = get_cfg_defaults() + parser = default_argument_parser() + args = parser.parse_args() + if args.config: + config.merge_from_file(args.config) + if args.opts: + config.merge_from_list(args.opts) + config.freeze() + print(config) print_arguments(args) - train() - -if __name__ == '__main__': - main() + main(config, args) diff --git a/training/cli.py b/training/cli.py new file mode 100644 index 000000000..c4a87a7f0 --- /dev/null +++ b/training/cli.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + + +def default_argument_parser(): + r"""A simple yet genral argument parser for experiments with parakeet. + + This is used in examples with parakeet. And it is intended to be used by + other experiments with parakeet. It requires a minimal set of command line + arguments to start a training script. + + The ``--config`` and ``--opts`` are used for overwrite the deault + configuration. + + The ``--data`` and ``--output`` specifies the data path and output path. + Resuming training from existing progress at the output directory is the + intended default behavior. + + The ``--checkpoint_path`` specifies the checkpoint to load from. + + The ``--device`` and ``--nprocs`` specifies how to run the training. + + + See Also + -------- + parakeet.training.experiment + Returns + ------- + argparse.ArgumentParser + the parser + """ + parser = argparse.ArgumentParser() + + # yapf: disable + # data and output + parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") + parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") + parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") + + # load from saved checkpoint + parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") + + # running + parser.add_argument("--device", type=str, choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.") + parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") + + # overwrite extra config and default config + parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") + # yapd: enable + + return parser \ No newline at end of file