Simplify parallel part for data processing and fix abnormal exit.

pull/58/head
yangyaming 7 years ago
parent a835d41206
commit 20e225875c

@ -290,10 +290,7 @@ class DataGenerator(object):
reader, cleanup_callback = xmap_readers_mp( reader, cleanup_callback = xmap_readers_mp(
lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]), lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
reader, reader, self._num_threads, 4096)
self._num_threads,
4096,
order=True)
# register callback to main process # register callback to main process
atexit.register(cleanup_callback) atexit.register(cleanup_callback)

@ -10,7 +10,7 @@ import tarfile
import time import time
from Queue import Queue from Queue import Queue
from threading import Thread from threading import Thread
from multiprocessing import Process, Manager from multiprocessing import Process, Manager, Value
from paddle.v2.dataset.common import md5file from paddle.v2.dataset.common import md5file
@ -101,40 +101,35 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
:type process_num: int :type process_num: int
:param buffer_size: Maximal buffer size. :param buffer_size: Maximal buffer size.
:type buffer_size: int :type buffer_size: int
:param order: Reserve the order of samples from the given reader. :return: The wrappered reader and cleanup callback
:type order: bool :rtype: tuple
:return: The wrappered reader
:rtype: callable
""" """
end_flag = XmapEndSignal() end_flag = XmapEndSignal()
# define a worker to read samples from reader to in_queue read_workers = []
def read_worker(reader, in_queue): handle_workers = []
for sample in reader(): flush_workers = []
in_queue.put(sample)
in_queue.put(end_flag) read_exit_flag = Value('i', 0)
handle_exit_flag = Value('i', 0)
flush_exit_flag = Value('i', 0)
# define a worker to read samples from reader to in_queue with order flag # define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue): def order_read_worker(reader, in_queue):
for order_id, sample in enumerate(reader()): for order_id, sample in enumerate(reader()):
if read_exit_flag.value == 1: break
in_queue.put((order_id, sample)) in_queue.put((order_id, sample))
in_queue.put(end_flag) in_queue.put(end_flag)
# the reading worker should not exit until all handling work exited
# define a worker to handle samples from in_queue by mapper and put results while handle_exit_flag.value == 0 or read_exit_flag.value == 0:
# to out_queue time.sleep(0.001)
def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get()
while not isinstance(sample, XmapEndSignal):
out_queue.put(mapper(sample))
sample = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# define a worker to handle samples from in_queue by mapper and put results # define a worker to handle samples from in_queue by mapper and put results
# to out_queue with order # to out_queue with order
def order_handle_worker(in_queue, out_queue, mapper, out_order): def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get() ins = in_queue.get()
while not isinstance(ins, XmapEndSignal): while not isinstance(ins, XmapEndSignal):
if handle_exit_flag.value == 1: break
order_id, sample = ins order_id, sample = ins
result = mapper(sample) result = mapper(sample)
while order_id != out_order[0]: while order_id != out_order[0]:
@ -144,22 +139,39 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
ins = in_queue.get() ins = in_queue.get()
in_queue.put(end_flag) in_queue.put(end_flag)
out_queue.put(end_flag) out_queue.put(end_flag)
# wait for exit of flushing worker
while flush_exit_flag.value == 0 or handle_exit_flag.value == 0:
time.sleep(0.001)
read_exit_flag.value = 1
handle_exit_flag.value = 1
# define a thread worker to flush samples from Manager.Queue to Queue # define a thread worker to flush samples from Manager.Queue to Queue
# for acceleration # for acceleration
def flush_worker(in_queue, out_queue): def flush_worker(in_queue, out_queue):
finish = 0 finish = 0
while finish < process_num: while finish < process_num and flush_exit_flag.value == 0:
sample = in_queue.get() sample = in_queue.get()
if isinstance(sample, XmapEndSignal): if isinstance(sample, XmapEndSignal):
finish += 1 finish += 1
else: else:
out_queue.put(sample) out_queue.put(sample)
out_queue.put(end_flag) out_queue.put(end_flag)
handle_exit_flag.value = 1
flush_exit_flag.value = 1
def cleanup(): def cleanup():
# kill all sub process and threads # first exit flushing workers
os._exit(0) flush_exit_flag.value = 1
for w in flush_workers:
w.join()
# next exit handling workers
handle_exit_flag.value = 1
for w in handle_workers:
w.join()
# last exit reading workers
read_exit_flag.value = 1
for w in read_workers:
w.join()
def xreader(): def xreader():
# prepare shared memory # prepare shared memory
@ -169,27 +181,29 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
out_order = manager.list([0]) out_order = manager.list([0])
# start a read worker in a process # start a read worker in a process
target = order_read_worker if order else read_worker target = order_read_worker
p = Process(target=target, args=(reader, in_queue)) p = Process(target=target, args=(reader, in_queue))
p.daemon = True p.daemon = True
p.start() p.start()
read_workers.append(p)
# start handle_workers with multiple processes # start handle_workers with multiple processes
target = order_handle_worker if order else handle_worker target = order_handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else ( args = (in_queue, out_queue, mapper, out_order)
in_queue, out_queue, mapper)
workers = [ workers = [
Process(target=target, args=args) for _ in xrange(process_num) Process(target=target, args=args) for _ in xrange(process_num)
] ]
for w in workers: for w in workers:
w.daemon = True w.daemon = True
w.start() w.start()
handle_workers.append(w)
# start a thread to read data from slow Manager.Queue # start a thread to read data from slow Manager.Queue
flush_queue = Queue(buffer_size) flush_queue = Queue(buffer_size)
t = Thread(target=flush_worker, args=(out_queue, flush_queue)) t = Thread(target=flush_worker, args=(out_queue, flush_queue))
t.daemon = True t.daemon = True
t.start() t.start()
flush_workers.append(t)
# get results # get results
sample = flush_queue.get() sample = flush_queue.get()

Loading…
Cancel
Save