From 20e225875c9c877aafb5b0f254d0ccf4de04afb4 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Tue, 5 Dec 2017 14:14:58 +0800 Subject: [PATCH] Simplify parallel part for data processing and fix abnormal exit. --- data_utils/data.py | 5 +--- data_utils/utility.py | 68 ++++++++++++++++++++++++++----------------- 2 files changed, 42 insertions(+), 31 deletions(-) diff --git a/data_utils/data.py b/data_utils/data.py index 9dd2a91f..af6734f7 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -290,10 +290,7 @@ class DataGenerator(object): reader, cleanup_callback = xmap_readers_mp( lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]), - reader, - self._num_threads, - 4096, - order=True) + reader, self._num_threads, 4096) # register callback to main process atexit.register(cleanup_callback) diff --git a/data_utils/utility.py b/data_utils/utility.py index 2633e1b4..89a74c41 100644 --- a/data_utils/utility.py +++ b/data_utils/utility.py @@ -10,7 +10,7 @@ import tarfile import time from Queue import Queue from threading import Thread -from multiprocessing import Process, Manager +from multiprocessing import Process, Manager, Value from paddle.v2.dataset.common import md5file @@ -101,40 +101,35 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False): :type process_num: int :param buffer_size: Maximal buffer size. :type buffer_size: int - :param order: Reserve the order of samples from the given reader. - :type order: bool - :return: The wrappered reader - :rtype: callable + :return: The wrappered reader and cleanup callback + :rtype: tuple """ end_flag = XmapEndSignal() - # define a worker to read samples from reader to in_queue - def read_worker(reader, in_queue): - for sample in reader(): - in_queue.put(sample) - in_queue.put(end_flag) + read_workers = [] + handle_workers = [] + flush_workers = [] + + read_exit_flag = Value('i', 0) + handle_exit_flag = Value('i', 0) + flush_exit_flag = Value('i', 0) # define a worker to read samples from reader to in_queue with order flag def order_read_worker(reader, in_queue): for order_id, sample in enumerate(reader()): + if read_exit_flag.value == 1: break in_queue.put((order_id, sample)) in_queue.put(end_flag) - - # define a worker to handle samples from in_queue by mapper and put results - # to out_queue - def handle_worker(in_queue, out_queue, mapper): - sample = in_queue.get() - while not isinstance(sample, XmapEndSignal): - out_queue.put(mapper(sample)) - sample = in_queue.get() - in_queue.put(end_flag) - out_queue.put(end_flag) + # the reading worker should not exit until all handling work exited + while handle_exit_flag.value == 0 or read_exit_flag.value == 0: + time.sleep(0.001) # define a worker to handle samples from in_queue by mapper and put results # to out_queue with order def order_handle_worker(in_queue, out_queue, mapper, out_order): ins = in_queue.get() while not isinstance(ins, XmapEndSignal): + if handle_exit_flag.value == 1: break order_id, sample = ins result = mapper(sample) while order_id != out_order[0]: @@ -144,22 +139,39 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False): ins = in_queue.get() in_queue.put(end_flag) out_queue.put(end_flag) + # wait for exit of flushing worker + while flush_exit_flag.value == 0 or handle_exit_flag.value == 0: + time.sleep(0.001) + read_exit_flag.value = 1 + handle_exit_flag.value = 1 # define a thread worker to flush samples from Manager.Queue to Queue # for acceleration def flush_worker(in_queue, out_queue): finish = 0 - while finish < process_num: + while finish < process_num and flush_exit_flag.value == 0: sample = in_queue.get() if isinstance(sample, XmapEndSignal): finish += 1 else: out_queue.put(sample) out_queue.put(end_flag) + handle_exit_flag.value = 1 + flush_exit_flag.value = 1 def cleanup(): - # kill all sub process and threads - os._exit(0) + # first exit flushing workers + flush_exit_flag.value = 1 + for w in flush_workers: + w.join() + # next exit handling workers + handle_exit_flag.value = 1 + for w in handle_workers: + w.join() + # last exit reading workers + read_exit_flag.value = 1 + for w in read_workers: + w.join() def xreader(): # prepare shared memory @@ -169,27 +181,29 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False): out_order = manager.list([0]) # start a read worker in a process - target = order_read_worker if order else read_worker + target = order_read_worker p = Process(target=target, args=(reader, in_queue)) p.daemon = True p.start() + read_workers.append(p) # start handle_workers with multiple processes - target = order_handle_worker if order else handle_worker - args = (in_queue, out_queue, mapper, out_order) if order else ( - in_queue, out_queue, mapper) + target = order_handle_worker + args = (in_queue, out_queue, mapper, out_order) workers = [ Process(target=target, args=args) for _ in xrange(process_num) ] for w in workers: w.daemon = True w.start() + handle_workers.append(w) # start a thread to read data from slow Manager.Queue flush_queue = Queue(buffer_size) t = Thread(target=flush_worker, args=(out_queue, flush_queue)) t.daemon = True t.start() + flush_workers.append(t) # get results sample = flush_queue.get()