From 5b44de3a7ca2a1974016e07de1ac3e7e932a3fa0 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 18 Sep 2021 09:37:26 +0000 Subject: [PATCH 1/5] resuem train with epoch and iteration increase --- deepspeech/exps/u2/model.py | 10 +------- deepspeech/exps/u2_kaldi/model.py | 6 +---- deepspeech/exps/u2_st/model.py | 9 +------ deepspeech/training/trainer.py | 39 ++++++++++++++++++++----------- 4 files changed, 28 insertions(+), 36 deletions(-) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index b228c5e38..a5cef15c8 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -183,15 +183,7 @@ class U2Trainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init', infos=None) - - # lr will resotre from optimizer ckpt - # self.lr_scheduler.step(self.iteration) - if self.parallel and hasattr(self.train_loader, 'batch_sampler'): - self.train_loader.batch_sampler.set_epoch(self.epoch) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: diff --git a/deepspeech/exps/u2_kaldi/model.py b/deepspeech/exps/u2_kaldi/model.py index 116ab2808..bc7cd4fd3 100644 --- a/deepspeech/exps/u2_kaldi/model.py +++ b/deepspeech/exps/u2_kaldi/model.py @@ -184,11 +184,7 @@ class U2Trainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init') - self.lr_scheduler.step(self.iteration) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index eb84d6f11..4f95bc42b 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -198,14 +198,7 @@ class U2STTrainer(Trainer): # script_model_path = str(self.checkpoint_dir / 'init') # paddle.jit.save(script_model, script_model_path) - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init') - - self.lr_scheduler.step(self.iteration) - if self.parallel: - self.train_loader.batch_sampler.set_epoch(self.epoch) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index f4998fdf1..7815ed675 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -179,7 +179,8 @@ class Trainer(): checkpoint_dir=self.checkpoint_dir, checkpoint_path=self.args.checkpoint_path) if infos: - # restore from ckpt + # just restore ckpt + # lr will resotre from optimizer ckpt self.iteration = infos["step"] self.epoch = infos["epoch"] scratch = False @@ -190,14 +191,31 @@ class Trainer(): logger.info("Restore/Init checkpoint!") return scratch + def maybe_batch_sampler_step(self): + """ batch_sampler seed by epoch """ + if hasattr(self.train_loader, "batch_sampler"): + batch_sampler = self.train_loader.batch_sampler + if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): + batch_sampler.set_epoch(self.epoch) + + def before_train(self, from_scratch): + from_scratch = self.resume_or_scratch() + if from_scratch: + # scratch: save init model, i.e. 0 epoch + self.save(tag='init', infos=None) + else: + # resume: train next_epoch and next_iteration + self.epoch += 1 + self.iteration += 1 + + self.maybe_batch_sampler_step() + def new_epoch(self): """Reset the train loader seed and increment `epoch`. """ + # `iteration` increased by train step self.epoch += 1 - if self.parallel and hasattr(self.train_loader, "batch_sampler"): - batch_sampler = self.train_loader.batch_sampler - if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): - batch_sampler.set_epoch(self.epoch) + self.maybe_batch_sampler_step() def after_train_batch(self): if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step: @@ -209,15 +227,7 @@ class Trainer(): def train(self): """The training process control by epoch.""" - from_scratch = self.resume_or_scratch() - if from_scratch: - # save init model, i.e. 0 epoch - self.save(tag='init', infos=None) - - # lr will resotre from optimizer ckpt - # self.lr_scheduler.step(self.epoch) - if self.parallel and hasattr(self.train_loader, "batch_sampler"): - self.train_loader.batch_sampler.set_epoch(self.epoch) + self.before_train() logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") while self.epoch < self.config.training.n_epoch: @@ -275,6 +285,7 @@ class Trainer(): 'epoch', {'cv_loss': cv_loss, 'lr': self.lr_scheduler()}, self.epoch) + # after epoch self.save(tag=self.epoch, infos={'val_loss': cv_loss}) # step lr every epoch self.lr_scheduler.step() From e4f378160f26d287396bc2121f7e27c44d3bd760 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 18 Sep 2021 09:41:18 +0000 Subject: [PATCH 2/5] more resume ckpt info --- deepspeech/training/trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 7815ed675..22dded6fa 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -184,11 +184,13 @@ class Trainer(): self.iteration = infos["step"] self.epoch = infos["epoch"] scratch = False + logger.info( + f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!") else: self.iteration = 0 self.epoch = 0 scratch = True - logger.info("Restore/Init checkpoint!") + logger.info("Init from scratch!") return scratch def maybe_batch_sampler_step(self): @@ -207,6 +209,8 @@ class Trainer(): # resume: train next_epoch and next_iteration self.epoch += 1 self.iteration += 1 + logger.info( + f"Resume train: epoch {self.epoch }, step {self.iteration}!") self.maybe_batch_sampler_step() From 9824ac53a60f3226b96fee1539ae99005b1fe0f7 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 18 Sep 2021 09:42:24 +0000 Subject: [PATCH 3/5] fix bugs --- deepspeech/training/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 22dded6fa..ffc52775a 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -200,7 +200,7 @@ class Trainer(): if isinstance(batch_sampler, paddle.io.DistributedBatchSampler): batch_sampler.set_epoch(self.epoch) - def before_train(self, from_scratch): + def before_train(self): from_scratch = self.resume_or_scratch() if from_scratch: # scratch: save init model, i.e. 0 epoch From d9823835cfcac23fca8f58cbc9b616b9b9bcbc5a Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Sat, 18 Sep 2021 09:45:22 +0000 Subject: [PATCH 4/5] not save ckpt when except, since resume train will increase epoch and step --- deepspeech/training/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index ffc52775a..c1afd3629 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -303,7 +303,6 @@ class Trainer(): try: self.train() except KeyboardInterrupt: - self.save() exit(-1) finally: self.destory() From fac81e46a71a803082828a5abc5efec8bf3adcad Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 22 Sep 2021 06:57:34 +0000 Subject: [PATCH 5/5] fix train log --- deepspeech/exps/u2/model.py | 4 ++-- deepspeech/training/trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index a5cef15c8..811da39b5 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -199,8 +199,8 @@ class U2Trainer(Trainer): report("Rank", dist.get_rank()) report("epoch", self.epoch) report('step', self.iteration) - report('step/total', - (batch_index + 1) / len(self.train_loader)) + report('iter', batch_index + 1) + report('total',len(self.train_loader)) report("lr", self.lr_scheduler()) self.train_batch(batch_index, batch, msg) self.after_train_batch() diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index c1afd3629..9ff95f29b 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -247,8 +247,8 @@ class Trainer(): report("Rank", dist.get_rank()) report("epoch", self.epoch) report('step', self.iteration) - report('step/total', - (batch_index + 1) / len(self.train_loader)) + report('iter', batch_index + 1) + report('total',len(self.train_loader)) report("lr", self.lr_scheduler()) self.train_batch(batch_index, batch, msg) self.after_train_batch()