# Copyright (c) 2023 speechbrain Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# 
# Modified from speechbrain 2023 (https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/dataio/dataloader.py)
"""Paddle compatible DataLoaders

Essentially we extend Paddle DataLoader by adding the ability to save the
data loading state, so that a checkpoint may be saved in the middle of an
epoch.

Authors:
  * Aku Rouhe 2020
"""
import collections
import functools
import logging
import warnings

import paddle
from paddle.io import DataLoader

from paddlespeech.s2t.io.speechbrain.data_utils import batch_pad_right
from paddlespeech.s2t.io.speechbrain.data_utils import mod_default_collate
from paddlespeech.s2t.io.speechbrain.dataset import DynamicItemDataset
from paddlespeech.s2t.io.speechbrain.sampler import ReproducibleRandomSampler
PaddedData = collections.namedtuple("PaddedData", ["data", "lengths"])
import numpy


class Wav2vec2DataLoader(DataLoader):
    def __init__(self,
                 dataset,
                 batch_size=1,
                 shuffle=False,
                 sampler=None,
                 batch_sampler=None,
                 num_workers=0,
                 collate_fn=None,
                 pin_memory=False,
                 drop_last=False,
                 timeout=0,
                 worker_init_fn=None,
                 multiprocessing_context=None,
                 generator=None):
        if isinstance(dataset[0], (tuple, list)):
            return_list = True
        else:
            return_list = False

        super().__init__(
            dataset,
            feed_list=None,
            places=None,
            return_list=return_list,
            batch_sampler=batch_sampler,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=drop_last,
            collate_fn=collate_fn,
            num_workers=num_workers,
            use_buffer_reader=True,
            use_shared_memory=False,
            timeout=timeout,
            worker_init_fn=worker_init_fn)
        if sampler is not None:
            self.batch_sampler.sampler = sampler


def PaddedBatch(
        examples,
        padded_keys=None,
        device_prep_keys=None,
        padding_func=batch_pad_right,
        padding_kwargs={},
        nonpadded_stack=True, ):
    __length = len(examples)
    __keys = list(examples[0].keys())
    __padded_keys = []
    __device_prep_keys = []
    res = {}
    for key in __keys:
        values = [example[key] for example in examples]
        # Default convert usually does the right thing (numpy2tensor etc.)
        # values = default_convert(values)
        if (padded_keys is not None and key in padded_keys) or (
                padded_keys is None and isinstance(values[0], numpy.ndarray)):
            # Padding and PaddedData
            __padded_keys.append(key)

            padded = PaddedData(*padding_func(values, **padding_kwargs))
            res[key] = padded
        else:
            # Default collate usually does the right thing
            # (convert lists of equal sized tensors to batch tensors, etc.)
            if nonpadded_stack:
                values = mod_default_collate(values)
            res[key] = values
        if (device_prep_keys is not None and key in device_prep_keys) or (
                device_prep_keys is None and
                isinstance(values[0], paddle.Tensor)):
            __device_prep_keys.append(key)
    return res


def make_dataloader(dataset, stage, **loader_kwargs):
    """Makes a basic DataLoader.

    For DynamicItemDatasets (which return dicts), use
    PaddedBatch as the default collate_fn.

    Shuffling gets implemented by ReproducibleRandomSampler.

    If the Dataset is not an IterableDataset, the DataLoader
    is a SaveableDataLoader.

    If the Dataset is a webdataset.dataset.Composable, set default
    batch_size = None.

    Can also loop over the underlying dataloader continuously,
    and stop iterations at nominal epoch lengths.

    Arguments
    ---------
    dataset : Dataset
        The dataset to make a DataLoader for.
    looped_nominal_epoch : None, int
        If an integer is given, loop the underlying DataLoader infinitely and
        set a nominal epoch length in batches (or whatever the DataLoader
        yields).
    **loader_kwargs : dict
        Keyword args to DataLoader, see Paddle DataLoader for
        options.

    Returns
    -------
    DataLoader
        If looped_nominal_epoch is None
    LoopedLoader
        If looped_nominal_epoch is not None
    """
    # PaddedBatch as default collation for DynamicItemDataset
    if "collate_fn" not in loader_kwargs and isinstance(dataset,
                                                        DynamicItemDataset):
        loader_kwargs["collate_fn"] = PaddedBatch
    # Reproducible random sampling
    if loader_kwargs.get("shuffle", False):
        if loader_kwargs.get("sampler") is not None:
            raise ValueError("Cannot specify both shuffle=True and a "
                             "sampler in loader_kwargs")
        sampler = ReproducibleRandomSampler(dataset)
        loader_kwargs["sampler"] = sampler
        # Should delete shuffle because you can't set both Sampler and
        # shuffle
        # NOTE: the dict of loader options may get used elsewhere!
        # However, this del doesn't touch those because loader_kwargs comes
        # from a **kwargs dict.
        del loader_kwargs["shuffle"]
    # Create the loader
    dataloader = Wav2vec2DataLoader(dataset, **loader_kwargs)
    return dataloader