|
|
|
@ -179,25 +179,47 @@ class Trainer():
|
|
|
|
|
checkpoint_dir=self.checkpoint_dir,
|
|
|
|
|
checkpoint_path=self.args.checkpoint_path)
|
|
|
|
|
if infos:
|
|
|
|
|
# restore from ckpt
|
|
|
|
|
# just restore ckpt
|
|
|
|
|
# lr will resotre from optimizer ckpt
|
|
|
|
|
self.iteration = infos["step"]
|
|
|
|
|
self.epoch = infos["epoch"]
|
|
|
|
|
scratch = False
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
|
|
|
|
|
else:
|
|
|
|
|
self.iteration = 0
|
|
|
|
|
self.epoch = 0
|
|
|
|
|
scratch = True
|
|
|
|
|
logger.info("Restore/Init checkpoint!")
|
|
|
|
|
logger.info("Init from scratch!")
|
|
|
|
|
return scratch
|
|
|
|
|
|
|
|
|
|
def maybe_batch_sampler_step(self):
|
|
|
|
|
""" batch_sampler seed by epoch """
|
|
|
|
|
if hasattr(self.train_loader, "batch_sampler"):
|
|
|
|
|
batch_sampler = self.train_loader.batch_sampler
|
|
|
|
|
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
|
|
|
|
|
batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
|
|
|
|
|
def before_train(self):
|
|
|
|
|
from_scratch = self.resume_or_scratch()
|
|
|
|
|
if from_scratch:
|
|
|
|
|
# scratch: save init model, i.e. 0 epoch
|
|
|
|
|
self.save(tag='init', infos=None)
|
|
|
|
|
else:
|
|
|
|
|
# resume: train next_epoch and next_iteration
|
|
|
|
|
self.epoch += 1
|
|
|
|
|
self.iteration += 1
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Resume train: epoch {self.epoch }, step {self.iteration}!")
|
|
|
|
|
|
|
|
|
|
self.maybe_batch_sampler_step()
|
|
|
|
|
|
|
|
|
|
def new_epoch(self):
|
|
|
|
|
"""Reset the train loader seed and increment `epoch`.
|
|
|
|
|
"""
|
|
|
|
|
# `iteration` increased by train step
|
|
|
|
|
self.epoch += 1
|
|
|
|
|
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
|
|
|
|
|
batch_sampler = self.train_loader.batch_sampler
|
|
|
|
|
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
|
|
|
|
|
batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
self.maybe_batch_sampler_step()
|
|
|
|
|
|
|
|
|
|
def after_train_batch(self):
|
|
|
|
|
if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step:
|
|
|
|
@ -209,15 +231,7 @@ class Trainer():
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
"""The training process control by epoch."""
|
|
|
|
|
from_scratch = self.resume_or_scratch()
|
|
|
|
|
if from_scratch:
|
|
|
|
|
# save init model, i.e. 0 epoch
|
|
|
|
|
self.save(tag='init', infos=None)
|
|
|
|
|
|
|
|
|
|
# lr will resotre from optimizer ckpt
|
|
|
|
|
# self.lr_scheduler.step(self.epoch)
|
|
|
|
|
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
|
|
|
|
|
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
self.before_train()
|
|
|
|
|
|
|
|
|
|
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
|
|
|
|
|
while self.epoch < self.config.training.n_epoch:
|
|
|
|
@ -233,8 +247,8 @@ class Trainer():
|
|
|
|
|
report("Rank", dist.get_rank())
|
|
|
|
|
report("epoch", self.epoch)
|
|
|
|
|
report('step', self.iteration)
|
|
|
|
|
report('step/total',
|
|
|
|
|
(batch_index + 1) / len(self.train_loader))
|
|
|
|
|
report('iter', batch_index + 1)
|
|
|
|
|
report('total',len(self.train_loader))
|
|
|
|
|
report("lr", self.lr_scheduler())
|
|
|
|
|
self.train_batch(batch_index, batch, msg)
|
|
|
|
|
self.after_train_batch()
|
|
|
|
@ -275,6 +289,7 @@ class Trainer():
|
|
|
|
|
'epoch', {'cv_loss': cv_loss,
|
|
|
|
|
'lr': self.lr_scheduler()}, self.epoch)
|
|
|
|
|
|
|
|
|
|
# after epoch
|
|
|
|
|
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
|
|
|
|
|
# step lr every epoch
|
|
|
|
|
self.lr_scheduler.step()
|
|
|
|
@ -288,7 +303,6 @@ class Trainer():
|
|
|
|
|
try:
|
|
|
|
|
self.train()
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
self.save()
|
|
|
|
|
exit(-1)
|
|
|
|
|
finally:
|
|
|
|
|
self.destory()
|
|
|
|
|