|
|
|
@ -10,7 +10,7 @@ import tarfile
|
|
|
|
|
import time
|
|
|
|
|
from Queue import Queue
|
|
|
|
|
from threading import Thread
|
|
|
|
|
from multiprocessing import Process, Manager
|
|
|
|
|
from multiprocessing import Process, Manager, Value
|
|
|
|
|
from paddle.v2.dataset.common import md5file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -101,40 +101,35 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
|
|
|
|
|
:type process_num: int
|
|
|
|
|
:param buffer_size: Maximal buffer size.
|
|
|
|
|
:type buffer_size: int
|
|
|
|
|
:param order: Reserve the order of samples from the given reader.
|
|
|
|
|
:type order: bool
|
|
|
|
|
:return: The wrappered reader
|
|
|
|
|
:rtype: callable
|
|
|
|
|
:return: The wrappered reader and cleanup callback
|
|
|
|
|
:rtype: tuple
|
|
|
|
|
"""
|
|
|
|
|
end_flag = XmapEndSignal()
|
|
|
|
|
|
|
|
|
|
# define a worker to read samples from reader to in_queue
|
|
|
|
|
def read_worker(reader, in_queue):
|
|
|
|
|
for sample in reader():
|
|
|
|
|
in_queue.put(sample)
|
|
|
|
|
in_queue.put(end_flag)
|
|
|
|
|
read_workers = []
|
|
|
|
|
handle_workers = []
|
|
|
|
|
flush_workers = []
|
|
|
|
|
|
|
|
|
|
read_exit_flag = Value('i', 0)
|
|
|
|
|
handle_exit_flag = Value('i', 0)
|
|
|
|
|
flush_exit_flag = Value('i', 0)
|
|
|
|
|
|
|
|
|
|
# define a worker to read samples from reader to in_queue with order flag
|
|
|
|
|
def order_read_worker(reader, in_queue):
|
|
|
|
|
for order_id, sample in enumerate(reader()):
|
|
|
|
|
if read_exit_flag.value == 1: break
|
|
|
|
|
in_queue.put((order_id, sample))
|
|
|
|
|
in_queue.put(end_flag)
|
|
|
|
|
|
|
|
|
|
# define a worker to handle samples from in_queue by mapper and put results
|
|
|
|
|
# to out_queue
|
|
|
|
|
def handle_worker(in_queue, out_queue, mapper):
|
|
|
|
|
sample = in_queue.get()
|
|
|
|
|
while not isinstance(sample, XmapEndSignal):
|
|
|
|
|
out_queue.put(mapper(sample))
|
|
|
|
|
sample = in_queue.get()
|
|
|
|
|
in_queue.put(end_flag)
|
|
|
|
|
out_queue.put(end_flag)
|
|
|
|
|
# the reading worker should not exit until all handling work exited
|
|
|
|
|
while handle_exit_flag.value == 0 or read_exit_flag.value == 0:
|
|
|
|
|
time.sleep(0.001)
|
|
|
|
|
|
|
|
|
|
# define a worker to handle samples from in_queue by mapper and put results
|
|
|
|
|
# to out_queue with order
|
|
|
|
|
def order_handle_worker(in_queue, out_queue, mapper, out_order):
|
|
|
|
|
ins = in_queue.get()
|
|
|
|
|
while not isinstance(ins, XmapEndSignal):
|
|
|
|
|
if handle_exit_flag.value == 1: break
|
|
|
|
|
order_id, sample = ins
|
|
|
|
|
result = mapper(sample)
|
|
|
|
|
while order_id != out_order[0]:
|
|
|
|
@ -144,22 +139,39 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
|
|
|
|
|
ins = in_queue.get()
|
|
|
|
|
in_queue.put(end_flag)
|
|
|
|
|
out_queue.put(end_flag)
|
|
|
|
|
# wait for exit of flushing worker
|
|
|
|
|
while flush_exit_flag.value == 0 or handle_exit_flag.value == 0:
|
|
|
|
|
time.sleep(0.001)
|
|
|
|
|
read_exit_flag.value = 1
|
|
|
|
|
handle_exit_flag.value = 1
|
|
|
|
|
|
|
|
|
|
# define a thread worker to flush samples from Manager.Queue to Queue
|
|
|
|
|
# for acceleration
|
|
|
|
|
def flush_worker(in_queue, out_queue):
|
|
|
|
|
finish = 0
|
|
|
|
|
while finish < process_num:
|
|
|
|
|
while finish < process_num and flush_exit_flag.value == 0:
|
|
|
|
|
sample = in_queue.get()
|
|
|
|
|
if isinstance(sample, XmapEndSignal):
|
|
|
|
|
finish += 1
|
|
|
|
|
else:
|
|
|
|
|
out_queue.put(sample)
|
|
|
|
|
out_queue.put(end_flag)
|
|
|
|
|
handle_exit_flag.value = 1
|
|
|
|
|
flush_exit_flag.value = 1
|
|
|
|
|
|
|
|
|
|
def cleanup():
|
|
|
|
|
# kill all sub process and threads
|
|
|
|
|
os._exit(0)
|
|
|
|
|
# first exit flushing workers
|
|
|
|
|
flush_exit_flag.value = 1
|
|
|
|
|
for w in flush_workers:
|
|
|
|
|
w.join()
|
|
|
|
|
# next exit handling workers
|
|
|
|
|
handle_exit_flag.value = 1
|
|
|
|
|
for w in handle_workers:
|
|
|
|
|
w.join()
|
|
|
|
|
# last exit reading workers
|
|
|
|
|
read_exit_flag.value = 1
|
|
|
|
|
for w in read_workers:
|
|
|
|
|
w.join()
|
|
|
|
|
|
|
|
|
|
def xreader():
|
|
|
|
|
# prepare shared memory
|
|
|
|
@ -169,27 +181,29 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
|
|
|
|
|
out_order = manager.list([0])
|
|
|
|
|
|
|
|
|
|
# start a read worker in a process
|
|
|
|
|
target = order_read_worker if order else read_worker
|
|
|
|
|
target = order_read_worker
|
|
|
|
|
p = Process(target=target, args=(reader, in_queue))
|
|
|
|
|
p.daemon = True
|
|
|
|
|
p.start()
|
|
|
|
|
read_workers.append(p)
|
|
|
|
|
|
|
|
|
|
# start handle_workers with multiple processes
|
|
|
|
|
target = order_handle_worker if order else handle_worker
|
|
|
|
|
args = (in_queue, out_queue, mapper, out_order) if order else (
|
|
|
|
|
in_queue, out_queue, mapper)
|
|
|
|
|
target = order_handle_worker
|
|
|
|
|
args = (in_queue, out_queue, mapper, out_order)
|
|
|
|
|
workers = [
|
|
|
|
|
Process(target=target, args=args) for _ in xrange(process_num)
|
|
|
|
|
]
|
|
|
|
|
for w in workers:
|
|
|
|
|
w.daemon = True
|
|
|
|
|
w.start()
|
|
|
|
|
handle_workers.append(w)
|
|
|
|
|
|
|
|
|
|
# start a thread to read data from slow Manager.Queue
|
|
|
|
|
flush_queue = Queue(buffer_size)
|
|
|
|
|
t = Thread(target=flush_worker, args=(out_queue, flush_queue))
|
|
|
|
|
t.daemon = True
|
|
|
|
|
t.start()
|
|
|
|
|
flush_workers.append(t)
|
|
|
|
|
|
|
|
|
|
# get results
|
|
|
|
|
sample = flush_queue.get()
|
|
|
|
|