# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ScorerInterface implementation for CTC."""
import numpy as np
import paddle

from .ctc_prefix_score import CTCPrefixScore
from .ctc_prefix_score import CTCPrefixScorePD
from .scorer_interface import BatchPartialScorerInterface


class CTCPrefixScorer(BatchPartialScorerInterface):
    """Decoder interface wrapper for CTCPrefixScore."""

    def __init__(self, ctc: paddle.nn.Layer, eos: int):
        """Initialize class.

        Args:
            ctc (paddle.nn.Layer): The CTC implementation.
                For example, :class:`deepspeech.modules.ctc.CTC`
            eos (int): The end-of-sequence id.

        """
        self.ctc = ctc
        self.eos = eos
        self.impl = None

    def init_state(self, x: paddle.Tensor):
        """Get an initial state for decoding.

        Args:
            x (paddle.Tensor): The encoded feature tensor

        Returns: initial state

        """
        logp = self.ctc.log_softmax(x.unsqueeze(0)).squeeze(0).numpy()
        # TODO(karita): use CTCPrefixScorePD
        self.impl = CTCPrefixScore(logp, 0, self.eos, np)
        return 0, self.impl.initial_state()

    def select_state(self, state, i, new_id=None):
        """Select state with relative ids in the main beam search.

        Args:
            state: Decoder state for prefix tokens
            i (int): Index to select a state in the main beam search
            new_id (int): New label id to select a state if necessary

        Returns:
            state: pruned state

        """
        if type(state) == tuple:
            if len(state) == 2:  # for CTCPrefixScore
                sc, st = state
                return sc[i], st[i]
            else:  # for CTCPrefixScorePD (need new_id > 0)
                r, log_psi, f_min, f_max, scoring_idmap = state
                s = log_psi[i, new_id].expand(log_psi.size(1))
                if scoring_idmap is not None:
                    return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max
                else:
                    return r[:, :, i, new_id], s, f_min, f_max
        return None if state is None else state[i]

    def score_partial(self, y, ids, state, x):
        """Score new token.

        Args:
            y (paddle.Tensor): 1D prefix token
            next_tokens (paddle.Tensor): paddle.int64 next token to score
            state: decoder state for prefix tokens
            x (paddle.Tensor): 2D encoder feature that generates ys

        Returns:
            tuple[paddle.Tensor, Any]:
                Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
                and next state for ys

        """
        prev_score, state = state
        presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state)
        tscore = paddle.to_tensor(
            presub_score - prev_score, place=x.place, dtype=x.dtype)
        return tscore, (presub_score, new_st)

    def batch_init_state(self, x: paddle.Tensor):
        """Get an initial state for decoding.

        Args:
            x (paddle.Tensor): The encoded feature tensor

        Returns: initial state

        """
        logp = self.ctc.log_softmax(x.unsqueeze(0))  # assuming batch_size = 1
        xlen = paddle.to_tensor([logp.size(1)])
        self.impl = CTCPrefixScorePD(logp, xlen, 0, self.eos)
        return None

    def batch_score_partial(self, y, ids, state, x):
        """Score new token.

        Args:
            y (paddle.Tensor): 1D prefix token
            ids (paddle.Tensor): paddle.int64 next token to score
            state: decoder state for prefix tokens
            x (paddle.Tensor): 2D encoder feature that generates ys

        Returns:
            tuple[paddle.Tensor, Any]:
                Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
                and next state for ys

        """
        batch_state = (
            (paddle.stack([s[0] for s in state], axis=2),
             paddle.stack([s[1] for s in state]), state[0][2], state[0][3], )
            if state[0] is not None else None)
        return self.impl(y, batch_state, ids)

    def extend_prob(self, x: paddle.Tensor):
        """Extend probs for decoding.

        This extension is for streaming decoding
        as in Eq (14) in https://arxiv.org/abs/2006.14941

        Args:
            x (paddle.Tensor): The encoded feature tensor

        """
        logp = self.ctc.log_softmax(x.unsqueeze(0))
        self.impl.extend_prob(logp)

    def extend_state(self, state):
        """Extend state for decoding.

        This extension is for streaming decoding
        as in Eq (14) in https://arxiv.org/abs/2006.14941

        Args:
            state: The states of hyps

        Returns: exteded state

        """
        new_state = []
        for s in state:
            new_state.append(self.impl.extend_state(s))

        return new_state