# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ST Interface module."""
from .asr_interface import ASRInterface
from deepspeech.utils.dynamic_import import dynamic_import


class STInterface(ASRInterface):
    """ST Interface model implementation.

    NOTE: This class is inherited from ASRInterface to enable joint translation
    and recognition when performing multi-task learning with the ASR task.

    """

    def translate(self,
                  x,
                  trans_args,
                  char_list=None,
                  rnnlm=None,
                  ensemble_models=[]):
        """Recognize x for evaluation.

        :param ndarray x: input acouctic feature (B, T, D) or (T, D)
        :param namespace trans_args: argment namespace contraining options
        :param list char_list: list of characters
        :param paddle.nn.Layer rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        raise NotImplementedError("translate method is not implemented")

    def translate_batch(self, x, trans_args, char_list=None, rnnlm=None):
        """Beam search implementation for batch.

        :param paddle.Tensor x: encoder hidden state sequences (B, Tmax, Henc)
        :param namespace trans_args: argument namespace containing options
        :param list char_list: list of characters
        :param paddle.nn.Layer rnnlm: language model module
        :return: N-best decoding results
        :rtype: list
        """
        raise NotImplementedError("Batch decoding is not supported yet.")


predefined_st = {
    "transformer": "deepspeech.models.u2_st:U2STModel",
}


def dynamic_import_st(module):
    """Import ST models dynamically.

    Args:
        module (str): module_name:class_name or alias in `predefined_st`

    Returns:
        type: ST class

    """
    model_class = dynamic_import(module, predefined_st)
    assert issubclass(model_class,
                      STInterface), f"{module} does not implement STInterface"
    return model_class