From b7c7ebba5b8f346593019d0027131ab51cb02e5d Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 25 Aug 2021 02:22:48 +0000 Subject: [PATCH] fix trainer when dataloader not using batch_sampler --- deepspeech/training/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeech/training/trainer.py b/deepspeech/training/trainer.py index 2ab7eac03..866be552d 100644 --- a/deepspeech/training/trainer.py +++ b/deepspeech/training/trainer.py @@ -181,7 +181,7 @@ class Trainer(): """Reset the train loader seed and increment `epoch`. """ self.epoch += 1 - if self.parallel: + if self.parallel and hasattr(self.train_loader, "batch_sampler"): self.train_loader.batch_sampler.set_epoch(self.epoch) def train(self): @@ -191,7 +191,7 @@ class Trainer(): # save init model, i.e. 0 epoch self.save(tag='init', infos=None) self.lr_scheduler.step(self.epoch) - if self.parallel: + if self.parallel and hasattr(self.train_loader, "batch_sampler"): self.train_loader.batch_sampler.set_epoch(self.epoch) logger.info(f"Train Total Examples: {len(self.train_loader.dataset)}")