fix dataset batch shuffle and add batch sampler log

print model parameter
pull/522/head
Hui Zhang 5 years ago
parent 6f5b837e54
commit a94fc3f6ed

@ -17,13 +17,15 @@ import random
import tarfile import tarfile
import logging import logging
import numpy as np import numpy as np
from collections import namedtuple
from functools import partial
import paddle import paddle
from paddle.io import Dataset from paddle.io import Dataset
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from collections import namedtuple from paddle import distributed as dist
from functools import partial
from data_utils.utility import read_manifest from data_utils.utility import read_manifest
from data_utils.augmentor.augmentation import AugmentationPipeline from data_utils.augmentor.augmentation import AugmentationPipeline
@ -229,8 +231,7 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler):
:rtype: list :rtype: list
""" """
rng = np.random.RandomState(self.epoch) rng = np.random.RandomState(self.epoch)
# must shift at leat by one shift_len = rng.randint(0, batch_size - 1)
shift_len = rng.randint(1, batch_size - 1)
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
@ -255,8 +256,13 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler):
# sort (by duration) or batch-wise shuffle the manifest # sort (by duration) or batch-wise shuffle the manifest
if self.shuffle: if self.shuffle:
if self.epoch == 0 and self._sortagrad: if self.epoch == 0 and self._sortagrad:
logger.info(f'dataset sortagrad! epoch {self.epoch}') logger.info(
f'rank: {dist.get_rank()} dataset sortagrad! epoch {self.epoch}'
)
else: else:
logger.info(
f'rank: {dist.get_rank()} dataset shuffle! epoch {self.epoch}'
)
if self._shuffle_method == "batch_shuffle": if self._shuffle_method == "batch_shuffle":
indices = self._batch_shuffle( indices = self._batch_shuffle(
indices, self.batch_size, clipped=False) indices, self.batch_size, clipped=False)
@ -268,7 +274,6 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler):
assert len( assert len(
indices indices
) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}" ) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}"
self.epoch += 1
# subsample # subsample
def _get_indices_by_batch_size(indices): def _get_indices_by_batch_size(indices):
@ -298,6 +303,8 @@ class DeepSpeech2DistributedBatchSampler(DistributedBatchSampler):
for idx in _sample_iter: for idx in _sample_iter:
batch_indices.append(idx) batch_indices.append(idx)
if len(batch_indices) == self.batch_size: if len(batch_indices) == self.batch_size:
logger.info(
f"rank: {dist.get_rank()} batch index: {batch_indices} ")
yield batch_indices yield batch_indices
batch_indices = [] batch_indices = []
if not self.drop_last and len(batch_indices) > 0: if not self.drop_last and len(batch_indices) > 0:
@ -316,9 +323,7 @@ class DeepSpeech2BatchSampler(BatchSampler):
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
sortagrad=False, sortagrad=False,
shuffle_method="batch_shuffle", shuffle_method="batch_shuffle"):
num_replicas=1,
rank=0):
self.dataset = dataset self.dataset = dataset
assert isinstance(batch_size, int) and batch_size > 0, \ assert isinstance(batch_size, int) and batch_size > 0, \
@ -330,24 +335,10 @@ class DeepSpeech2BatchSampler(BatchSampler):
assert isinstance(drop_last, bool), \ assert isinstance(drop_last, bool), \
"drop_last should be a boolean number" "drop_last should be a boolean number"
if num_replicas is not None:
assert isinstance(num_replicas, int) and num_replicas > 0, \
"num_replicas should be a positive integer"
self.nranks = num_replicas
else:
self.nranks = num_replicas
if rank is not None:
assert isinstance(rank, int) and rank >= 0, \
"rank should be a non-negative integer"
self.local_rank = rank
else:
self.local_rank = rank
self.drop_last = drop_last self.drop_last = drop_last
self.epoch = 0 self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks)) self.num_samples = int(math.ceil(len(self.dataset) * 1.0))
self.total_size = self.num_samples * self.nranks self.total_size = self.num_samples
self._sortagrad = sortagrad self._sortagrad = sortagrad
self._shuffle_method = shuffle_method self._shuffle_method = shuffle_method
@ -374,7 +365,7 @@ class DeepSpeech2BatchSampler(BatchSampler):
""" """
rng = np.random.RandomState(self.epoch) rng = np.random.RandomState(self.epoch)
# must shift at leat by one # must shift at leat by one
shift_len = rng.randint(1, batch_size - 1) shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
@ -401,6 +392,7 @@ class DeepSpeech2BatchSampler(BatchSampler):
if self.epoch == 0 and self._sortagrad: if self.epoch == 0 and self._sortagrad:
logger.info(f'dataset sortagrad! epoch {self.epoch}') logger.info(f'dataset sortagrad! epoch {self.epoch}')
else: else:
logger.info(f'dataset shuffle! epoch {self.epoch}')
if self._shuffle_method == "batch_shuffle": if self._shuffle_method == "batch_shuffle":
indices = self._batch_shuffle( indices = self._batch_shuffle(
indices, self.batch_size, clipped=False) indices, self.batch_size, clipped=False)
@ -412,28 +404,6 @@ class DeepSpeech2BatchSampler(BatchSampler):
assert len( assert len(
indices indices
) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}" ) == self.total_size, f"batch shuffle examples error: {len(indices)} : {self.total_size}"
self.epoch += 1
# subsample
def _get_indices_by_batch_size(indices):
subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks
for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(
indices[self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)
assert len(indices) == self.num_samples assert len(indices) == self.num_samples
_sample_iter = iter(indices) _sample_iter = iter(indices)
@ -442,53 +412,20 @@ class DeepSpeech2BatchSampler(BatchSampler):
for idx in _sample_iter: for idx in _sample_iter:
batch_indices.append(idx) batch_indices.append(idx)
if len(batch_indices) == self.batch_size: if len(batch_indices) == self.batch_size:
logger.info(
f"rank: {dist.get_rank()} batch index: {batch_indices} ")
yield batch_indices yield batch_indices
batch_indices = [] batch_indices = []
if not self.drop_last and len(batch_indices) > 0: if not self.drop_last and len(batch_indices) > 0:
yield batch_indices yield batch_indices
self.epoch += 1
def __len__(self): def __len__(self):
num_samples = self.num_samples num_samples = self.num_samples
num_samples += int(not self.drop_last) * (self.batch_size - 1) num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size return num_samples // self.batch_size
def set_epoch(self, epoch):
"""
Sets the epoch number. When :attr:`shuffle=True`, this number is used
as seeds of random numbers. By default, users may not set this, all
replicas (workers) use a different random ordering for each epoch.
If set same number at each epoch, this sampler will yield the same
ordering at all epoches.
Arguments:
epoch (int): Epoch number.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for epoch in range(10):
sampler.set_epoch(epoch)
"""
self.epoch = epoch
class SpeechCollator(): class SpeechCollator():
def __init__(self, padding_to=-1): def __init__(self, padding_to=-1):

@ -6,7 +6,7 @@ data:
mean_std_filepath: data/mean_std.npz mean_std_filepath: data/mean_std.npz
vocab_filepath: data/vocab.txt vocab_filepath: data/vocab.txt
augmentation_config: conf/augmentation.config augmentation_config: conf/augmentation.config
batch_size: 16 # one gpu batch_size: 64 # one gpu
max_duration: 27.0 max_duration: 27.0
min_duration: 0.0 min_duration: 0.0
specgram_type: linear specgram_type: linear

@ -6,7 +6,7 @@ export FLAGS_sync_nccl_allreduce=0
CUDA_VISIBLE_DEVICES=0,1 \ CUDA_VISIBLE_DEVICES=0,1 \
python3 -u ${MAIN_ROOT}/train.py \ python3 -u ${MAIN_ROOT}/train.py \
--device 'gpu' \ --device 'gpu' \
--nproc 1 \ --nproc 2 \
--config conf/deepspeech2.yaml \ --config conf/deepspeech2.yaml \
--output ckpt --output ckpt

@ -41,22 +41,6 @@ from decoders.swig_wrapper import ctc_beam_search_decoder_batch
from utils.error_rate import char_errors, word_errors, cer, wer from utils.error_rate import char_errors, word_errors, cer, wer
# class ExponentialDecayOld(LRScheduler):
# def __init__(self, learning_rate, gamma, last_epoch=-1,
# decay_steps, decay_rate, staircase=False, verbose=False):
# self.learning_rate = learning_rate
# self.decay_steps = decay_steps
# self.decay_rate = decay_rate
# self.staircase = staircase
# super(ExponentialDecay, self).__init__(learning_rate, last_epoch,
# verbose)
# def get_lr(self):
# div_res = self.step_num / self.decay_steps
# if self.staircase:
# div_res = paddle.floor(div_res)
# return self.base_lr * (self.decay_rate**div_res)
class DeepSpeech2Trainer(Trainer): class DeepSpeech2Trainer(Trainer):
def __init__(self, config, args): def __init__(self, config, args):
@ -114,6 +98,8 @@ class DeepSpeech2Trainer(Trainer):
It includes forward/backward/update and periodical validation and It includes forward/backward/update and periodical validation and
saving. saving.
""" """
self.logger.info(
f"Train Total Examples: {len(self.train_loader.dataset)}")
self.new_epoch() self.new_epoch()
while self.epoch <= self.config.training.n_epoch: while self.epoch <= self.config.training.n_epoch:
for batch in self.train_loader: for batch in self.train_loader:
@ -137,6 +123,8 @@ class DeepSpeech2Trainer(Trainer):
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@paddle.no_grad() @paddle.no_grad()
def valid(self): def valid(self):
self.logger.info(
f"Valid Total Examples: {len(self.valid_loader.dataset)}")
self.model.eval() self.model.eval()
valid_losses = defaultdict(list) valid_losses = defaultdict(list)
for i, batch in enumerate(self.valid_loader): for i, batch in enumerate(self.valid_loader):
@ -178,6 +166,10 @@ class DeepSpeech2Trainer(Trainer):
if self.parallel: if self.parallel:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
for n, p in model.named_parameters():
self.logger.info(
f"param: {n}: shape: {p.shape} stop_grad: {p.stop_gradient}")
grad_clip = paddle.nn.ClipGradByGlobalNorm( grad_clip = paddle.nn.ClipGradByGlobalNorm(
config.training.global_grad_clip) config.training.global_grad_clip)
@ -341,9 +333,10 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
@mp_tools.rank_zero_only @mp_tools.rank_zero_only
@paddle.no_grad() @paddle.no_grad()
def test(self): def test(self):
self.logger.info(
f"Test Total Examples: {len(self.test_loader.dataset)}")
self.model.eval() self.model.eval()
losses = defaultdict(list) losses = defaultdict(list)
cfg = self.config cfg = self.config
# decoders only accept string encoded in utf-8 # decoders only accept string encoded in utf-8
vocab_list = self.test_loader.dataset.vocab_list vocab_list = self.test_loader.dataset.vocab_list

@ -254,18 +254,18 @@ class Trainer():
formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S') formatter = logging.Formatter(fmt=format, datefmt='%Y/%m/%d %H:%M:%S')
#stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
#stream_handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
#logger.addHandler(stream_handler) logger.addHandler(stream_handler)
log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank()) log_file = self.output_dir / 'worker_{}.log'.format(dist.get_rank())
file_handler = logging.FileHandler(str(log_file)) # file_handler = logging.FileHandler(str(log_file))
file_handler.setFormatter(formatter) # file_handler.setFormatter(formatter)
logger.addHandler(file_handler) # logger.addHandler(file_handler)
# global logger # global logger
stdout = True stdout = False
save_path = '/dev/null' save_path = log_file
logging.basicConfig( logging.basicConfig(
level=logging.DEBUG if stdout else logging.INFO, level=logging.DEBUG if stdout else logging.INFO,
format=format, format=format,

Loading…
Cancel
Save