# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
from collections import OrderedDict

import numpy as np
import paddle
from paddleaudio.backends import load as load_audio
from paddleaudio.compliance.librosa import melspectrogram

from paddlespeech.cli.log import logger
from paddlespeech.cli.vector.infer import VectorExecutor
from paddlespeech.server.engine.base_engine import BaseEngine
from paddlespeech.vector.io.batch import feature_normalize


class PaddleVectorConnectionHandler:
    def __init__(self, vector_engine):
        """The PaddleSpeech Vector Server Connection Handler
           This connection process every server request
        Args:
            vector_engine (VectorEngine): The Vector engine
        """
        super().__init__()
        logger.info(
            "Create PaddleVectorConnectionHandler to process the vector request")
        self.vector_engine = vector_engine
        self.executor = self.vector_engine.executor
        self.task = self.vector_engine.executor.task
        self.model = self.vector_engine.executor.model
        self.config = self.vector_engine.executor.config

        self._inputs = OrderedDict()
        self._outputs = OrderedDict()

    @paddle.no_grad()
    def run(self, audio_data, task="spk"):
        """The connection process the http request audio

        Args:
            audio_data (bytes): base64.b64decode

        Returns:
            str: the punctuation text
        """
        logger.info(
            f"start to extract the do vector {self.task} from the http request")
        if self.task == "spk" and task == "spk":
            embedding = self.extract_audio_embedding(audio_data)
            return embedding
        else:
            logger.error(
                "The request task is not matched with server model task")
            logger.error(
                f"The server model task is: {self.task}, but the request task is: {task}"
            )

        return np.array([
            0.0,
        ])

    @paddle.no_grad()
    def get_enroll_test_score(self, enroll_audio, test_audio):
        """Get the enroll and test audio score

        Args:
            enroll_audio (str): the base64 format enroll audio
            test_audio (str): the base64 format test audio

        Returns:
            float: the score between enroll and test audio
        """
        logger.info("start to extract the enroll audio embedding")
        enroll_emb = self.extract_audio_embedding(enroll_audio)

        logger.info("start to extract the test audio embedding")
        test_emb = self.extract_audio_embedding(test_audio)

        logger.info(
            "start to get the score between the enroll and test embedding")
        score = self.executor.get_embeddings_score(enroll_emb, test_emb)

        logger.info(f"get the enroll vs test score: {score}")
        return score

    @paddle.no_grad()
    def extract_audio_embedding(self, audio: str, sample_rate: int=16000):
        """extract the audio embedding

        Args:
            audio (str): the audio data
            sample_rate (int, optional): the audio sample rate. Defaults to 16000.
        """
        # we can not reuse the cache io.BytesIO(audio) data, 
        # because the soundfile will change the io.BytesIO(audio) to the end
        # thus we should convert the base64 string to io.BytesIO when we need the audio data
        if not self.executor._check(io.BytesIO(audio), sample_rate):
            logger.info("check the audio sample rate occurs error")
            return np.array([0.0])

        waveform, sr = load_audio(io.BytesIO(audio))
        logger.info(f"load the audio sample points, shape is: {waveform.shape}")

        # stage 2: get the audio feat
        # Note: Now we only support fbank feature
        try:
            feats = melspectrogram(
                x=waveform,
                sr=self.config.sr,
                n_mels=self.config.n_mels,
                window_size=self.config.window_size,
                hop_length=self.config.hop_size)
            logger.info(f"extract the audio feats, shape is: {feats.shape}")
        except Exception as e:
            logger.info(f"feats occurs exception {e}")
            sys.exit(-1)

        feats = paddle.to_tensor(feats).unsqueeze(0)
        # in inference period, the lengths is all one without padding
        lengths = paddle.ones([1])

        # stage 3: we do feature normalize,
        #          Now we assume that the feats must do normalize
        feats = feature_normalize(feats, mean_norm=True, std_norm=False)

        # stage 4: store the feats and length in the _inputs,
        #          which will be used in other function
        logger.info(f"feats shape: {feats.shape}")
        logger.info("audio extract the feats success")

        logger.info("start to extract the audio embedding")
        embedding = self.model.backbone(feats, lengths).squeeze().numpy()
        logger.info(f"embedding size: {embedding.shape}")

        return embedding


class VectorServerExecutor(VectorExecutor):
    def __init__(self):
        """The wrapper for TextEcutor
        """
        super().__init__()
        pass


class VectorEngine(BaseEngine):
    def __init__(self):
        """The Vector Engine
        """
        super(VectorEngine, self).__init__()
        logger.info("Create the VectorEngine Instance")

    def init(self, config: dict):
        """Init the Vector Engine

        Args:
            config (dict): The server configuation

        Returns:
            bool: The engine instance flag
        """
        logger.info("Init the vector engine")
        try:
            self.config = config
            if self.config.device:
                self.device = self.config.device
            else:
                self.device = paddle.get_device()

            paddle.set_device(self.device)
            logger.info(f"Vector Engine set the device: {self.device}")
        except BaseException as e:
            logger.error(
                "Set device failed, please check if device is already used and the parameter 'device' in the yaml file"
            )
            logger.error("Initialize Vector server engine Failed on device: %s."
                         % (self.device))
            return False

        self.executor = VectorServerExecutor()

        self.executor._init_from_path(
            model_type=config.model_type,
            cfg_path=config.cfg_path,
            ckpt_path=config.ckpt_path,
            task=config.task)

        logger.info("Init the Vector engine successfully")
        return True