# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import Callable from typing import List from typing import Optional from typing import Tuple from typing import Union import librosa import numpy as np import paddle from paddle import nn from paddle.nn import functional as F from scipy import signal from scipy.stats import betabinom from typeguard import typechecked from paddlespeech.audiotools.core.audio_signal import AudioSignal from paddlespeech.audiotools.core.audio_signal import STFTParams from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from paddlespeech.t2s.modules.predictor.duration_predictor import ( DurationPredictorLoss, # noqa: H301 ) # Losses for WaveRNN def log_sum_exp(x): """ numerically stable log_sum_exp implementation that prevents overflow """ # TF ordering axis = len(x.shape) - 1 m = paddle.max(x, axis=axis) m2 = paddle.max(x, axis=axis, keepdim=True) return m + paddle.log(paddle.sum(paddle.exp(x - m2), axis=axis)) # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): if log_scale_min is None: log_scale_min = float(np.log(1e-14)) y_hat = y_hat.transpose([0, 2, 1]) assert y_hat.dim() == 3 assert y_hat.shape[1] % 3 == 0 nr_mix = y_hat.shape[1] // 3 # (B x T x C) y_hat = y_hat.transpose([0, 2, 1]) # unpack parameters. (B, T, num_mixtures) x 3 logit_probs = y_hat[:, :, :nr_mix] means = y_hat[:, :, nr_mix:2 * nr_mix] log_scales = paddle.clip( y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) # B x T x 1 -> B x T x num_mixtures y = y.expand_as(means) centered_y = paddle.cast(y, dtype=paddle.get_default_dtype()) - means inv_stdv = paddle.exp(-log_scales) plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) cdf_plus = F.sigmoid(plus_in) min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) cdf_min = F.sigmoid(min_in) # log probability for edge case of 0 (before scaling) # equivalent: torch.log(F.sigmoid(plus_in)) # softplus: log(1+ e^{-x}) log_cdf_plus = plus_in - F.softplus(plus_in) # log probability for edge case of 255 (before scaling) # equivalent: (1 - F.sigmoid(min_in)).log() log_one_minus_cdf_min = -F.softplus(min_in) # probability for all other cases cdf_delta = cdf_plus - cdf_min mid_in = inv_stdv * centered_y # log probability in the center of the bin, to be used in extreme cases # (not actually used in our code) log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value # for num_classes=65536 case? 1e-7? not sure.. inner_inner_cond = cdf_delta > 1e-5 inner_inner_cond = paddle.cast( inner_inner_cond, dtype=paddle.get_default_dtype()) # inner_inner_out = inner_inner_cond * \ # paddle.log(paddle.clip(cdf_delta, min=1e-12)) + \ # (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) inner_inner_out = inner_inner_cond * paddle.log( paddle.clip(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * ( log_pdf_mid - np.log((num_classes - 1) / 2)) inner_cond = y > 0.999 inner_cond = paddle.cast(inner_cond, dtype=paddle.get_default_dtype()) inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond ) * inner_inner_out cond = y < -0.999 cond = paddle.cast(cond, dtype=paddle.get_default_dtype()) log_probs = cond * log_cdf_plus + (1. - cond) * inner_out log_probs = log_probs + F.log_softmax(logit_probs, -1) if reduce: return -paddle.mean(log_sum_exp(log_probs)) else: return -log_sum_exp(log_probs).unsqueeze(-1) def sample_from_discretized_mix_logistic(y, log_scale_min=None): """ Sample from discretized mixture of logistic distributions Args: y(Tensor): (B, C, T) log_scale_min(float, optional): (Default value = None) Returns: Tensor: sample in range of [-1, 1]. """ if log_scale_min is None: log_scale_min = float(np.log(1e-14)) assert y.shape[1] % 3 == 0 nr_mix = y.shape[1] // 3 # (B, T, C) y = y.transpose([0, 2, 1]) logit_probs = y[:, :, :nr_mix] # sample mixture indicator from softmax temp = paddle.uniform( logit_probs.shape, dtype=logit_probs.dtype, min=1e-5, max=1.0 - 1e-5) temp = logit_probs - paddle.log(-paddle.log(temp)) argmax = paddle.argmax(temp, axis=-1) # (B, T) -> (B, T, nr_mix) one_hot = F.one_hot(argmax, nr_mix) one_hot = paddle.cast(one_hot, dtype=paddle.get_default_dtype()) # select logistic parameters means = paddle.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1) log_scales = paddle.clip( paddle.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1), min=log_scale_min) # sample from logistic & clip to interval # we don't actually round to the nearest 8bit value when sampling u = paddle.uniform(means.shape, min=1e-5, max=1.0 - 1e-5) x = means + paddle.exp(log_scales) * (paddle.log(u) - paddle.log(1. - u)) x = paddle.clip(x, min=-1., max=-1.) return x # Loss for Tacotron2 class GuidedAttentionLoss(nn.Layer): """Guided attention loss function module. This module calculates the guided attention loss described in `Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention`_, which forces the attention to be diagonal. .. _`Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention`: https://arxiv.org/abs/1710.08969 """ def __init__(self, sigma=0.4, alpha=1.0, reset_always=True): """Initialize guided attention loss module. Args: sigma (float, optional): Standard deviation to control how close attention to a diagonal. alpha (float, optional): Scaling coefficient (lambda). reset_always (bool, optional): Whether to always reset masks. """ super().__init__() self.sigma = sigma self.alpha = alpha self.reset_always = reset_always self.guided_attn_masks = None self.masks = None def _reset_masks(self): self.guided_attn_masks = None self.masks = None def forward(self, att_ws, ilens, olens): """Calculate forward propagation. Args: att_ws(Tensor): Batch of attention weights (B, T_max_out, T_max_in). ilens(Tensor(int64)): Batch of input lenghts (B,). olens(Tensor(int64)): Batch of output lenghts (B,). Returns: Tensor: Guided attention loss value. """ if self.guided_attn_masks is None: self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens) if self.masks is None: self.masks = self._make_masks(ilens, olens) losses = self.guided_attn_masks * att_ws loss = paddle.mean( losses.masked_select(self.masks.broadcast_to(losses.shape))) if self.reset_always: self._reset_masks() return self.alpha * loss def _make_guided_attention_masks(self, ilens, olens): n_batches = len(ilens) max_ilen = max(ilens) max_olen = max(olens) guided_attn_masks = paddle.zeros((n_batches, max_olen, max_ilen)) for idx, (ilen, olen) in enumerate(zip(ilens, olens)): guided_attn_masks[idx, :olen, : ilen] = self._make_guided_attention_mask( ilen, olen, self.sigma) return guided_attn_masks @staticmethod def _make_guided_attention_mask(ilen, olen, sigma): """Make guided attention mask. Examples ---------- >>> guided_attn_mask =_make_guided_attention(5, 5, 0.4) >>> guided_attn_mask.shape [5, 5] >>> guided_attn_mask tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647], [0.1175, 0.0000, 0.1175, 0.3935, 0.6753], [0.3935, 0.1175, 0.0000, 0.1175, 0.3935], [0.6753, 0.3935, 0.1175, 0.0000, 0.1175], [0.8647, 0.6753, 0.3935, 0.1175, 0.0000]]) >>> guided_attn_mask =_make_guided_attention(3, 6, 0.4) >>> guided_attn_mask.shape [6, 3] >>> guided_attn_mask tensor([[0.0000, 0.2934, 0.7506], [0.0831, 0.0831, 0.5422], [0.2934, 0.0000, 0.2934], [0.5422, 0.0831, 0.0831], [0.7506, 0.2934, 0.0000], [0.8858, 0.5422, 0.0831]]) """ grid_x, grid_y = paddle.meshgrid( paddle.arange(olen), paddle.arange(ilen)) grid_x = grid_x.cast(dtype=paddle.float32) grid_y = grid_y.cast(dtype=paddle.float32) return 1.0 - paddle.exp(-( (grid_y / ilen - grid_x / olen)**2) / (2 * (sigma**2))) @staticmethod def _make_masks(ilens, olens): """Make masks indicating non-padded part. Args: ilens(Tensor(int64) or List): Batch of lengths (B,). olens(Tensor(int64) or List): Batch of lengths (B,). Returns: Tensor: Mask tensor indicating non-padded part. Examples: >>> ilens, olens = [5, 2], [8, 5] >>> _make_mask(ilens, olens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], [[1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]], dtype=paddle.uint8) """ # (B, T_in) in_masks = make_non_pad_mask(ilens) # (B, T_out) out_masks = make_non_pad_mask(olens) # (B, T_out, T_in) return paddle.logical_and( out_masks.unsqueeze(-1), in_masks.unsqueeze(-2)) class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss): """Guided attention loss function module for multi head attention. Args: sigma (float, optional): Standard deviation to controlGuidedAttentionLoss how close attention to a diagonal. alpha (float, optional): Scaling coefficient (lambda). reset_always (bool, optional): Whether to always reset masks. """ def forward(self, att_ws, ilens, olens): """Calculate forward propagation. Args: att_ws(Tensor): Batch of multi head attention weights (B, H, T_max_out, T_max_in). ilens(Tensor): Batch of input lenghts (B,). olens(Tensor): Batch of output lenghts (B,). Returns: Tensor: Guided attention loss value. """ if self.guided_attn_masks is None: self.guided_attn_masks = ( self._make_guided_attention_masks(ilens, olens).unsqueeze(1)) if self.masks is None: self.masks = self._make_masks(ilens, olens).unsqueeze(1) losses = self.guided_attn_masks * att_ws loss = paddle.mean( losses.masked_select(self.masks.broadcast_to(losses.shape))) if self.reset_always: self._reset_masks() return self.alpha * loss class Tacotron2Loss(nn.Layer): """Loss function module for Tacotron2.""" def __init__(self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0): """Initialize Tactoron2 loss module. Args: use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. bce_pos_weight (float): Weight of positive sample of stop token. """ super().__init__() assert (use_masking != use_weighted_masking) or not use_masking self.use_masking = use_masking self.use_weighted_masking = use_weighted_masking # define criterions reduction = "none" if self.use_weighted_masking else "mean" self.l1_criterion = nn.L1Loss(reduction=reduction) self.mse_criterion = nn.MSELoss(reduction=reduction) self.bce_criterion = nn.BCEWithLogitsLoss( reduction=reduction, pos_weight=paddle.to_tensor(bce_pos_weight)) def forward(self, after_outs, before_outs, logits, ys, stop_labels, olens): """Calculate forward propagation. Args: after_outs(Tensor): Batch of outputs after postnets (B, Lmax, odim). before_outs(Tensor): Batch of outputs before postnets (B, Lmax, odim). logits(Tensor): Batch of stop logits (B, Lmax). ys(Tensor): Batch of padded target features (B, Lmax, odim). stop_labels(Tensor(int64)): Batch of the sequences of stop token labels (B, Lmax). olens(Tensor(int64)): Returns: Tensor: L1 loss value. Tensor: Mean square error loss value. Tensor: Binary cross entropy loss value. """ # make mask and apply it if self.use_masking: masks = make_non_pad_mask(olens).unsqueeze(-1) ys = ys.masked_select(masks.broadcast_to(ys.shape)) after_outs = after_outs.masked_select( masks.broadcast_to(after_outs.shape)) before_outs = before_outs.masked_select( masks.broadcast_to(before_outs.shape)) stop_labels = stop_labels.masked_select( masks[:, :, 0].broadcast_to(stop_labels.shape)) logits = logits.masked_select( masks[:, :, 0].broadcast_to(logits.shape)) # calculate loss l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion( before_outs, ys) mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion( before_outs, ys) bce_loss = self.bce_criterion(logits, stop_labels) # make weighted mask and apply it if self.use_weighted_masking: masks = make_non_pad_mask(olens).unsqueeze(-1) weights = masks.float() / masks.sum(axis=1, keepdim=True).float() out_weights = weights.divide( paddle.shape(ys)[0] * paddle.shape(ys)[2]) logit_weights = weights.divide(paddle.shape(ys)[0]) # apply weight l1_loss = l1_loss.multiply(out_weights) l1_loss = l1_loss.masked_select(masks.broadcast_to(l1_loss)).sum() mse_loss = mse_loss.multiply(out_weights) mse_loss = mse_loss.masked_select( masks.broadcast_to(mse_loss)).sum() bce_loss = bce_loss.multiply(logit_weights.squeeze(-1)) bce_loss = bce_loss.masked_select( masks.squeeze(-1).broadcast_to(bce_loss)).sum() return l1_loss, mse_loss, bce_loss # Losses for GAN Vocoder def stft(x, fft_size, hop_length=None, win_length=None, window='hann', center=True, pad_mode='reflect'): """Perform STFT and convert to magnitude spectrogram. Args: x(Tensor): Input signal tensor (B, T). fft_size(int): FFT size. hop_size(int): Hop size. win_length(int, optional): window (str, optional): (Default value = None) window(str, optional): Name of window function, see `scipy.signal.get_window` for more details. Defaults to "hann". center(bool, optional, optional): center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of :math:`t`-th frame. Default: `True`. pad_mode(str, optional, optional): (Default value = 'reflect') hop_length: (Default value = None) Returns: Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). """ # calculate window window = signal.get_window(window, win_length, fftbins=True) window = paddle.to_tensor(window, dtype=x.dtype) x_stft = paddle.signal.stft( x, fft_size, hop_length, win_length, window=window, center=center, pad_mode=pad_mode) real = x_stft.real() imag = x_stft.imag() return paddle.sqrt(paddle.clip(real**2 + imag**2, min=1e-7)).transpose( [0, 2, 1]) class SpectralConvergenceLoss(nn.Layer): """Spectral convergence loss module.""" def __init__(self): """Initilize spectral convergence loss module.""" super().__init__() def forward(self, x_mag, y_mag): """Calculate forward propagation. Args: x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). Returns: Tensor: Spectral convergence loss value. """ return paddle.norm( y_mag - x_mag, p="fro") / paddle.clip( paddle.norm(y_mag, p="fro"), min=1e-10) class LogSTFTMagnitudeLoss(nn.Layer): """Log STFT magnitude loss module.""" def __init__(self, epsilon=1e-7): """Initilize los STFT magnitude loss module.""" super().__init__() self.epsilon = epsilon def forward(self, x_mag, y_mag): """Calculate forward propagation. Args: x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). Returns: Tensor: Log STFT magnitude loss value. """ return F.l1_loss( paddle.log(paddle.clip(y_mag, min=self.epsilon)), paddle.log(paddle.clip(x_mag, min=self.epsilon))) class STFTLoss(nn.Layer): """STFT loss module.""" def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann"): """Initialize STFT loss module.""" super().__init__() self.fft_size = fft_size self.shift_size = shift_size self.win_length = win_length self.window = window self.spectral_convergence_loss = SpectralConvergenceLoss() self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() def forward(self, x, y): """Calculate forward propagation. Args: x (Tensor): Predicted signal (B, T). y (Tensor): Groundtruth signal (B, T). Returns: Tensor: Spectral convergence loss value. Tensor: Log STFT magnitude loss value. """ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) sc_loss = self.spectral_convergence_loss(x_mag, y_mag) mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) return sc_loss, mag_loss class MultiResolutionSTFTLoss(nn.Layer): """Multi resolution STFT loss module.""" def __init__( self, fft_sizes=[1024, 2048, 512], hop_sizes=[120, 240, 50], win_lengths=[600, 1200, 240], window="hann", ): """Initialize Multi resolution STFT loss module. Args: fft_sizes (list): List of FFT sizes. hop_sizes (list): List of hop sizes. win_lengths (list): List of window lengths. window (str): Window function type. """ super().__init__() assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) self.stft_losses = nn.LayerList() for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): self.stft_losses.append(STFTLoss(fs, ss, wl, window)) def forward(self, x, y): """Calculate forward propagation. Args: x (Tensor): Predicted signal (B, T) or (B, #subband, T). y (Tensor): Groundtruth signal (B, T) or (B, #subband, T). Returns: Tensor: Multi resolution spectral convergence loss value. Tensor: Multi resolution log STFT magnitude loss value. """ if len(x.shape) == 3: # (B, C, T) -> (B x C, T) x = x.reshape([-1, x.shape[2]]) # (B, C, T) -> (B x C, T) y = y.reshape([-1, y.shape[2]]) sc_loss = 0.0 mag_loss = 0.0 for f in self.stft_losses: sc_l, mag_l = f(x, y) sc_loss += sc_l mag_loss += mag_l sc_loss /= len(self.stft_losses) mag_loss /= len(self.stft_losses) return sc_loss, mag_loss class GeneratorAdversarialLoss(nn.Layer): """Generator adversarial loss module.""" def __init__( self, average_by_discriminators=True, loss_type="mse", ): """Initialize GeneratorAversarialLoss module.""" super().__init__() self.average_by_discriminators = average_by_discriminators assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported." if loss_type == "mse": self.criterion = self._mse_loss else: self.criterion = self._hinge_loss def forward(self, outputs): """Calcualate generator adversarial loss. Args: outputs (Tensor or List): Discriminator outputs or list of discriminator outputs. Returns: Tensor: Generator adversarial loss value. """ if isinstance(outputs, (tuple, list)): adv_loss = 0.0 for i, outputs_ in enumerate(outputs): if isinstance(outputs_, (tuple, list)): # case including feature maps outputs_ = outputs_[-1] adv_loss += self.criterion(outputs_) if self.average_by_discriminators: adv_loss /= i + 1 else: adv_loss = self.criterion(outputs) return adv_loss def _mse_loss(self, x): return F.mse_loss(x, paddle.ones_like(x)) def _hinge_loss(self, x): return -x.mean() class DiscriminatorAdversarialLoss(nn.Layer): """Discriminator adversarial loss module.""" def __init__( self, average_by_discriminators=True, loss_type="mse", ): """Initialize DiscriminatorAversarialLoss module.""" super().__init__() self.average_by_discriminators = average_by_discriminators assert loss_type in ["mse"], f"{loss_type} is not supported." if loss_type == "mse": self.fake_criterion = self._mse_fake_loss self.real_criterion = self._mse_real_loss def forward(self, outputs_hat, outputs): """Calcualate discriminator adversarial loss. Args: outputs_hat (Tensor or list): Discriminator outputs or list of discriminator outputs calculated from generator outputs. outputs (Tensor or list): Discriminator outputs or list of discriminator outputs calculated from groundtruth. Returns: Tensor: Discriminator real loss value. Tensor: Discriminator fake loss value. """ if isinstance(outputs, (tuple, list)): real_loss = 0.0 fake_loss = 0.0 for i, (outputs_hat_, outputs_) in enumerate(zip(outputs_hat, outputs)): if isinstance(outputs_hat_, (tuple, list)): # case including feature maps outputs_hat_ = outputs_hat_[-1] outputs_ = outputs_[-1] real_loss += self.real_criterion(outputs_) fake_loss += self.fake_criterion(outputs_hat_) if self.average_by_discriminators: fake_loss /= i + 1 real_loss /= i + 1 else: real_loss = self.real_criterion(outputs) fake_loss = self.fake_criterion(outputs_hat) return real_loss, fake_loss def _mse_real_loss(self, x): return F.mse_loss(x, paddle.ones_like(x)) def _mse_fake_loss(self, x): return F.mse_loss(x, paddle.zeros_like(x)) # Losses for SpeedySpeech # Structural Similarity Index Measure (SSIM) def gaussian(window_size, sigma): gauss = paddle.to_tensor([ math.exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) for x in range(window_size) ]) return gauss / gauss.sum() def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = paddle.matmul(_1D_window, paddle.transpose( _1D_window, [1, 0])).unsqueeze([0, 1]) window = paddle.expand(_2D_window, [channel, 1, window_size, window_size]) return window def _ssim(img1, img2, window, window_size, channel, size_average=True): mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d( img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq sigma2_sq = F.conv2d( img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq sigma12 = F.conv2d( img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) \ / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1) def ssim(img1, img2, window_size=11, size_average=True): (_, channel, _, _) = img1.shape window = create_window(window_size, channel) return _ssim(img1, img2, window, window_size, channel, size_average) def weighted_mean(input, weight): """Weighted mean. It can also be used as masked mean. Args: input(Tensor): The input tensor. weight(Tensor): The weight tensor with broadcastable shape with the input. Returns: Tensor: Weighted mean tensor with the same dtype as input. shape=(1,) """ weight = paddle.cast(weight, input.dtype) # paddle.Tensor.size is different with torch.size() and has been overrided in s2t.__init__ broadcast_ratio = input.numel() / weight.numel() return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_ratio) def masked_l1_loss(prediction, target, mask): """Compute maksed L1 loss. Args: prediction(Tensor): The prediction. target(Tensor): The target. The shape should be broadcastable to ``prediction``. mask(Tensor): The mask. The shape should be broadcatable to the broadcasted shape of ``prediction`` and ``target``. Returns: Tensor: The masked L1 loss. shape=(1,) """ abs_error = F.l1_loss(prediction, target, reduction='none') loss = weighted_mean(abs_error, mask) return loss class MelSpectrogram(nn.Layer): """Calculate Mel-spectrogram.""" def __init__( self, fs=22050, fft_size=1024, hop_size=256, win_length=None, window="hann", num_mels=80, fmin=80, fmax=7600, center=True, normalized=False, onesided=True, eps=1e-10, log_base=10.0, ): """Initialize MelSpectrogram module.""" super().__init__() self.fft_size = fft_size if win_length is None: self.win_length = fft_size else: self.win_length = win_length self.hop_size = hop_size self.center = center self.normalized = normalized self.onesided = onesided if window is not None and not hasattr(signal.windows, f"{window}"): raise ValueError(f"{window} window is not implemented") self.window = window self.eps = eps fmin = 0 if fmin is None else fmin fmax = fs / 2 if fmax is None else fmax melmat = librosa.filters.mel( sr=fs, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax, ) self.melmat = paddle.to_tensor(melmat.T) self.stft_params = { "n_fft": self.fft_size, "win_length": self.win_length, "hop_length": self.hop_size, "center": self.center, "normalized": self.normalized, "onesided": self.onesided, } self.log_base = log_base if self.log_base is None: self.log = paddle.log elif self.log_base == 2.0: self.log = paddle.log2 elif self.log_base == 10.0: self.log = paddle.log10 else: raise ValueError(f"log_base: {log_base} is not supported.") def forward(self, x): """Calculate Mel-spectrogram. Args: x (Tensor): Input waveform tensor (B, T) or (B, 1, T). Returns: Tensor: Mel-spectrogram (B, #mels, #frames). """ if len(x.shape) == 3: # (B, C, T) -> (B*C, T) x = x.reshape([-1, paddle.shape(x)[2]]) if self.window is not None: # calculate window window = signal.get_window( self.window, self.win_length, fftbins=True) window = paddle.to_tensor(window, dtype=x.dtype) else: window = None x_stft = paddle.signal.stft(x, window=window, **self.stft_params) real = x_stft.real() imag = x_stft.imag() # (B, #freqs, #frames) -> (B, $frames, #freqs) real = real.transpose([0, 2, 1]) imag = imag.transpose([0, 2, 1]) x_power = real**2 + imag**2 x_amp = paddle.sqrt(paddle.clip(x_power, min=self.eps)) x_mel = paddle.matmul(x_amp, self.melmat) x_mel = paddle.clip(x_mel, min=self.eps) return self.log(x_mel).transpose([0, 2, 1]) class MelSpectrogramLoss(nn.Layer): """Mel-spectrogram loss.""" def __init__( self, fs=22050, fft_size=1024, hop_size=256, win_length=None, window="hann", num_mels=80, fmin=80, fmax=7600, center=True, normalized=False, onesided=True, eps=1e-10, log_base=10.0, ): """Initialize Mel-spectrogram loss.""" super().__init__() self.mel_spectrogram = MelSpectrogram( fs=fs, fft_size=fft_size, hop_size=hop_size, win_length=win_length, window=window, num_mels=num_mels, fmin=fmin, fmax=fmax, center=center, normalized=normalized, onesided=onesided, eps=eps, log_base=log_base, ) def forward(self, y_hat, y): """Calculate Mel-spectrogram loss. Args: y_hat(Tensor): Generated single tensor (B, 1, T). y(Tensor): Groundtruth single tensor (B, 1, T). Returns: Tensor: Mel-spectrogram loss value. """ mel_hat = self.mel_spectrogram(y_hat) mel = self.mel_spectrogram(y) mel_loss = F.l1_loss(mel_hat, mel) return mel_loss class FeatureMatchLoss(nn.Layer): """Feature matching loss module.""" def __init__( self, average_by_layers=True, average_by_discriminators=True, include_final_outputs=False, ): """Initialize FeatureMatchLoss module.""" super().__init__() self.average_by_layers = average_by_layers self.average_by_discriminators = average_by_discriminators self.include_final_outputs = include_final_outputs def forward(self, feats_hat, feats): """Calcualate feature matching loss. Args: feats_hat(list): List of list of discriminator outputs calcuated from generater outputs. feats(list): List of list of discriminator outputs Returns: Tensor: Feature matching loss value. """ feat_match_loss = 0.0 for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)): feat_match_loss_ = 0.0 if not self.include_final_outputs: feats_hat_ = feats_hat_[:-1] feats_ = feats_[:-1] for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)): feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach()) if self.average_by_layers: feat_match_loss_ /= j + 1 feat_match_loss += feat_match_loss_ if self.average_by_discriminators: feat_match_loss /= i + 1 return feat_match_loss # loss for VITS class KLDivergenceLoss(nn.Layer): """KL divergence loss.""" def forward( self, z_p: paddle.Tensor, logs_q: paddle.Tensor, m_p: paddle.Tensor, logs_p: paddle.Tensor, z_mask: paddle.Tensor, ) -> paddle.Tensor: """Calculate KL divergence loss. Args: z_p (Tensor): Flow hidden representation (B, H, T_feats). logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats). m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats). logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats). z_mask (Tensor): Mask tensor (B, 1, T_feats). Returns: Tensor: KL divergence loss. """ z_p = paddle.cast(z_p, 'float32') logs_q = paddle.cast(logs_q, 'float32') m_p = paddle.cast(m_p, 'float32') logs_p = paddle.cast(logs_p, 'float32') z_mask = paddle.cast(z_mask, 'float32') kl = logs_p - logs_q - 0.5 kl += 0.5 * ((z_p - m_p)**2) * paddle.exp(-2.0 * logs_p) kl = paddle.sum(kl * z_mask) loss = kl / paddle.sum(z_mask) return loss # loss for ERNIE SAT class MLMLoss(nn.Layer): def __init__(self, odim: int, vocab_size: int=0, lsm_weight: float=0.1, ignore_id: int=-1, text_masking: bool=False): super().__init__() if text_masking: self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id) if lsm_weight > 50: self.l1_loss_func = nn.MSELoss() else: self.l1_loss_func = nn.L1Loss(reduction='none') self.text_masking = text_masking self.odim = odim self.vocab_size = vocab_size def forward( self, speech: paddle.Tensor, before_outs: paddle.Tensor, after_outs: paddle.Tensor, masked_pos: paddle.Tensor, # for text_loss when text_masking == True text: paddle.Tensor=None, text_outs: paddle.Tensor=None, text_masked_pos: paddle.Tensor=None): xs_pad = speech mlm_loss_pos = masked_pos > 0 loss = paddle.sum( self.l1_loss_func( paddle.reshape(before_outs, (-1, self.odim)), paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) if after_outs is not None: loss += paddle.sum( self.l1_loss_func( paddle.reshape(after_outs, (-1, self.odim)), paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) mlm_loss_pos = (mlm_loss_pos).astype(loss.dtype) mlm_loss = paddle.sum((loss * paddle.reshape( mlm_loss_pos, [-1]).astype(loss.dtype))) / paddle.sum((mlm_loss_pos) + 1e-10) text_mlm_loss = None if self.text_masking: assert text is not None assert text_outs is not None assert text_masked_pos is not None text_outs = paddle.reshape(text_outs, [-1, self.vocab_size]) text = paddle.reshape(text, [-1]) text_mlm_loss = self.text_mlm_loss(text_outs, text) text_masked_pos_reshape = paddle.reshape(text_masked_pos, [-1]) text_mlm_loss = paddle.sum( text_mlm_loss * text_masked_pos_reshape) / paddle.sum((text_masked_pos) + 1e-10) return mlm_loss, text_mlm_loss class VarianceLoss(nn.Layer): @typechecked def __init__(self, use_masking: bool=True, use_weighted_masking: bool=False): """Initialize JETS variance loss module. Args: use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to weighted masking in loss calculation. """ super().__init__() assert (use_masking != use_weighted_masking) or not use_masking self.use_masking = use_masking self.use_weighted_masking = use_weighted_masking # define criterions reduction = "none" if self.use_weighted_masking else "mean" self.mse_criterion = nn.MSELoss(reduction=reduction) self.duration_criterion = DurationPredictorLoss(reduction=reduction) def forward( self, d_outs: paddle.Tensor, ds: paddle.Tensor, p_outs: paddle.Tensor, ps: paddle.Tensor, e_outs: paddle.Tensor, es: paddle.Tensor, ilens: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Calculate forward propagation. Args: d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text). ds (LongTensor): Batch of durations (B, T_text). p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1). ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1). e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1). es (Tensor): Batch of target token-averaged energy (B, T_text, 1). ilens (LongTensor): Batch of the lengths of each input (B,). Returns: Tensor: Duration predictor loss value. Tensor: Pitch predictor loss value. Tensor: Energy predictor loss value. """ # apply mask to remove padded part if self.use_masking: duration_masks = paddle.to_tensor( make_non_pad_mask(ilens), place=ds.place) d_outs = d_outs.masked_select(duration_masks) ds = ds.masked_select(duration_masks) pitch_masks = paddle.to_tensor( make_non_pad_mask(ilens).unsqueeze(-1), place=ds.place) p_outs = p_outs.masked_select(pitch_masks) e_outs = e_outs.masked_select(pitch_masks) ps = ps.masked_select(pitch_masks) es = es.masked_select(pitch_masks) # calculate loss duration_loss = self.duration_criterion(d_outs, ds) pitch_loss = self.mse_criterion(p_outs, ps) energy_loss = self.mse_criterion(e_outs, es) # make weighted mask and apply it if self.use_weighted_masking: duration_masks = paddle.to_tensor( make_non_pad_mask(ilens), place=ds.place) duration_weights = (duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float()) duration_weights /= ds.size(0) # apply weight duration_loss = (duration_loss.mul(duration_weights).masked_select( duration_masks).sum()) pitch_masks = duration_masks.unsqueeze(-1) pitch_weights = duration_weights.unsqueeze(-1) pitch_loss = pitch_loss.mul(pitch_weights).masked_select( pitch_masks).sum() energy_loss = ( energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum()) return duration_loss, pitch_loss, energy_loss class ForwardSumLoss(nn.Layer): """ https://openreview.net/forum?id=0NQwnnwAORi """ def __init__(self, cache_prior: bool=True): """ Args: cache_prior (bool): Whether to cache beta-binomial prior """ super().__init__() self.cache_prior = cache_prior self._cache = {} def forward( self, log_p_attn: paddle.Tensor, ilens: paddle.Tensor, olens: paddle.Tensor, blank_prob: float=np.e**-1, ) -> paddle.Tensor: """ Args: log_p_attn (Tensor): Batch of log probability of attention matrix (B, T_feats, T_text). ilens (Tensor): Batch of the lengths of each input (B,). olens (Tensor): Batch of the lengths of each target (B,). blank_prob (float): Blank symbol probability Returns: Tensor: forwardsum loss value. """ B = log_p_attn.shape[0] # add beta-binomial prior bb_prior = self._generate_prior(ilens, olens) bb_prior = paddle.to_tensor( bb_prior, dtype=log_p_attn.dtype, place=log_p_attn.place) log_p_attn = log_p_attn + bb_prior # a row must be added to the attention matrix to account for blank token of CTC loss # (B,T_feats,T_text+1) log_p_attn_pd = F.pad( log_p_attn, (0, 0, 0, 0, 1, 0), value=np.log(blank_prob)) loss = 0 for bidx in range(B): # construct target sequnece. # Every text token is mapped to a unique sequnece number. target_seq = paddle.arange( 1, ilens[bidx] + 1, dtype="int32").unsqueeze(0) cur_log_p_attn_pd = log_p_attn_pd[bidx, :olens[bidx], :ilens[ bidx] + 1].unsqueeze(1) # (T_feats,1,T_text+1) # The input of ctc_loss API need to be fixed loss += F.ctc_loss( log_probs=cur_log_p_attn_pd, labels=target_seq, input_lengths=olens[bidx:bidx + 1], label_lengths=ilens[bidx:bidx + 1]) loss = loss / B return loss def _generate_prior(self, text_lengths, feats_lengths, w=1) -> paddle.Tensor: """Generate alignment prior formulated as beta-binomial distribution Args: text_lengths (Tensor): Batch of the lengths of each input (B,). feats_lengths (Tensor): Batch of the lengths of each target (B,). w (float): Scaling factor; lower -> wider the width Returns: Tensor: Batched 2d static prior matrix (B, T_feats, T_text) """ B = len(text_lengths) T_text = text_lengths.max() T_feats = feats_lengths.max() bb_prior = paddle.full((B, T_feats, T_text), fill_value=-np.inf) for bidx in range(B): T = feats_lengths[bidx].item() N = text_lengths[bidx].item() key = str(T) + ',' + str(N) if self.cache_prior and key in self._cache: prob = self._cache[key] else: alpha = w * np.arange(1, T + 1, dtype=float) # (T,) beta = w * np.array([T - t + 1 for t in alpha]) k = np.arange(N) batched_k = k[..., None] # (N,1) prob = betabinom.pmf(batched_k, N, alpha, beta) # (N,T) # store cache if self.cache_prior and key not in self._cache: self._cache[key] = prob prob = paddle.to_tensor( prob, place=text_lengths.place, dtype="float32").transpose( (1, 0)) # -> (T,N) bb_prior[bidx, :T, :N] = prob return bb_prior class MultiScaleSTFTLoss(nn.Layer): """Computes the multi-scale STFT loss from [1]. References ---------- 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. "DDSP: Differentiable Digital Signal Processing." International Conference on Learning Representations. 2019. Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/spectral.py """ def __init__( self, window_lengths: List[int]=[2048, 512], loss_fn: Callable=nn.L1Loss(), clamp_eps: float=1e-5, mag_weight: float=1.0, log_weight: float=1.0, pow: float=2.0, weight: float=1.0, match_stride: bool=False, window_type: Optional[str]=None, ): """ Args: window_lengths : List[int], optional Length of each window of each STFT, by default [2048, 512] loss_fn : typing.Callable, optional How to compare each loss, by default nn.L1Loss() clamp_eps : float, optional Clamp on the log magnitude, below, by default 1e-5 mag_weight : float, optional Weight of raw magnitude portion of loss, by default 1.0 log_weight : float, optional Weight of log magnitude portion of loss, by default 1.0 pow : float, optional Power to raise magnitude to before taking log, by default 2.0 weight : float, optional Weight of this loss, by default 1.0 match_stride : bool, optional Whether to match the stride of convolutional layers, by default False window_type : str, optional Type of window to use, by default None. """ super().__init__() self.stft_params = [ STFTParams( window_length=w, hop_length=w // 4, match_stride=match_stride, window_type=window_type, ) for w in window_lengths ] self.loss_fn = loss_fn self.log_weight = log_weight self.mag_weight = mag_weight self.clamp_eps = clamp_eps self.weight = weight self.pow = pow def forward(self, x: AudioSignal, y: AudioSignal): """Computes multi-scale STFT between an estimate and a reference signal. Args: x : AudioSignal Estimate signal y : AudioSignal Reference signal Returns: paddle.Tensor Multi-scale STFT loss. Example: >>> from paddlespeech.audiotools.core.audio_signal import AudioSignal >>> import paddle >>> x = AudioSignal("https://paddlespeech.cdn.bcebos.com/PaddleAudio/en.wav", 2_05) >>> y = x * 0.01 >>> loss = MultiScaleSTFTLoss() >>> loss(x, y).numpy() 7.562150 """ for s in self.stft_params: x.stft(s.window_length, s.hop_length, s.window_type) y.stft(s.window_length, s.hop_length, s.window_type) loss += self.log_weight * self.loss_fn( x.magnitude.clip(self.clamp_eps).pow(self.pow).log10(), y.magnitude.clip(self.clamp_eps).pow(self.pow).log10(), ) loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) return loss class GANLoss(nn.Layer): """ Computes a discriminator loss, given a discriminator on generated waveforms/spectrograms compared to ground truth waveforms/spectrograms. Computes the loss for both the discriminator and the generator in separate functions. Example: >>> from paddlespeech.audiotools.core.audio_signal import AudioSignal >>> import paddle >>> x = AudioSignal("https://paddlespeech.cdn.bcebos.com/PaddleAudio/en.wav", 2_05) >>> y = x * 0.01 >>> class My_discriminator0: >>> def __call__(self, x): >>> return x.sum() >>> loss = GANLoss(My_discriminator0()) >>> [loss(x, y)[0].numpy(), loss(x, y)[1].numpy()] [-0.102722, -0.001027] >>> class My_discriminator1: >>> def __call__(self, x): >>> return x.sum() >>> loss = GANLoss(My_discriminator1()) >>> [loss.generator_loss(x, y)[0].numpy(), loss.generator_loss(x, y)[1].numpy()] [1.00019, 0] >>> loss.discriminator_loss(x, y) 1.000200 """ def __init__(self, discriminator): """ Args: discriminator : paddle.nn.layer Discriminator model """ super().__init__() self.discriminator = discriminator def forward(self, fake: Union[AudioSignal, paddle.Tensor], real: Union[AudioSignal, paddle.Tensor]): if isinstance(fake, AudioSignal): d_fake = self.discriminator(fake.audio_data) else: d_fake = self.discriminator(fake) if isinstance(real, AudioSignal): d_real = self.discriminator(real.audio_data) else: d_real = self.discriminator(real) return d_fake, d_real def discriminator_loss(self, fake, real): d_fake, d_real = self.forward(fake, real) loss_d = 0 for x_fake, x_real in zip(d_fake, d_real): loss_d += paddle.mean(x_fake[-1]**2) loss_d += paddle.mean((1 - x_real[-1])**2) return loss_d def generator_loss(self, fake, real): d_fake, d_real = self.forward(fake, real) loss_g = 0 for x_fake in d_fake: loss_g += paddle.mean((1 - x_fake[-1])**2) loss_feature = 0 for i in range(len(d_fake)): for j in range(len(d_fake[i]) - 1): loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j]()) return loss_g, loss_feature class SISDRLoss(nn.Layer): """ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch of estimated and reference audio signals or aligned features. Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/distance.py Example: >>> from paddlespeech.audiotools.core.audio_signal import AudioSignal >>> import paddle >>> x = AudioSignal("https://paddlespeech.cdn.bcebos.com/PaddleAudio/en.wav", 2_05) >>> y = x * 0.01 >>> sisdr = SISDRLoss() >>> sisdr(x, y).numpy() -145.377640 """ def __init__( self, scaling: bool=True, reduction: str="mean", zero_mean: bool=True, clip_min: Optional[int]=None, weight: float=1.0, ): """ Args: scaling : bool, optional Whether to use scale-invariant (True) or signal-to-noise ratio (False), by default True reduction : str, optional How to reduce across the batch (either 'mean', 'sum', or none).], by default ' mean' zero_mean : bool, optional Zero mean the references and estimates before computing the loss, by default True clip_min : int, optional The minimum possible loss value. Helps network to not focus on making already good examples better, by default None weight : float, optional Weight of this loss, defaults to 1.0. """ self.scaling = scaling self.reduction = reduction self.zero_mean = zero_mean self.clip_min = clip_min self.weight = weight super().__init__() def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): eps = 1e-8 # B, C, T if isinstance(x, AudioSignal): references = x.audio_data estimates = y.audio_data else: references = x estimates = y nb = references.shape[0] references = references.reshape([nb, 1, -1]).transpose([0, 2, 1]) estimates = estimates.reshape([nb, 1, -1]).transpose([0, 2, 1]) # samples now on axis 1 if self.zero_mean: mean_reference = references.mean(axis=1, keepdim=True) mean_estimate = estimates.mean(axis=1, keepdim=True) else: mean_reference = 0 mean_estimate = 0 _references = references - mean_reference _estimates = estimates - mean_estimate references_projection = (_references**2).sum(axis=-2) + eps references_on_estimates = (_estimates * _references).sum(axis=-2) + eps scale = ( (references_on_estimates / references_projection).unsqueeze(axis=1) if self.scaling else 1) e_true = scale * _references e_res = _estimates - e_true signal = (e_true**2).sum(axis=1) noise = (e_res**2).sum(axis=1) sdr = -10 * paddle.log10(signal / noise + eps) if self.clip_min != None: sdr = paddle.clip(sdr, min=self.clip_min) if self.reduction == "mean": sdr = sdr.mean() elif self.reduction == "sum": sdr = sdr.sum() return sdr