Merge pull request #847 from PaddlePaddle/resume_train

resume train with epoch and iteration increase
pull/850/head
Hui Zhang 3 years ago committed by GitHub
commit ecb5d4f862
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -183,15 +183,7 @@ class U2Trainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init') # script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
# lr will resotre from optimizer ckpt
# self.lr_scheduler.step(self.iteration)
if self.parallel and hasattr(self.train_loader, 'batch_sampler'):
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
@ -207,8 +199,8 @@ class U2Trainer(Trainer):
report("Rank", dist.get_rank()) report("Rank", dist.get_rank())
report("epoch", self.epoch) report("epoch", self.epoch)
report('step', self.iteration) report('step', self.iteration)
report('step/total', report('iter', batch_index + 1)
(batch_index + 1) / len(self.train_loader)) report('total',len(self.train_loader))
report("lr", self.lr_scheduler()) report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch() self.after_train_batch()

@ -184,11 +184,7 @@ class U2Trainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init') # script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:

@ -198,14 +198,7 @@ class U2STTrainer(Trainer):
# script_model_path = str(self.checkpoint_dir / 'init') # script_model_path = str(self.checkpoint_dir / 'init')
# paddle.jit.save(script_model, script_model_path) # paddle.jit.save(script_model, script_model_path)
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init')
self.lr_scheduler.step(self.iteration)
if self.parallel:
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:

@ -179,25 +179,47 @@ class Trainer():
checkpoint_dir=self.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
checkpoint_path=self.args.checkpoint_path) checkpoint_path=self.args.checkpoint_path)
if infos: if infos:
# restore from ckpt # just restore ckpt
# lr will resotre from optimizer ckpt
self.iteration = infos["step"] self.iteration = infos["step"]
self.epoch = infos["epoch"] self.epoch = infos["epoch"]
scratch = False scratch = False
logger.info(
f"Restore ckpt: epoch {self.epoch }, step {self.iteration}!")
else: else:
self.iteration = 0 self.iteration = 0
self.epoch = 0 self.epoch = 0
scratch = True scratch = True
logger.info("Restore/Init checkpoint!") logger.info("Init from scratch!")
return scratch return scratch
def maybe_batch_sampler_step(self):
""" batch_sampler seed by epoch """
if hasattr(self.train_loader, "batch_sampler"):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def before_train(self):
from_scratch = self.resume_or_scratch()
if from_scratch:
# scratch: save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
else:
# resume: train next_epoch and next_iteration
self.epoch += 1
self.iteration += 1
logger.info(
f"Resume train: epoch {self.epoch }, step {self.iteration}!")
self.maybe_batch_sampler_step()
def new_epoch(self): def new_epoch(self):
"""Reset the train loader seed and increment `epoch`. """Reset the train loader seed and increment `epoch`.
""" """
# `iteration` increased by train step
self.epoch += 1 self.epoch += 1
if self.parallel and hasattr(self.train_loader, "batch_sampler"): self.maybe_batch_sampler_step()
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
batch_sampler.set_epoch(self.epoch)
def after_train_batch(self): def after_train_batch(self):
if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step: if self.args.benchmark_max_step and self.iteration > self.args.benchmark_max_step:
@ -209,15 +231,7 @@ class Trainer():
def train(self): def train(self):
"""The training process control by epoch.""" """The training process control by epoch."""
from_scratch = self.resume_or_scratch() self.before_train()
if from_scratch:
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
# lr will resotre from optimizer ckpt
# self.lr_scheduler.step(self.epoch)
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}") logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
while self.epoch < self.config.training.n_epoch: while self.epoch < self.config.training.n_epoch:
@ -233,8 +247,8 @@ class Trainer():
report("Rank", dist.get_rank()) report("Rank", dist.get_rank())
report("epoch", self.epoch) report("epoch", self.epoch)
report('step', self.iteration) report('step', self.iteration)
report('step/total', report('iter', batch_index + 1)
(batch_index + 1) / len(self.train_loader)) report('total',len(self.train_loader))
report("lr", self.lr_scheduler()) report("lr", self.lr_scheduler())
self.train_batch(batch_index, batch, msg) self.train_batch(batch_index, batch, msg)
self.after_train_batch() self.after_train_batch()
@ -275,6 +289,7 @@ class Trainer():
'epoch', {'cv_loss': cv_loss, 'epoch', {'cv_loss': cv_loss,
'lr': self.lr_scheduler()}, self.epoch) 'lr': self.lr_scheduler()}, self.epoch)
# after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss}) self.save(tag=self.epoch, infos={'val_loss': cv_loss})
# step lr every epoch # step lr every epoch
self.lr_scheduler.step() self.lr_scheduler.step()
@ -288,7 +303,6 @@ class Trainer():
try: try:
self.train() self.train()
except KeyboardInterrupt: except KeyboardInterrupt:
self.save()
exit(-1) exit(-1)
finally: finally:
self.destory() self.destory()

Loading…
Cancel
Save