Simplify parallel part for data processing and fix abnormal exit.

pull/58/head
yangyaming 7 years ago
parent a835d41206
commit 20e225875c

@ -290,10 +290,7 @@ class DataGenerator(object):
reader, cleanup_callback = xmap_readers_mp(
lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
reader,
self._num_threads,
4096,
order=True)
reader, self._num_threads, 4096)
# register callback to main process
atexit.register(cleanup_callback)

@ -10,7 +10,7 @@ import tarfile
import time
from Queue import Queue
from threading import Thread
from multiprocessing import Process, Manager
from multiprocessing import Process, Manager, Value
from paddle.v2.dataset.common import md5file
@ -101,40 +101,35 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
:type process_num: int
:param buffer_size: Maximal buffer size.
:type buffer_size: int
:param order: Reserve the order of samples from the given reader.
:type order: bool
:return: The wrappered reader
:rtype: callable
:return: The wrappered reader and cleanup callback
:rtype: tuple
"""
end_flag = XmapEndSignal()
# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
for sample in reader():
in_queue.put(sample)
in_queue.put(end_flag)
read_workers = []
handle_workers = []
flush_workers = []
read_exit_flag = Value('i', 0)
handle_exit_flag = Value('i', 0)
flush_exit_flag = Value('i', 0)
# define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue):
for order_id, sample in enumerate(reader()):
if read_exit_flag.value == 1: break
in_queue.put((order_id, sample))
in_queue.put(end_flag)
# define a worker to handle samples from in_queue by mapper and put results
# to out_queue
def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get()
while not isinstance(sample, XmapEndSignal):
out_queue.put(mapper(sample))
sample = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# the reading worker should not exit until all handling work exited
while handle_exit_flag.value == 0 or read_exit_flag.value == 0:
time.sleep(0.001)
# define a worker to handle samples from in_queue by mapper and put results
# to out_queue with order
def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
while not isinstance(ins, XmapEndSignal):
if handle_exit_flag.value == 1: break
order_id, sample = ins
result = mapper(sample)
while order_id != out_order[0]:
@ -144,22 +139,39 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
ins = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# wait for exit of flushing worker
while flush_exit_flag.value == 0 or handle_exit_flag.value == 0:
time.sleep(0.001)
read_exit_flag.value = 1
handle_exit_flag.value = 1
# define a thread worker to flush samples from Manager.Queue to Queue
# for acceleration
def flush_worker(in_queue, out_queue):
finish = 0
while finish < process_num:
while finish < process_num and flush_exit_flag.value == 0:
sample = in_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
out_queue.put(sample)
out_queue.put(end_flag)
handle_exit_flag.value = 1
flush_exit_flag.value = 1
def cleanup():
# kill all sub process and threads
os._exit(0)
# first exit flushing workers
flush_exit_flag.value = 1
for w in flush_workers:
w.join()
# next exit handling workers
handle_exit_flag.value = 1
for w in handle_workers:
w.join()
# last exit reading workers
read_exit_flag.value = 1
for w in read_workers:
w.join()
def xreader():
# prepare shared memory
@ -169,27 +181,29 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
out_order = manager.list([0])
# start a read worker in a process
target = order_read_worker if order else read_worker
target = order_read_worker
p = Process(target=target, args=(reader, in_queue))
p.daemon = True
p.start()
read_workers.append(p)
# start handle_workers with multiple processes
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
target = order_handle_worker
args = (in_queue, out_queue, mapper, out_order)
workers = [
Process(target=target, args=args) for _ in xrange(process_num)
]
for w in workers:
w.daemon = True
w.start()
handle_workers.append(w)
# start a thread to read data from slow Manager.Queue
flush_queue = Queue(buffer_size)
t = Thread(target=flush_worker, args=(out_queue, flush_queue))
t.daemon = True
t.start()
flush_workers.append(t)
# get results
sample = flush_queue.get()

Loading…
Cancel
Save