You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/data_utils/utility.py

182 lines
6.0 KiB

"""Contains data helper functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import codecs
import os
import tarfile
import time
from Queue import Queue
from threading import Thread
from multiprocessing import Process, Manager
from paddle.v2.dataset.common import md5file
def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
"""Load and parse manifest file.
Instances with durations outside [min_duration, max_duration] will be
filtered out.
:param manifest_path: Manifest file to load and parse.
:type manifest_path: basestring
:param max_duration: Maximal duration in seconds for instance filter.
:type max_duration: float
:param min_duration: Minimal duration in seconds for instance filter.
:type min_duration: float
:return: Manifest parsing results. List of dict.
:rtype: list
:raises IOError: If failed to parse the manifest.
"""
manifest = []
for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
try:
json_data = json.loads(json_line)
except Exception as e:
raise IOError("Error reading manifest: %s" % str(e))
if (json_data["duration"] <= max_duration and
json_data["duration"] >= min_duration):
manifest.append(json_data)
return manifest
def download(url, md5sum, target_dir):
"""Download file from url to target_dir, and check md5sum."""
if not os.path.exists(target_dir): os.makedirs(target_dir)
filepath = os.path.join(target_dir, url.split("/")[-1])
if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
print("Downloading %s ..." % url)
os.system("wget -c " + url + " -P " + target_dir)
print("\nMD5 Chesksum %s ..." % filepath)
if not md5file(filepath) == md5sum:
raise RuntimeError("MD5 checksum failed.")
else:
print("File exists, skip downloading. (%s)" % filepath)
return filepath
def unpack(filepath, target_dir, rm_tar=False):
"""Unpack the file to the target_dir."""
print("Unpacking %s ..." % filepath)
tar = tarfile.open(filepath)
tar.extractall(target_dir)
tar.close()
if rm_tar == True:
os.remove(filepath)
class XmapEndSignal():
pass
def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
"""A multiprocessing pipeline wrapper for the data reader.
:param mapper: Function to map sample.
:type mapper: callable
:param reader: Given data reader.
:type reader: callable
:param process_num: Number of processes in the pipeline
:type process_num: int
:param buffer_size: Maximal buffer size.
:type buffer_size: int
:param order: Reserve the order of samples from the given reader.
:type order: bool
:return: The wrappered reader
:rtype: callable
"""
end_flag = XmapEndSignal()
# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
for sample in reader():
in_queue.put(sample)
in_queue.put(end_flag)
# define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue):
for order_id, sample in enumerate(reader()):
in_queue.put((order_id, sample))
in_queue.put(end_flag)
# define a worker to handle samples from in_queue by mapper and put results
# to out_queue
def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get()
while not isinstance(sample, XmapEndSignal):
out_queue.put(mapper(sample))
sample = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# define a worker to handle samples from in_queue by mapper and put results
# to out_queue with order
def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
while not isinstance(ins, XmapEndSignal):
order_id, sample = ins
result = mapper(sample)
while order_id != out_order[0]:
time.sleep(0.001)
out_queue.put(result)
out_order[0] += 1
ins = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# define a thread worker to flush samples from Manager.Queue to Queue
# for acceleration
def flush_worker(in_queue, out_queue):
finish = 0
while finish < process_num:
sample = in_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
out_queue.put(sample)
out_queue.put(end_flag)
def cleanup():
# kill all sub process and threads
os._exit(0)
def xreader():
# prepare shared memory
manager = Manager()
in_queue = manager.Queue(buffer_size)
out_queue = manager.Queue(buffer_size)
out_order = manager.list([0])
# start a read worker in a process
target = order_read_worker if order else read_worker
p = Process(target=target, args=(reader, in_queue))
p.daemon = True
p.start()
# start handle_workers with multiple processes
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = [
Process(target=target, args=args) for _ in xrange(process_num)
]
for w in workers:
w.daemon = True
w.start()
# start a thread to read data from slow Manager.Queue
flush_queue = Queue(buffer_size)
t = Thread(target=flush_worker, args=(out_queue, flush_queue))
t.daemon = True
t.start()
# get results
sample = flush_queue.get()
while not isinstance(sample, XmapEndSignal):
yield sample
sample = flush_queue.get()
return xreader, cleanup