|
|
|
@ -11,7 +11,6 @@ import multiprocessing
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle.v2 as paddle
|
|
|
|
|
from threading import local
|
|
|
|
|
import atexit
|
|
|
|
|
from data_utils.utility import read_manifest
|
|
|
|
|
from data_utils.utility import xmap_readers_mp
|
|
|
|
|
from data_utils.augmentor.augmentation import AugmentationPipeline
|
|
|
|
@ -194,15 +193,18 @@ class DataGenerator(object):
|
|
|
|
|
raise ValueError("Unknown shuffle method %s." %
|
|
|
|
|
shuffle_method)
|
|
|
|
|
# prepare batches
|
|
|
|
|
instance_reader = self._instance_reader_creator(manifest)
|
|
|
|
|
instance_reader, cleanup = self._instance_reader_creator(manifest)
|
|
|
|
|
batch = []
|
|
|
|
|
for instance in instance_reader():
|
|
|
|
|
batch.append(instance)
|
|
|
|
|
if len(batch) == batch_size:
|
|
|
|
|
try:
|
|
|
|
|
for instance in instance_reader():
|
|
|
|
|
batch.append(instance)
|
|
|
|
|
if len(batch) == batch_size:
|
|
|
|
|
yield self._padding_batch(batch, padding_to, flatten)
|
|
|
|
|
batch = []
|
|
|
|
|
if len(batch) >= min_batch_size:
|
|
|
|
|
yield self._padding_batch(batch, padding_to, flatten)
|
|
|
|
|
batch = []
|
|
|
|
|
if len(batch) >= min_batch_size:
|
|
|
|
|
yield self._padding_batch(batch, padding_to, flatten)
|
|
|
|
|
finally:
|
|
|
|
|
cleanup()
|
|
|
|
|
self._epoch += 1
|
|
|
|
|
|
|
|
|
|
return batch_reader
|
|
|
|
@ -280,10 +282,7 @@ class DataGenerator(object):
|
|
|
|
|
lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
|
|
|
|
|
reader, self._num_threads, 4096)
|
|
|
|
|
|
|
|
|
|
# register callback to main process
|
|
|
|
|
atexit.register(cleanup_callback)
|
|
|
|
|
|
|
|
|
|
return reader
|
|
|
|
|
return reader, cleanup_callback
|
|
|
|
|
|
|
|
|
|
def _padding_batch(self, batch, padding_to=-1, flatten=False):
|
|
|
|
|
"""
|
|
|
|
|