pull/665/head
Haoxin Ma 3 years ago
parent 279348d786
commit c706dfec2a

@ -165,7 +165,7 @@ class DeepSpeech2Trainer(Trainer):
sortagrad=config.data.sortagrad,
shuffle_method=config.data.shuffle_method)
collate_fn = SpeechCollator(config, keep_transcription_text=False)
collate_fn = SpeechCollator(config=config, keep_transcription_text=False)
self.train_loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
@ -342,7 +342,7 @@ class DeepSpeech2Tester(DeepSpeech2Trainer):
batch_size=config.decoding.batch_size,
shuffle=False,
drop_last=False,
collate_fn=SpeechCollator(keep_transcription_text=True))
collate_fn=SpeechCollator(config=config, keep_transcription_text=True))
logger.info("Setup test Dataloader!")
def setup_output_dir(self):

@ -23,6 +23,8 @@ from deepspeech.frontend.speech import SpeechSegment
import io
import time
from collections import namedtuple
__all__ = ["SpeechCollator"]
logger = Log(__name__).getlog()
@ -50,7 +52,7 @@ class SpeechCollator():
aug_file = config.data.augmentation_config
assert isinstance(aug_file, io.StringIO)
self._local_data = TarLocalData(tar2info={}, tar2object={}
self._local_data = TarLocalData(tar2info={}, tar2object={})
self._augmentation_pipeline = AugmentationPipeline(
augmentation_config=aug_file.read(),
random_seed=config.data.random_seed)

Loading…
Cancel
Save