multi worker for dataloader

pull/865/head
Hui Zhang 3 years ago
parent b7b1bda34f
commit 856d641c9c

@ -235,16 +235,18 @@ class DeepSpeech2Trainer(Trainer):
num_workers=config.collator.num_workers)
self.valid_loader = DataLoader(
dev_dataset,
batch_size=int(config.collator.batch_size / 4),
batch_size=int(config.collator.batch_size),
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers)
self.test_loader = DataLoader(
test_dataset,
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_test)
collate_fn=collate_fn_test,
num_workers=config.collator.num_workers)
logger.info("Setup train/valid/test Dataloader!")

@ -292,7 +292,8 @@ class U2Trainer(Trainer):
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers, )
# test dataset, return raw text
config.data.manifest = config.data.test_manifest
@ -314,7 +315,8 @@ class U2Trainer(Trainer):
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
collate_fn=SpeechCollator.from_config(config),
num_workers=config.collator.num_workers, )
# return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
@ -322,7 +324,8 @@ class U2Trainer(Trainer):
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator.from_config(config))
collate_fn=SpeechCollator.from_config(config),
num_workers=config.collator.num_workers, )
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self):

@ -292,7 +292,8 @@ class U2STTrainer(Trainer):
batch_size=config.collator.batch_size,
shuffle=False,
drop_last=False,
collate_fn=collate_fn_dev)
collate_fn=collate_fn_dev,
num_workers=config.collator.num_workers, )
# test dataset, return raw text
config.data.manifest = config.data.test_manifest
@ -313,7 +314,8 @@ class U2STTrainer(Trainer):
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=TestCollator.from_config(config))
collate_fn=TestCollator.from_config(config),
num_workers=config.collator.num_workers, )
# return text token id
config.collator.keep_transcription_text = False
self.align_loader = DataLoader(
@ -321,7 +323,8 @@ class U2STTrainer(Trainer):
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=TestCollator.from_config(config))
collate_fn=TestCollator.from_config(config),
num_workers=config.collator.num_workers, )
logger.info("Setup train/valid/test/align Dataloader!")
def setup_model(self):

Loading…
Cancel
Save