resuem train with epoch and iteration increase

pull/847/head
Hui Zhang 3 years ago
parent 3432de4347
commit daaa72a606

@ -183,15 +183,7 @@ class U2Trainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init') # script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
# lr will resotre from optimizer ckpt
# self.lr_scheduler.step(self.iteration)
if self.parallel and hasattr(self.train_loader, 'batch_sampler'):
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:

@ -184,11 +184,7 @@ class U2Trainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init') # script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:

@ -198,14 +198,7 @@ class U2STTrainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init') # script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:

@ -179,7 +179,8 @@ class Trainer():
checkpoint_dir=self.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path) checkpoint_path=self.args.checkpoint_path)
if infos: if infos:
# restore from ckpt # just restore ckpt
# lr will resotre from optimizer ckpt
self.iteration = infos["step"] self.iteration = infos["step"]
self.epoch = infos["epoch"] self.epoch = infos["epoch"]
scratch = False scratch = False
@ -190,14 +191,31 @@ class Trainer():
logger.info("Restore/Init checkpoint!") logger.info("Restore/Init checkpoint!")
return scratch return scratch
def maybe_batch_sampler_step(self):
""" batch_sampler seed by epoch """
if hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def before_train(self, from_scratch):
from_scratch = self.resume_or_scratch()
if from_scratch:
# scratch: save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
else:
# resume: train next_epoch and next_iteration
self.epoch += 1
self.iteration += 1
self.maybe_batch_sampler_step()
def new_epoch(self): def new_epoch(self):
"""Reset the train loader seed and increment `epoch`. """Reset the train loader seed and increment `epoch`.
""" """
# `iteration` increased by train step
self.epoch += 1 self.epoch += 1
if self.parallel and hasattr(self.train_loader, "batch_sampler"): self.maybe_batch_sampler_step()
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def after_train_batch(self): def after_train_batch(self):
if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step: if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step:
@ -209,15 +227,7 @@ class Trainer():
def train(self): def train(self):
"""The training process control by epoch.""" """The training process control by epoch."""
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
# lr will resotre from optimizer ckpt
# self.lr_scheduler.step(self.epoch)
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
@ -275,6 +285,7 @@ class Trainer():
'epoch', {'cv_loss': cv_loss, 'epoch', {'cv_loss': cv_loss,
'lr': self.lr_scheduler()}, self.epoch) 'lr': self.lr_scheduler()}, self.epoch)
# after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
# step lr every epoch # step lr every epoch
self.lr_scheduler.step() self.lr_scheduler.step()

Loading…
Cancel
Save