add dataloader; check augmenter base class type

pull/756/head
Hui Zhang 3 years ago
parent 64cf538e17
commit 4af774d8f0

@ -18,6 +18,7 @@ from inspect import signature
import numpy as np
from deepspeech.frontend.augmentor.base import AugmentorBase
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.log import Log
@ -209,6 +210,7 @@ class AugmentationPipeline():
def _get_augmentor(self, augmentor_type, params):
"""Return an augmentation model by the type name, and pass in params."""
class_obj = dynamic_import(augmentor_type, import_alias)
assert issubclass(class_obj, AugmentorBase)
try:
obj = class_obj(self._rng, **params)
except Exception:

@ -15,8 +15,8 @@ from paddle.io import DataLoader
from deepspeech.frontend.utility import read_manifest
from deepspeech.io.batchfy import make_batchset
from deepspeech.io.converter import CustomConverter
from deepspeech.io.dataset import TransformDataset
from deepspeech.io.reader import CustomConverter
from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.utils.log import Log
@ -46,7 +46,6 @@ class BatchDataLoader():
num_encs: int=1):
self.json_file = json_file
self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
self.batch_size = batch_size
self.maxlen_in = maxlen_in
@ -56,20 +55,17 @@ class BatchDataLoader():
self.batch_frames_in = batch_frames_in
self.batch_frames_out = batch_frames_out
self.batch_frames_inout = batch_frames_inout
self.subsampling_factor = subsampling_factor
self.num_encs = num_encs
self.preprocess_conf = preprocess_conf
self.n_iter_processes = n_iter_processes
# read json data
data_json = read_manifest(json_file)
logger.info(f"load {json_file} file.")
self.data_json = read_manifest(json_file)
# make minibatch list (variable length)
self.data = make_batchset(
data_json,
self.minibaches = make_batchset(
self.data_json,
batch_size,
maxlen_in,
maxlen_out,
@ -83,9 +79,9 @@ class BatchDataLoader():
batch_frames_inout=batch_frames_inout,
iaxis=0,
oaxis=0, )
logger.info(f"batchfy data {json_file}: {len(self.data)}.")
self.load = LoadInputsAndTargets(
# data reader
self.reader = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=preprocess_conf,
@ -96,7 +92,7 @@ class BatchDataLoader():
# Setup a converter
if num_encs == 1:
self.converter = CustomConverter(
subsampling_factor=subsampling_factor, dtype=dtype)
subsampling_factor=subsampling_factor, dtype=np.float32)
else:
assert NotImplementedError("not impl CustomConverterMulEnc.")
@ -104,14 +100,39 @@ class BatchDataLoader():
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
self.train_loader = DataLoader(
dataset=TransformDataset(
self.data, lambda data: self.converter([self.load(data, return_uttid=True)])),
self.dataset = TransformDataset(
self.minibaches,
lambda data: self.converter([self.reader(data, return_uttid=True)]))
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=1,
shuffle=not use_sortagrad if train_mode else False,
collate_fn=lambda x: x[0],
num_workers=n_iter_processes, )
logger.info(f"dataloader for {json_file}.")
def __repr__(self):
return f"DataLoader {self.json_file}-{self.train_mode}-{self.use_sortagrad}"
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
echo += f"train_mode: {self.train_mode}, "
echo += f"sortagrad: {self.use_sortagrad}, "
echo += f"batch_size: {self.batch_size}, "
echo += f"maxlen_in: {self.maxlen_in}, "
echo += f"maxlen_out: {self.maxlen_out}, "
echo += f"batch_count: {self.batch_count}, "
echo += f"batch_bins: {self.batch_bins}, "
echo += f"batch_frames_in: {self.batch_frames_in}, "
echo += f"batch_frames_out: {self.batch_frames_out}, "
echo += f"batch_frames_inout: {self.batch_frames_inout}, "
echo += f"subsampling_factor: {self.subsampling_factor}, "
echo += f"num_encs: {self.num_encs}, "
echo += f"num_workers: {self.n_iter_processes}, "
echo += f"file: {self.json_file}"
return echo
def __len__(self):
return len(self.dataloader)
def __iter__(self):
return self.dataloader.__iter__()
def __call__(self):
return self.__iter__()

Loading…
Cancel
Save