You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/modules/losses.py

1062 lines
36 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import librosa
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
from scipy import signal
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
# Losses for WaveRNN
def log_sum_exp(x):
""" numerically stable log_sum_exp implementation that prevents overflow """
# TF ordering
axis = len(x.shape) - 1
m = paddle.max(x, axis=axis)
m2 = paddle.max(x, axis=axis, keepdim=True)
return m + paddle.log(paddle.sum(paddle.exp(x - m2), axis=axis))
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
def discretized_mix_logistic_loss(y_hat,
y,
num_classes=65536,
log_scale_min=None,
reduce=True):
if log_scale_min is None:
log_scale_min = float(np.log(1e-14))
y_hat = y_hat.transpose([0, 2, 1])
assert y_hat.dim() == 3
assert y_hat.shape[1] % 3 == 0
nr_mix = y_hat.shape[1] // 3
# (B x T x C)
y_hat = y_hat.transpose([0, 2, 1])
# unpack parameters. (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix]
log_scales = paddle.clip(
y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
# B x T x 1 -> B x T x num_mixtures
y = y.expand_as(means)
centered_y = paddle.cast(y, dtype=paddle.get_default_dtype()) - means
inv_stdv = paddle.exp(-log_scales)
plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
cdf_plus = F.sigmoid(plus_in)
min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
cdf_min = F.sigmoid(min_in)
# log probability for edge case of 0 (before scaling)
# equivalent: torch.log(F.sigmoid(plus_in))
# softplus: log(1+ e^{-x})
log_cdf_plus = plus_in - F.softplus(plus_in)
# log probability for edge case of 255 (before scaling)
# equivalent: (1 - F.sigmoid(min_in)).log()
log_one_minus_cdf_min = -F.softplus(min_in)
# probability for all other cases
cdf_delta = cdf_plus - cdf_min
mid_in = inv_stdv * centered_y
# log probability in the center of the bin, to be used in extreme cases
# (not actually used in our code)
log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
# for num_classes=65536 case? 1e-7? not sure..
inner_inner_cond = cdf_delta > 1e-5
inner_inner_cond = paddle.cast(
inner_inner_cond, dtype=paddle.get_default_dtype())
# inner_inner_out = inner_inner_cond * \
# paddle.log(paddle.clip(cdf_delta, min=1e-12)) + \
# (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
inner_inner_out = inner_inner_cond * paddle.log(
paddle.clip(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (
log_pdf_mid - np.log((num_classes - 1) / 2))
inner_cond = y > 0.999
inner_cond = paddle.cast(inner_cond, dtype=paddle.get_default_dtype())
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond
) * inner_inner_out
cond = y < -0.999
cond = paddle.cast(cond, dtype=paddle.get_default_dtype())
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
log_probs = log_probs + F.log_softmax(logit_probs, -1)
if reduce:
return -paddle.mean(log_sum_exp(log_probs))
else:
return -log_sum_exp(log_probs).unsqueeze(-1)
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
"""
Sample from discretized mixture of logistic distributions
Args:
y(Tensor): (B, C, T)
log_scale_min(float, optional): (Default value = None)
Returns:
Tensor: sample in range of [-1, 1].
"""
if log_scale_min is None:
log_scale_min = float(np.log(1e-14))
assert y.shape[1] % 3 == 0
nr_mix = y.shape[1] // 3
# (B, T, C)
y = y.transpose([0, 2, 1])
logit_probs = y[:, :, :nr_mix]
# sample mixture indicator from softmax
temp = paddle.uniform(
logit_probs.shape, dtype=logit_probs.dtype, min=1e-5, max=1.0 - 1e-5)
temp = logit_probs - paddle.log(-paddle.log(temp))
argmax = paddle.argmax(temp, axis=-1)
# (B, T) -> (B, T, nr_mix)
one_hot = F.one_hot(argmax, nr_mix)
one_hot = paddle.cast(one_hot, dtype=paddle.get_default_dtype())
# select logistic parameters
means = paddle.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
log_scales = paddle.clip(
paddle.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1),
min=log_scale_min)
# sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling
u = paddle.uniform(means.shape, min=1e-5, max=1.0 - 1e-5)
x = means + paddle.exp(log_scales) * (paddle.log(u) - paddle.log(1. - u))
x = paddle.clip(x, min=-1., max=-1.)
return x
# Loss for Tacotron2
class GuidedAttentionLoss(nn.Layer):
"""Guided attention loss function module.
This module calculates the guided attention loss described
in `Efficiently Trainable Text-to-Speech System Based
on Deep Convolutional Networks with Guided Attention`_,
which forces the attention to be diagonal.
.. _`Efficiently Trainable Text-to-Speech System
Based on Deep Convolutional Networks with Guided Attention`:
https://arxiv.org/abs/1710.08969
"""
def __init__(self, sigma=0.4, alpha=1.0, reset_always=True):
"""Initialize guided attention loss module.
Args:
sigma (float, optional): Standard deviation to control how close attention to a diagonal.
alpha (float, optional): Scaling coefficient (lambda).
reset_always (bool, optional): Whether to always reset masks.
"""
super().__init__()
self.sigma = sigma
self.alpha = alpha
self.reset_always = reset_always
self.guided_attn_masks = None
self.masks = None
def _reset_masks(self):
self.guided_attn_masks = None
self.masks = None
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Args:
att_ws(Tensor): Batch of attention weights (B, T_max_out, T_max_in).
ilens(Tensor(int64)): Batch of input lenghts (B,).
olens(Tensor(int64)): Batch of output lenghts (B,).
Returns:
Tensor: Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = self._make_guided_attention_masks(ilens,
olens)
if self.masks is None:
self.masks = self._make_masks(ilens, olens)
losses = self.guided_attn_masks * att_ws
loss = paddle.mean(
losses.masked_select(self.masks.broadcast_to(losses.shape)))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
def _make_guided_attention_masks(self, ilens, olens):
n_batches = len(ilens)
max_ilen = max(ilens)
max_olen = max(olens)
guided_attn_masks = paddle.zeros((n_batches, max_olen, max_ilen))
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
guided_attn_masks[idx, :olen, :
ilen] = self._make_guided_attention_mask(
ilen, olen, self.sigma)
return guided_attn_masks
@staticmethod
def _make_guided_attention_mask(ilen, olen, sigma):
"""Make guided attention mask.
Examples
----------
>>> guided_attn_mask =_make_guided_attention(5, 5, 0.4)
>>> guided_attn_mask.shape
[5, 5]
>>> guided_attn_mask
tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647],
[0.1175, 0.0000, 0.1175, 0.3935, 0.6753],
[0.3935, 0.1175, 0.0000, 0.1175, 0.3935],
[0.6753, 0.3935, 0.1175, 0.0000, 0.1175],
[0.8647, 0.6753, 0.3935, 0.1175, 0.0000]])
>>> guided_attn_mask =_make_guided_attention(3, 6, 0.4)
>>> guided_attn_mask.shape
[6, 3]
>>> guided_attn_mask
tensor([[0.0000, 0.2934, 0.7506],
[0.0831, 0.0831, 0.5422],
[0.2934, 0.0000, 0.2934],
[0.5422, 0.0831, 0.0831],
[0.7506, 0.2934, 0.0000],
[0.8858, 0.5422, 0.0831]])
"""
grid_x, grid_y = paddle.meshgrid(
paddle.arange(olen), paddle.arange(ilen))
grid_x = grid_x.cast(dtype=paddle.float32)
grid_y = grid_y.cast(dtype=paddle.float32)
return 1.0 - paddle.exp(-(
(grid_y / ilen - grid_x / olen)**2) / (2 * (sigma**2)))
@staticmethod
def _make_masks(ilens, olens):
"""Make masks indicating non-padded part.
Args:
ilens(Tensor(int64) or List): Batch of lengths (B,).
olens(Tensor(int64) or List): Batch of lengths (B,).
Returns:
Tensor: Mask tensor indicating non-padded part.
Examples:
>>> ilens, olens = [5, 2], [8, 5]
>>> _make_mask(ilens, olens)
tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]],
[[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=paddle.uint8)
"""
# (B, T_in)
in_masks = make_non_pad_mask(ilens)
# (B, T_out)
out_masks = make_non_pad_mask(olens)
# (B, T_out, T_in)
return paddle.logical_and(
out_masks.unsqueeze(-1), in_masks.unsqueeze(-2))
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
"""Guided attention loss function module for multi head attention.
Args:
sigma (float, optional): Standard deviation to controlGuidedAttentionLoss
how close attention to a diagonal.
alpha (float, optional): Scaling coefficient (lambda).
reset_always (bool, optional): Whether to always reset masks.
"""
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Args:
att_ws(Tensor): Batch of multi head attention weights (B, H, T_max_out, T_max_in).
ilens(Tensor): Batch of input lenghts (B,).
olens(Tensor): Batch of output lenghts (B,).
Returns:
Tensor: Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = (
self._make_guided_attention_masks(ilens, olens).unsqueeze(1))
if self.masks is None:
self.masks = self._make_masks(ilens, olens).unsqueeze(1)
losses = self.guided_attn_masks * att_ws
loss = paddle.mean(
losses.masked_select(self.masks.broadcast_to(losses.shape)))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
class Tacotron2Loss(nn.Layer):
"""Loss function module for Tacotron2."""
def __init__(self,
use_masking=True,
use_weighted_masking=False,
bce_pos_weight=20.0):
"""Initialize Tactoron2 loss module.
Args:
use_masking (bool): Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
bce_pos_weight (float): Weight of positive sample of stop token.
"""
super().__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = nn.L1Loss(reduction=reduction)
self.mse_criterion = nn.MSELoss(reduction=reduction)
self.bce_criterion = nn.BCEWithLogitsLoss(
reduction=reduction, pos_weight=paddle.to_tensor(bce_pos_weight))
def forward(self, after_outs, before_outs, logits, ys, stop_labels, olens):
"""Calculate forward propagation.
Args:
after_outs(Tensor): Batch of outputs after postnets (B, Lmax, odim).
before_outs(Tensor): Batch of outputs before postnets (B, Lmax, odim).
logits(Tensor): Batch of stop logits (B, Lmax).
ys(Tensor): Batch of padded target features (B, Lmax, odim).
stop_labels(Tensor(int64)): Batch of the sequences of stop token labels (B, Lmax).
olens(Tensor(int64)):
Returns:
Tensor: L1 loss value.
Tensor: Mean square error loss value.
Tensor: Binary cross entropy loss value.
"""
# make mask and apply it
if self.use_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1)
ys = ys.masked_select(masks.broadcast_to(ys.shape))
after_outs = after_outs.masked_select(
masks.broadcast_to(after_outs.shape))
before_outs = before_outs.masked_select(
masks.broadcast_to(before_outs.shape))
stop_labels = stop_labels.masked_select(
masks[:, :, 0].broadcast_to(stop_labels.shape))
logits = logits.masked_select(
masks[:, :, 0].broadcast_to(logits.shape))
# calculate loss
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(
before_outs, ys)
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(
before_outs, ys)
bce_loss = self.bce_criterion(logits, stop_labels)
# make weighted mask and apply it
if self.use_weighted_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1)
weights = masks.float() / masks.sum(axis=1, keepdim=True).float()
out_weights = weights.divide(
paddle.shape(ys)[0] * paddle.shape(ys)[2])
logit_weights = weights.divide(paddle.shape(ys)[0])
# apply weight
l1_loss = l1_loss.multiply(out_weights)
l1_loss = l1_loss.masked_select(masks.broadcast_to(l1_loss)).sum()
mse_loss = mse_loss.multiply(out_weights)
mse_loss = mse_loss.masked_select(
masks.broadcast_to(mse_loss)).sum()
bce_loss = bce_loss.multiply(logit_weights.squeeze(-1))
bce_loss = bce_loss.masked_select(
masks.squeeze(-1).broadcast_to(bce_loss)).sum()
return l1_loss, mse_loss, bce_loss
# Losses for GAN Vocoder
def stft(x,
fft_size,
hop_length=None,
win_length=None,
window='hann',
center=True,
pad_mode='reflect'):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x(Tensor): Input signal tensor (B, T).
fft_size(int): FFT size.
hop_size(int): Hop size.
win_length(int, optional): window : str, optional (Default value = None)
window(str, optional): Name of window function, see `scipy.signal.get_window` for more
details. Defaults to "hann".
center(bool, optional, optional): center (bool, optional): Whether to pad `x` to make that the
:math:`t \times hop\\_length` at the center of :math:`t`-th frame. Default: `True`.
pad_mode(str, optional, optional): (Default value = 'reflect')
hop_length: (Default value = None)
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
# calculate window
window = signal.get_window(window, win_length, fftbins=True)
window = paddle.to_tensor(window, dtype=x.dtype)
x_stft = paddle.signal.stft(
x,
fft_size,
hop_length,
win_length,
window=window,
center=center,
pad_mode=pad_mode)
real = x_stft.real()
imag = x_stft.imag()
return paddle.sqrt(paddle.clip(real**2 + imag**2, min=1e-7)).transpose(
[0, 2, 1])
class SpectralConvergenceLoss(nn.Layer):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super().__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return paddle.norm(
y_mag - x_mag, p="fro") / paddle.clip(
paddle.norm(y_mag, p="fro"), min=1e-10)
class LogSTFTMagnitudeLoss(nn.Layer):
"""Log STFT magnitude loss module."""
def __init__(self, epsilon=1e-7):
"""Initilize los STFT magnitude loss module."""
super().__init__()
self.epsilon = epsilon
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(
paddle.log(paddle.clip(y_mag, min=self.epsilon)),
paddle.log(paddle.clip(x_mag, min=self.epsilon)))
class STFTLoss(nn.Layer):
"""STFT loss module."""
def __init__(self,
fft_size=1024,
shift_size=120,
win_length=600,
window="hann"):
"""Initialize STFT loss module."""
super().__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = window
self.spectral_convergence_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length,
self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length,
self.window)
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(nn.Layer):
"""Multi resolution STFT loss module."""
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann", ):
"""Initialize Multi resolution STFT loss module.
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_lengths (list): List of window lengths.
window (str): Window function type.
"""
super().__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = nn.LayerList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses.append(STFTLoss(fs, ss, wl, window))
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T) or (B, #subband, T).
y (Tensor): Groundtruth signal (B, T) or (B, #subband, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
if len(x.shape) == 3:
# (B, C, T) -> (B x C, T)
x = x.reshape([-1, x.shape[2]])
# (B, C, T) -> (B x C, T)
y = y.reshape([-1, y.shape[2]])
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss
class GeneratorAdversarialLoss(nn.Layer):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type="mse", ):
"""Initialize GeneratorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(self, outputs):
"""Calcualate generator adversarial loss.
Args:
outputs (Tensor or List): Discriminator outputs or list of discriminator outputs.
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# case including feature maps
outputs_ = outputs_[-1]
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, paddle.ones_like(x))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(nn.Layer):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type="mse", ):
"""Initialize DiscriminatorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
def forward(self, outputs_hat, outputs):
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from generator outputs.
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_,
outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x):
return F.mse_loss(x, paddle.ones_like(x))
def _mse_fake_loss(self, x):
return F.mse_loss(x, paddle.zeros_like(x))
# Losses for SpeedySpeech
# Structural Similarity Index Measure (SSIM)
def gaussian(window_size, sigma):
gauss = paddle.to_tensor([
math.exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
for x in range(window_size)
])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = paddle.matmul(_1D_window, paddle.transpose(
_1D_window, [1, 0])).unsqueeze([0, 1])
window = paddle.expand(_2D_window, [channel, 1, window_size, window_size])
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) \
/ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.shape
window = create_window(window_size, channel)
return _ssim(img1, img2, window, window_size, channel, size_average)
def weighted_mean(input, weight):
"""Weighted mean. It can also be used as masked mean.
Args:
input(Tensor): The input tensor.
weight(Tensor): The weight tensor with broadcastable shape with the input.
Returns:
Tensor: Weighted mean tensor with the same dtype as input. shape=(1,)
"""
weight = paddle.cast(weight, input.dtype)
# paddle.Tensor.size is different with torch.size() and has been overrided in s2t.__init__
broadcast_ratio = input.numel() / weight.numel()
return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_ratio)
def masked_l1_loss(prediction, target, mask):
"""Compute maksed L1 loss.
Args:
prediction(Tensor): The prediction.
target(Tensor): The target. The shape should be broadcastable to ``prediction``.
mask(Tensor): The mask. The shape should be broadcatable to the broadcasted shape of
``prediction`` and ``target``.
Returns:
Tensor: The masked L1 loss. shape=(1,)
"""
abs_error = F.l1_loss(prediction, target, reduction='none')
loss = weighted_mean(abs_error, mask)
return loss
class MelSpectrogram(nn.Layer):
"""Calculate Mel-spectrogram."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0, ):
"""Initialize MelSpectrogram module."""
super().__init__()
self.fft_size = fft_size
if win_length is None:
self.win_length = fft_size
else:
self.win_length = win_length
self.hop_size = hop_size
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(signal.windows, f"{window}"):
raise ValueError(f"{window} window is not implemented")
self.window = window
self.eps = eps
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
melmat = librosa.filters.mel(
sr=fs,
n_fft=fft_size,
n_mels=num_mels,
fmin=fmin,
fmax=fmax, )
self.melmat = paddle.to_tensor(melmat.T)
self.stft_params = {
"n_fft": self.fft_size,
"win_length": self.win_length,
"hop_length": self.hop_size,
"center": self.center,
"normalized": self.normalized,
"onesided": self.onesided,
}
self.log_base = log_base
if self.log_base is None:
self.log = paddle.log
elif self.log_base == 2.0:
self.log = paddle.log2
elif self.log_base == 10.0:
self.log = paddle.log10
else:
raise ValueError(f"log_base: {log_base} is not supported.")
def forward(self, x):
"""Calculate Mel-spectrogram.
Args:
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
Returns:
Tensor: Mel-spectrogram (B, #mels, #frames).
"""
if len(x.shape) == 3:
# (B, C, T) -> (B*C, T)
x = x.reshape([-1, paddle.shape(x)[2]])
if self.window is not None:
# calculate window
window = signal.get_window(
self.window, self.win_length, fftbins=True)
window = paddle.to_tensor(window, dtype=x.dtype)
else:
window = None
x_stft = paddle.signal.stft(x, window=window, **self.stft_params)
real = x_stft.real()
imag = x_stft.imag()
# (B, #freqs, #frames) -> (B, $frames, #freqs)
real = real.transpose([0, 2, 1])
imag = imag.transpose([0, 2, 1])
x_power = real**2 + imag**2
x_amp = paddle.sqrt(paddle.clip(x_power, min=self.eps))
x_mel = paddle.matmul(x_amp, self.melmat)
x_mel = paddle.clip(x_mel, min=self.eps)
return self.log(x_mel).transpose([0, 2, 1])
class MelSpectrogramLoss(nn.Layer):
"""Mel-spectrogram loss."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0, ):
"""Initialize Mel-spectrogram loss."""
super().__init__()
self.mel_spectrogram = MelSpectrogram(
fs=fs,
fft_size=fft_size,
hop_size=hop_size,
win_length=win_length,
window=window,
num_mels=num_mels,
fmin=fmin,
fmax=fmax,
center=center,
normalized=normalized,
onesided=onesided,
eps=eps,
log_base=log_base, )
def forward(self, y_hat, y):
"""Calculate Mel-spectrogram loss.
Args:
y_hat(Tensor): Generated single tensor (B, 1, T).
y(Tensor): Groundtruth single tensor (B, 1, T).
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.mel_spectrogram(y_hat)
mel = self.mel_spectrogram(y)
mel_loss = F.l1_loss(mel_hat, mel)
return mel_loss
class FeatureMatchLoss(nn.Layer):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers=True,
average_by_discriminators=True,
include_final_outputs=False, ):
"""Initialize FeatureMatchLoss module."""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
self.include_final_outputs = include_final_outputs
def forward(self, feats_hat, feats):
"""Calcualate feature matching loss.
Args:
feats_hat(list): List of list of discriminator outputs
calcuated from generater outputs.
feats(list): List of list of discriminator outputs
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
if not self.include_final_outputs:
feats_hat_ = feats_hat_[:-1]
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
# loss for VITS
class KLDivergenceLoss(nn.Layer):
"""KL divergence loss."""
def forward(
self,
z_p: paddle.Tensor,
logs_q: paddle.Tensor,
m_p: paddle.Tensor,
logs_p: paddle.Tensor,
z_mask: paddle.Tensor, ) -> paddle.Tensor:
"""Calculate KL divergence loss.
Args:
z_p (Tensor): Flow hidden representation (B, H, T_feats).
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
z_mask (Tensor): Mask tensor (B, 1, T_feats).
Returns:
Tensor: KL divergence loss.
"""
z_p = paddle.cast(z_p, 'float32')
logs_q = paddle.cast(logs_q, 'float32')
m_p = paddle.cast(m_p, 'float32')
logs_p = paddle.cast(logs_p, 'float32')
z_mask = paddle.cast(z_mask, 'float32')
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p)**2) * paddle.exp(-2.0 * logs_p)
kl = paddle.sum(kl * z_mask)
loss = kl / paddle.sum(z_mask)
return loss
# loss for ERNIE SAT
class MLMLoss(nn.Layer):
def __init__(self,
lsm_weight: float=0.1,
ignore_id: int=-1,
text_masking: bool=False):
super().__init__()
if text_masking:
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
if lsm_weight > 50:
self.l1_loss_func = nn.MSELoss()
else:
self.l1_loss_func = nn.L1Loss(reduction='none')
self.text_masking = text_masking
def forward(self,
speech: paddle.Tensor,
before_outs: paddle.Tensor,
after_outs: paddle.Tensor,
masked_pos: paddle.Tensor,
text: paddle.Tensor=None,
text_outs: paddle.Tensor=None,
text_masked_pos: paddle.Tensor=None):
xs_pad = speech
mlm_loss_pos = masked_pos > 0
loss = paddle.sum(
self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
if after_outs is not None:
loss += paddle.sum(
self.l1_loss_func(
paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
loss_mlm = paddle.sum((loss * paddle.reshape(
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
if self.text_masking:
loss_text = paddle.sum((self.text_mlm_loss(
paddle.reshape(text_outs, (-1, self.vocab_size)),
paddle.reshape(text, (-1))) * paddle.reshape(
text_masked_pos,
(-1)))) / paddle.sum((text_masked_pos) + 1e-10)
return loss_mlm, loss_text
return loss_mlm