|
|
@ -165,7 +165,7 @@ class DeepSpeech2Trainer(Trainer):
|
|
|
|
sortagrad=config.data.sortagrad,
|
|
|
|
sortagrad=config.data.sortagrad,
|
|
|
|
shuffle_method=config.data.shuffle_method)
|
|
|
|
shuffle_method=config.data.shuffle_method)
|
|
|
|
|
|
|
|
|
|
|
|
collate_fn = SpeechCollator(config, keep_transcription_text=False)
|
|
|
|
collate_fn = SpeechCollator(config=config, keep_transcription_text=False)
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
train_dataset,
|
|
|
|
train_dataset,
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
|
batch_sampler=batch_sampler,
|
|
|
@ -342,7 +342,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
batch_size=config.decoding.batch_size,
|
|
|
|
shuffle=False,
|
|
|
|
shuffle=False,
|
|
|
|
drop_last=False,
|
|
|
|
drop_last=False,
|
|
|
|
collate_fn=SpeechCollator(keep_transcription_text=True))
|
|
|
|
collate_fn=SpeechCollator(config=config, keep_transcription_text=True))
|
|
|
|
logger.info("Setup test Dataloader!")
|
|
|
|
logger.info("Setup test Dataloader!")
|
|
|
|
|
|
|
|
|
|
|
|
def setup_output_dir(self):
|
|
|
|
def setup_output_dir(self):
|
|
|
|