From 8873ebe38c5973993471e86a1760b412f92449dd Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Fri, 10 Sep 2021 04:27:54 +0000 Subject: [PATCH] add timer for u2; refactor grad norm type --- deepspeech/exps/u2/model.py | 61 +++++++++++++++++++------------------ deepspeech/modules/loss.py | 6 ++-- 2 files changed, 35 insertions(+), 32 deletions(-) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 8ab9a26e..2b6e2433 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -34,6 +34,7 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler from deepspeech.models.u2 import U2Model from deepspeech.training.optimizer import OptimizerFactory from deepspeech.training.scheduler import LRSchedulerFactory +from deepspeech.training.timer import Timer from deepspeech.training.trainer import Trainer from deepspeech.utils import ctc_utils from deepspeech.utils import error_rate @@ -184,40 +185,42 @@ class U2Trainer(Trainer): self.save(tag='init') self.lr_scheduler.step(self.iteration) - if self.parallel: + if self.parallel and hasattr(self.train_loader, 'batch_sampler'): self.train_loader.batch_sampler.set_epoch(self.epoch) logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) diff --git a/deepspeech/modules/loss.py b/deepspeech/modules/loss.py index 399e84e2..023a1923 100644 --- a/deepspeech/modules/loss.py +++ b/deepspeech/modules/loss.py @@ -36,16 +36,16 @@ class CTCLoss(nn.Layer): f"CTCLoss Loss reduction: {reduction}, div-bs: {batch_average}") # instance for norm_by_times - # batchsize for norm_by_batchsize + # batch for norm_by_batchsize # frame for norm_by_total_logits_len - assert grad_norm_type in ('instance', 'batchsize', 'frame', None) + assert grad_norm_type in ('instance', 'batch', 'frame', None) self.norm_by_times = False self.norm_by_batchsize = False self.norm_by_total_logits_len = False logger.info(f"CTCLoss Grad Norm Type: {grad_norm_type}") if grad_norm_type == 'instance': self.norm_by_times = True - if grad_norm_type == 'batchsize': + if grad_norm_type == 'batch': self.norm_by_times = True if grad_norm_type == 'frame': self.norm_by_total_logits_len = True