fix trainer when dataloader not using batch_sampler

pull/788/head
Hui Zhang 3 years ago
parent 44c84e26c3
commit 14ac780658

@ -181,7 +181,7 @@ class Trainer():
"""Reset the train loader seed and increment `epoch`.
"""
self.epoch += 1
if self.parallel:
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
def train(self):
@ -191,7 +191,7 @@ class Trainer():
# save init model, i.e. 0 epoch
self.save(tag='init', infos=None)
self.lr_scheduler.step(self.epoch)
if self.parallel:
if self.parallel and hasattr(self.train_loader, "batch_sampler"):
self.train_loader.batch_sampler.set_epoch(self.epoch)
logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")

Loading…
Cancel
Save