From 856d641c9ce748766ae53c1939fc995dea6aec9a Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Tue, 28 Sep 2021 11:48:21 +0000 Subject: [PATCH] multi worker for dataloader --- deepspeech/exps/deepspeech2/model.py | 8 +++++--- deepspeech/exps/u2/model.py | 9 ++++++--- deepspeech/exps/u2_st/model.py | 9 ++++++--- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/deepspeech/exps/deepspeech2/model.py b/deepspeech/exps/deepspeech2/model.py index b854a996..e84de615 100644 --- a/deepspeech/exps/deepspeech2/model.py +++ b/deepspeech/exps/deepspeech2/model.py @@ -235,16 +235,18 @@ class DeepSpeech2Trainer(Trainer): num_workers=config.collator.num_workers) self.valid_loader = DataLoader( dev_dataset, - batch_size=int(config.collator.batch_size / 4), + batch_size=int(config.collator.batch_size), shuffle=False, drop_last=False, - collate_fn=collate_fn_dev) + collate_fn=collate_fn_dev, + num_workers=config.collator.num_workers) self.test_loader = DataLoader( test_dataset, batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn_test) + collate_fn=collate_fn_test, + num_workers=config.collator.num_workers) logger.info("Setup train/valid/test Dataloader!") diff --git a/deepspeech/exps/u2/model.py b/deepspeech/exps/u2/model.py index 1afd9b10..c30f324b 100644 --- a/deepspeech/exps/u2/model.py +++ b/deepspeech/exps/u2/model.py @@ -292,7 +292,8 @@ class U2Trainer(Trainer): batch_size=config.collator.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn_dev) + collate_fn=collate_fn_dev, + num_workers=config.collator.num_workers, ) # test dataset, return raw text config.data.manifest = config.data.test_manifest @@ -314,7 +315,8 @@ class U2Trainer(Trainer): batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=SpeechCollator.from_config(config)) + collate_fn=SpeechCollator.from_config(config), + num_workers=config.collator.num_workers, ) # return text token id config.collator.keep_transcription_text = False self.align_loader = DataLoader( @@ -322,7 +324,8 @@ class U2Trainer(Trainer): batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=SpeechCollator.from_config(config)) + collate_fn=SpeechCollator.from_config(config), + num_workers=config.collator.num_workers, ) logger.info("Setup train/valid/test/align Dataloader!") def setup_model(self): diff --git a/deepspeech/exps/u2_st/model.py b/deepspeech/exps/u2_st/model.py index 9a34cbdc..c480499c 100644 --- a/deepspeech/exps/u2_st/model.py +++ b/deepspeech/exps/u2_st/model.py @@ -292,7 +292,8 @@ class U2STTrainer(Trainer): batch_size=config.collator.batch_size, shuffle=False, drop_last=False, - collate_fn=collate_fn_dev) + collate_fn=collate_fn_dev, + num_workers=config.collator.num_workers, ) # test dataset, return raw text config.data.manifest = config.data.test_manifest @@ -313,7 +314,8 @@ class U2STTrainer(Trainer): batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=TestCollator.from_config(config)) + collate_fn=TestCollator.from_config(config), + num_workers=config.collator.num_workers, ) # return text token id config.collator.keep_transcription_text = False self.align_loader = DataLoader( @@ -321,7 +323,8 @@ class U2STTrainer(Trainer): batch_size=config.decoding.batch_size, shuffle=False, drop_last=False, - collate_fn=TestCollator.from_config(config)) + collate_fn=TestCollator.from_config(config), + num_workers=config.collator.num_workers, ) logger.info("Setup train/valid/test/align Dataloader!") def setup_model(self):