From bc9f444d8a31c4751d4aef5e4f90c37f2c3cc4cb Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Wed, 18 Aug 2021 07:53:52 +0000 Subject: [PATCH] add dataloader; check augmenter base class type --- deepspeech/frontend/augmentor/augmentation.py | 2 + deepspeech/io/dataloader.py | 53 +++++++++++++------ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/deepspeech/frontend/augmentor/augmentation.py b/deepspeech/frontend/augmentor/augmentation.py index cfebc463c..7b43988e4 100644 --- a/deepspeech/frontend/augmentor/augmentation.py +++ b/deepspeech/frontend/augmentor/augmentation.py @@ -18,6 +18,7 @@ from inspect import signature import numpy as np +from deepspeech.frontend.augmentor.base import AugmentorBase from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.log import Log @@ -209,6 +210,7 @@ class AugmentationPipeline(): def _get_augmentor(self, augmentor_type, params): """Return an augmentation model by the type name, and pass in params.""" class_obj = dynamic_import(augmentor_type, import_alias) + assert issubclass(class_obj, AugmentorBase) try: obj = class_obj(self._rng, **params) except Exception: diff --git a/deepspeech/io/dataloader.py b/deepspeech/io/dataloader.py index 3c4c2d5ef..15ab73157 100644 --- a/deepspeech/io/dataloader.py +++ b/deepspeech/io/dataloader.py @@ -15,8 +15,8 @@ from paddle.io import DataLoader from deepspeech.frontend.utility import read_manifest from deepspeech.io.batchfy import make_batchset +from deepspeech.io.converter import CustomConverter from deepspeech.io.dataset import TransformDataset -from deepspeech.io.reader import CustomConverter from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.utils.log import Log @@ -46,7 +46,6 @@ class BatchDataLoader(): num_encs: int=1): self.json_file = json_file self.train_mode = train_mode - self.use_sortagrad = sortagrad == -1 or sortagrad > 0 self.batch_size = batch_size self.maxlen_in = maxlen_in @@ -56,20 +55,17 @@ class BatchDataLoader(): self.batch_frames_in = batch_frames_in self.batch_frames_out = batch_frames_out self.batch_frames_inout = batch_frames_inout - self.subsampling_factor = subsampling_factor self.num_encs = num_encs self.preprocess_conf = preprocess_conf - self.n_iter_processes = n_iter_processes # read json data - data_json = read_manifest(json_file) - logger.info(f"load {json_file} file.") + self.data_json = read_manifest(json_file) # make minibatch list (variable length) - self.data = make_batchset( - data_json, + self.minibaches = make_batchset( + self.data_json, batch_size, maxlen_in, maxlen_out, @@ -83,9 +79,9 @@ class BatchDataLoader(): batch_frames_inout=batch_frames_inout, iaxis=0, oaxis=0, ) - logger.info(f"batchfy data {json_file}: {len(self.data)}.") - self.load = LoadInputsAndTargets( + # data reader + self.reader = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=preprocess_conf, @@ -96,7 +92,7 @@ class BatchDataLoader(): # Setup a converter if num_encs == 1: self.converter = CustomConverter( - subsampling_factor=subsampling_factor, dtype=dtype) + subsampling_factor=subsampling_factor, dtype=np.float32) else: assert NotImplementedError("not impl CustomConverterMulEnc.") @@ -104,14 +100,39 @@ class BatchDataLoader(): # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list - self.train_loader = DataLoader( - dataset=TransformDataset( - self.data, lambda data: self.converter([self.load(data, return_uttid=True)])), + self.dataset = TransformDataset( + self.minibaches, + lambda data: self.converter([self.reader(data, return_uttid=True)])) + self.dataloader = DataLoader( + dataset=self.dataset, batch_size=1, shuffle=not use_sortagrad if train_mode else False, collate_fn=lambda x: x[0], num_workers=n_iter_processes, ) - logger.info(f"dataloader for {json_file}.") def __repr__(self): - return f"DataLoader {self.json_file}-{self.train_mode}-{self.use_sortagrad}" + echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> " + echo += f"train_mode: {self.train_mode}, " + echo += f"sortagrad: {self.use_sortagrad}, " + echo += f"batch_size: {self.batch_size}, " + echo += f"maxlen_in: {self.maxlen_in}, " + echo += f"maxlen_out: {self.maxlen_out}, " + echo += f"batch_count: {self.batch_count}, " + echo += f"batch_bins: {self.batch_bins}, " + echo += f"batch_frames_in: {self.batch_frames_in}, " + echo += f"batch_frames_out: {self.batch_frames_out}, " + echo += f"batch_frames_inout: {self.batch_frames_inout}, " + echo += f"subsampling_factor: {self.subsampling_factor}, " + echo += f"num_encs: {self.num_encs}, " + echo += f"num_workers: {self.n_iter_processes}, " + echo += f"file: {self.json_file}" + return echo + + def __len__(self): + return len(self.dataloader) + + def __iter__(self): + return self.dataloader.__iter__() + + def __call__(self): + return self.__iter__()