"""Contains data helper functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import codecs
import os
import tarfile
import time
from Queue import Queue
from threading import Thread
from multiprocessing import Process, Manager, Value
from paddle.v2.dataset.common import md5file


def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0):
    """Load and parse manifest file.

    Instances with durations outside [min_duration, max_duration] will be
    filtered out.

    :param manifest_path: Manifest file to load and parse.
    :type manifest_path: basestring
    :param max_duration: Maximal duration in seconds for instance filter.
    :type max_duration: float
    :param min_duration: Minimal duration in seconds for instance filter.
    :type min_duration: float
    :return: Manifest parsing results. List of dict.
    :rtype: list
    :raises IOError: If failed to parse the manifest.
    """
    manifest = []
    for json_line in codecs.open(manifest_path, 'r', 'utf-8'):
        try:
            json_data = json.loads(json_line)
        except Exception as e:
            raise IOError("Error reading manifest: %s" % str(e))
        if (json_data["duration"] <= max_duration and
                json_data["duration"] >= min_duration):
            manifest.append(json_data)
    return manifest


def getfile_insensitive(path):
    """Get the actual file path when given insensitive filename."""
    directory, filename = os.path.split(path)
    directory, filename = (directory or '.'), filename.lower()
    for f in os.listdir(directory):
        newpath = os.path.join(directory, f)
        if os.path.isfile(newpath) and f.lower() == filename:
            return newpath


def download_multi(url, target_dir, extra_args):
    """Download multiple files from url to target_dir."""
    if not os.path.exists(target_dir): os.makedirs(target_dir)
    print("Downloading %s ..." % url)
    ret_code = os.system("wget -c " + url + ' ' + extra_args + " -P " +
                         target_dir)
    return ret_code


def download(url, md5sum, target_dir):
    """Download file from url to target_dir, and check md5sum."""
    if not os.path.exists(target_dir): os.makedirs(target_dir)
    filepath = os.path.join(target_dir, url.split("/")[-1])
    if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
        print("Downloading %s ..." % url)
        os.system("wget -c " + url + " -P " + target_dir)
        print("\nMD5 Chesksum %s ..." % filepath)
        if not md5file(filepath) == md5sum:
            raise RuntimeError("MD5 checksum failed.")
    else:
        print("File exists, skip downloading. (%s)" % filepath)
    return filepath


def unpack(filepath, target_dir, rm_tar=False):
    """Unpack the file to the target_dir."""
    print("Unpacking %s ..." % filepath)
    tar = tarfile.open(filepath)
    tar.extractall(target_dir)
    tar.close()
    if rm_tar == True:
        os.remove(filepath)


class XmapEndSignal():
    pass


def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
    """A multiprocessing pipeline wrapper for the data reader.

    :param mapper:  Function to map sample.
    :type mapper: callable
    :param reader: Given data reader.
    :type reader: callable
    :param process_num: Number of processes in the pipeline
    :type process_num: int
    :param buffer_size: Maximal buffer size.
    :type buffer_size: int
    :return: The wrappered reader and cleanup callback
    :rtype: tuple
    """
    end_flag = XmapEndSignal()

    read_workers = []
    handle_workers = []
    flush_workers = []

    read_exit_flag = Value('i', 0)
    handle_exit_flag = Value('i', 0)
    flush_exit_flag = Value('i', 0)

    # define a worker to read samples from reader to in_queue with order flag
    def order_read_worker(reader, in_queue):
        for order_id, sample in enumerate(reader()):
            if read_exit_flag.value == 1: break
            in_queue.put((order_id, sample))
        in_queue.put(end_flag)
        # the reading worker should not exit until all handling work exited
        while handle_exit_flag.value == 0 or read_exit_flag.value == 0:
            time.sleep(0.001)

    # define a worker to handle samples from in_queue by mapper and put results
    # to out_queue with order
    def order_handle_worker(in_queue, out_queue, mapper, out_order):
        ins = in_queue.get()
        while not isinstance(ins, XmapEndSignal):
            if handle_exit_flag.value == 1: break
            order_id, sample = ins
            result = mapper(sample)
            while order_id != out_order[0]:
                time.sleep(0.001)
            out_queue.put(result)
            out_order[0] += 1
            ins = in_queue.get()
        in_queue.put(end_flag)
        out_queue.put(end_flag)
        # wait for exit of flushing worker
        while flush_exit_flag.value == 0 or handle_exit_flag.value == 0:
            time.sleep(0.001)
        read_exit_flag.value = 1
        handle_exit_flag.value = 1

    # define a thread worker to flush samples from Manager.Queue to Queue
    # for acceleration
    def flush_worker(in_queue, out_queue):
        finish = 0
        while finish < process_num and flush_exit_flag.value == 0:
            sample = in_queue.get()
            if isinstance(sample, XmapEndSignal):
                finish += 1
            else:
                out_queue.put(sample)
        out_queue.put(end_flag)
        handle_exit_flag.value = 1
        flush_exit_flag.value = 1

    def cleanup():
        # first exit flushing workers
        flush_exit_flag.value = 1
        for w in flush_workers:
            w.join()
        # next exit handling workers
        handle_exit_flag.value = 1
        for w in handle_workers:
            w.join()
        # last exit reading workers
        read_exit_flag.value = 1
        for w in read_workers:
            w.join()

    def xreader():
        # prepare shared memory
        manager = Manager()
        in_queue = manager.Queue(buffer_size)
        out_queue = manager.Queue(buffer_size)
        out_order = manager.list([0])

        # start a read worker in a process
        target = order_read_worker
        p = Process(target=target, args=(reader, in_queue))
        p.daemon = True
        p.start()
        read_workers.append(p)

        # start handle_workers with multiple processes
        target = order_handle_worker
        args = (in_queue, out_queue, mapper, out_order)
        workers = [
            Process(target=target, args=args) for _ in xrange(process_num)
        ]
        for w in workers:
            w.daemon = True
            w.start()
            handle_workers.append(w)

        # start a thread to read data from slow Manager.Queue
        flush_queue = Queue(buffer_size)
        t = Thread(target=flush_worker, args=(out_queue, flush_queue))
        t.daemon = True
        t.start()
        flush_workers.append(t)

        # get results
        sample = flush_queue.get()
        while not isinstance(sample, XmapEndSignal):
            yield sample
            sample = flush_queue.get()

    return xreader, cleanup