# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import collections
import json
import os
from typing import Dict

from paddle.io import Dataset
from tqdm import tqdm

from ..backends import load as load_audio
from ..utils.download import decompress
from ..utils.download import download_and_decompress
from ..utils.env import DATA_HOME
from ..utils.log import logger
from .dataset import feat_funcs

__all__ = ['AISHELL1']


class AISHELL1(Dataset):
    """
    This Open Source Mandarin Speech Corpus, AISHELL-ASR0009-OS1, is 178 hours long.
    It is a part of AISHELL-ASR0009, of which utterance contains 11 domains, including
    smart home, autonomous driving, and industrial production. The whole recording was
    put in quiet indoor environment, using 3 different devices at the same time: high
    fidelity microphone (44.1kHz, 16-bit,); Android-system mobile phone (16kHz, 16-bit),
    iOS-system mobile phone (16kHz, 16-bit). Audios in high fidelity were re-sampled
    to 16kHz to build AISHELL- ASR0009-OS1. 400 speakers from different accent areas
    in China were invited to participate in the recording. The manual transcription
    accuracy rate is above 95%, through professional speech annotation and strict
    quality inspection. The corpus is divided into training, development and testing
    sets.

    Reference:
        AISHELL-1: An Open-Source Mandarin Speech Corpus and A Speech Recognition Baseline
        https://arxiv.org/abs/1709.05522
    """

    archieves = [
        {
            'url': 'http://www.openslr.org/resources/33/data_aishell.tgz',
            'md5': '2f494334227864a8a8fec932999db9d8',
        },
    ]
    text_meta = os.path.join('data_aishell', 'transcript',
                             'aishell_transcript_v0.8.txt')
    utt_info = collections.namedtuple('META_INFO',
                                      ('file_path', 'utt_id', 'text'))
    audio_path = os.path.join('data_aishell', 'wav')
    manifest_path = os.path.join('data_aishell', 'manifest')
    subset = ['train', 'dev', 'test']

    def __init__(self, subset: str='train', feat_type: str='raw', **kwargs):
        assert subset in self.subset, 'Dataset subset must be one in {}, but got {}'.format(
            self.subset, subset)
        self.subset = subset
        self.feat_type = feat_type
        self.feat_config = kwargs
        self._data = self._get_data()
        super(AISHELL1, self).__init__()

    def _get_text_info(self) -> Dict[str, str]:
        ret = {}
        with open(os.path.join(DATA_HOME, self.text_meta), 'r') as rf:
            for line in rf.readlines()[1:]:
                utt_id, text = map(str.strip, line.split(' ',
                                                         1))  # utt_id, text
                ret.update({utt_id: ''.join(text.split())})
        return ret

    def _get_data(self):
        if not os.path.isdir(os.path.join(DATA_HOME, self.audio_path)) or \
            not os.path.isfile(os.path.join(DATA_HOME, self.text_meta)):
            download_and_decompress(self.archieves, DATA_HOME)
            # Extract *wav from *.tar.gz.
            for root, _, files in os.walk(
                    os.path.join(DATA_HOME, self.audio_path)):
                for file in files:
                    if file.endswith('.tar.gz'):
                        decompress(os.path.join(root, file))
                        os.remove(os.path.join(root, file))

        text_info = self._get_text_info()

        data = []
        for root, _, files in os.walk(
                os.path.join(DATA_HOME, self.audio_path, self.subset)):
            for file in files:
                if file.endswith('.wav'):
                    utt_id = os.path.splitext(file)[0]
                    if utt_id not in text_info:  # There are some utt_id that without label
                        continue
                    text = text_info[utt_id]
                    file_path = os.path.join(root, file)
                    data.append(self.utt_info(file_path, utt_id, text))

        return data

    def _convert_to_record(self, idx: int):
        sample = self._data[idx]

        record = {}
        # To show all fields in a namedtuple: `type(sample)._fields`
        for field in type(sample)._fields:
            record[field] = getattr(sample, field)

        waveform, sr = load_audio(
            sample[0])  # The first element of sample is file path
        feat_func = feat_funcs[self.feat_type]
        feat = feat_func(
            waveform, sample_rate=sr,
            **self.feat_config) if feat_func else waveform
        record.update({'feat': feat, 'duration': len(waveform) / sr})
        return record

    def create_manifest(self, prefix='manifest'):
        if not os.path.isdir(os.path.join(DATA_HOME, self.manifest_path)):
            os.makedirs(os.path.join(DATA_HOME, self.manifest_path))

        manifest_file = os.path.join(DATA_HOME, self.manifest_path,
                                     f'{prefix}.{self.subset}')
        with codecs.open(manifest_file, 'w', 'utf-8') as f:
            for idx in tqdm(range(len(self))):
                record = self._convert_to_record(idx)
                record_line = json.dumps(
                    {
                        'utt': record['utt_id'],
                        'feat': record['file_path'],
                        'feat_shape': (record['duration'], ),
                        'text': record['text']
                    },
                    ensure_ascii=False)
                f.write(record_line + '\n')
        logger.info(f'Manifest file {manifest_file} created.')

    def __getitem__(self, idx):
        record = self._convert_to_record(idx)
        return tuple(record.values())

    def __len__(self):
        return len(self._data)