Turn on rnn_use_batch of Paddle for accelartion.

Improve xmap_reader_mp by adding a flush thread.
pull/2/head
Xinghai Sun 8 years ago
parent 8e5c2eb969
commit 2e5e9b8c11

@ -9,6 +9,7 @@ import os
import tarfile import tarfile
import time import time
from Queue import Queue from Queue import Queue
from threading import Thread
from multiprocessing import Process, Manager from multiprocessing import Process, Manager
from paddle.v2.dataset.common import md5file from paddle.v2.dataset.common import md5file
@ -100,7 +101,8 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
in_queue.put((order_id, sample)) in_queue.put((order_id, sample))
in_queue.put(end_flag) in_queue.put(end_flag)
# define a worker to handle samples from in_queue by mapper and put results to out_queue # define a worker to handle samples from in_queue by mapper and put results
# to out_queue
def handle_worker(in_queue, out_queue, mapper): def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get() sample = in_queue.get()
while not isinstance(sample, XmapEndSignal): while not isinstance(sample, XmapEndSignal):
@ -109,7 +111,8 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
in_queue.put(end_flag) in_queue.put(end_flag)
out_queue.put(end_flag) out_queue.put(end_flag)
# define a worker to handle samples from in_queue by mapper and put results to out_queue with order # define a worker to handle samples from in_queue by mapper and put results
# to out_queue with order
def order_handle_worker(in_queue, out_queue, mapper, out_order): def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get() ins = in_queue.get()
while not isinstance(ins, XmapEndSignal): while not isinstance(ins, XmapEndSignal):
@ -123,6 +126,18 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
in_queue.put(end_flag) in_queue.put(end_flag)
out_queue.put(end_flag) out_queue.put(end_flag)
# define a thread worker to flush samples from Manager.Queue to Queue
# for acceleration
def flush_worker(in_queue, out_queue):
finish = 0
while finish < process_num:
sample = in_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
out_queue.put(sample)
out_queue.put(end_flag)
def xreader(): def xreader():
# prepare shared memory # prepare shared memory
manager = Manager() manager = Manager()
@ -147,13 +162,16 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
w.daemon = True w.daemon = True
w.start() w.start()
# start a thread to read data from slow Manager.Queue
flush_queue = Queue(buffer_size)
t = Thread(target=flush_worker, args=(out_queue, flush_queue))
t.daemon = True
t.start()
# get results # get results
finish = 0 sample = flush_queue.get()
while finish < process_num: while not isinstance(sample, XmapEndSignal):
sample = out_queue.get() yield sample
if isinstance(sample, XmapEndSignal): sample = flush_queue.get()
finish += 1
else:
yield sample
return xreader return xreader

@ -116,7 +116,9 @@ def infer():
def main(): def main():
print_arguments(args) print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) paddle.init(use_gpu=args.use_gpu,
rnn_use_batch=True,
trainer_count=args.trainer_count)
infer() infer()

@ -119,7 +119,9 @@ def evaluate():
def main(): def main():
print_arguments(args) print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) paddle.init(use_gpu=args.use_gpu,
rnn_use_batch=True,
trainer_count=args.trainer_count)
evaluate() evaluate()

@ -217,7 +217,9 @@ def tune():
def main(): def main():
print_arguments(args) print_arguments(args)
paddle.init(use_gpu=args.use_gpu, trainer_count=args.trainer_count) paddle.init(use_gpu=args.use_gpu,
rnn_use_batch=True,
trainer_count=args.trainer_count)
tune() tune()

@ -119,6 +119,7 @@ def train():
def main(): def main():
print_arguments(args) print_arguments(args)
paddle.init(use_gpu=args.use_gpu, paddle.init(use_gpu=args.use_gpu,
rnn_use_batch=True,
trainer_count=args.trainer_count, trainer_count=args.trainer_count,
log_clipping=True) log_clipping=True)
train() train()

Loading…
Cancel
Save