|
|
@ -80,7 +80,7 @@ class DataGenerator(object):
|
|
|
|
padding_to=-1,
|
|
|
|
padding_to=-1,
|
|
|
|
flatten=False,
|
|
|
|
flatten=False,
|
|
|
|
sortagrad=False,
|
|
|
|
sortagrad=False,
|
|
|
|
batch_shuffle=False):
|
|
|
|
shuffle_method="batch_shuffle"):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Batch data reader creator for audio data. Return a callable generator
|
|
|
|
Batch data reader creator for audio data. Return a callable generator
|
|
|
|
function to produce batches of data.
|
|
|
|
function to produce batches of data.
|
|
|
@ -104,12 +104,22 @@ class DataGenerator(object):
|
|
|
|
:param sortagrad: If set True, sort the instances by audio duration
|
|
|
|
:param sortagrad: If set True, sort the instances by audio duration
|
|
|
|
in the first epoch for speed up training.
|
|
|
|
in the first epoch for speed up training.
|
|
|
|
:type sortagrad: bool
|
|
|
|
:type sortagrad: bool
|
|
|
|
:param batch_shuffle: If set True, instances are batch-wise shuffled.
|
|
|
|
:param shuffle_method: Shuffle method. Options:
|
|
|
|
For more details, please see
|
|
|
|
'' or None: no shuffle.
|
|
|
|
``_batch_shuffle.__doc__``.
|
|
|
|
'instance_shuffle': instance-wise shuffle.
|
|
|
|
If sortagrad is True, batch_shuffle is disabled
|
|
|
|
'batch_shuffle': similarly-sized instances are
|
|
|
|
|
|
|
|
put into batches, and then
|
|
|
|
|
|
|
|
batch-wise shuffle the batches.
|
|
|
|
|
|
|
|
For more details, please see
|
|
|
|
|
|
|
|
``_batch_shuffle.__doc__``.
|
|
|
|
|
|
|
|
'batch_shuffle_clipped': 'batch_shuffle' with
|
|
|
|
|
|
|
|
head shift and tail
|
|
|
|
|
|
|
|
clipping. For more
|
|
|
|
|
|
|
|
details, please see
|
|
|
|
|
|
|
|
``_batch_shuffle``.
|
|
|
|
|
|
|
|
If sortagrad is True, shuffle is disabled
|
|
|
|
for the first epoch.
|
|
|
|
for the first epoch.
|
|
|
|
:type batch_shuffle: bool
|
|
|
|
:type shuffle_method: None|str
|
|
|
|
:return: Batch reader function, producing batches of data when called.
|
|
|
|
:return: Batch reader function, producing batches of data when called.
|
|
|
|
:rtype: callable
|
|
|
|
:rtype: callable
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -123,8 +133,20 @@ class DataGenerator(object):
|
|
|
|
# sort (by duration) or batch-wise shuffle the manifest
|
|
|
|
# sort (by duration) or batch-wise shuffle the manifest
|
|
|
|
if self._epoch == 0 and sortagrad:
|
|
|
|
if self._epoch == 0 and sortagrad:
|
|
|
|
manifest.sort(key=lambda x: x["duration"])
|
|
|
|
manifest.sort(key=lambda x: x["duration"])
|
|
|
|
elif batch_shuffle:
|
|
|
|
else:
|
|
|
|
manifest = self._batch_shuffle(manifest, batch_size)
|
|
|
|
if shuffle_method == "batch_shuffle":
|
|
|
|
|
|
|
|
manifest = self._batch_shuffle(
|
|
|
|
|
|
|
|
manifest, batch_size, clipped=False)
|
|
|
|
|
|
|
|
elif shuffle_method == "batch_shuffle_clipped":
|
|
|
|
|
|
|
|
manifest = self._batch_shuffle(
|
|
|
|
|
|
|
|
manifest, batch_size, clipped=True)
|
|
|
|
|
|
|
|
elif shuffle_method == "instance_shuffle":
|
|
|
|
|
|
|
|
self._rng.shuffle(manifest)
|
|
|
|
|
|
|
|
elif not shuffle_method:
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError("Unknown shuffle method %s." %
|
|
|
|
|
|
|
|
shuffle_method)
|
|
|
|
# prepare batches
|
|
|
|
# prepare batches
|
|
|
|
instance_reader = self._instance_reader_creator(manifest)
|
|
|
|
instance_reader = self._instance_reader_creator(manifest)
|
|
|
|
batch = []
|
|
|
|
batch = []
|
|
|
@ -218,7 +240,7 @@ class DataGenerator(object):
|
|
|
|
new_batch.append((padded_audio, text))
|
|
|
|
new_batch.append((padded_audio, text))
|
|
|
|
return new_batch
|
|
|
|
return new_batch
|
|
|
|
|
|
|
|
|
|
|
|
def _batch_shuffle(self, manifest, batch_size):
|
|
|
|
def _batch_shuffle(self, manifest, batch_size, clipped=False):
|
|
|
|
"""Put similarly-sized instances into minibatches for better efficiency
|
|
|
|
"""Put similarly-sized instances into minibatches for better efficiency
|
|
|
|
and make a batch-wise shuffle.
|
|
|
|
and make a batch-wise shuffle.
|
|
|
|
|
|
|
|
|
|
|
@ -233,6 +255,9 @@ class DataGenerator(object):
|
|
|
|
:param batch_size: Batch size. This size is also used for generate
|
|
|
|
:param batch_size: Batch size. This size is also used for generate
|
|
|
|
a random number for batch shuffle.
|
|
|
|
a random number for batch shuffle.
|
|
|
|
:type batch_size: int
|
|
|
|
:type batch_size: int
|
|
|
|
|
|
|
|
:param clipped: Whether to clip the heading (small shift) and trailing
|
|
|
|
|
|
|
|
(incomplete batch) instances.
|
|
|
|
|
|
|
|
:type clipped: bool
|
|
|
|
:return: Batch shuffled mainifest.
|
|
|
|
:return: Batch shuffled mainifest.
|
|
|
|
:rtype: list
|
|
|
|
:rtype: list
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -241,7 +266,8 @@ class DataGenerator(object):
|
|
|
|
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
|
|
|
|
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
|
|
|
|
self._rng.shuffle(batch_manifest)
|
|
|
|
self._rng.shuffle(batch_manifest)
|
|
|
|
batch_manifest = list(sum(batch_manifest, ()))
|
|
|
|
batch_manifest = list(sum(batch_manifest, ()))
|
|
|
|
res_len = len(manifest) - shift_len - len(batch_manifest)
|
|
|
|
if not clipped:
|
|
|
|
batch_manifest.extend(manifest[-res_len:])
|
|
|
|
res_len = len(manifest) - shift_len - len(batch_manifest)
|
|
|
|
batch_manifest.extend(manifest[0:shift_len])
|
|
|
|
batch_manifest.extend(manifest[-res_len:])
|
|
|
|
|
|
|
|
batch_manifest.extend(manifest[0:shift_len])
|
|
|
|
return batch_manifest
|
|
|
|
return batch_manifest
|
|
|
|