From 4dc75c40c977e4a0b73de168e3c358aa4dda6b3f Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 9 Feb 2021 08:12:35 +0000 Subject: [PATCH] add config and train script --- data_utils/dataset.py | 50 ++++++----- data_utils/featurizer/audio_featurizer.py | 8 +- examples/tiny/conf/augmentation.config | 8 ++ examples/tiny/conf/deepspeech2.yaml | 39 ++++++++ examples/tiny/local/run_train.sh | 42 ++++----- model_utils/config.py | 2 +- model_utils/model.py | 36 ++++---- model_utils/network2_test.py | 104 ++++++++++++++++++++++ requirements.txt | 2 + train.py | 24 ++--- training/cli.py | 5 +- training/trainer.py | 15 +++- 12 files changed, 251 insertions(+), 84 deletions(-) create mode 100644 examples/tiny/conf/augmentation.config create mode 100644 examples/tiny/conf/deepspeech2.yaml create mode 100644 model_utils/network2_test.py diff --git a/data_utils/dataset.py b/data_utils/dataset.py index 151e6ef1c..e9b581d0b 100644 --- a/data_utils/dataset.py +++ b/data_utils/dataset.py @@ -109,6 +109,10 @@ class DeepSpeech2Dataset(Dataset): """ return self._speech_featurizer.vocab_list + @property + def feature_size(self): + return self._speech_featurizer.feature_size + def _parse_tar(self, file): """Parse a tar file to get a tarfile object and a map containing tarinfoes @@ -200,7 +204,7 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): self._sortagrad = sortagrad self._shuffle_method = shuffle_method - def _batch_shuffle(self, manifest, batch_size, clipped=False): + def _batch_shuffle(self, indices, batch_size, clipped=False): """Put similarly-sized instances into minibatches for better efficiency and make a batch-wise shuffle. @@ -210,8 +214,8 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): for different epochs. Create minibatches. 4. Shuffle the minibatches. - :param manifest: Manifest contents. List of dict. - :type manifest: list + :param indices: indexes. List of int. + :type indices: list :param batch_size: Batch size. This size is also used for generate a random number for batch shuffle. :type batch_size: int @@ -222,16 +226,16 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): :rtype: list """ rng = np.random.RandomState(self.epoch) - manifest.sort(key=lambda x: x["duration"]) shift_len = rng.randint(0, batch_size - 1) - batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size)) - rng.shuffle(batch_manifest) - batch_manifest = [item for batch in batch_manifest for item in batch] + batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) + rng.shuffle(batch_indices) + batch_indices = [item for batch in batch_indices for item in batch] + assert (clipped == False) if not clipped: - res_len = len(manifest) - shift_len - len(batch_manifest) - batch_manifest.extend(manifest[-res_len:]) - batch_manifest.extend(manifest[0:shift_len]) - return batch_manifest + res_len = len(indices) - shift_len - len(batch_indices) + batch_indices.extend(indices[-res_len:]) + batch_indices.extend(indices[0:shift_len]) + return batch_indices def __iter__(self): num_samples = len(self.dataset) @@ -336,7 +340,7 @@ class DeepSpeech2BatchSampler(BatchSampler): self._sortagrad = sortagrad self._shuffle_method = shuffle_method - def _batch_shuffle(self, manifest, batch_size, clipped=False): + def _batch_shuffle(self, indices, batch_size, clipped=False): """Put similarly-sized instances into minibatches for better efficiency and make a batch-wise shuffle. @@ -346,8 +350,8 @@ class DeepSpeech2BatchSampler(BatchSampler): for different epochs. Create minibatches. 4. Shuffle the minibatches. - :param manifest: Manifest contents. List of dict. - :type manifest: list + :param indices: indexes. List of int. + :type indices: list :param batch_size: Batch size. This size is also used for generate a random number for batch shuffle. :type batch_size: int @@ -358,16 +362,16 @@ class DeepSpeech2BatchSampler(BatchSampler): :rtype: list """ rng = np.random.RandomState(self.epoch) - manifest.sort(key=lambda x: x["duration"]) shift_len = rng.randint(0, batch_size - 1) - batch_manifest = list(zip(*[iter(manifest[shift_len:])] * batch_size)) - rng.shuffle(batch_manifest) - batch_manifest = [item for batch in batch_manifest for item in batch] + batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) + rng.shuffle(batch_indices) + batch_indices = [item for batch in batch_indices for item in batch] + assert (clipped == False) if not clipped: - res_len = len(manifest) - shift_len - len(batch_manifest) - batch_manifest.extend(manifest[-res_len:]) - batch_manifest.extend(manifest[0:shift_len]) - return batch_manifest + res_len = len(indices) - shift_len - len(batch_indices) + batch_indices.extend(indices[-res_len:]) + batch_indices.extend(indices[0:shift_len]) + return batch_indices def __iter__(self): num_samples = len(self.dataset) @@ -377,7 +381,7 @@ class DeepSpeech2BatchSampler(BatchSampler): # sort (by duration) or batch-wise shuffle the manifest if self.shuffle: - if self.epoch == 0 and self.sortagrad: + if self.epoch == 0 and self._sortagrad: pass else: if self._shuffle_method == "batch_shuffle": diff --git a/data_utils/featurizer/audio_featurizer.py b/data_utils/featurizer/audio_featurizer.py index ad9901d4a..7e04a03e2 100644 --- a/data_utils/featurizer/audio_featurizer.py +++ b/data_utils/featurizer/audio_featurizer.py @@ -103,15 +103,19 @@ class AudioFeaturizer(object): @property def feature_size(self): """audio feature size""" + feat_dim = 0 if self._specgram_type == 'linear': fft_point = self._window_ms if self._fft_point is None else self._fft_point - return fft_point * (self._target_sample_rate / 1000) / 2 + 1 + feat_dim = int(fft_point * (self._target_sample_rate / 1000) / 2 + + 1) elif self._specgram_type == 'mfcc': # mfcc,delta, delta-delta - return 13 * 3 + feat_dim = int(13 * 3) else: raise ValueError("Unknown specgram_type %s. " "Supported values: linear." % self._specgram_type) + print('feat_dim:', feat_dim) + return feat_dim def _compute_specgram(self, samples, sample_rate): """Extract various audio features.""" diff --git a/examples/tiny/conf/augmentation.config b/examples/tiny/conf/augmentation.config new file mode 100644 index 000000000..6c24da549 --- /dev/null +++ b/examples/tiny/conf/augmentation.config @@ -0,0 +1,8 @@ +[ + { + "type": "shift", + "params": {"min_shift_ms": -5, + "max_shift_ms": 5}, + "prob": 1.0 + } +] diff --git a/examples/tiny/conf/deepspeech2.yaml b/examples/tiny/conf/deepspeech2.yaml new file mode 100644 index 000000000..3ec051bf9 --- /dev/null +++ b/examples/tiny/conf/deepspeech2.yaml @@ -0,0 +1,39 @@ +# https://yaml.org/type/float.html +data: + train_manifest: data/manifest.tiny + dev_manifest: data/manifest.tiny + test_manifest: data/manifest.tiny + mean_std_filepath: data/mean_std.npz + vocab_filepath: data/vocab.txt + augmentation_config: conf/augmentation.config + batch_size: 4 + max_duration: 27.0 + min_duration: 0.0 + specgram_type: linear + target_sample_rate: 16000 + max_freq: None + n_fft: None + stride_ms: 10.0 + window_ms: 20.0 + use_dB_normalization: True + target_dB: -20 + random_seed: 0 + keep_transcription_text: False + sortagrad: True + shuffle_method: batch_shuffle + num_workers: 0 +model: + num_conv_layers: 2 + num_rnn_layers: 3 + rnn_layer_size: 2048 + use_gru: False + share_rnn_weights: True +training: + n_epoch: 20 + lr: 1e-5 + weight_decay: 1e-06 + global_grad_clip: 400.0 + max_iteration: 500000 + plot_interval: 1000 + save_interval: 1000 + valid_interval: 1000 diff --git a/examples/tiny/local/run_train.sh b/examples/tiny/local/run_train.sh index de9dcbd74..7880a4bba 100644 --- a/examples/tiny/local/run_train.sh +++ b/examples/tiny/local/run_train.sh @@ -3,33 +3,23 @@ # train model # if you wish to resume from an exists model, uncomment --init_from_pretrained_model export FLAGS_sync_nccl_allreduce=0 -CUDA_VISIBLE_DEVICES=0,1,2,3 \ + +#CUDA_VISIBLE_DEVICES=0,1,2,3 \ +#python3 -u ${MAIN_ROOT}/train.py \ +#--num_iter_print=1 \ +#--save_epoch=1 \ +#--num_samples=64 \ +#--test_off=False \ +#--is_local=True \ +#--output_model_dir="./checkpoints/" \ +#--shuffle_method="batch_shuffle_clipped" \ + +#CUDA_VISIBLE_DEVICES=0,1,2,3 \ +CUDA_VISIBLE_DEVICES=1,2,3 \ python3 -u ${MAIN_ROOT}/train.py \ ---batch_size=4 \ ---num_epoch=20 \ ---num_conv_layers=2 \ ---num_rnn_layers=3 \ ---rnn_layer_size=2048 \ ---num_iter_print=1 \ ---save_epoch=1 \ ---num_samples=64 \ ---learning_rate=1e-5 \ ---max_duration=27.0 \ ---min_duration=0.0 \ ---test_off=False \ ---use_sortagrad=True \ ---use_gru=False \ ---use_gpu=True \ ---is_local=True \ ---share_rnn_weights=True \ ---train_manifest="data/manifest.tiny" \ ---dev_manifest="data/manifest.tiny" \ ---mean_std_path="data/mean_std.npz" \ ---vocab_path="data/vocab.txt" \ ---output_model_dir="./checkpoints/" \ ---augment_conf_path="${MAIN_ROOT}/conf/augmentation.config" \ ---specgram_type="linear" \ ---shuffle_method="batch_shuffle_clipped" \ +--nproc 1 \ +--config conf/deepspeech2.yaml \ +--output ckpt if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/model_utils/config.py b/model_utils/config.py index fffec423e..ead2dbd5f 100644 --- a/model_utils/config.py +++ b/model_utils/config.py @@ -22,7 +22,7 @@ _C.data = CN( test_manifest="", vocab_filepath="", mean_std_filepath="", - augmentation_config='{}', + augmentation_config="", max_duration=float('inf'), min_duration=0.0, stride_ms=10.0, # ms diff --git a/model_utils/model.py b/model_utils/model.py index ecd8a3d5e..d845eb3c2 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -13,25 +13,26 @@ # limitations under the License. """Contains DeepSpeech2 model.""" +import io import sys import os import time -import logging -import gzip -import copy -import inspect -import collections -import multiprocessing import numpy as np -from distutils.dir_util import mkpath -import paddle.fluid as fluid +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from utils import mp_tools from training import Trainer from model_utils.network import DeepSpeech2 from model_utils.network import DeepSpeech2Loss -from model_utils.network import SpeechCollator + +from data_utils.dataset import SpeechCollator +from data_utils.dataset import DeepSpeech2Dataset +from data_utils.dataset import DeepSpeech2DistributedBatchSampler +from data_utils.dataset import DeepSpeech2BatchSampler from decoders.swig_wrapper import Scorer from decoders.swig_wrapper import ctc_greedy_decoder @@ -39,7 +40,8 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch class DeepSpeech2Trainer(Trainer): - def __init__(self): + def __init__(self, config, args): + super().__init__(config, args) self._ext_scorer = None def setup_dataloader(self): @@ -49,7 +51,9 @@ class DeepSpeech2Trainer(Trainer): config.data.train_manifest, config.data.vocab_filepath, config.data.mean_std_filepath, - augmentation_config=config.data.augmentation_config, + augmentation_config=io.open( + config.data.augmentation_config, mode='r', + encoding='utf8').read(), max_duration=config.data.max_duration, min_duration=config.data.min_duration, stride_ms=config.data.stride_ms, @@ -67,7 +71,7 @@ class DeepSpeech2Trainer(Trainer): config.data.dev_manifest, config.data.vocab_filepath, config.data.mean_std_filepath, - augmentation_config=config.data.augmentation_config, + augmentation_config="{}", max_duration=config.data.max_duration, min_duration=config.data.min_duration, stride_ms=config.data.stride_ms, @@ -117,8 +121,8 @@ class DeepSpeech2Trainer(Trainer): def setup_model(self): config = self.config model = DeepSpeech2( - feat_size=self.train_loader.feature_size, - dict_size=self.train_loader.vocab_size, + feat_size=self.train_loader.dataset.feature_size, + dict_size=self.train_loader.dataset.vocab_size, num_conv_layers=config.model.num_conv_layers, num_rnn_layers=config.model.num_rnn_layers, rnn_size=config.model.rnn_layer_size, @@ -133,11 +137,11 @@ class DeepSpeech2Trainer(Trainer): optimizer = paddle.optimizer.Adam( learning_rate=config.training.lr, parameters=model.parameters(), - weight_decay=paddle.regulaerizer.L2Decay( + weight_decay=paddle.regularizer.L2Decay( config.training.weight_decay), grad_clip=grad_clip) - criterion = DeepSpeech2Loss(self.train_loader.vocab_size) + criterion = DeepSpeech2Loss(self.train_loader.dataset.vocab_size) self.model = model self.optimizer = optimizer diff --git a/model_utils/network2_test.py b/model_utils/network2_test.py new file mode 100644 index 000000000..0064be21d --- /dev/null +++ b/model_utils/network2_test.py @@ -0,0 +1,104 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from network2 import DeepSpeech2 +import paddle +import numpy as np + +if __name__ == '__main__': + + batch_size = 2 + feat_dim = 161 + max_len = 100 + audio = np.random.randn(batch_size, feat_dim, max_len) + audio_len = np.random.randint(100, size=batch_size, dtype='int32') + audio_len[-1] = 100 + text = np.array([[1, 2], [1, 2]], dtype='int32') + text_len = np.array([2] * batch_size, dtype='int32') + + place = paddle.CUDAPinnedPlace() + audio = paddle.to_tensor( + audio, dtype='float32', place=place, stop_gradient=True) + audio_len = paddle.to_tensor( + audio_len, dtype='int64', place=place, stop_gradient=True) + text = paddle.to_tensor( + text, dtype='int32', place=place, stop_gradient=True) + text_len = paddle.to_tensor( + text_len, dtype='int64', place=place, stop_gradient=True) + + print(audio.shape) + print(audio_len.shape) + print(text.shape) + print(text_len.shape) + print("-----------------") + + model = DeepSpeech2( + feat_size=feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=False, ) + probs = model(audio, text, audio_len, text_len) + print('probs.shape', probs.shape) + print("-----------------") + + model2 = DeepSpeech2( + feat_size=feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=True, + share_rnn_weights=False, ) + probs = model2(audio, text, audio_len, text_len) + print('probs.shape', probs.shape) + print("-----------------") + + model3 = DeepSpeech2( + feat_size=feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=True, ) + probs = model3(audio, text, audio_len, text_len) + print('probs.shape', probs.shape) + print("-----------------") + + model4 = DeepSpeech2( + feat_size=feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=True, + share_rnn_weights=True, ) + probs = model4(audio, text, audio_len, text_len) + print('probs.shape', probs.shape) + print("-----------------") + + model5 = DeepSpeech2( + feat_size=feat_dim, + dict_size=10, + num_conv_layers=2, + num_rnn_layers=3, + rnn_size=1024, + use_gru=False, + share_rnn_weights=False, ) + probs = model5(audio, text, audio_len, text_len) + print('probs.shape', probs.shape) + print("-----------------") diff --git a/requirements.txt b/requirements.txt index af2993b6d..8ab09f626 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,5 @@ scipy==1.2.1 resampy==0.2.2 SoundFile==0.9.0.post1 python_speech_features +tensorboardX +yacs diff --git a/train.py b/train.py index f06c7627d..87bd33d07 100644 --- a/train.py +++ b/train.py @@ -13,20 +13,19 @@ # limitations under the License. """Trainer for DeepSpeech2 model.""" +import io +import logging import argparse import functools -import io -from utils.model_check import check_cuda, check_version +from paddle import distributed as dist + from utils.utility import print_arguments from training.cli import default_argument_parser from model_utils.config import get_cfg_defaults from model_utils.model import DeepSpeech2Trainer as Trainer -logging.basicConfig( - format='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') - def main_sp(config, args): exp = Trainer(config, args) @@ -35,26 +34,27 @@ def main_sp(config, args): def main(config, args): - # check if set use_gpu=True in paddlepaddle cpu version - check_cuda(args.device == 'gpu') - # check if paddlepaddle version is satisfied - check_version() - if args.nprocs > 1 and args.device == "gpu": + if args.device == "gpu" and args.nprocs > 1: dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs) else: main_sp(config, args) if __name__ == "__main__": - config = get_cfg_defaults() parser = default_argument_parser() args = parser.parse_args() + print_arguments(args) + + # https://yaml.org/type/float.html + config = get_cfg_defaults() if args.config: config.merge_from_file(args.config) if args.opts: config.merge_from_list(args.opts) config.freeze() print(config) - print_arguments(args) + if args.dump_config: + with open(args.dump_config, 'w') as f: + print(config, file=f) main(config, args) diff --git a/training/cli.py b/training/cli.py index c4a87a7f0..e0ebfc7de 100644 --- a/training/cli.py +++ b/training/cli.py @@ -47,6 +47,7 @@ def default_argument_parser(): # yapf: disable # data and output parser.add_argument("--config", metavar="FILE", help="path of the config file to overwrite to default config with.") + parser.add_argument("--dump-config", metavar="FILE", help="dump config to yaml file.") parser.add_argument("--data", metavar="DATA_DIR", help="path to the datatset.") parser.add_argument("--output", metavar="OUTPUT_DIR", help="path to save checkpoint and logs.") @@ -54,11 +55,11 @@ def default_argument_parser(): parser.add_argument("--checkpoint_path", type=str, help="path of the checkpoint to load") # running - parser.add_argument("--device", type=str, choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.") + parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.") parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") # overwrite extra config and default config parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") # yapd: enable - return parser \ No newline at end of file + return parser diff --git a/training/trainer.py b/training/trainer.py index d4173d5ec..e1b898df5 100644 --- a/training/trainer.py +++ b/training/trainer.py @@ -20,6 +20,7 @@ from collections import defaultdict import paddle from paddle import distributed as dist +from paddle.distributed.utils import get_gpus from tensorboardX import SummaryWriter from utils import checkpoint @@ -238,9 +239,19 @@ class Trainer(): """ logger = logging.getLogger(__name__) logger.setLevel("INFO") - logger.addHandler(logging.StreamHandler()) + + formatter = logging.Formatter( + fmt='[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s', + datefmt='%Y/%m/%d %H:%M:%S') + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank()) - logger.addHandler(logging.FileHandler(str(log_file))) + file_handler = logging.FileHandler(str(log_file)) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) self.logger = logger