|
|
@ -15,8 +15,8 @@ from paddle.io import DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
from deepspeech.frontend.utility import read_manifest
|
|
|
|
from deepspeech.frontend.utility import read_manifest
|
|
|
|
from deepspeech.io.batchfy import make_batchset
|
|
|
|
from deepspeech.io.batchfy import make_batchset
|
|
|
|
|
|
|
|
from deepspeech.io.converter import CustomConverter
|
|
|
|
from deepspeech.io.dataset import TransformDataset
|
|
|
|
from deepspeech.io.dataset import TransformDataset
|
|
|
|
from deepspeech.io.reader import CustomConverter
|
|
|
|
|
|
|
|
from deepspeech.io.reader import LoadInputsAndTargets
|
|
|
|
from deepspeech.io.reader import LoadInputsAndTargets
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
from deepspeech.utils.log import Log
|
|
|
|
|
|
|
|
|
|
|
@ -46,7 +46,6 @@ class BatchDataLoader():
|
|
|
|
num_encs: int=1):
|
|
|
|
num_encs: int=1):
|
|
|
|
self.json_file = json_file
|
|
|
|
self.json_file = json_file
|
|
|
|
self.train_mode = train_mode
|
|
|
|
self.train_mode = train_mode
|
|
|
|
|
|
|
|
|
|
|
|
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
|
|
|
|
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.maxlen_in = maxlen_in
|
|
|
|
self.maxlen_in = maxlen_in
|
|
|
@ -56,20 +55,17 @@ class BatchDataLoader():
|
|
|
|
self.batch_frames_in = batch_frames_in
|
|
|
|
self.batch_frames_in = batch_frames_in
|
|
|
|
self.batch_frames_out = batch_frames_out
|
|
|
|
self.batch_frames_out = batch_frames_out
|
|
|
|
self.batch_frames_inout = batch_frames_inout
|
|
|
|
self.batch_frames_inout = batch_frames_inout
|
|
|
|
|
|
|
|
|
|
|
|
self.subsampling_factor = subsampling_factor
|
|
|
|
self.subsampling_factor = subsampling_factor
|
|
|
|
self.num_encs = num_encs
|
|
|
|
self.num_encs = num_encs
|
|
|
|
self.preprocess_conf = preprocess_conf
|
|
|
|
self.preprocess_conf = preprocess_conf
|
|
|
|
|
|
|
|
|
|
|
|
self.n_iter_processes = n_iter_processes
|
|
|
|
self.n_iter_processes = n_iter_processes
|
|
|
|
|
|
|
|
|
|
|
|
# read json data
|
|
|
|
# read json data
|
|
|
|
data_json = read_manifest(json_file)
|
|
|
|
self.data_json = read_manifest(json_file)
|
|
|
|
logger.info(f"load {json_file} file.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# make minibatch list (variable length)
|
|
|
|
# make minibatch list (variable length)
|
|
|
|
self.data = make_batchset(
|
|
|
|
self.minibaches = make_batchset(
|
|
|
|
data_json,
|
|
|
|
self.data_json,
|
|
|
|
batch_size,
|
|
|
|
batch_size,
|
|
|
|
maxlen_in,
|
|
|
|
maxlen_in,
|
|
|
|
maxlen_out,
|
|
|
|
maxlen_out,
|
|
|
@ -83,9 +79,9 @@ class BatchDataLoader():
|
|
|
|
batch_frames_inout=batch_frames_inout,
|
|
|
|
batch_frames_inout=batch_frames_inout,
|
|
|
|
iaxis=0,
|
|
|
|
iaxis=0,
|
|
|
|
oaxis=0, )
|
|
|
|
oaxis=0, )
|
|
|
|
logger.info(f"batchfy data {json_file}: {len(self.data)}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.load = LoadInputsAndTargets(
|
|
|
|
# data reader
|
|
|
|
|
|
|
|
self.reader = LoadInputsAndTargets(
|
|
|
|
mode="asr",
|
|
|
|
mode="asr",
|
|
|
|
load_output=True,
|
|
|
|
load_output=True,
|
|
|
|
preprocess_conf=preprocess_conf,
|
|
|
|
preprocess_conf=preprocess_conf,
|
|
|
@ -96,7 +92,7 @@ class BatchDataLoader():
|
|
|
|
# Setup a converter
|
|
|
|
# Setup a converter
|
|
|
|
if num_encs == 1:
|
|
|
|
if num_encs == 1:
|
|
|
|
self.converter = CustomConverter(
|
|
|
|
self.converter = CustomConverter(
|
|
|
|
subsampling_factor=subsampling_factor, dtype=dtype)
|
|
|
|
subsampling_factor=subsampling_factor, dtype=np.float32)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
assert NotImplementedError("not impl CustomConverterMulEnc.")
|
|
|
|
assert NotImplementedError("not impl CustomConverterMulEnc.")
|
|
|
|
|
|
|
|
|
|
|
@ -104,14 +100,39 @@ class BatchDataLoader():
|
|
|
|
# actual bathsize is included in a list
|
|
|
|
# actual bathsize is included in a list
|
|
|
|
# default collate function converts numpy array to pytorch tensor
|
|
|
|
# default collate function converts numpy array to pytorch tensor
|
|
|
|
# we used an empty collate function instead which returns list
|
|
|
|
# we used an empty collate function instead which returns list
|
|
|
|
self.train_loader = DataLoader(
|
|
|
|
self.dataset = TransformDataset(
|
|
|
|
dataset=TransformDataset(
|
|
|
|
self.minibaches,
|
|
|
|
self.data, lambda data: self.converter([self.load(data, return_uttid=True)])),
|
|
|
|
lambda data: self.converter([self.reader(data, return_uttid=True)]))
|
|
|
|
|
|
|
|
self.dataloader = DataLoader(
|
|
|
|
|
|
|
|
dataset=self.dataset,
|
|
|
|
batch_size=1,
|
|
|
|
batch_size=1,
|
|
|
|
shuffle=not use_sortagrad if train_mode else False,
|
|
|
|
shuffle=not use_sortagrad if train_mode else False,
|
|
|
|
collate_fn=lambda x: x[0],
|
|
|
|
collate_fn=lambda x: x[0],
|
|
|
|
num_workers=n_iter_processes, )
|
|
|
|
num_workers=n_iter_processes, )
|
|
|
|
logger.info(f"dataloader for {json_file}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
def __repr__(self):
|
|
|
|
return f"DataLoader {self.json_file}-{self.train_mode}-{self.use_sortagrad}"
|
|
|
|
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
|
|
|
|
|
|
|
|
echo += f"train_mode: {self.train_mode}, "
|
|
|
|
|
|
|
|
echo += f"sortagrad: {self.use_sortagrad}, "
|
|
|
|
|
|
|
|
echo += f"batch_size: {self.batch_size}, "
|
|
|
|
|
|
|
|
echo += f"maxlen_in: {self.maxlen_in}, "
|
|
|
|
|
|
|
|
echo += f"maxlen_out: {self.maxlen_out}, "
|
|
|
|
|
|
|
|
echo += f"batch_count: {self.batch_count}, "
|
|
|
|
|
|
|
|
echo += f"batch_bins: {self.batch_bins}, "
|
|
|
|
|
|
|
|
echo += f"batch_frames_in: {self.batch_frames_in}, "
|
|
|
|
|
|
|
|
echo += f"batch_frames_out: {self.batch_frames_out}, "
|
|
|
|
|
|
|
|
echo += f"batch_frames_inout: {self.batch_frames_inout}, "
|
|
|
|
|
|
|
|
echo += f"subsampling_factor: {self.subsampling_factor}, "
|
|
|
|
|
|
|
|
echo += f"num_encs: {self.num_encs}, "
|
|
|
|
|
|
|
|
echo += f"num_workers: {self.n_iter_processes}, "
|
|
|
|
|
|
|
|
echo += f"file: {self.json_file}"
|
|
|
|
|
|
|
|
return echo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
|
|
|
return len(self.dataloader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
|
|
|
return self.dataloader.__iter__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self):
|
|
|
|
|
|
|
|
return self.__iter__()
|
|
|
|