# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Adversarial loss modules.""" import paddle import paddle.nn.functional as F from paddle import nn class GeneratorAdversarialLoss(nn.Layer): """Generator adversarial loss module.""" def __init__( self, average_by_discriminators=True, loss_type="mse", ): """Initialize GeneratorAversarialLoss module.""" super().__init__() self.average_by_discriminators = average_by_discriminators assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." if loss_type == "mse": self.criterion = self._mse_loss else: self.criterion = self._hinge_loss def forward(self, outputs): """Calcualate generator adversarial loss. Parameters ---------- outputs: Tensor or List Discriminator outputs or list of discriminator outputs. Returns ---------- Tensor Generator adversarial loss value. """ if isinstance(outputs, (tuple, list)): adv_loss = 0.0 for i, outputs_ in enumerate(outputs): if isinstance(outputs_, (tuple, list)): # case including feature maps outputs_ = outputs_[-1] adv_loss += self.criterion(outputs_) if self.average_by_discriminators: adv_loss /= i + 1 else: adv_loss = self.criterion(outputs) return adv_loss def _mse_loss(self, x): return F.mse_loss(x, paddle.ones_like(x)) def _hinge_loss(self, x): return -x.mean() class DiscriminatorAdversarialLoss(nn.Layer): """Discriminator adversarial loss module.""" def __init__( self, average_by_discriminators=True, loss_type="mse", ): """Initialize DiscriminatorAversarialLoss module.""" super().__init__() self.average_by_discriminators = average_by_discriminators assert loss_type in ["mse"], f"{loss_type} is not supported." if loss_type == "mse": self.fake_criterion = self._mse_fake_loss self.real_criterion = self._mse_real_loss def forward(self, outputs_hat, outputs): """Calcualate discriminator adversarial loss. Parameters ---------- outputs_hat : Tensor or list Discriminator outputs or list of discriminator outputs calculated from generator outputs. outputs : Tensor or list Discriminator outputs or list of discriminator outputs calculated from groundtruth. Returns ---------- Tensor Discriminator real loss value. Tensor Discriminator fake loss value. """ if isinstance(outputs, (tuple, list)): real_loss = 0.0 fake_loss = 0.0 for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): if isinstance(outputs_hat_, (tuple, list)): # case including feature maps outputs_hat_ = outputs_hat_[-1] outputs_ = outputs_[-1] real_loss += self.real_criterion(outputs_) fake_loss += self.fake_criterion(outputs_hat_) if self.average_by_discriminators: fake_loss /= i + 1 real_loss /= i + 1 else: real_loss = self.real_criterion(outputs) fake_loss = self.fake_criterion(outputs_hat) return real_loss, fake_loss def _mse_real_loss(self, x): return F.mse_loss(x, paddle.ones_like(x)) def _mse_fake_loss(self, x): return F.mse_loss(x, paddle.zeros_like(x))