|
|
|
@ -55,7 +55,7 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
'train_loss': float(loss),
|
|
|
|
|
}
|
|
|
|
|
msg += "train time: {:>.3f}s, ".format(iteration_time)
|
|
|
|
|
msg += "batch size: {}, ".format(self.config.data.batch_size)
|
|
|
|
|
msg += "batch size: {}, ".format(self.config.collator.batch_size)
|
|
|
|
|
msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
|
|
|
|
for k, v in losses_np.items())
|
|
|
|
|
logger.info(msg)
|
|
|
|
@ -149,31 +149,31 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
|
if self.parallel:
|
|
|
|
|
batch_sampler = SortagradDistributedBatchSampler(
|
|
|
|
|
train_dataset,
|
|
|
|
|
batch_size=config.data.batch_size,
|
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
|
num_replicas=None,
|
|
|
|
|
rank=None,
|
|
|
|
|
shuffle=True,
|
|
|
|
|
drop_last=True,
|
|
|
|
|
sortagrad=config.data.sortagrad,
|
|
|
|
|
shuffle_method=config.data.shuffle_method)
|
|
|
|
|
sortagrad=config.collator.sortagrad,
|
|
|
|
|
shuffle_method=config.collator.shuffle_method)
|
|
|
|
|
else:
|
|
|
|
|
batch_sampler = SortagradBatchSampler(
|
|
|
|
|
train_dataset,
|
|
|
|
|
shuffle=True,
|
|
|
|
|
batch_size=config.data.batch_size,
|
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
|
drop_last=True,
|
|
|
|
|
sortagrad=config.data.sortagrad,
|
|
|
|
|
shuffle_method=config.data.shuffle_method)
|
|
|
|
|
sortagrad=config.collator.sortagrad,
|
|
|
|
|
shuffle_method=config.collator.shuffle_method)
|
|
|
|
|
|
|
|
|
|
collate_fn = SpeechCollator.from_config(config)
|
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
|
train_dataset,
|
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
|
collate_fn=collate_fn,
|
|
|
|
|
num_workers=config.data.num_workers)
|
|
|
|
|
num_workers=config.collator.num_workers)
|
|
|
|
|
self.valid_loader = DataLoader(
|
|
|
|
|
dev_dataset,
|
|
|
|
|
batch_size=config.data.batch_size,
|
|
|
|
|
batch_size=config.collator.batch_size,
|
|
|
|
|
shuffle=False,
|
|
|
|
|
drop_last=False,
|
|
|
|
|
collate_fn=collate_fn)
|
|
|
|
|