You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1062 lines
36 KiB
1062 lines
36 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import math
|
|
|
|
import librosa
|
|
import numpy as np
|
|
import paddle
|
|
from paddle import nn
|
|
from paddle.nn import functional as F
|
|
from scipy import signal
|
|
|
|
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
|
|
|
|
|
# Losses for WaveRNN
|
|
def log_sum_exp(x):
|
|
""" numerically stable log_sum_exp implementation that prevents overflow """
|
|
# TF ordering
|
|
axis = len(x.shape) - 1
|
|
m = paddle.max(x, axis=axis)
|
|
m2 = paddle.max(x, axis=axis, keepdim=True)
|
|
return m + paddle.log(paddle.sum(paddle.exp(x - m2), axis=axis))
|
|
|
|
|
|
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
|
|
def discretized_mix_logistic_loss(y_hat,
|
|
y,
|
|
num_classes=65536,
|
|
log_scale_min=None,
|
|
reduce=True):
|
|
if log_scale_min is None:
|
|
log_scale_min = float(np.log(1e-14))
|
|
y_hat = y_hat.transpose([0, 2, 1])
|
|
assert y_hat.dim() == 3
|
|
assert y_hat.shape[1] % 3 == 0
|
|
nr_mix = y_hat.shape[1] // 3
|
|
|
|
# (B x T x C)
|
|
y_hat = y_hat.transpose([0, 2, 1])
|
|
|
|
# unpack parameters. (B, T, num_mixtures) x 3
|
|
logit_probs = y_hat[:, :, :nr_mix]
|
|
means = y_hat[:, :, nr_mix:2 * nr_mix]
|
|
log_scales = paddle.clip(
|
|
y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
|
|
|
|
# B x T x 1 -> B x T x num_mixtures
|
|
y = y.expand_as(means)
|
|
centered_y = paddle.cast(y, dtype=paddle.get_default_dtype()) - means
|
|
inv_stdv = paddle.exp(-log_scales)
|
|
plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
|
|
cdf_plus = F.sigmoid(plus_in)
|
|
min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
|
|
cdf_min = F.sigmoid(min_in)
|
|
|
|
# log probability for edge case of 0 (before scaling)
|
|
# equivalent: torch.log(F.sigmoid(plus_in))
|
|
# softplus: log(1+ e^{-x})
|
|
log_cdf_plus = plus_in - F.softplus(plus_in)
|
|
|
|
# log probability for edge case of 255 (before scaling)
|
|
# equivalent: (1 - F.sigmoid(min_in)).log()
|
|
log_one_minus_cdf_min = -F.softplus(min_in)
|
|
|
|
# probability for all other cases
|
|
cdf_delta = cdf_plus - cdf_min
|
|
|
|
mid_in = inv_stdv * centered_y
|
|
# log probability in the center of the bin, to be used in extreme cases
|
|
# (not actually used in our code)
|
|
log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
|
|
|
|
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
|
# for num_classes=65536 case? 1e-7? not sure..
|
|
inner_inner_cond = cdf_delta > 1e-5
|
|
|
|
inner_inner_cond = paddle.cast(
|
|
inner_inner_cond, dtype=paddle.get_default_dtype())
|
|
|
|
# inner_inner_out = inner_inner_cond * \
|
|
# paddle.log(paddle.clip(cdf_delta, min=1e-12)) + \
|
|
# (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
|
|
|
|
inner_inner_out = inner_inner_cond * paddle.log(
|
|
paddle.clip(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (
|
|
log_pdf_mid - np.log((num_classes - 1) / 2))
|
|
|
|
inner_cond = y > 0.999
|
|
|
|
inner_cond = paddle.cast(inner_cond, dtype=paddle.get_default_dtype())
|
|
|
|
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond
|
|
) * inner_inner_out
|
|
cond = y < -0.999
|
|
cond = paddle.cast(cond, dtype=paddle.get_default_dtype())
|
|
|
|
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
|
|
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
|
|
|
if reduce:
|
|
return -paddle.mean(log_sum_exp(log_probs))
|
|
else:
|
|
return -log_sum_exp(log_probs).unsqueeze(-1)
|
|
|
|
|
|
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
|
"""
|
|
Sample from discretized mixture of logistic distributions
|
|
|
|
Args:
|
|
y(Tensor): (B, C, T)
|
|
log_scale_min(float, optional): (Default value = None)
|
|
|
|
Returns:
|
|
Tensor: sample in range of [-1, 1].
|
|
"""
|
|
if log_scale_min is None:
|
|
log_scale_min = float(np.log(1e-14))
|
|
|
|
assert y.shape[1] % 3 == 0
|
|
nr_mix = y.shape[1] // 3
|
|
|
|
# (B, T, C)
|
|
y = y.transpose([0, 2, 1])
|
|
logit_probs = y[:, :, :nr_mix]
|
|
|
|
# sample mixture indicator from softmax
|
|
temp = paddle.uniform(
|
|
logit_probs.shape, dtype=logit_probs.dtype, min=1e-5, max=1.0 - 1e-5)
|
|
temp = logit_probs - paddle.log(-paddle.log(temp))
|
|
argmax = paddle.argmax(temp, axis=-1)
|
|
|
|
# (B, T) -> (B, T, nr_mix)
|
|
one_hot = F.one_hot(argmax, nr_mix)
|
|
one_hot = paddle.cast(one_hot, dtype=paddle.get_default_dtype())
|
|
|
|
# select logistic parameters
|
|
means = paddle.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
|
|
log_scales = paddle.clip(
|
|
paddle.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1),
|
|
min=log_scale_min)
|
|
# sample from logistic & clip to interval
|
|
# we don't actually round to the nearest 8bit value when sampling
|
|
u = paddle.uniform(means.shape, min=1e-5, max=1.0 - 1e-5)
|
|
x = means + paddle.exp(log_scales) * (paddle.log(u) - paddle.log(1. - u))
|
|
x = paddle.clip(x, min=-1., max=-1.)
|
|
|
|
return x
|
|
|
|
|
|
# Loss for Tacotron2
|
|
class GuidedAttentionLoss(nn.Layer):
|
|
"""Guided attention loss function module.
|
|
|
|
This module calculates the guided attention loss described
|
|
in `Efficiently Trainable Text-to-Speech System Based
|
|
on Deep Convolutional Networks with Guided Attention`_,
|
|
which forces the attention to be diagonal.
|
|
|
|
.. _`Efficiently Trainable Text-to-Speech System
|
|
Based on Deep Convolutional Networks with Guided Attention`:
|
|
https://arxiv.org/abs/1710.08969
|
|
|
|
"""
|
|
|
|
def __init__(self, sigma=0.4, alpha=1.0, reset_always=True):
|
|
"""Initialize guided attention loss module.
|
|
|
|
Args:
|
|
sigma (float, optional): Standard deviation to control how close attention to a diagonal.
|
|
alpha (float, optional): Scaling coefficient (lambda).
|
|
reset_always (bool, optional): Whether to always reset masks.
|
|
|
|
"""
|
|
super().__init__()
|
|
self.sigma = sigma
|
|
self.alpha = alpha
|
|
self.reset_always = reset_always
|
|
self.guided_attn_masks = None
|
|
self.masks = None
|
|
|
|
def _reset_masks(self):
|
|
self.guided_attn_masks = None
|
|
self.masks = None
|
|
|
|
def forward(self, att_ws, ilens, olens):
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
att_ws(Tensor): Batch of attention weights (B, T_max_out, T_max_in).
|
|
ilens(Tensor(int64)): Batch of input lenghts (B,).
|
|
olens(Tensor(int64)): Batch of output lenghts (B,).
|
|
|
|
Returns:
|
|
Tensor: Guided attention loss value.
|
|
|
|
"""
|
|
if self.guided_attn_masks is None:
|
|
self.guided_attn_masks = self._make_guided_attention_masks(ilens,
|
|
olens)
|
|
if self.masks is None:
|
|
self.masks = self._make_masks(ilens, olens)
|
|
losses = self.guided_attn_masks * att_ws
|
|
loss = paddle.mean(
|
|
losses.masked_select(self.masks.broadcast_to(losses.shape)))
|
|
if self.reset_always:
|
|
self._reset_masks()
|
|
return self.alpha * loss
|
|
|
|
def _make_guided_attention_masks(self, ilens, olens):
|
|
n_batches = len(ilens)
|
|
max_ilen = max(ilens)
|
|
max_olen = max(olens)
|
|
guided_attn_masks = paddle.zeros((n_batches, max_olen, max_ilen))
|
|
|
|
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
|
|
guided_attn_masks[idx, :olen, :
|
|
ilen] = self._make_guided_attention_mask(
|
|
ilen, olen, self.sigma)
|
|
return guided_attn_masks
|
|
|
|
@staticmethod
|
|
def _make_guided_attention_mask(ilen, olen, sigma):
|
|
"""Make guided attention mask.
|
|
|
|
Examples
|
|
----------
|
|
>>> guided_attn_mask =_make_guided_attention(5, 5, 0.4)
|
|
>>> guided_attn_mask.shape
|
|
[5, 5]
|
|
>>> guided_attn_mask
|
|
tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647],
|
|
[0.1175, 0.0000, 0.1175, 0.3935, 0.6753],
|
|
[0.3935, 0.1175, 0.0000, 0.1175, 0.3935],
|
|
[0.6753, 0.3935, 0.1175, 0.0000, 0.1175],
|
|
[0.8647, 0.6753, 0.3935, 0.1175, 0.0000]])
|
|
>>> guided_attn_mask =_make_guided_attention(3, 6, 0.4)
|
|
>>> guided_attn_mask.shape
|
|
[6, 3]
|
|
>>> guided_attn_mask
|
|
tensor([[0.0000, 0.2934, 0.7506],
|
|
[0.0831, 0.0831, 0.5422],
|
|
[0.2934, 0.0000, 0.2934],
|
|
[0.5422, 0.0831, 0.0831],
|
|
[0.7506, 0.2934, 0.0000],
|
|
[0.8858, 0.5422, 0.0831]])
|
|
|
|
"""
|
|
grid_x, grid_y = paddle.meshgrid(
|
|
paddle.arange(olen), paddle.arange(ilen))
|
|
grid_x = grid_x.cast(dtype=paddle.float32)
|
|
grid_y = grid_y.cast(dtype=paddle.float32)
|
|
return 1.0 - paddle.exp(-(
|
|
(grid_y / ilen - grid_x / olen)**2) / (2 * (sigma**2)))
|
|
|
|
@staticmethod
|
|
def _make_masks(ilens, olens):
|
|
"""Make masks indicating non-padded part.
|
|
|
|
Args:
|
|
ilens(Tensor(int64) or List): Batch of lengths (B,).
|
|
olens(Tensor(int64) or List): Batch of lengths (B,).
|
|
|
|
Returns:
|
|
Tensor: Mask tensor indicating non-padded part.
|
|
|
|
Examples:
|
|
>>> ilens, olens = [5, 2], [8, 5]
|
|
>>> _make_mask(ilens, olens)
|
|
tensor([[[1, 1, 1, 1, 1],
|
|
[1, 1, 1, 1, 1],
|
|
[1, 1, 1, 1, 1],
|
|
[1, 1, 1, 1, 1],
|
|
[1, 1, 1, 1, 1],
|
|
[1, 1, 1, 1, 1],
|
|
[1, 1, 1, 1, 1],
|
|
[1, 1, 1, 1, 1]],
|
|
|
|
[[1, 1, 0, 0, 0],
|
|
[1, 1, 0, 0, 0],
|
|
[1, 1, 0, 0, 0],
|
|
[1, 1, 0, 0, 0],
|
|
[1, 1, 0, 0, 0],
|
|
[0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0]]], dtype=paddle.uint8)
|
|
|
|
"""
|
|
# (B, T_in)
|
|
in_masks = make_non_pad_mask(ilens)
|
|
# (B, T_out)
|
|
out_masks = make_non_pad_mask(olens)
|
|
# (B, T_out, T_in)
|
|
|
|
return paddle.logical_and(
|
|
out_masks.unsqueeze(-1), in_masks.unsqueeze(-2))
|
|
|
|
|
|
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
|
|
"""Guided attention loss function module for multi head attention.
|
|
|
|
Args:
|
|
sigma (float, optional): Standard deviation to controlGuidedAttentionLoss
|
|
how close attention to a diagonal.
|
|
alpha (float, optional): Scaling coefficient (lambda).
|
|
reset_always (bool, optional): Whether to always reset masks.
|
|
|
|
"""
|
|
|
|
def forward(self, att_ws, ilens, olens):
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
att_ws(Tensor): Batch of multi head attention weights (B, H, T_max_out, T_max_in).
|
|
ilens(Tensor): Batch of input lenghts (B,).
|
|
olens(Tensor): Batch of output lenghts (B,).
|
|
|
|
Returns:
|
|
Tensor: Guided attention loss value.
|
|
|
|
"""
|
|
if self.guided_attn_masks is None:
|
|
self.guided_attn_masks = (
|
|
self._make_guided_attention_masks(ilens, olens).unsqueeze(1))
|
|
if self.masks is None:
|
|
self.masks = self._make_masks(ilens, olens).unsqueeze(1)
|
|
losses = self.guided_attn_masks * att_ws
|
|
loss = paddle.mean(
|
|
losses.masked_select(self.masks.broadcast_to(losses.shape)))
|
|
if self.reset_always:
|
|
self._reset_masks()
|
|
|
|
return self.alpha * loss
|
|
|
|
|
|
class Tacotron2Loss(nn.Layer):
|
|
"""Loss function module for Tacotron2."""
|
|
|
|
def __init__(self,
|
|
use_masking=True,
|
|
use_weighted_masking=False,
|
|
bce_pos_weight=20.0):
|
|
"""Initialize Tactoron2 loss module.
|
|
|
|
Args:
|
|
use_masking (bool): Whether to apply masking for padded part in loss calculation.
|
|
use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
|
|
bce_pos_weight (float): Weight of positive sample of stop token.
|
|
"""
|
|
super().__init__()
|
|
assert (use_masking != use_weighted_masking) or not use_masking
|
|
self.use_masking = use_masking
|
|
self.use_weighted_masking = use_weighted_masking
|
|
|
|
# define criterions
|
|
reduction = "none" if self.use_weighted_masking else "mean"
|
|
self.l1_criterion = nn.L1Loss(reduction=reduction)
|
|
self.mse_criterion = nn.MSELoss(reduction=reduction)
|
|
self.bce_criterion = nn.BCEWithLogitsLoss(
|
|
reduction=reduction, pos_weight=paddle.to_tensor(bce_pos_weight))
|
|
|
|
def forward(self, after_outs, before_outs, logits, ys, stop_labels, olens):
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
after_outs(Tensor): Batch of outputs after postnets (B, Lmax, odim).
|
|
before_outs(Tensor): Batch of outputs before postnets (B, Lmax, odim).
|
|
logits(Tensor): Batch of stop logits (B, Lmax).
|
|
ys(Tensor): Batch of padded target features (B, Lmax, odim).
|
|
stop_labels(Tensor(int64)): Batch of the sequences of stop token labels (B, Lmax).
|
|
olens(Tensor(int64)):
|
|
|
|
Returns:
|
|
Tensor: L1 loss value.
|
|
Tensor: Mean square error loss value.
|
|
Tensor: Binary cross entropy loss value.
|
|
"""
|
|
# make mask and apply it
|
|
if self.use_masking:
|
|
masks = make_non_pad_mask(olens).unsqueeze(-1)
|
|
ys = ys.masked_select(masks.broadcast_to(ys.shape))
|
|
after_outs = after_outs.masked_select(
|
|
masks.broadcast_to(after_outs.shape))
|
|
before_outs = before_outs.masked_select(
|
|
masks.broadcast_to(before_outs.shape))
|
|
stop_labels = stop_labels.masked_select(
|
|
masks[:, :, 0].broadcast_to(stop_labels.shape))
|
|
logits = logits.masked_select(
|
|
masks[:, :, 0].broadcast_to(logits.shape))
|
|
|
|
# calculate loss
|
|
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(
|
|
before_outs, ys)
|
|
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(
|
|
before_outs, ys)
|
|
bce_loss = self.bce_criterion(logits, stop_labels)
|
|
|
|
# make weighted mask and apply it
|
|
if self.use_weighted_masking:
|
|
masks = make_non_pad_mask(olens).unsqueeze(-1)
|
|
weights = masks.float() / masks.sum(axis=1, keepdim=True).float()
|
|
out_weights = weights.divide(
|
|
paddle.shape(ys)[0] * paddle.shape(ys)[2])
|
|
logit_weights = weights.divide(paddle.shape(ys)[0])
|
|
|
|
# apply weight
|
|
l1_loss = l1_loss.multiply(out_weights)
|
|
l1_loss = l1_loss.masked_select(masks.broadcast_to(l1_loss)).sum()
|
|
mse_loss = mse_loss.multiply(out_weights)
|
|
mse_loss = mse_loss.masked_select(
|
|
masks.broadcast_to(mse_loss)).sum()
|
|
bce_loss = bce_loss.multiply(logit_weights.squeeze(-1))
|
|
bce_loss = bce_loss.masked_select(
|
|
masks.squeeze(-1).broadcast_to(bce_loss)).sum()
|
|
|
|
return l1_loss, mse_loss, bce_loss
|
|
|
|
|
|
# Losses for GAN Vocoder
|
|
def stft(x,
|
|
fft_size,
|
|
hop_length=None,
|
|
win_length=None,
|
|
window='hann',
|
|
center=True,
|
|
pad_mode='reflect'):
|
|
"""Perform STFT and convert to magnitude spectrogram.
|
|
Args:
|
|
x(Tensor): Input signal tensor (B, T).
|
|
fft_size(int): FFT size.
|
|
hop_size(int): Hop size.
|
|
win_length(int, optional): window : str, optional (Default value = None)
|
|
window(str, optional): Name of window function, see `scipy.signal.get_window` for more
|
|
details. Defaults to "hann".
|
|
center(bool, optional, optional): center (bool, optional): Whether to pad `x` to make that the
|
|
:math:`t \times hop\\_length` at the center of :math:`t`-th frame. Default: `True`.
|
|
pad_mode(str, optional, optional): (Default value = 'reflect')
|
|
hop_length: (Default value = None)
|
|
|
|
Returns:
|
|
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
|
"""
|
|
# calculate window
|
|
window = signal.get_window(window, win_length, fftbins=True)
|
|
window = paddle.to_tensor(window, dtype=x.dtype)
|
|
x_stft = paddle.signal.stft(
|
|
x,
|
|
fft_size,
|
|
hop_length,
|
|
win_length,
|
|
window=window,
|
|
center=center,
|
|
pad_mode=pad_mode)
|
|
|
|
real = x_stft.real()
|
|
imag = x_stft.imag()
|
|
|
|
return paddle.sqrt(paddle.clip(real**2 + imag**2, min=1e-7)).transpose(
|
|
[0, 2, 1])
|
|
|
|
|
|
class SpectralConvergenceLoss(nn.Layer):
|
|
"""Spectral convergence loss module."""
|
|
|
|
def __init__(self):
|
|
"""Initilize spectral convergence loss module."""
|
|
super().__init__()
|
|
|
|
def forward(self, x_mag, y_mag):
|
|
"""Calculate forward propagation.
|
|
Args:
|
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
|
Returns:
|
|
Tensor: Spectral convergence loss value.
|
|
"""
|
|
return paddle.norm(
|
|
y_mag - x_mag, p="fro") / paddle.clip(
|
|
paddle.norm(y_mag, p="fro"), min=1e-10)
|
|
|
|
|
|
class LogSTFTMagnitudeLoss(nn.Layer):
|
|
"""Log STFT magnitude loss module."""
|
|
|
|
def __init__(self, epsilon=1e-7):
|
|
"""Initilize los STFT magnitude loss module."""
|
|
super().__init__()
|
|
self.epsilon = epsilon
|
|
|
|
def forward(self, x_mag, y_mag):
|
|
"""Calculate forward propagation.
|
|
Args:
|
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
|
Returns:
|
|
Tensor: Log STFT magnitude loss value.
|
|
"""
|
|
return F.l1_loss(
|
|
paddle.log(paddle.clip(y_mag, min=self.epsilon)),
|
|
paddle.log(paddle.clip(x_mag, min=self.epsilon)))
|
|
|
|
|
|
class STFTLoss(nn.Layer):
|
|
"""STFT loss module."""
|
|
|
|
def __init__(self,
|
|
fft_size=1024,
|
|
shift_size=120,
|
|
win_length=600,
|
|
window="hann"):
|
|
"""Initialize STFT loss module."""
|
|
super().__init__()
|
|
self.fft_size = fft_size
|
|
self.shift_size = shift_size
|
|
self.win_length = win_length
|
|
self.window = window
|
|
self.spectral_convergence_loss = SpectralConvergenceLoss()
|
|
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
|
|
|
def forward(self, x, y):
|
|
"""Calculate forward propagation.
|
|
Args:
|
|
x (Tensor): Predicted signal (B, T).
|
|
y (Tensor): Groundtruth signal (B, T).
|
|
Returns:
|
|
Tensor: Spectral convergence loss value.
|
|
Tensor: Log STFT magnitude loss value.
|
|
"""
|
|
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length,
|
|
self.window)
|
|
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length,
|
|
self.window)
|
|
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
|
|
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
|
|
|
return sc_loss, mag_loss
|
|
|
|
|
|
class MultiResolutionSTFTLoss(nn.Layer):
|
|
"""Multi resolution STFT loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
fft_sizes=[1024, 2048, 512],
|
|
hop_sizes=[120, 240, 50],
|
|
win_lengths=[600, 1200, 240],
|
|
window="hann", ):
|
|
"""Initialize Multi resolution STFT loss module.
|
|
Args:
|
|
fft_sizes (list): List of FFT sizes.
|
|
hop_sizes (list): List of hop sizes.
|
|
win_lengths (list): List of window lengths.
|
|
window (str): Window function type.
|
|
"""
|
|
super().__init__()
|
|
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
|
self.stft_losses = nn.LayerList()
|
|
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
|
self.stft_losses.append(STFTLoss(fs, ss, wl, window))
|
|
|
|
def forward(self, x, y):
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
x (Tensor): Predicted signal (B, T) or (B, #subband, T).
|
|
y (Tensor): Groundtruth signal (B, T) or (B, #subband, T).
|
|
Returns:
|
|
Tensor: Multi resolution spectral convergence loss value.
|
|
Tensor: Multi resolution log STFT magnitude loss value.
|
|
"""
|
|
if len(x.shape) == 3:
|
|
# (B, C, T) -> (B x C, T)
|
|
x = x.reshape([-1, x.shape[2]])
|
|
# (B, C, T) -> (B x C, T)
|
|
y = y.reshape([-1, y.shape[2]])
|
|
sc_loss = 0.0
|
|
mag_loss = 0.0
|
|
for f in self.stft_losses:
|
|
sc_l, mag_l = f(x, y)
|
|
sc_loss += sc_l
|
|
mag_loss += mag_l
|
|
sc_loss /= len(self.stft_losses)
|
|
mag_loss /= len(self.stft_losses)
|
|
|
|
return sc_loss, mag_loss
|
|
|
|
|
|
class GeneratorAdversarialLoss(nn.Layer):
|
|
"""Generator adversarial loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_discriminators=True,
|
|
loss_type="mse", ):
|
|
"""Initialize GeneratorAversarialLoss module."""
|
|
super().__init__()
|
|
self.average_by_discriminators = average_by_discriminators
|
|
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
|
|
if loss_type == "mse":
|
|
self.criterion = self._mse_loss
|
|
else:
|
|
self.criterion = self._hinge_loss
|
|
|
|
def forward(self, outputs):
|
|
"""Calcualate generator adversarial loss.
|
|
Args:
|
|
outputs (Tensor or List): Discriminator outputs or list of discriminator outputs.
|
|
Returns:
|
|
Tensor: Generator adversarial loss value.
|
|
"""
|
|
if isinstance(outputs, (tuple, list)):
|
|
adv_loss = 0.0
|
|
for i, outputs_ in enumerate(outputs):
|
|
if isinstance(outputs_, (tuple, list)):
|
|
# case including feature maps
|
|
outputs_ = outputs_[-1]
|
|
adv_loss += self.criterion(outputs_)
|
|
if self.average_by_discriminators:
|
|
adv_loss /= i + 1
|
|
else:
|
|
adv_loss = self.criterion(outputs)
|
|
|
|
return adv_loss
|
|
|
|
def _mse_loss(self, x):
|
|
return F.mse_loss(x, paddle.ones_like(x))
|
|
|
|
def _hinge_loss(self, x):
|
|
return -x.mean()
|
|
|
|
|
|
class DiscriminatorAdversarialLoss(nn.Layer):
|
|
"""Discriminator adversarial loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_discriminators=True,
|
|
loss_type="mse", ):
|
|
"""Initialize DiscriminatorAversarialLoss module."""
|
|
super().__init__()
|
|
self.average_by_discriminators = average_by_discriminators
|
|
assert loss_type in ["mse"], f"{loss_type} is not supported."
|
|
if loss_type == "mse":
|
|
self.fake_criterion = self._mse_fake_loss
|
|
self.real_criterion = self._mse_real_loss
|
|
|
|
def forward(self, outputs_hat, outputs):
|
|
"""Calcualate discriminator adversarial loss.
|
|
|
|
Args:
|
|
outputs_hat (Tensor or list): Discriminator outputs or list of
|
|
discriminator outputs calculated from generator outputs.
|
|
outputs (Tensor or list): Discriminator outputs or list of
|
|
discriminator outputs calculated from groundtruth.
|
|
Returns:
|
|
Tensor: Discriminator real loss value.
|
|
Tensor: Discriminator fake loss value.
|
|
"""
|
|
if isinstance(outputs, (tuple, list)):
|
|
real_loss = 0.0
|
|
fake_loss = 0.0
|
|
for i, (outputs_hat_,
|
|
outputs_) in enumerate(zip(outputs_hat, outputs)):
|
|
if isinstance(outputs_hat_, (tuple, list)):
|
|
# case including feature maps
|
|
outputs_hat_ = outputs_hat_[-1]
|
|
outputs_ = outputs_[-1]
|
|
real_loss += self.real_criterion(outputs_)
|
|
fake_loss += self.fake_criterion(outputs_hat_)
|
|
if self.average_by_discriminators:
|
|
fake_loss /= i + 1
|
|
real_loss /= i + 1
|
|
else:
|
|
real_loss = self.real_criterion(outputs)
|
|
fake_loss = self.fake_criterion(outputs_hat)
|
|
|
|
return real_loss, fake_loss
|
|
|
|
def _mse_real_loss(self, x):
|
|
return F.mse_loss(x, paddle.ones_like(x))
|
|
|
|
def _mse_fake_loss(self, x):
|
|
return F.mse_loss(x, paddle.zeros_like(x))
|
|
|
|
|
|
# Losses for SpeedySpeech
|
|
# Structural Similarity Index Measure (SSIM)
|
|
def gaussian(window_size, sigma):
|
|
gauss = paddle.to_tensor([
|
|
math.exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
|
|
for x in range(window_size)
|
|
])
|
|
return gauss / gauss.sum()
|
|
|
|
|
|
def create_window(window_size, channel):
|
|
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
|
_2D_window = paddle.matmul(_1D_window, paddle.transpose(
|
|
_1D_window, [1, 0])).unsqueeze([0, 1])
|
|
window = paddle.expand(_2D_window, [channel, 1, window_size, window_size])
|
|
return window
|
|
|
|
|
|
def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
|
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
|
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
|
|
|
mu1_sq = mu1.pow(2)
|
|
mu2_sq = mu2.pow(2)
|
|
mu1_mu2 = mu1 * mu2
|
|
|
|
sigma1_sq = F.conv2d(
|
|
img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
|
sigma2_sq = F.conv2d(
|
|
img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
|
sigma12 = F.conv2d(
|
|
img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
|
|
|
C1 = 0.01**2
|
|
C2 = 0.03**2
|
|
|
|
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) \
|
|
/ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
|
|
|
if size_average:
|
|
return ssim_map.mean()
|
|
else:
|
|
return ssim_map.mean(1).mean(1).mean(1)
|
|
|
|
|
|
def ssim(img1, img2, window_size=11, size_average=True):
|
|
(_, channel, _, _) = img1.shape
|
|
window = create_window(window_size, channel)
|
|
return _ssim(img1, img2, window, window_size, channel, size_average)
|
|
|
|
|
|
def weighted_mean(input, weight):
|
|
"""Weighted mean. It can also be used as masked mean.
|
|
|
|
Args:
|
|
input(Tensor): The input tensor.
|
|
weight(Tensor): The weight tensor with broadcastable shape with the input.
|
|
|
|
Returns:
|
|
Tensor: Weighted mean tensor with the same dtype as input. shape=(1,)
|
|
|
|
"""
|
|
weight = paddle.cast(weight, input.dtype)
|
|
# paddle.Tensor.size is different with torch.size() and has been overrided in s2t.__init__
|
|
broadcast_ratio = input.numel() / weight.numel()
|
|
return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_ratio)
|
|
|
|
|
|
def masked_l1_loss(prediction, target, mask):
|
|
"""Compute maksed L1 loss.
|
|
|
|
Args:
|
|
prediction(Tensor): The prediction.
|
|
target(Tensor): The target. The shape should be broadcastable to ``prediction``.
|
|
mask(Tensor): The mask. The shape should be broadcatable to the broadcasted shape of
|
|
``prediction`` and ``target``.
|
|
|
|
Returns:
|
|
Tensor: The masked L1 loss. shape=(1,)
|
|
|
|
"""
|
|
abs_error = F.l1_loss(prediction, target, reduction='none')
|
|
loss = weighted_mean(abs_error, mask)
|
|
return loss
|
|
|
|
|
|
class MelSpectrogram(nn.Layer):
|
|
"""Calculate Mel-spectrogram."""
|
|
|
|
def __init__(
|
|
self,
|
|
fs=22050,
|
|
fft_size=1024,
|
|
hop_size=256,
|
|
win_length=None,
|
|
window="hann",
|
|
num_mels=80,
|
|
fmin=80,
|
|
fmax=7600,
|
|
center=True,
|
|
normalized=False,
|
|
onesided=True,
|
|
eps=1e-10,
|
|
log_base=10.0, ):
|
|
"""Initialize MelSpectrogram module."""
|
|
super().__init__()
|
|
self.fft_size = fft_size
|
|
if win_length is None:
|
|
self.win_length = fft_size
|
|
else:
|
|
self.win_length = win_length
|
|
self.hop_size = hop_size
|
|
self.center = center
|
|
self.normalized = normalized
|
|
self.onesided = onesided
|
|
|
|
if window is not None and not hasattr(signal.windows, f"{window}"):
|
|
raise ValueError(f"{window} window is not implemented")
|
|
self.window = window
|
|
self.eps = eps
|
|
|
|
fmin = 0 if fmin is None else fmin
|
|
fmax = fs / 2 if fmax is None else fmax
|
|
melmat = librosa.filters.mel(
|
|
sr=fs,
|
|
n_fft=fft_size,
|
|
n_mels=num_mels,
|
|
fmin=fmin,
|
|
fmax=fmax, )
|
|
|
|
self.melmat = paddle.to_tensor(melmat.T)
|
|
self.stft_params = {
|
|
"n_fft": self.fft_size,
|
|
"win_length": self.win_length,
|
|
"hop_length": self.hop_size,
|
|
"center": self.center,
|
|
"normalized": self.normalized,
|
|
"onesided": self.onesided,
|
|
}
|
|
|
|
self.log_base = log_base
|
|
if self.log_base is None:
|
|
self.log = paddle.log
|
|
elif self.log_base == 2.0:
|
|
self.log = paddle.log2
|
|
elif self.log_base == 10.0:
|
|
self.log = paddle.log10
|
|
else:
|
|
raise ValueError(f"log_base: {log_base} is not supported.")
|
|
|
|
def forward(self, x):
|
|
"""Calculate Mel-spectrogram.
|
|
Args:
|
|
|
|
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
|
|
Returns:
|
|
Tensor: Mel-spectrogram (B, #mels, #frames).
|
|
"""
|
|
if len(x.shape) == 3:
|
|
# (B, C, T) -> (B*C, T)
|
|
x = x.reshape([-1, paddle.shape(x)[2]])
|
|
|
|
if self.window is not None:
|
|
# calculate window
|
|
window = signal.get_window(
|
|
self.window, self.win_length, fftbins=True)
|
|
window = paddle.to_tensor(window, dtype=x.dtype)
|
|
else:
|
|
window = None
|
|
|
|
x_stft = paddle.signal.stft(x, window=window, **self.stft_params)
|
|
real = x_stft.real()
|
|
imag = x_stft.imag()
|
|
# (B, #freqs, #frames) -> (B, $frames, #freqs)
|
|
real = real.transpose([0, 2, 1])
|
|
imag = imag.transpose([0, 2, 1])
|
|
x_power = real**2 + imag**2
|
|
x_amp = paddle.sqrt(paddle.clip(x_power, min=self.eps))
|
|
x_mel = paddle.matmul(x_amp, self.melmat)
|
|
x_mel = paddle.clip(x_mel, min=self.eps)
|
|
|
|
return self.log(x_mel).transpose([0, 2, 1])
|
|
|
|
|
|
class MelSpectrogramLoss(nn.Layer):
|
|
"""Mel-spectrogram loss."""
|
|
|
|
def __init__(
|
|
self,
|
|
fs=22050,
|
|
fft_size=1024,
|
|
hop_size=256,
|
|
win_length=None,
|
|
window="hann",
|
|
num_mels=80,
|
|
fmin=80,
|
|
fmax=7600,
|
|
center=True,
|
|
normalized=False,
|
|
onesided=True,
|
|
eps=1e-10,
|
|
log_base=10.0, ):
|
|
"""Initialize Mel-spectrogram loss."""
|
|
super().__init__()
|
|
self.mel_spectrogram = MelSpectrogram(
|
|
fs=fs,
|
|
fft_size=fft_size,
|
|
hop_size=hop_size,
|
|
win_length=win_length,
|
|
window=window,
|
|
num_mels=num_mels,
|
|
fmin=fmin,
|
|
fmax=fmax,
|
|
center=center,
|
|
normalized=normalized,
|
|
onesided=onesided,
|
|
eps=eps,
|
|
log_base=log_base, )
|
|
|
|
def forward(self, y_hat, y):
|
|
"""Calculate Mel-spectrogram loss.
|
|
Args:
|
|
y_hat(Tensor): Generated single tensor (B, 1, T).
|
|
y(Tensor): Groundtruth single tensor (B, 1, T).
|
|
|
|
Returns:
|
|
Tensor: Mel-spectrogram loss value.
|
|
"""
|
|
mel_hat = self.mel_spectrogram(y_hat)
|
|
mel = self.mel_spectrogram(y)
|
|
mel_loss = F.l1_loss(mel_hat, mel)
|
|
|
|
return mel_loss
|
|
|
|
|
|
class FeatureMatchLoss(nn.Layer):
|
|
"""Feature matching loss module."""
|
|
|
|
def __init__(
|
|
self,
|
|
average_by_layers=True,
|
|
average_by_discriminators=True,
|
|
include_final_outputs=False, ):
|
|
"""Initialize FeatureMatchLoss module."""
|
|
super().__init__()
|
|
self.average_by_layers = average_by_layers
|
|
self.average_by_discriminators = average_by_discriminators
|
|
self.include_final_outputs = include_final_outputs
|
|
|
|
def forward(self, feats_hat, feats):
|
|
"""Calcualate feature matching loss.
|
|
|
|
Args:
|
|
feats_hat(list): List of list of discriminator outputs
|
|
calcuated from generater outputs.
|
|
feats(list): List of list of discriminator outputs
|
|
|
|
Returns:
|
|
Tensor: Feature matching loss value.
|
|
|
|
"""
|
|
feat_match_loss = 0.0
|
|
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
|
feat_match_loss_ = 0.0
|
|
if not self.include_final_outputs:
|
|
feats_hat_ = feats_hat_[:-1]
|
|
feats_ = feats_[:-1]
|
|
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
|
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
|
if self.average_by_layers:
|
|
feat_match_loss_ /= j + 1
|
|
feat_match_loss += feat_match_loss_
|
|
if self.average_by_discriminators:
|
|
feat_match_loss /= i + 1
|
|
|
|
return feat_match_loss
|
|
|
|
|
|
# loss for VITS
|
|
class KLDivergenceLoss(nn.Layer):
|
|
"""KL divergence loss."""
|
|
|
|
def forward(
|
|
self,
|
|
z_p: paddle.Tensor,
|
|
logs_q: paddle.Tensor,
|
|
m_p: paddle.Tensor,
|
|
logs_p: paddle.Tensor,
|
|
z_mask: paddle.Tensor, ) -> paddle.Tensor:
|
|
"""Calculate KL divergence loss.
|
|
|
|
Args:
|
|
z_p (Tensor): Flow hidden representation (B, H, T_feats).
|
|
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
|
|
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
|
|
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
|
|
z_mask (Tensor): Mask tensor (B, 1, T_feats).
|
|
|
|
Returns:
|
|
Tensor: KL divergence loss.
|
|
|
|
"""
|
|
z_p = paddle.cast(z_p, 'float32')
|
|
logs_q = paddle.cast(logs_q, 'float32')
|
|
m_p = paddle.cast(m_p, 'float32')
|
|
logs_p = paddle.cast(logs_p, 'float32')
|
|
z_mask = paddle.cast(z_mask, 'float32')
|
|
kl = logs_p - logs_q - 0.5
|
|
kl += 0.5 * ((z_p - m_p)**2) * paddle.exp(-2.0 * logs_p)
|
|
kl = paddle.sum(kl * z_mask)
|
|
loss = kl / paddle.sum(z_mask)
|
|
|
|
return loss
|
|
|
|
|
|
# loss for ERNIE SAT
|
|
class MLMLoss(nn.Layer):
|
|
def __init__(self,
|
|
lsm_weight: float=0.1,
|
|
ignore_id: int=-1,
|
|
text_masking: bool=False):
|
|
super().__init__()
|
|
if text_masking:
|
|
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
|
|
if lsm_weight > 50:
|
|
self.l1_loss_func = nn.MSELoss()
|
|
else:
|
|
self.l1_loss_func = nn.L1Loss(reduction='none')
|
|
self.text_masking = text_masking
|
|
|
|
def forward(self,
|
|
speech: paddle.Tensor,
|
|
before_outs: paddle.Tensor,
|
|
after_outs: paddle.Tensor,
|
|
masked_pos: paddle.Tensor,
|
|
text: paddle.Tensor=None,
|
|
text_outs: paddle.Tensor=None,
|
|
text_masked_pos: paddle.Tensor=None):
|
|
|
|
xs_pad = speech
|
|
mlm_loss_pos = masked_pos > 0
|
|
loss = paddle.sum(
|
|
self.l1_loss_func(
|
|
paddle.reshape(before_outs, (-1, self.odim)),
|
|
paddle.reshape(xs_pad, (-1, self.odim))),
|
|
axis=-1)
|
|
if after_outs is not None:
|
|
loss += paddle.sum(
|
|
self.l1_loss_func(
|
|
paddle.reshape(after_outs, (-1, self.odim)),
|
|
paddle.reshape(xs_pad, (-1, self.odim))),
|
|
axis=-1)
|
|
loss_mlm = paddle.sum((loss * paddle.reshape(
|
|
mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10)
|
|
|
|
if self.text_masking:
|
|
loss_text = paddle.sum((self.text_mlm_loss(
|
|
paddle.reshape(text_outs, (-1, self.vocab_size)),
|
|
paddle.reshape(text, (-1))) * paddle.reshape(
|
|
text_masked_pos,
|
|
(-1)))) / paddle.sum((text_masked_pos) + 1e-10)
|
|
|
|
return loss_mlm, loss_text
|
|
|
|
return loss_mlm
|