|
|
|
@ -179,7 +179,8 @@ class Trainer():
|
|
|
|
|
checkpoint_dir=self.checkpoint_dir,
|
|
|
|
|
checkpoint_path=self.args.checkpoint_path)
|
|
|
|
|
if infos:
|
|
|
|
|
# restore from ckpt
|
|
|
|
|
# just restore ckpt
|
|
|
|
|
# lr will resotre from optimizer ckpt
|
|
|
|
|
self.iteration = infos["step"]
|
|
|
|
|
self.epoch = infos["epoch"]
|
|
|
|
|
scratch = False
|
|
|
|
@ -190,14 +191,31 @@ class Trainer():
|
|
|
|
|
logger.info("Restore/Init checkpoint!")
|
|
|
|
|
return scratch
|
|
|
|
|
|
|
|
|
|
def maybe_batch_sampler_step(self):
|
|
|
|
|
""" batch_sampler seed by epoch """
|
|
|
|
|
if hasattr(self.train_loader, "batch_sampler"):
|
|
|
|
|
batch_sampler = self.train_loader.batch_sampler
|
|
|
|
|
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
|
|
|
|
|
batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
|
|
|
|
|
def before_train(self, from_scratch):
|
|
|
|
|
from_scratch = self.resume_or_scratch()
|
|
|
|
|
if from_scratch:
|
|
|
|
|
# scratch: save init model, i.e. 0 epoch
|
|
|
|
|
self.save(tag='init', infos=None)
|
|
|
|
|
else:
|
|
|
|
|
# resume: train next_epoch and next_iteration
|
|
|
|
|
self.epoch += 1
|
|
|
|
|
self.iteration += 1
|
|
|
|
|
|
|
|
|
|
self.maybe_batch_sampler_step()
|
|
|
|
|
|
|
|
|
|
def new_epoch(self):
|
|
|
|
|
"""Reset the train loader seed and increment `epoch`.
|
|
|
|
|
"""
|
|
|
|
|
# `iteration` increased by train step
|
|
|
|
|
self.epoch += 1
|
|
|
|
|
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
|
|
|
|
|
batch_sampler = self.train_loader.batch_sampler
|
|
|
|
|
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
|
|
|
|
|
batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
self.maybe_batch_sampler_step()
|
|
|
|
|
|
|
|
|
|
def after_train_batch(self):
|
|
|
|
|
if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step:
|
|
|
|
@ -209,15 +227,7 @@ class Trainer():
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
"""The training process control by epoch."""
|
|
|
|
|
from_scratch = self.resume_or_scratch()
|
|
|
|
|
if from_scratch:
|
|
|
|
|
# save init model, i.e. 0 epoch
|
|
|
|
|
self.save(tag='init', infos=None)
|
|
|
|
|
|
|
|
|
|
# lr will resotre from optimizer ckpt
|
|
|
|
|
# self.lr_scheduler.step(self.epoch)
|
|
|
|
|
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
|
|
|
|
|
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
self.before_train()
|
|
|
|
|
|
|
|
|
|
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
|
|
|
|
|
while self.epoch < self.config.training.n_epoch:
|
|
|
|
@ -275,6 +285,7 @@ class Trainer():
|
|
|
|
|
'epoch', {'cv_loss': cv_loss,
|
|
|
|
|
'lr': self.lr_scheduler()}, self.epoch)
|
|
|
|
|
|
|
|
|
|
# after epoch
|
|
|
|
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
|
|
|
|
# step lr every epoch
|
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
|