|
|
|
@ -181,7 +181,7 @@ class Trainer():
|
|
|
|
|
"""Reset the train loader seed and increment `epoch`.
|
|
|
|
|
"""
|
|
|
|
|
self.epoch += 1
|
|
|
|
|
if self.parallel:
|
|
|
|
|
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
|
|
|
|
|
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
@ -191,7 +191,7 @@ class Trainer():
|
|
|
|
|
# save init model, i.e. 0 epoch
|
|
|
|
|
self.save(tag='init', infos=None)
|
|
|
|
|
self.lr_scheduler.step(self.epoch)
|
|
|
|
|
if self.parallel:
|
|
|
|
|
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
|
|
|
|
|
self.train_loader.batch_sampler.set_epoch(self.epoch)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")
|
|
|
|
|