# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains the text featurizer class."""
from pprint import pformat

import sentencepiece as spm

from ..utility import BLANK
from ..utility import EOS
from ..utility import load_dict
from ..utility import MASKCTC
from ..utility import SOS
from ..utility import SPACE
from ..utility import UNK
from deepspeech.utils.log import Log

logger = Log(__name__).getlog()

__all__ = ["TextFeaturizer"]


class TextFeaturizer():
    def __init__(self,
                 unit_type,
                 vocab_filepath,
                 spm_model_prefix=None,
                 maskctc=False):
        """Text featurizer, for processing or extracting features from text.

        Currently, it supports char/word/sentence-piece level tokenizing and conversion into
        a list of token indices. Note that the token indexing order follows the
        given vocabulary file.

        Args:
            unit_type (str): unit type, e.g. char, word, spm
            vocab_filepath (str): Filepath to load vocabulary for token indices conversion.
            spm_model_prefix (str, optional): spm model prefix. Defaults to None.
        """
        assert unit_type in ('char', 'spm', 'word')
        self.unit_type = unit_type
        self.unk = UNK
        self.maskctc = maskctc

        if vocab_filepath:
            self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id = self._load_vocabulary_from_file(
                vocab_filepath, maskctc)
            self.vocab_size = len(self.vocab_list)

        if unit_type == 'spm':
            spm_model = spm_model_prefix + '.model'
            self.sp = spm.SentencePieceProcessor()
            self.sp.Load(spm_model)

    def tokenize(self, text, replace_space=True):
        if self.unit_type == 'char':
            tokens = self.char_tokenize(text, replace_space)
        elif self.unit_type == 'word':
            tokens = self.word_tokenize(text)
        else:  # spm
            tokens = self.spm_tokenize(text)
        return tokens

    def detokenize(self, tokens):
        if self.unit_type == 'char':
            text = self.char_detokenize(tokens)
        elif self.unit_type == 'word':
            text = self.word_detokenize(tokens)
        else:  # spm
            text = self.spm_detokenize(tokens)
        return text

    def featurize(self, text):
        """Convert text string to a list of token indices.

        Args:
            text (str): Text to process.

        Returns:
            List[int]: List of token indices.
        """
        tokens = self.tokenize(text)
        ids = []
        for token in tokens:
            token = token if token in self.vocab_dict else self.unk
            ids.append(self.vocab_dict[token])
        return ids

    def defeaturize(self, idxs):
        """Convert a list of token indices to text string,
        ignore index after eos_id.

        Args:
            idxs (List[int]): List of token indices.

        Returns:
            str: Text.
        """
        tokens = []
        for idx in idxs:
            if idx == self.eos_id:
                break
            tokens.append(self._id2token[idx])
        text = self.detokenize(tokens)
        return text

    def char_tokenize(self, text, replace_space=True):
        """Character tokenizer.

        Args:
            text (str): text string.
            replace_space (bool): False only used by build_vocab.py.

        Returns:
            List[str]: tokens.
        """
        text = text.strip()
        if replace_space:
            text_list = [SPACE if item == " " else item for item in list(text)]
        else:
            text_list = list(text)
        return text_list

    def char_detokenize(self, tokens):
        """Character detokenizer.

        Args:
            tokens (List[str]): tokens.

        Returns:
           str: text string.
        """
        tokens = tokens.replace(SPACE, " ")
        return "".join(tokens)

    def word_tokenize(self, text):
        """Word tokenizer, separate by <space>."""
        return text.strip().split()

    def word_detokenize(self, tokens):
        """Word detokenizer, separate by <space>."""
        return " ".join(tokens)

    def spm_tokenize(self, text):
        """spm tokenize.

        Args:
            text (str): text string.

        Returns:
            List[str]: sentence pieces str code
        """
        stats = {"num_empty": 0, "num_filtered": 0}

        def valid(line):
            return True

        def encode(l):
            return self.sp.EncodeAsPieces(l)

        def encode_line(line):
            line = line.strip()
            if len(line) > 0:
                line = encode(line)
                if valid(line):
                    return line
                else:
                    stats["num_filtered"] += 1
            else:
                stats["num_empty"] += 1
            return None

        enc_line = encode_line(text)
        return enc_line

    def spm_detokenize(self, tokens, input_format='piece'):
        """spm detokenize.

        Args:
            ids (List[str]): tokens.

        Returns:
            str: text
        """
        if input_format == "piece":

            def decode(l):
                return "".join(self.sp.DecodePieces(l))
        elif input_format == "id":

            def decode(l):
                return "".join(self.sp.DecodeIds(l))

        return decode(tokens)

    def _load_vocabulary_from_file(self, vocab_filepath: str, maskctc: bool):
        """Load vocabulary from file."""
        vocab_list = load_dict(vocab_filepath, maskctc)
        assert vocab_list is not None
        logger.info(f"Vocab: {pformat(vocab_list)}")

        id2token = dict(
            [(idx, token) for (idx, token) in enumerate(vocab_list)])
        token2id = dict(
            [(token, idx) for (idx, token) in enumerate(vocab_list)])

        blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1
        maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1
        unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1
        eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1
        sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
        space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1

        logger.info(f"BLANK id: {blank_id}")
        logger.info(f"UNK id: {unk_id}")
        logger.info(f"EOS id: {eos_id}")
        logger.info(f"SOS id: {sos_id}")
        logger.info(f"SPACE id: {space_id}")
        logger.info(f"MASKCTC id: {maskctc_id}")
        return token2id, id2token, vocab_list, unk_id, eos_id