diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index b228c5e3..a5cef15c 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -183,15 +183,7 @@ class U2Trainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init', infos=None) - - # lr will resotre from optimizer ckpt - # self.lr_scheduler.step(self.iteration) - if self.parallel and hasattr(self.train_loader, 'batch_sampler'): - self.train_loader.batch_sampler.set_epoch(self.epoch) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 116ab280..bc7cd4fd 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -184,11 +184,7 @@ class U2Trainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init') - self.lr_scheduler.step(self.iteration) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index eb84d6f1..4f95bc42 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -198,14 +198,7 @@ class U2STTrainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init') - - self.lr_scheduler.step(self.iteration) - if self.parallel: - self.train_loader.batch_sampler.set_epoch(self.epoch) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index f4998fdf..7815ed67 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -179,7 +179,8 @@ class Trainer(): checkpoint_dir=self.checkpoint_dir, checkpoint_path=self.args.checkpoint_path) if infos: - # restore from ckpt + # just restore ckpt + # lr will resotre from optimizer ckpt self.iteration = infos["step"] self.epoch = infos["epoch"] scratch = False @@ -190,14 +191,31 @@ class Trainer(): logger.info("Restore/Init checkpoint!") return scratch + def maybe_batch_sampler_step(self): + """ batch_sampler seed by epoch """ + if hasattr(self.train_loader, "batch_sampler"): + batch_sampler = self.train_loader.batch_sampler + if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): + batch_sampler.set_epoch(self.epoch) + + def before_train(self, from_scratch): + from_scratch = self.resume_or_scratch() + if from_scratch: + # scratch: save init model, i.e. 0 epoch + self.save(tag='init', infos=None) + else: + # resume: train next_epoch and next_iteration + self.epoch += 1 + self.iteration += 1 + + self.maybe_batch_sampler_step() + def new_epoch(self): """Reset the train loader seed and increment `epoch`. """ + # `iteration` increased by train step self.epoch += 1 - if self.parallel and hasattr(self.train_loader, "batch_sampler"): - batch_sampler = self.train_loader.batch_sampler - if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): - batch_sampler.set_epoch(self.epoch) + self.maybe_batch_sampler_step() def after_train_batch(self): if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step: @@ -209,15 +227,7 @@ class Trainer(): def train(self): """The training process control by epoch.""" - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init', infos=None) - - # lr will resotre from optimizer ckpt - # self.lr_scheduler.step(self.epoch) - if self.parallel and hasattr(self.train_loader, "batch_sampler"): - self.train_loader.batch_sampler.set_epoch(self.epoch) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: @@ -275,6 +285,7 @@ class Trainer(): 'epoch', {'cv_loss': cv_loss, 'lr': self.lr_scheduler()}, self.epoch) + # after epoch self.save(tag=self.epoch, infos={'val_loss': cv_loss}) # step lr every epoch self.lr_scheduler.step()