# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import List
from typing import Tuple

import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
from deepspeech.models.lm_interface import LMInterface
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.mask import subsequent_mask
from deepspeech.utils.log import Log

logger = Log(__name__).getlog()


class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
    def __init__(self,
                 n_vocab: int,
                 pos_enc: str=None,
                 embed_unit: int=128,
                 att_unit: int=256,
                 head: int=2,
                 unit: int=1024,
                 layer: int=4,
                 dropout_rate: float=0.5,
                 emb_dropout_rate: float=0.0,
                 att_dropout_rate: float=0.0,
                 tie_weights: bool=False,
                 **kwargs):
        nn.Layer.__init__(self)

        if pos_enc == "sinusoidal":
            pos_enc_layer_type = "abs_pos"
        elif pos_enc is None:
            pos_enc_layer_type = "no_pos"
        else:
            raise ValueError(f"unknown pos-enc option: {pos_enc}")

        self.embed = nn.Embedding(n_vocab, embed_unit)

        if emb_dropout_rate == 0.0:
            self.embed_drop = None
        else:
            self.embed_drop = nn.Dropout(emb_dropout_rate)

        self.encoder = TransformerEncoder(
            input_size=embed_unit,
            output_size=att_unit,
            attention_heads=head,
            linear_units=unit,
            num_blocks=layer,
            dropout_rate=dropout_rate,
            attention_dropout_rate=att_dropout_rate,
            input_layer="linear",
            pos_enc_layer_type=pos_enc_layer_type,
            concat_after=False,
            static_chunk_size=1,
            use_dynamic_chunk=False,
            use_dynamic_left_chunk=False)

        self.decoder = nn.Linear(att_unit, n_vocab)

        logger.info("Tie weights set to {}".format(tie_weights))
        logger.info("Dropout set to {}".format(dropout_rate))
        logger.info("Emb Dropout set to {}".format(emb_dropout_rate))
        logger.info("Att Dropout set to {}".format(att_dropout_rate))

        if tie_weights:
            assert (
                att_unit == embed_unit
            ), "Tie Weights: True need embedding and final dimensions to match"
            self.decoder.weight = self.embed.weight

    def _target_mask(self, ys_in_pad):
        ys_mask = ys_in_pad != 0
        m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
        return ys_mask.unsqueeze(-2) & m

    def forward(self, x: paddle.Tensor, t: paddle.Tensor
                ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
        """Compute LM loss value from buffer sequences.

        Args:
            x (paddle.Tensor): Input ids. (batch, len)
            t (paddle.Tensor): Target ids. (batch, len)

        Returns:
            tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: Tuple of
                loss to backward (scalar),
                negative log-likelihood of t: -log p(t) (scalar) and
                the number of elements in x (scalar)

        Notes:
            The last two return values are used
            in perplexity: p(t)^{-n} = exp(-log p(t) / n)

        """
        batch_size = x.size(0)
        xm = x != 0
        xlen = xm.sum(axis=1)
        if self.embed_drop is not None:
            emb = self.embed_drop(self.embed(x))
        else:
            emb = self.embed(x)
        h, _ = self.encoder(emb, xlen)
        y = self.decoder(h)
        loss = F.cross_entropy(
            y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
        mask = xm.to(loss.dtype)
        logp = loss * mask.view(-1)
        nll = logp.view(batch_size, -1).sum(-1)
        nll_count = mask.sum(-1)
        logp = logp.sum()
        count = mask.sum()
        return logp / count, logp, count, nll, nll_count

    # beam search API (see ScorerInterface)
    def score(self, y: paddle.Tensor, state: Any,
              x: paddle.Tensor) -> Tuple[paddle.Tensor, Any]:
        """Score new token.

        Args:
            y (paddle.Tensor): 1D paddle.int64 prefix tokens.
            state: Scorer state for prefix tokens
            x (paddle.Tensor): encoder feature that generates ys.

        Returns:
            tuple[paddle.Tensor, Any]: Tuple of
                paddle.float32 scores for next token (n_vocab)
                and next state for ys

        """
        y = y.unsqueeze(0)

        if self.embed_drop is not None:
            emb = self.embed_drop(self.embed(y))
        else:
            emb = self.embed(y)

        h, _, cache = self.encoder.forward_one_step(
            emb, self._target_mask(y), cache=state)
        h = self.decoder(h[:, -1])
        logp = F.log_softmax(h).squeeze(0)
        return logp, cache

    # batch beam search API (see BatchScorerInterface)
    def batch_score(self,
                    ys: paddle.Tensor,
                    states: List[Any],
                    xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
        """Score new token batch (required).

        Args:
            ys (paddle.Tensor): paddle.int64 prefix tokens (n_batch, ylen).
            states (List[Any]): Scorer states for prefix tokens.
            xs (paddle.Tensor):
                The encoder feature that generates ys (n_batch, xlen, n_feat).

        Returns:
            tuple[paddle.Tensor, List[Any]]: Tuple of
                batchfied scores for next token with shape of `(n_batch, n_vocab)`
                and next state list for ys.

        """
        # merge states
        n_batch = len(ys)
        n_layers = len(self.encoder.encoders)
        if states[0] is None:
            batch_state = None
        else:
            # transpose state of [batch, layer] into [layer, batch]
            batch_state = [
                paddle.stack([states[b][i] for b in range(n_batch)])
                for i in range(n_layers)
            ]

        if self.embed_drop is not None:
            emb = self.embed_drop(self.embed(ys))
        else:
            emb = self.embed(ys)

        # batch decoding
        h, _, states = self.encoder.forward_one_step(
            emb, self._target_mask(ys), cache=batch_state)
        h = self.decoder(h[:, -1])
        logp = F.log_softmax(h)

        # transpose state of [layer, batch] into [batch, layer]
        state_list = [[states[i][b] for i in range(n_layers)]
                      for b in range(n_batch)]
        return logp, state_list


if __name__ == "__main__":
    tlm = TransformerLM(
        n_vocab=5002,
        pos_enc=None,
        embed_unit=128,
        att_unit=512,
        head=8,
        unit=2048,
        layer=16,
        dropout_rate=0.5, )

    #     n_vocab: int,
    # pos_enc: str=None,
    # embed_unit: int=128,
    # att_unit: int=256,
    # head: int=2,
    # unit: int=1024,
    # layer: int=4,
    # dropout_rate: float=0.5,
    # emb_dropout_rate: float = 0.0,
    # att_dropout_rate: float = 0.0,
    # tie_weights: bool = False,):
    paddle.set_device("cpu")
    model_dict = paddle.load("transformerLM.pdparams")
    tlm.set_state_dict(model_dict)

    tlm.eval()
    #Test the score
    input2 = np.array([5])
    input2 = paddle.to_tensor(input2)
    state = None
    output, state = tlm.score(input2, state, None)

    input3 = np.array([5, 10])
    input3 = paddle.to_tensor(input3)
    output, state = tlm.score(input3, state, None)

    input4 = np.array([5, 10, 0])
    input4 = paddle.to_tensor(input4)
    output, state = tlm.score(input4, state, None)
    print("output", output)
    """
    #Test the batch score
    batch_size = 2
    inp2 = np.array([[5], [10]])
    inp2 = paddle.to_tensor(inp2)
    output, states = tlm.batch_score(
        inp2, [(None,None,0)] * batch_size)
    inp3 = np.array([[100], [30]])
    inp3 = paddle.to_tensor(inp3)
    output, states = tlm.batch_score(
        inp3, states)
    print("output", output)
    #print("cache", cache)
    #np.save("output_pd.npy", output)
    """