# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import re
from collections import OrderedDict
from typing import List
from typing import Optional
from typing import Union

import paddle

from ...s2t.utils.dynamic_import import dynamic_import
from ..executor import BaseExecutor
from ..log import logger
from ..utils import cli_register
from ..utils import stats_wrapper
from .pretrained_models import model_alias
from .pretrained_models import pretrained_models
from .pretrained_models import tokenizer_alias

__all__ = ['TextExecutor']


@cli_register(name='paddlespeech.text', description='Text infer command.')
class TextExecutor(BaseExecutor):
    def __init__(self):
        super().__init__()
        self.model_alias = model_alias
        self.pretrained_models = pretrained_models
        self.tokenizer_alias = tokenizer_alias

        self.parser = argparse.ArgumentParser(
            prog='paddlespeech.text', add_help=True)
        self.parser.add_argument(
            '--input', type=str, default=None, help='Input text.')
        self.parser.add_argument(
            '--task',
            type=str,
            default='punc',
            choices=['punc'],
            help='Choose text task.')
        self.parser.add_argument(
            '--model',
            type=str,
            default='ernie_linear_p7_wudao',
            choices=[
                tag[:tag.index('-')] for tag in self.pretrained_models.keys()
            ],
            help='Choose model type of text task.')
        self.parser.add_argument(
            '--lang',
            type=str,
            default='zh',
            choices=['zh', 'en'],
            help='Choose model language.')
        self.parser.add_argument(
            '--config',
            type=str,
            default=None,
            help='Config of cls task. Use deault config when it is None.')
        self.parser.add_argument(
            '--ckpt_path',
            type=str,
            default=None,
            help='Checkpoint file of model.')
        self.parser.add_argument(
            '--punc_vocab',
            type=str,
            default=None,
            help='Vocabulary file of punctuation restoration task.')
        self.parser.add_argument(
            '--device',
            type=str,
            default=paddle.get_device(),
            help='Choose device to execute model inference.')
        self.parser.add_argument(
            '-d',
            '--job_dump_result',
            action='store_true',
            help='Save job result into file.')
        self.parser.add_argument(
            '-v',
            '--verbose',
            action='store_true',
            help='Increase logger verbosity of current task.')

    def _init_from_path(self,
                        task: str='punc',
                        model_type: str='ernie_linear_p7_wudao',
                        lang: str='zh',
                        cfg_path: Optional[os.PathLike]=None,
                        ckpt_path: Optional[os.PathLike]=None,
                        vocab_file: Optional[os.PathLike]=None):
        """
            Init model and other resources from a specific path.
        """
        if hasattr(self, 'model'):
            logger.info('Model had been initialized.')
            return

        self.task = task

        if cfg_path is None or ckpt_path is None or vocab_file is None:
            tag = '-'.join([model_type, task, lang])
            self.res_path = self._get_pretrained_path(tag)
            self.cfg_path = os.path.join(
                self.res_path, self.pretrained_models[tag]['cfg_path'])
            self.ckpt_path = os.path.join(
                self.res_path, self.pretrained_models[tag]['ckpt_path'])
            self.vocab_file = os.path.join(
                self.res_path, self.pretrained_models[tag]['vocab_file'])
        else:
            self.cfg_path = os.path.abspath(cfg_path)
            self.ckpt_path = os.path.abspath(ckpt_path)
            self.vocab_file = os.path.abspath(vocab_file)

        model_name = model_type[:model_type.rindex('_')]
        if self.task == 'punc':
            # punc list
            self._punc_list = []
            with open(self.vocab_file, 'r') as f:
                for line in f:
                    self._punc_list.append(line.strip())

            # model
            model_class = dynamic_import(model_name, self.model_alias)
            tokenizer_class = dynamic_import(model_name, self.tokenizer_alias)
            self.model = model_class(
                cfg_path=self.cfg_path, ckpt_path=self.ckpt_path)
            self.tokenizer = tokenizer_class.from_pretrained('ernie-1.0')
        else:
            raise NotImplementedError

        self.model.eval()

    def _clean_text(self, text):
        text = text.lower()
        text = re.sub('[^A-Za-z0-9\u4e00-\u9fa5]', '', text)
        text = re.sub(f'[{"".join([p for p in self._punc_list][1:])}]', '',
                      text)
        return text

    def preprocess(self, text: Union[str, os.PathLike]):
        """
            Input preprocess and return paddle.Tensor stored in self.input.
            Input content can be a text(tts), a file(asr, cls) or a streaming(not supported yet).
        """
        if self.task == 'punc':
            clean_text = self._clean_text(text)
            assert len(clean_text) > 0, f'Invalid input string: {text}'

            tokenized_input = self.tokenizer(
                list(clean_text), return_length=True, is_split_into_words=True)

            self._inputs['input_ids'] = tokenized_input['input_ids']
            self._inputs['seg_ids'] = tokenized_input['token_type_ids']
            self._inputs['seq_len'] = tokenized_input['seq_len']
        else:
            raise NotImplementedError

    @paddle.no_grad()
    def infer(self):
        """
            Model inference and result stored in self.output.
        """
        if self.task == 'punc':
            input_ids = paddle.to_tensor(self._inputs['input_ids']).unsqueeze(0)
            seg_ids = paddle.to_tensor(self._inputs['seg_ids']).unsqueeze(0)
            logits, _ = self.model(input_ids, seg_ids)
            preds = paddle.argmax(logits, axis=-1).squeeze(0)

            self._outputs['preds'] = preds
        else:
            raise NotImplementedError

    def postprocess(self) -> Union[str, os.PathLike]:
        """
            Output postprocess and return human-readable results such as texts and audio files.
        """
        if self.task == 'punc':
            input_ids = self._inputs['input_ids']
            seq_len = self._inputs['seq_len']
            preds = self._outputs['preds']

            tokens = self.tokenizer.convert_ids_to_tokens(
                input_ids[1:seq_len - 1])
            labels = preds[1:seq_len - 1].tolist()
            assert len(tokens) == len(labels)

            text = ''
            for t, l in zip(tokens, labels):
                text += t
                if l != 0:  # Non punc.
                    text += self._punc_list[l]

            return text
        else:
            raise NotImplementedError

    def execute(self, argv: List[str]) -> bool:
        """
            Command line entry.
        """
        parser_args = self.parser.parse_args(argv)

        task = parser_args.task
        model_type = parser_args.model
        lang = parser_args.lang
        cfg_path = parser_args.config
        ckpt_path = parser_args.ckpt_path
        punc_vocab = parser_args.punc_vocab
        device = parser_args.device

        if not parser_args.verbose:
            self.disable_task_loggers()

        task_source = self.get_task_source(parser_args.input)
        task_results = OrderedDict()
        has_exceptions = False

        for id_, input_ in task_source.items():
            try:
                res = self(input_, task, model_type, lang, cfg_path, ckpt_path,
                           punc_vocab, device)
                task_results[id_] = res
            except Exception as e:
                has_exceptions = True
                task_results[id_] = f'{e.__class__.__name__}: {e}'

        self.process_task_results(parser_args.input, task_results,
                                  parser_args.job_dump_result)

        if has_exceptions:
            return False
        else:
            return True

    @stats_wrapper
    def __call__(
            self,
            text: str,
            task: str='punc',
            model: str='ernie_linear_p7_wudao',
            lang: str='zh',
            config: Optional[os.PathLike]=None,
            ckpt_path: Optional[os.PathLike]=None,
            punc_vocab: Optional[os.PathLike]=None,
            device: str=paddle.get_device(), ):
        """
            Python API to call an executor.
        """
        paddle.set_device(device)
        self._init_from_path(task, model, lang, config, ckpt_path, punc_vocab)
        self.preprocess(text)
        self.infer()
        res = self.postprocess()  # Retrieve result of text task.

        return res