add dataloader; check augmenter base class type

pull/756/head
Hui Zhang 3 years ago
parent 64cf538e17
commit 4af774d8f0

@ -18,6 +18,7 @@ from inspect import signature
import numpy as np import numpy as np
from deepspeech.frontend.augmentor.base import AugmentorBase
from deepspeech.utils.dynamic_import import dynamic_import from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
@ -209,6 +210,7 @@ class AugmentationPipeline():
def _get_augmentor(self, augmentor_type, params): def _get_augmentor(self, augmentor_type, params):
"""Return an augmentation model by the type name, and pass in params.""" """Return an augmentation model by the type name, and pass in params."""
class_obj = dynamic_import(augmentor_type, import_alias) class_obj = dynamic_import(augmentor_type, import_alias)
assert issubclass(class_obj, AugmentorBase)
try: try:
obj = class_obj(self._rng, **params) obj = class_obj(self._rng, **params)
except Exception: except Exception:

@ -15,8 +15,8 @@ from paddle.io import DataLoader
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from deepspeech.io.batchfy import make_batchset from deepspeech.io.batchfy import make_batchset
from deepspeech.io.converter import CustomConverter
from deepspeech.io.dataset import TransformDataset from deepspeech.io.dataset import TransformDataset
from deepspeech.io.reader import CustomConverter
from deepspeech.io.reader import LoadInputsAndTargets from deepspeech.io.reader import LoadInputsAndTargets
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
@ -46,7 +46,6 @@ class BatchDataLoader():
num_encs: int=1): num_encs: int=1):
self.json_file = json_file self.json_file = json_file
self.train_mode = train_mode self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0 self.use_sortagrad = sortagrad == -1 or sortagrad > 0
self.batch_size = batch_size self.batch_size = batch_size
self.maxlen_in = maxlen_in self.maxlen_in = maxlen_in
@ -56,20 +55,17 @@ class BatchDataLoader():
self.batch_frames_in = batch_frames_in self.batch_frames_in = batch_frames_in
self.batch_frames_out = batch_frames_out self.batch_frames_out = batch_frames_out
self.batch_frames_inout = batch_frames_inout self.batch_frames_inout = batch_frames_inout
self.subsampling_factor = subsampling_factor self.subsampling_factor = subsampling_factor
self.num_encs = num_encs self.num_encs = num_encs
self.preprocess_conf = preprocess_conf self.preprocess_conf = preprocess_conf
self.n_iter_processes = n_iter_processes self.n_iter_processes = n_iter_processes
# read json data # read json data
data_json = read_manifest(json_file) self.data_json = read_manifest(json_file)
logger.info(f"load {json_file} file.")
# make minibatch list (variable length) # make minibatch list (variable length)
self.data = make_batchset( self.minibaches = make_batchset(
data_json, self.data_json,
batch_size, batch_size,
maxlen_in, maxlen_in,
maxlen_out, maxlen_out,
@ -83,9 +79,9 @@ class BatchDataLoader():
batch_frames_inout=batch_frames_inout, batch_frames_inout=batch_frames_inout,
iaxis=0, iaxis=0,
oaxis=0, ) oaxis=0, )
logger.info(f"batchfy data {json_file}: {len(self.data)}.")
self.load = LoadInputsAndTargets( # data reader
self.reader = LoadInputsAndTargets(
mode="asr", mode="asr",
load_output=True, load_output=True,
preprocess_conf=preprocess_conf, preprocess_conf=preprocess_conf,
@ -96,7 +92,7 @@ class BatchDataLoader():
# Setup a converter # Setup a converter
if num_encs == 1: if num_encs == 1:
self.converter = CustomConverter( self.converter = CustomConverter(
subsampling_factor=subsampling_factor, dtype=dtype) subsampling_factor=subsampling_factor, dtype=np.float32)
else: else:
assert NotImplementedError("not impl CustomConverterMulEnc.") assert NotImplementedError("not impl CustomConverterMulEnc.")
@ -104,14 +100,39 @@ class BatchDataLoader():
# actual bathsize is included in a list # actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor # default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list # we used an empty collate function instead which returns list
self.train_loader = DataLoader( self.dataset = TransformDataset(
dataset=TransformDataset( self.minibaches,
self.data, lambda data: self.converter([self.load(data, return_uttid=True)])), lambda data: self.converter([self.reader(data, return_uttid=True)]))
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=1, batch_size=1,
shuffle=not use_sortagrad if train_mode else False, shuffle=not use_sortagrad if train_mode else False,
collate_fn=lambda x: x[0], collate_fn=lambda x: x[0],
num_workers=n_iter_processes, ) num_workers=n_iter_processes, )
logger.info(f"dataloader for {json_file}.")
def __repr__(self): def __repr__(self):
return f"DataLoader {self.json_file}-{self.train_mode}-{self.use_sortagrad}" echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
echo += f"train_mode: {self.train_mode}, "
echo += f"sortagrad: {self.use_sortagrad}, "
echo += f"batch_size: {self.batch_size}, "
echo += f"maxlen_in: {self.maxlen_in}, "
echo += f"maxlen_out: {self.maxlen_out}, "
echo += f"batch_count: {self.batch_count}, "
echo += f"batch_bins: {self.batch_bins}, "
echo += f"batch_frames_in: {self.batch_frames_in}, "
echo += f"batch_frames_out: {self.batch_frames_out}, "
echo += f"batch_frames_inout: {self.batch_frames_inout}, "
echo += f"subsampling_factor: {self.subsampling_factor}, "
echo += f"num_encs: {self.num_encs}, "
echo += f"num_workers: {self.n_iter_processes}, "
echo += f"file: {self.json_file}"
return echo
def __len__(self):
return len(self.dataloader)
def __iter__(self):
return self.dataloader.__iter__()
def __call__(self):
return self.__iter__()

Loading…
Cancel
Save