# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext

import paddle
from paddle import distributed as dist

from deepspeech.training.extensions.evaluator import StandardEvaluator
from deepspeech.training.reporter import report
from deepspeech.training.timer import Timer
from deepspeech.training.updaters.standard_updater import StandardUpdater
from deepspeech.utils import layer_tools
from deepspeech.utils.log import Log

logger = Log(__name__).getlog()


class U2Evaluator(StandardEvaluator):
    def __init__(self, model, dataloader):
        super().__init__(model, dataloader)
        self.msg = ""
        self.num_seen_utts = 0
        self.total_loss = 0.0

    def evaluate_core(self, batch):
        self.msg = "Valid: Rank: {}, ".format(dist.get_rank())
        losses_dict = {}

        loss, attention_loss, ctc_loss = self.model(*batch[1:])
        if paddle.isfinite(loss):
            num_utts = batch[1].shape[0]
            self.num_seen_utts += num_utts
            self.total_loss += float(loss) * num_utts

            losses_dict['loss'] = float(loss)
            if attention_loss:
                losses_dict['att_loss'] = float(attention_loss)
            if ctc_loss:
                losses_dict['ctc_loss'] = float(ctc_loss)

            for k, v in losses_dict.items():
                report("eval/" + k, v)

        self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
                              for k, v in losses_dict.items())
        logger.info(self.msg)
        return self.total_loss, self.num_seen_utts


class U2Updater(StandardUpdater):
    def __init__(self,
                 model,
                 optimizer,
                 scheduler,
                 dataloader,
                 init_state=None,
                 accum_grad=1,
                 **kwargs):
        super().__init__(
            model, optimizer, scheduler, dataloader, init_state=init_state)
        self.accum_grad = accum_grad
        self.forward_count = 0
        self.msg = ""

    def update_core(self, batch):
        """One Step

        Args:
            batch (List[Object]): utts, xs, xlens, ys, ylens
        """
        losses_dict = {}
        self.msg = "Rank: {}, ".format(dist.get_rank())

        # forward
        batch_size = batch[1].shape[0]
        loss, attention_loss, ctc_loss = self.model(*batch[1:])
        # loss div by `batch_size * accum_grad`
        loss /= self.accum_grad

        # loss backward
        if (self.forward_count + 1) != self.accum_grad:
            # Disable gradient synchronizations across DDP processes.
            # Within this context, gradients will be accumulated on module
            # variables, which will later be synchronized.
            context = self.model.no_sync
        else:
            # Used for single gpu training and DDP gradient synchronization
            # processes.
            context = nullcontext

        with context():
            loss.backward()
            layer_tools.print_grads(self.model, print_func=None)

        # loss info
        losses_dict['loss'] = float(loss) * self.accum_grad
        if attention_loss:
            losses_dict['att_loss'] = float(attention_loss)
        if ctc_loss:
            losses_dict['ctc_loss'] = float(ctc_loss)
        # report loss
        for k, v in losses_dict.items():
            report("train/" + k, v)
        # loss msg
        self.msg += "batch size: {}, ".format(batch_size)
        self.msg += "accum: {}, ".format(self.accum_grad)
        self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
                              for k, v in losses_dict.items())

        # Truncate the graph
        loss.detach()

        # update parameters
        self.forward_count += 1
        if self.forward_count != self.accum_grad:
            return
        self.forward_count = 0

        self.optimizer.step()
        self.optimizer.clear_grad()
        self.scheduler.step()

    def update(self):
        # model is default in train mode

        # training for a step is implemented here
        with Timer("data time cost:{}"):
            batch = self.read_batch()
        with Timer("step time cost:{}"):
            self.update_core(batch)

        # #iterations with accum_grad > 1
        # Ref.: https://github.com/espnet/espnet/issues/777
        if self.forward_count == 0:
            self.state.iteration += 1
        if self.updates_per_epoch is not None:
            if self.state.iteration % self.updates_per_epoch == 0:
                self.state.epoch += 1