From a94fc3f6edc334995180f9494dd0c7b575a88a8c Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 20 Feb 2021 10:07:12 +0000 Subject: [PATCH] fix dataset batch shuffle and add batch sampler log print model parameter --- data_utils/dataset.py | 109 ++++++------------------- examples/aishell/conf/deepspeech2.yaml | 2 +- examples/tiny/local/run_train.sh | 2 +- model_utils/model.py | 27 +++--- training/trainer.py | 16 ++-- 5 files changed, 43 insertions(+), 113 deletions(-) diff --git a/data_utils/dataset.py b/data_utils/dataset.py index 4a01b7298..a2c4fdcb5 100644 --- a/data_utils/dataset.py +++ b/data_utils/dataset.py @@ -17,13 +17,15 @@ import random import tarfile import logging import numpy as np +from collections import namedtuple +from functools import partial + import paddle from paddle.io import Dataset from paddle.io import DataLoader from paddle.io import BatchSampler from paddle.io import DistributedBatchSampler -from collections import namedtuple -from functools import partial +from paddle import distributed as dist from data_utils.utility import read_manifest from data_utils.augmentor.augmentation import AugmentationPipeline @@ -229,8 +231,7 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): :rtype: list """ rng = np.random.RandomState(self.epoch) - # must shift at leat by one - shift_len = rng.randint(1, batch_size - 1) + shift_len = rng.randint(0, batch_size - 1) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] @@ -255,8 +256,13 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): # sort (by duration) or batch-wise shuffle the manifest if self.shuffle: if self.epoch == 0 and self._sortagrad: - logger.info(f'dataset sortagrad! epoch {self.epoch}') + logger.info( + f'rank: {dist.get_rank()} dataset sortagrad! epoch {self.epoch}' + ) else: + logger.info( + f'rank: {dist.get_rank()} dataset shuffle! epoch {self.epoch}' + ) if self._shuffle_method == "batch_shuffle": indices = self._batch_shuffle( indices, self.batch_size, clipped=False) @@ -268,7 +274,6 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): assert len( indices ) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}" - self.epoch += 1 # subsample def _get_indices_by_batch_size(indices): @@ -298,6 +303,8 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler): for idx in _sample_iter: batch_indices.append(idx) if len(batch_indices) == self.batch_size: + logger.info( + f"rank: {dist.get_rank()} batch index: {batch_indices} ") yield batch_indices batch_indices = [] if not self.drop_last and len(batch_indices) > 0: @@ -316,9 +323,7 @@ class DeepSpeech2BatchSampler(BatchSampler): shuffle=False, drop_last=False, sortagrad=False, - shuffle_method="batch_shuffle", - num_replicas=1, - rank=0): + shuffle_method="batch_shuffle"): self.dataset = dataset assert isinstance(batch_size, int) and batch_size > 0, \ @@ -330,24 +335,10 @@ class DeepSpeech2BatchSampler(BatchSampler): assert isinstance(drop_last, bool), \ "drop_last should be a boolean number" - if num_replicas is not None: - assert isinstance(num_replicas, int) and num_replicas > 0, \ - "num_replicas should be a positive integer" - self.nranks = num_replicas - else: - self.nranks = num_replicas - - if rank is not None: - assert isinstance(rank, int) and rank >= 0, \ - "rank should be a non-negative integer" - self.local_rank = rank - else: - self.local_rank = rank - self.drop_last = drop_last self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks)) - self.total_size = self.num_samples * self.nranks + self.num_samples = int(math.ceil(len(self.dataset) * 1.0)) + self.total_size = self.num_samples self._sortagrad = sortagrad self._shuffle_method = shuffle_method @@ -374,7 +365,7 @@ class DeepSpeech2BatchSampler(BatchSampler): """ rng = np.random.RandomState(self.epoch) # must shift at leat by one - shift_len = rng.randint(1, batch_size - 1) + shift_len = rng.randint(0, batch_size - 1) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) rng.shuffle(batch_indices) batch_indices = [item for batch in batch_indices for item in batch] @@ -401,6 +392,7 @@ class DeepSpeech2BatchSampler(BatchSampler): if self.epoch == 0 and self._sortagrad: logger.info(f'dataset sortagrad! epoch {self.epoch}') else: + logger.info(f'dataset shuffle! epoch {self.epoch}') if self._shuffle_method == "batch_shuffle": indices = self._batch_shuffle( indices, self.batch_size, clipped=False) @@ -412,28 +404,6 @@ class DeepSpeech2BatchSampler(BatchSampler): assert len( indices ) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}" - self.epoch += 1 - - # subsample - def _get_indices_by_batch_size(indices): - subsampled_indices = [] - last_batch_size = self.total_size % (self.batch_size * self.nranks) - assert last_batch_size % self.nranks == 0 - last_local_batch_size = last_batch_size // self.nranks - - for i in range(self.local_rank * self.batch_size, - len(indices) - last_batch_size, - self.batch_size * self.nranks): - subsampled_indices.extend(indices[i:i + self.batch_size]) - - indices = indices[len(indices) - last_batch_size:] - subsampled_indices.extend( - indices[self.local_rank * last_local_batch_size:( - self.local_rank + 1) * last_local_batch_size]) - return subsampled_indices - - if self.nranks > 1: - indices = _get_indices_by_batch_size(indices) assert len(indices) == self.num_samples _sample_iter = iter(indices) @@ -442,53 +412,20 @@ class DeepSpeech2BatchSampler(BatchSampler): for idx in _sample_iter: batch_indices.append(idx) if len(batch_indices) == self.batch_size: + logger.info( + f"rank: {dist.get_rank()} batch index: {batch_indices} ") yield batch_indices batch_indices = [] if not self.drop_last and len(batch_indices) > 0: yield batch_indices + self.epoch += 1 + def __len__(self): num_samples = self.num_samples num_samples += int(not self.drop_last) * (self.batch_size - 1) return num_samples // self.batch_size - def set_epoch(self, epoch): - """ - Sets the epoch number. When :attr:`shuffle=True`, this number is used - as seeds of random numbers. By default, users may not set this, all - replicas (workers) use a different random ordering for each epoch. - If set same number at each epoch, this sampler will yield the same - ordering at all epoches. - Arguments: - epoch (int): Epoch number. - Examples: - .. code-block:: python - - import numpy as np - - from paddle.io import Dataset, DistributedBatchSampler - - # init with dataset - class RandomDataset(Dataset): - def __init__(self, num_samples): - self.num_samples = num_samples - - def __getitem__(self, idx): - image = np.random.random([784]).astype('float32') - label = np.random.randint(0, 9, (1, )).astype('int64') - return image, label - - def __len__(self): - return self.num_samples - - dataset = RandomDataset(100) - sampler = DistributedBatchSampler(dataset, batch_size=64) - - for epoch in range(10): - sampler.set_epoch(epoch) - """ - self.epoch = epoch - class SpeechCollator(): def __init__(self, padding_to=-1): @@ -532,4 +469,4 @@ class SpeechCollator(): audio_lens = np.array(audio_lens).astype('int64') texts = np.array(texts).astype('int32') text_lens = np.array(text_lens).astype('int64') - return padded_audios, texts, audio_lens, text_lens \ No newline at end of file + return padded_audios, texts, audio_lens, text_lens diff --git a/examples/aishell/conf/deepspeech2.yaml b/examples/aishell/conf/deepspeech2.yaml index 552d114c5..d2d46eb44 100644 --- a/examples/aishell/conf/deepspeech2.yaml +++ b/examples/aishell/conf/deepspeech2.yaml @@ -6,7 +6,7 @@ data: mean_std_filepath: data/mean_std.npz vocab_filepath: data/vocab.txt augmentation_config: conf/augmentation.config - batch_size: 16 # one gpu + batch_size: 64 # one gpu max_duration: 27.0 min_duration: 0.0 specgram_type: linear diff --git a/examples/tiny/local/run_train.sh b/examples/tiny/local/run_train.sh index 9c81e49b5..8899d2fd1 100644 --- a/examples/tiny/local/run_train.sh +++ b/examples/tiny/local/run_train.sh @@ -6,7 +6,7 @@ export FLAGS_sync_nccl_allreduce=0 CUDA_VISIBLE_DEVICES=0,1 \ python3 -u ${MAIN_ROOT}/train.py \ --device 'gpu' \ ---nproc 1 \ +--nproc 2 \ --config conf/deepspeech2.yaml \ --output ckpt diff --git a/model_utils/model.py b/model_utils/model.py index f115028d0..d3a41a34f 100644 --- a/model_utils/model.py +++ b/model_utils/model.py @@ -41,22 +41,6 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch from utils.error_rate import char_errors, word_errors, cer, wer -# class ExponentialDecayOld(LRScheduler): -# def __init__(self, learning_rate, gamma, last_epoch=-1, -# decay_steps, decay_rate, staircase=False, verbose=False): -# self.learning_rate = learning_rate -# self.decay_steps = decay_steps -# self.decay_rate = decay_rate -# self.staircase = staircase -# super(ExponentialDecay, self).__init__(learning_rate, last_epoch, -# verbose) - -# def get_lr(self): -# div_res = self.step_num / self.decay_steps -# if self.staircase: -# div_res = paddle.floor(div_res) -# return self.base_lr * (self.decay_rate**div_res) - class DeepSpeech2Trainer(Trainer): def __init__(self, config, args): @@ -114,6 +98,8 @@ class DeepSpeech2Trainer(Trainer): It includes forward/backward/update and periodical validation and saving. """ + self.logger.info( + f"Train Total Examples: {len(self.train_loader.dataset)}") self.new_epoch() while self.epoch <= self.config.training.n_epoch: for batch in self.train_loader: @@ -137,6 +123,8 @@ class DeepSpeech2Trainer(Trainer): @mp_tools.rank_zero_only @paddle.no_grad() def valid(self): + self.logger.info( + f"Valid Total Examples: {len(self.valid_loader.dataset)}") self.model.eval() valid_losses = defaultdict(list) for i, batch in enumerate(self.valid_loader): @@ -178,6 +166,10 @@ class DeepSpeech2Trainer(Trainer): if self.parallel: model = paddle.DataParallel(model) + for n, p in model.named_parameters(): + self.logger.info( + f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}") + grad_clip = paddle.nn.ClipGradByGlobalNorm( config.training.global_grad_clip) @@ -341,9 +333,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer): @mp_tools.rank_zero_only @paddle.no_grad() def test(self): + self.logger.info( + f"Test Total Examples: {len(self.test_loader.dataset)}") self.model.eval() losses = defaultdict(list) - cfg = self.config # decoders only accept string encoded in utf-8 vocab_list = self.test_loader.dataset.vocab_list diff --git a/training/trainer.py b/training/trainer.py index 88ef847be..930b82818 100644 --- a/training/trainer.py +++ b/training/trainer.py @@ -254,18 +254,18 @@ class Trainer(): formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S') - #stream_handler = logging.StreamHandler() - #stream_handler.setFormatter(formatter) - #logger.addHandler(stream_handler) + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank()) - file_handler = logging.FileHandler(str(log_file)) - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) + # file_handler = logging.FileHandler(str(log_file)) + # file_handler.setFormatter(formatter) + # logger.addHandler(file_handler) # global logger - stdout = True - save_path = '/dev/null' + stdout = False + save_path = log_file logging.basicConfig( level=logging.DEBUG if stdout else logging.INFO, format=format,