diff --git a/deepspeech/training/gradclip.py b/deepspeech/training/gradclip.py index f46814eb..6393197a 100644 --- a/deepspeech/training/gradclip.py +++ b/deepspeech/training/gradclip.py @@ -48,9 +48,8 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): sum_square_list.append(sum_square) # debug log - if i < 10: - logger.debug( - f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") + logger.debug( + f"Grad Before Clip: {p.name}: {float(sum_square.sqrt()) }") # all parameters have been filterd out if len(sum_square_list) == 0: @@ -77,9 +76,8 @@ class ClipGradByGlobalNormWithLog(paddle.nn.ClipGradByGlobalNorm): params_and_grads.append((p, new_grad)) # debug log - if i < 10: - logger.debug( - f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" - ) + logger.debug( + f"Grad After Clip: {p.name}: {float(new_grad.square().sum().sqrt())}" + ) return params_and_grads diff --git a/deepspeech/training/timer.py b/deepspeech/training/timer.py index 2076737b..2ca9d638 100644 --- a/deepspeech/training/timer.py +++ b/deepspeech/training/timer.py @@ -27,7 +27,7 @@ class Timer(): do some thing """ - def __init__(self, message): + def __init__(self, message=None): self.message = message def duration(self) -> str: @@ -40,7 +40,8 @@ class Timer(): return self def __exit__(self, type, value, traceback): - logger.info(self.message.format(self.duration())) + if self.message: + logger.info(self.message.format(self.duration())) def __call__(self) -> float: return time.time() - self.start diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 521297d7..25c002df 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -185,46 +185,47 @@ class Trainer(): def train(self): """The training process control by epoch.""" - with Timer("Load/Init Model: {}"): - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init', infos=None) - self.lr_scheduler.step(self.epoch) - if self.parallel and hasattr(self.train_loader, "batch_sampler"): - self.train_loader.batch_sampler.set_epoch(self.epoch) + from_scratch = self.resume_or_scratch() + if from_scratch: + # save init model, i.e. 0 epoch + self.save(tag='init', infos=None) + self.lr_scheduler.step(self.epoch) + if self.parallel and hasattr(self.train_loader, "batch_sampler"): + self.train_loader.batch_sampler.set_epoch(self.epoch) logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: - self.model.train() - try: - data_start_time = time.time() - for batch_index, batch in enumerate(self.train_loader): - dataload_time = time.time() - data_start_time - msg = "Train: Rank: {}, ".format(dist.get_rank()) - msg += "epoch: {}, ".format(self.epoch) - msg += "step: {}, ".format(self.iteration) - msg += "batch : {}/{}, ".format(batch_index + 1, - len(self.train_loader)) - msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) - msg += "data time: {:>.3f}s, ".format(dataload_time) - self.train_batch(batch_index, batch, msg) + with Timer("Epoch-Train Time Cost: {}"): + self.model.train() + try: data_start_time = time.time() - except Exception as e: - logger.error(e) - raise e - - total_loss, num_seen_utts = self.valid() - if dist.get_world_size() > 1: - num_seen_utts = paddle.to_tensor(num_seen_utts) - # the default operator in all_reduce function is sum. - dist.all_reduce(num_seen_utts) - total_loss = paddle.to_tensor(total_loss) - dist.all_reduce(total_loss) - cv_loss = total_loss / num_seen_utts - cv_loss = float(cv_loss) - else: - cv_loss = total_loss / num_seen_utts + for batch_index, batch in enumerate(self.train_loader): + dataload_time = time.time() - data_start_time + msg = "Train: Rank: {}, ".format(dist.get_rank()) + msg += "epoch: {}, ".format(self.epoch) + msg += "step: {}, ".format(self.iteration) + msg += "batch : {}/{}, ".format(batch_index + 1, + len(self.train_loader)) + msg += "lr: {:>.8f}, ".format(self.lr_scheduler()) + msg += "data time: {:>.3f}s, ".format(dataload_time) + self.train_batch(batch_index, batch, msg) + data_start_time = time.time() + except Exception as e: + logger.error(e) + raise e + + with Timer("Eval Time Cost: {}"): + total_loss, num_seen_utts = self.valid() + if dist.get_world_size() > 1: + num_seen_utts = paddle.to_tensor(num_seen_utts) + # the default operator in all_reduce function is sum. + dist.all_reduce(num_seen_utts) + total_loss = paddle.to_tensor(total_loss) + dist.all_reduce(total_loss) + cv_loss = total_loss / num_seen_utts + cv_loss = float(cv_loss) + else: + cv_loss = total_loss / num_seen_utts logger.info( 'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss)) diff --git a/doc/src/deepspeech_architecture.md b/doc/src/deepspeech_architecture.md index dfa60790..c4c102ba 100644 --- a/doc/src/deepspeech_architecture.md +++ b/doc/src/deepspeech_architecture.md @@ -1,8 +1,8 @@ # Deepspeech2 ## Streaming -The implemented arcitecure of Deepspeech2 online model is based on [Deepspeech2 model](https://arxiv.org/pdf/1512.02595.pdf) with some changes. -The model is mainly composed of 2D convolution subsampling layer and stacked single direction rnn layers. +The implemented arcitecure of Deepspeech2 online model is based on [Deepspeech2 model](https://arxiv.org/pdf/1512.02595.pdf) with some changes. +The model is mainly composed of 2D convolution subsampling layer and stacked single direction rnn layers. To illustrate the model implementation clearly, 3 parts are described in detail. - Data Preparation @@ -11,10 +11,10 @@ To illustrate the model implementation clearly, 3 parts are described in detail. In addition, the training process and the testing process are also introduced. -The arcitecture of the model is shown in Fig.1. +The arcitecture of the model is shown in Fig.1.
-
+
Fig.1 The Arcitecture of deepspeech2 online model
-
+
Fig.2 The Arcitecture of deepspeech2 offline model