You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/modules/losses.py

1608 lines
55 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import librosa
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
from scipy import signal
from scipy.stats import betabinom
from typeguard import typechecked
from paddlespeech.audiotools.core.audio_signal import AudioSignal
from paddlespeech.audiotools.core.audio_signal import STFTParams
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.predictor.duration_predictor import (
DurationPredictorLoss, # noqa: H301
)
# Losses for WaveRNN
def log_sum_exp(x):
""" numerically stable log_sum_exp implementation that prevents overflow """
# TF ordering
axis = len(x.shape) - 1
m = paddle.max(x, axis=axis)
m2 = paddle.max(x, axis=axis, keepdim=True)
return m + paddle.log(paddle.sum(paddle.exp(x - m2), axis=axis))
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
def discretized_mix_logistic_loss(y_hat,
y,
num_classes=65536,
log_scale_min=None,
reduce=True):
if log_scale_min is None:
log_scale_min = float(np.log(1e-14))
y_hat = y_hat.transpose([0, 2, 1])
assert y_hat.dim() == 3
assert y_hat.shape[1] % 3 == 0
nr_mix = y_hat.shape[1] // 3
# (B x T x C)
y_hat = y_hat.transpose([0, 2, 1])
# unpack parameters. (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix:2 * nr_mix]
log_scales = paddle.clip(
y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
# B x T x 1 -> B x T x num_mixtures
y = y.expand_as(means)
centered_y = paddle.cast(y, dtype=paddle.get_default_dtype()) - means
inv_stdv = paddle.exp(-log_scales)
plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
cdf_plus = F.sigmoid(plus_in)
min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
cdf_min = F.sigmoid(min_in)
# log probability for edge case of 0 (before scaling)
# equivalent: torch.log(F.sigmoid(plus_in))
# softplus: log(1+ e^{-x})
log_cdf_plus = plus_in - F.softplus(plus_in)
# log probability for edge case of 255 (before scaling)
# equivalent: (1 - F.sigmoid(min_in)).log()
log_one_minus_cdf_min = -F.softplus(min_in)
# probability for all other cases
cdf_delta = cdf_plus - cdf_min
mid_in = inv_stdv * centered_y
# log probability in the center of the bin, to be used in extreme cases
# (not actually used in our code)
log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
# for num_classes=65536 case? 1e-7? not sure..
inner_inner_cond = cdf_delta > 1e-5
inner_inner_cond = paddle.cast(
inner_inner_cond, dtype=paddle.get_default_dtype())
# inner_inner_out = inner_inner_cond * \
# paddle.log(paddle.clip(cdf_delta, min=1e-12)) + \
# (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
inner_inner_out = inner_inner_cond * paddle.log(
paddle.clip(cdf_delta, min=1e-12)) + (1. - inner_inner_cond) * (
log_pdf_mid - np.log((num_classes - 1) / 2))
inner_cond = y > 0.999
inner_cond = paddle.cast(inner_cond, dtype=paddle.get_default_dtype())
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond
) * inner_inner_out
cond = y < -0.999
cond = paddle.cast(cond, dtype=paddle.get_default_dtype())
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
log_probs = log_probs + F.log_softmax(logit_probs, -1)
if reduce:
return -paddle.mean(log_sum_exp(log_probs))
else:
return -log_sum_exp(log_probs).unsqueeze(-1)
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
"""
Sample from discretized mixture of logistic distributions
Args:
y(Tensor): (B, C, T)
log_scale_min(float, optional): (Default value = None)
Returns:
Tensor: sample in range of [-1, 1].
"""
if log_scale_min is None:
log_scale_min = float(np.log(1e-14))
assert y.shape[1] % 3 == 0
nr_mix = y.shape[1] // 3
# (B, T, C)
y = y.transpose([0, 2, 1])
logit_probs = y[:, :, :nr_mix]
# sample mixture indicator from softmax
temp = paddle.uniform(
logit_probs.shape, dtype=logit_probs.dtype, min=1e-5, max=1.0 - 1e-5)
temp = logit_probs - paddle.log(-paddle.log(temp))
argmax = paddle.argmax(temp, axis=-1)
# (B, T) -> (B, T, nr_mix)
one_hot = F.one_hot(argmax, nr_mix)
one_hot = paddle.cast(one_hot, dtype=paddle.get_default_dtype())
# select logistic parameters
means = paddle.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1)
log_scales = paddle.clip(
paddle.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1),
min=log_scale_min)
# sample from logistic & clip to interval
# we don't actually round to the nearest 8bit value when sampling
u = paddle.uniform(means.shape, min=1e-5, max=1.0 - 1e-5)
x = means + paddle.exp(log_scales) * (paddle.log(u) - paddle.log(1. - u))
x = paddle.clip(x, min=-1., max=-1.)
return x
# Loss for Tacotron2
class GuidedAttentionLoss(nn.Layer):
"""Guided attention loss function module.
This module calculates the guided attention loss described
in `Efficiently Trainable Text-to-Speech System Based
on Deep Convolutional Networks with Guided Attention`_,
which forces the attention to be diagonal.
.. _`Efficiently Trainable Text-to-Speech System
Based on Deep Convolutional Networks with Guided Attention`:
https://arxiv.org/abs/1710.08969
"""
def __init__(self, sigma=0.4, alpha=1.0, reset_always=True):
"""Initialize guided attention loss module.
Args:
sigma (float, optional): Standard deviation to control how close attention to a diagonal.
alpha (float, optional): Scaling coefficient (lambda).
reset_always (bool, optional): Whether to always reset masks.
"""
super().__init__()
self.sigma = sigma
self.alpha = alpha
self.reset_always = reset_always
self.guided_attn_masks = None
self.masks = None
def _reset_masks(self):
self.guided_attn_masks = None
self.masks = None
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Args:
att_ws(Tensor): Batch of attention weights (B, T_max_out, T_max_in).
ilens(Tensor(int64)): Batch of input lenghts (B,).
olens(Tensor(int64)): Batch of output lenghts (B,).
Returns:
Tensor: Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = self._make_guided_attention_masks(ilens,
olens)
if self.masks is None:
self.masks = self._make_masks(ilens, olens)
losses = self.guided_attn_masks * att_ws
loss = paddle.mean(
losses.masked_select(self.masks.broadcast_to(losses.shape)))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
def _make_guided_attention_masks(self, ilens, olens):
n_batches = len(ilens)
max_ilen = max(ilens)
max_olen = max(olens)
guided_attn_masks = paddle.zeros((n_batches, max_olen, max_ilen))
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
guided_attn_masks[idx, :olen, :
ilen] = self._make_guided_attention_mask(
ilen, olen, self.sigma)
return guided_attn_masks
@staticmethod
def _make_guided_attention_mask(ilen, olen, sigma):
"""Make guided attention mask.
Examples
----------
>>> guided_attn_mask =_make_guided_attention(5, 5, 0.4)
>>> guided_attn_mask.shape
[5, 5]
>>> guided_attn_mask
tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647],
[0.1175, 0.0000, 0.1175, 0.3935, 0.6753],
[0.3935, 0.1175, 0.0000, 0.1175, 0.3935],
[0.6753, 0.3935, 0.1175, 0.0000, 0.1175],
[0.8647, 0.6753, 0.3935, 0.1175, 0.0000]])
>>> guided_attn_mask =_make_guided_attention(3, 6, 0.4)
>>> guided_attn_mask.shape
[6, 3]
>>> guided_attn_mask
tensor([[0.0000, 0.2934, 0.7506],
[0.0831, 0.0831, 0.5422],
[0.2934, 0.0000, 0.2934],
[0.5422, 0.0831, 0.0831],
[0.7506, 0.2934, 0.0000],
[0.8858, 0.5422, 0.0831]])
"""
grid_x, grid_y = paddle.meshgrid(
paddle.arange(olen), paddle.arange(ilen))
grid_x = grid_x.cast(dtype=paddle.float32)
grid_y = grid_y.cast(dtype=paddle.float32)
return 1.0 - paddle.exp(-(
(grid_y / ilen - grid_x / olen)**2) / (2 * (sigma**2)))
@staticmethod
def _make_masks(ilens, olens):
"""Make masks indicating non-padded part.
Args:
ilens(Tensor(int64) or List):
Batch of lengths (B,).
olens(Tensor(int64) or List):
Batch of lengths (B,).
Returns:
Tensor: Mask tensor indicating non-padded part.
Examples:
>>> ilens, olens = [5, 2], [8, 5]
>>> _make_mask(ilens, olens)
tensor([[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]],
[[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]], dtype=paddle.uint8)
"""
# (B, T_in)
in_masks = make_non_pad_mask(ilens)
# (B, T_out)
out_masks = make_non_pad_mask(olens)
# (B, T_out, T_in)
return paddle.logical_and(
out_masks.unsqueeze(-1), in_masks.unsqueeze(-2))
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
"""Guided attention loss function module for multi head attention.
Args:
sigma (float, optional): Standard deviation to controlGuidedAttentionLoss
how close attention to a diagonal.
alpha (float, optional): Scaling coefficient (lambda).
reset_always (bool, optional): Whether to always reset masks.
"""
def forward(self, att_ws, ilens, olens):
"""Calculate forward propagation.
Args:
att_ws(Tensor):
Batch of multi head attention weights (B, H, T_max_out, T_max_in).
ilens(Tensor):
Batch of input lenghts (B,).
olens(Tensor):
Batch of output lenghts (B,).
Returns:
Tensor: Guided attention loss value.
"""
if self.guided_attn_masks is None:
self.guided_attn_masks = (
self._make_guided_attention_masks(ilens, olens).unsqueeze(1))
if self.masks is None:
self.masks = self._make_masks(ilens, olens).unsqueeze(1)
losses = self.guided_attn_masks * att_ws
loss = paddle.mean(
losses.masked_select(self.masks.broadcast_to(losses.shape)))
if self.reset_always:
self._reset_masks()
return self.alpha * loss
class Tacotron2Loss(nn.Layer):
"""Loss function module for Tacotron2."""
def __init__(self,
use_masking=True,
use_weighted_masking=False,
bce_pos_weight=20.0):
"""Initialize Tactoron2 loss module.
Args:
use_masking (bool):
Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool):
Whether to apply weighted masking in loss calculation.
bce_pos_weight (float):
Weight of positive sample of stop token.
"""
super().__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = nn.L1Loss(reduction=reduction)
self.mse_criterion = nn.MSELoss(reduction=reduction)
self.bce_criterion = nn.BCEWithLogitsLoss(
reduction=reduction, pos_weight=paddle.to_tensor(bce_pos_weight))
def forward(self, after_outs, before_outs, logits, ys, stop_labels, olens):
"""Calculate forward propagation.
Args:
after_outs(Tensor):
Batch of outputs after postnets (B, Lmax, odim).
before_outs(Tensor):
Batch of outputs before postnets (B, Lmax, odim).
logits(Tensor):
Batch of stop logits (B, Lmax).
ys(Tensor):
Batch of padded target features (B, Lmax, odim).
stop_labels(Tensor(int64)):
Batch of the sequences of stop token labels (B, Lmax).
olens(Tensor(int64)):
Returns:
Tensor:
L1 loss value.
Tensor:
Mean square error loss value.
Tensor:
Binary cross entropy loss value.
"""
# make mask and apply it
if self.use_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1)
ys = ys.masked_select(masks.broadcast_to(ys.shape))
after_outs = after_outs.masked_select(
masks.broadcast_to(after_outs.shape))
before_outs = before_outs.masked_select(
masks.broadcast_to(before_outs.shape))
stop_labels = stop_labels.masked_select(
masks[:, :, 0].broadcast_to(stop_labels.shape))
logits = logits.masked_select(
masks[:, :, 0].broadcast_to(logits.shape))
# calculate loss
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(
before_outs, ys)
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(
before_outs, ys)
bce_loss = self.bce_criterion(logits, stop_labels)
# make weighted mask and apply it
if self.use_weighted_masking:
masks = make_non_pad_mask(olens).unsqueeze(-1)
weights = masks.float() / masks.sum(axis=1, keepdim=True).float()
out_weights = weights.divide(
paddle.shape(ys)[0] * paddle.shape(ys)[2])
logit_weights = weights.divide(paddle.shape(ys)[0])
# apply weight
l1_loss = l1_loss.multiply(out_weights)
l1_loss = l1_loss.masked_select(masks.broadcast_to(l1_loss)).sum()
mse_loss = mse_loss.multiply(out_weights)
mse_loss = mse_loss.masked_select(
masks.broadcast_to(mse_loss)).sum()
bce_loss = bce_loss.multiply(logit_weights.squeeze(-1))
bce_loss = bce_loss.masked_select(
masks.squeeze(-1).broadcast_to(bce_loss)).sum()
return l1_loss, mse_loss, bce_loss
# Losses for GAN Vocoder
def stft(x,
fft_size,
hop_length=None,
win_length=None,
window='hann',
center=True,
pad_mode='reflect'):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x(Tensor):
Input signal tensor (B, T).
fft_size(int):
FFT size.
hop_size(int):
Hop size.
win_length(int, optional):
window (str, optional):
(Default value = None)
window(str, optional):
Name of window function, see `scipy.signal.get_window` for more details. Defaults to "hann".
center(bool, optional, optional): center (bool, optional):
Whether to pad `x` to make that the
:math:`t \times hop\\_length` at the center of :math:`t`-th frame. Default: `True`.
pad_mode(str, optional, optional):
(Default value = 'reflect')
hop_length:
(Default value = None)
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
# calculate window
window = signal.get_window(window, win_length, fftbins=True)
window = paddle.to_tensor(window, dtype=x.dtype)
x_stft = paddle.signal.stft(
x,
fft_size,
hop_length,
win_length,
window=window,
center=center,
pad_mode=pad_mode)
real = x_stft.real()
imag = x_stft.imag()
return paddle.sqrt(paddle.clip(real**2 + imag**2, min=1e-7)).transpose(
[0, 2, 1])
class SpectralConvergenceLoss(nn.Layer):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super().__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor):
Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor):
Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return paddle.norm(
y_mag - x_mag, p="fro") / paddle.clip(
paddle.norm(y_mag, p="fro"), min=1e-10)
class LogSTFTMagnitudeLoss(nn.Layer):
"""Log STFT magnitude loss module."""
def __init__(self, epsilon=1e-7):
"""Initilize los STFT magnitude loss module."""
super().__init__()
self.epsilon = epsilon
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor):
Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor):
Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(
paddle.log(paddle.clip(y_mag, min=self.epsilon)),
paddle.log(paddle.clip(x_mag, min=self.epsilon)))
class STFTLoss(nn.Layer):
"""STFT loss module."""
def __init__(self,
fft_size=1024,
shift_size=120,
win_length=600,
window="hann"):
"""Initialize STFT loss module."""
super().__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = window
self.spectral_convergence_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor):
Predicted signal (B, T).
y (Tensor):
Groundtruth signal (B, T).
Returns:
Tensor:
Spectral convergence loss value.
Tensor:
Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length,
self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length,
self.window)
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(nn.Layer):
"""Multi resolution STFT loss module."""
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann", ):
"""Initialize Multi resolution STFT loss module.
Args:
fft_sizes (list):
List of FFT sizes.
hop_sizes (list):
List of hop sizes.
win_lengths (list):
List of window lengths.
window (str):
Window function type.
"""
super().__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = nn.LayerList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses.append(STFTLoss(fs, ss, wl, window))
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor):
Predicted signal (B, T) or (B, #subband, T).
y (Tensor):
Groundtruth signal (B, T) or (B, #subband, T).
Returns:
Tensor:
Multi resolution spectral convergence loss value.
Tensor:
Multi resolution log STFT magnitude loss value.
"""
if len(x.shape) == 3:
# (B, C, T) -> (B x C, T)
x = x.reshape([-1, x.shape[2]])
# (B, C, T) -> (B x C, T)
y = y.reshape([-1, y.shape[2]])
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss
class GeneratorAdversarialLoss(nn.Layer):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type="mse", ):
"""Initialize GeneratorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse", "hinge"], f"{loss_type} is not supported."
if loss_type == "mse":
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(self, outputs):
"""Calcualate generator adversarial loss.
Args:
outputs (Tensor or List):
Discriminator outputs or list of discriminator outputs.
Returns:
Tensor:
Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
if isinstance(outputs_, (tuple, list)):
# case including feature maps
outputs_ = outputs_[-1]
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, paddle.ones_like(x))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(nn.Layer):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type="mse", ):
"""Initialize DiscriminatorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ["mse"], f"{loss_type} is not supported."
if loss_type == "mse":
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
def forward(self, outputs_hat, outputs):
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Tensor or list):
Discriminator outputs or list of discriminator outputs calculated from generator outputs.
outputs (Tensor or list):
Discriminator outputs or list of discriminator outputs calculated from groundtruth.
Returns:
Tensor:
Discriminator real loss value.
Tensor:
Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_,
outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x):
return F.mse_loss(x, paddle.ones_like(x))
def _mse_fake_loss(self, x):
return F.mse_loss(x, paddle.zeros_like(x))
# Losses for SpeedySpeech
# Structural Similarity Index Measure (SSIM)
def gaussian(window_size, sigma):
gauss = paddle.to_tensor([
math.exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
for x in range(window_size)
])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = paddle.matmul(_1D_window, paddle.transpose(
_1D_window, [1, 0])).unsqueeze([0, 1])
window = paddle.expand(_2D_window, [channel, 1, window_size, window_size])
return window
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) \
/ ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.shape
window = create_window(window_size, channel)
return _ssim(img1, img2, window, window_size, channel, size_average)
def weighted_mean(input, weight):
"""Weighted mean. It can also be used as masked mean.
Args:
input(Tensor): The input tensor.
weight(Tensor): The weight tensor with broadcastable shape with the input.
Returns:
Tensor: Weighted mean tensor with the same dtype as input. shape=(1,)
"""
weight = paddle.cast(weight, input.dtype)
# paddle.Tensor.size is different with torch.size() and has been overrided in s2t.__init__
broadcast_ratio = input.numel() / weight.numel()
return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_ratio)
def masked_l1_loss(prediction, target, mask):
"""Compute maksed L1 loss.
Args:
prediction(Tensor):
The prediction.
target(Tensor):
The target. The shape should be broadcastable to ``prediction``.
mask(Tensor):
The mask. The shape should be broadcatable to the broadcasted shape of
``prediction`` and ``target``.
Returns:
Tensor: The masked L1 loss. shape=(1,)
"""
abs_error = F.l1_loss(prediction, target, reduction='none')
loss = weighted_mean(abs_error, mask)
return loss
class MelSpectrogram(nn.Layer):
"""Calculate Mel-spectrogram."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0, ):
"""Initialize MelSpectrogram module."""
super().__init__()
self.fft_size = fft_size
if win_length is None:
self.win_length = fft_size
else:
self.win_length = win_length
self.hop_size = hop_size
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(signal.windows, f"{window}"):
raise ValueError(f"{window} window is not implemented")
self.window = window
self.eps = eps
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
melmat = librosa.filters.mel(
sr=fs,
n_fft=fft_size,
n_mels=num_mels,
fmin=fmin,
fmax=fmax, )
self.melmat = paddle.to_tensor(melmat.T)
self.stft_params = {
"n_fft": self.fft_size,
"win_length": self.win_length,
"hop_length": self.hop_size,
"center": self.center,
"normalized": self.normalized,
"onesided": self.onesided,
}
self.log_base = log_base
if self.log_base is None:
self.log = paddle.log
elif self.log_base == 2.0:
self.log = paddle.log2
elif self.log_base == 10.0:
self.log = paddle.log10
else:
raise ValueError(f"log_base: {log_base} is not supported.")
def forward(self, x):
"""Calculate Mel-spectrogram.
Args:
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
Returns:
Tensor: Mel-spectrogram (B, #mels, #frames).
"""
if len(x.shape) == 3:
# (B, C, T) -> (B*C, T)
x = x.reshape([-1, paddle.shape(x)[2]])
if self.window is not None:
# calculate window
window = signal.get_window(
self.window, self.win_length, fftbins=True)
window = paddle.to_tensor(window, dtype=x.dtype)
else:
window = None
x_stft = paddle.signal.stft(x, window=window, **self.stft_params)
real = x_stft.real()
imag = x_stft.imag()
# (B, #freqs, #frames) -> (B, $frames, #freqs)
real = real.transpose([0, 2, 1])
imag = imag.transpose([0, 2, 1])
x_power = real**2 + imag**2
x_amp = paddle.sqrt(paddle.clip(x_power, min=self.eps))
x_mel = paddle.matmul(x_amp, self.melmat)
x_mel = paddle.clip(x_mel, min=self.eps)
return self.log(x_mel).transpose([0, 2, 1])
class MelSpectrogramLoss(nn.Layer):
"""Mel-spectrogram loss."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window="hann",
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0, ):
"""Initialize Mel-spectrogram loss."""
super().__init__()
self.mel_spectrogram = MelSpectrogram(
fs=fs,
fft_size=fft_size,
hop_size=hop_size,
win_length=win_length,
window=window,
num_mels=num_mels,
fmin=fmin,
fmax=fmax,
center=center,
normalized=normalized,
onesided=onesided,
eps=eps,
log_base=log_base, )
def forward(self, y_hat, y):
"""Calculate Mel-spectrogram loss.
Args:
y_hat(Tensor):
Generated single tensor (B, 1, T).
y(Tensor):
Groundtruth single tensor (B, 1, T).
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.mel_spectrogram(y_hat)
mel = self.mel_spectrogram(y)
mel_loss = F.l1_loss(mel_hat, mel)
return mel_loss
class FeatureMatchLoss(nn.Layer):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers=True,
average_by_discriminators=True,
include_final_outputs=False, ):
"""Initialize FeatureMatchLoss module."""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
self.include_final_outputs = include_final_outputs
def forward(self, feats_hat, feats):
"""Calcualate feature matching loss.
Args:
feats_hat(list):
List of list of discriminator outputs
calcuated from generater outputs.
feats(list):
List of list of discriminator outputs
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
if not self.include_final_outputs:
feats_hat_ = feats_hat_[:-1]
feats_ = feats_[:-1]
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
# loss for VITS
class KLDivergenceLoss(nn.Layer):
"""KL divergence loss."""
def forward(
self,
z_p: paddle.Tensor,
logs_q: paddle.Tensor,
m_p: paddle.Tensor,
logs_p: paddle.Tensor,
z_mask: paddle.Tensor, ) -> paddle.Tensor:
"""Calculate KL divergence loss.
Args:
z_p (Tensor):
Flow hidden representation (B, H, T_feats).
logs_q (Tensor):
Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor):
Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor):
Expanded text encoder projected scale (B, H, T_feats).
z_mask (Tensor):
Mask tensor (B, 1, T_feats).
Returns:
Tensor: KL divergence loss.
"""
z_p = paddle.cast(z_p, 'float32')
logs_q = paddle.cast(logs_q, 'float32')
m_p = paddle.cast(m_p, 'float32')
logs_p = paddle.cast(logs_p, 'float32')
z_mask = paddle.cast(z_mask, 'float32')
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p)**2) * paddle.exp(-2.0 * logs_p)
kl = paddle.sum(kl * z_mask)
loss = kl / paddle.sum(z_mask)
return loss
# loss for ERNIE SAT
class MLMLoss(nn.Layer):
def __init__(self,
odim: int,
vocab_size: int=0,
lsm_weight: float=0.1,
ignore_id: int=-1,
text_masking: bool=False):
super().__init__()
if text_masking:
self.text_mlm_loss = nn.CrossEntropyLoss(ignore_index=ignore_id)
if lsm_weight > 50:
self.l1_loss_func = nn.MSELoss()
else:
self.l1_loss_func = nn.L1Loss(reduction='none')
self.text_masking = text_masking
self.odim = odim
self.vocab_size = vocab_size
def forward(
self,
speech: paddle.Tensor,
before_outs: paddle.Tensor,
after_outs: paddle.Tensor,
masked_pos: paddle.Tensor,
# for text_loss when text_masking == True
text: paddle.Tensor=None,
text_outs: paddle.Tensor=None,
text_masked_pos: paddle.Tensor=None):
xs_pad = speech
mlm_loss_pos = masked_pos > 0
loss = paddle.sum(
self.l1_loss_func(
paddle.reshape(before_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
if after_outs is not None:
loss += paddle.sum(
self.l1_loss_func(
paddle.reshape(after_outs, (-1, self.odim)),
paddle.reshape(xs_pad, (-1, self.odim))),
axis=-1)
mlm_loss_pos = (mlm_loss_pos).astype(loss.dtype)
mlm_loss = paddle.sum((loss * paddle.reshape(
mlm_loss_pos,
[-1]).astype(loss.dtype))) / paddle.sum((mlm_loss_pos) + 1e-10)
text_mlm_loss = None
if self.text_masking:
assert text is not None
assert text_outs is not None
assert text_masked_pos is not None
text_outs = paddle.reshape(text_outs, [-1, self.vocab_size])
text = paddle.reshape(text, [-1])
text_mlm_loss = self.text_mlm_loss(text_outs, text)
text_masked_pos_reshape = paddle.reshape(text_masked_pos, [-1])
text_mlm_loss = paddle.sum(
text_mlm_loss *
text_masked_pos_reshape) / paddle.sum((text_masked_pos) + 1e-10)
return mlm_loss, text_mlm_loss
class VarianceLoss(nn.Layer):
@typechecked
def __init__(self, use_masking: bool=True,
use_weighted_masking: bool=False):
"""Initialize JETS variance loss module.
Args:
use_masking (bool): Whether to apply masking for padded part in loss
calculation.
use_weighted_masking (bool): Whether to weighted masking in loss
calculation.
"""
super().__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.mse_criterion = nn.MSELoss(reduction=reduction)
self.duration_criterion = DurationPredictorLoss(reduction=reduction)
def forward(
self,
d_outs: paddle.Tensor,
ds: paddle.Tensor,
p_outs: paddle.Tensor,
ps: paddle.Tensor,
e_outs: paddle.Tensor,
es: paddle.Tensor,
ilens: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Calculate forward propagation.
Args:
d_outs (LongTensor): Batch of outputs of duration predictor (B, T_text).
ds (LongTensor): Batch of durations (B, T_text).
p_outs (Tensor): Batch of outputs of pitch predictor (B, T_text, 1).
ps (Tensor): Batch of target token-averaged pitch (B, T_text, 1).
e_outs (Tensor): Batch of outputs of energy predictor (B, T_text, 1).
es (Tensor): Batch of target token-averaged energy (B, T_text, 1).
ilens (LongTensor): Batch of the lengths of each input (B,).
Returns:
Tensor: Duration predictor loss value.
Tensor: Pitch predictor loss value.
Tensor: Energy predictor loss value.
"""
# apply mask to remove padded part
if self.use_masking:
duration_masks = paddle.to_tensor(
make_non_pad_mask(ilens), place=ds.place)
d_outs = d_outs.masked_select(duration_masks)
ds = ds.masked_select(duration_masks)
pitch_masks = paddle.to_tensor(
make_non_pad_mask(ilens).unsqueeze(-1), place=ds.place)
p_outs = p_outs.masked_select(pitch_masks)
e_outs = e_outs.masked_select(pitch_masks)
ps = ps.masked_select(pitch_masks)
es = es.masked_select(pitch_masks)
# calculate loss
duration_loss = self.duration_criterion(d_outs, ds)
pitch_loss = self.mse_criterion(p_outs, ps)
energy_loss = self.mse_criterion(e_outs, es)
# make weighted mask and apply it
if self.use_weighted_masking:
duration_masks = paddle.to_tensor(
make_non_pad_mask(ilens), place=ds.place)
duration_weights = (duration_masks.float() /
duration_masks.sum(dim=1, keepdim=True).float())
duration_weights /= ds.size(0)
# apply weight
duration_loss = (duration_loss.mul(duration_weights).masked_select(
duration_masks).sum())
pitch_masks = duration_masks.unsqueeze(-1)
pitch_weights = duration_weights.unsqueeze(-1)
pitch_loss = pitch_loss.mul(pitch_weights).masked_select(
pitch_masks).sum()
energy_loss = (
energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum())
return duration_loss, pitch_loss, energy_loss
class ForwardSumLoss(nn.Layer):
"""
https://openreview.net/forum?id=0NQwnnwAORi
"""
def __init__(self, cache_prior: bool=True):
"""
Args:
cache_prior (bool): Whether to cache beta-binomial prior
"""
super().__init__()
self.cache_prior = cache_prior
self._cache = {}
def forward(
self,
log_p_attn: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor,
blank_prob: float=np.e**-1, ) -> paddle.Tensor:
"""
Args:
log_p_attn (Tensor): Batch of log probability of attention matrix (B, T_feats, T_text).
ilens (Tensor): Batch of the lengths of each input (B,).
olens (Tensor): Batch of the lengths of each target (B,).
blank_prob (float): Blank symbol probability
Returns:
Tensor: forwardsum loss value.
"""
B = log_p_attn.shape[0]
# add beta-binomial prior
bb_prior = self._generate_prior(ilens, olens)
bb_prior = paddle.to_tensor(
bb_prior, dtype=log_p_attn.dtype, place=log_p_attn.place)
log_p_attn = log_p_attn + bb_prior
# a row must be added to the attention matrix to account for blank token of CTC loss
# (B,T_feats,T_text+1)
log_p_attn_pd = F.pad(
log_p_attn, (0, 0, 0, 0, 1, 0), value=np.log(blank_prob))
loss = 0
for bidx in range(B):
# construct target sequnece.
# Every text token is mapped to a unique sequnece number.
target_seq = paddle.arange(
1, ilens[bidx] + 1, dtype="int32").unsqueeze(0)
cur_log_p_attn_pd = log_p_attn_pd[bidx, :olens[bidx], :ilens[
bidx] + 1].unsqueeze(1) # (T_feats,1,T_text+1)
# The input of ctc_loss API need to be fixed
loss += F.ctc_loss(
log_probs=cur_log_p_attn_pd,
labels=target_seq,
input_lengths=olens[bidx:bidx + 1],
label_lengths=ilens[bidx:bidx + 1])
loss = loss / B
return loss
def _generate_prior(self, text_lengths, feats_lengths,
w=1) -> paddle.Tensor:
"""Generate alignment prior formulated as beta-binomial distribution
Args:
text_lengths (Tensor): Batch of the lengths of each input (B,).
feats_lengths (Tensor): Batch of the lengths of each target (B,).
w (float): Scaling factor; lower -> wider the width
Returns:
Tensor: Batched 2d static prior matrix (B, T_feats, T_text)
"""
B = len(text_lengths)
T_text = text_lengths.max()
T_feats = feats_lengths.max()
bb_prior = paddle.full((B, T_feats, T_text), fill_value=-np.inf)
for bidx in range(B):
T = feats_lengths[bidx].item()
N = text_lengths[bidx].item()
key = str(T) + ',' + str(N)
if self.cache_prior and key in self._cache:
prob = self._cache[key]
else:
alpha = w * np.arange(1, T + 1, dtype=float) # (T,)
beta = w * np.array([T - t + 1 for t in alpha])
k = np.arange(N)
batched_k = k[..., None] # (N,1)
prob = betabinom.pmf(batched_k, N, alpha, beta) # (N,T)
# store cache
if self.cache_prior and key not in self._cache:
self._cache[key] = prob
prob = paddle.to_tensor(
prob, place=text_lengths.place, dtype="float32").transpose(
(1, 0)) # -> (T,N)
bb_prior[bidx, :T, :N] = prob
return bb_prior
class MultiScaleSTFTLoss(nn.Layer):
"""Computes the multi-scale STFT loss from [1].
References
----------
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
"DDSP: Differentiable Digital Signal Processing."
International Conference on Learning Representations. 2019.
Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/spectral.py
"""
def __init__(
self,
window_lengths: List[int]=[2048, 512],
loss_fn: Callable=nn.L1Loss(),
clamp_eps: float=1e-5,
mag_weight: float=1.0,
log_weight: float=1.0,
pow: float=2.0,
weight: float=1.0,
match_stride: bool=False,
window_type: Optional[str]=None, ):
"""
Args:
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
window_type : str, optional
Type of window to use, by default None.
"""
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type, ) for w in window_lengths
]
self.loss_fn = loss_fn
self.log_weight = log_weight
self.mag_weight = mag_weight
self.clamp_eps = clamp_eps
self.weight = weight
self.pow = pow
def forward(self, x: AudioSignal, y: AudioSignal):
"""Computes multi-scale STFT between an estimate and a reference
signal.
Args:
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns:
paddle.Tensor
Multi-scale STFT loss.
Example:
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
>>> import paddle
>>> x = AudioSignal("https://paddlespeech.cdn.bcebos.com/PaddleAudio/en.wav", 2_05)
>>> y = x * 0.01
>>> loss = MultiScaleSTFTLoss()
>>> loss(x, y).numpy()
7.562150
"""
for s in self.stft_params:
x.stft(s.window_length, s.hop_length, s.window_type)
y.stft(s.window_length, s.hop_length, s.window_type)
loss += self.log_weight * self.loss_fn(
x.magnitude.clip(self.clamp_eps).pow(self.pow).log10(),
y.magnitude.clip(self.clamp_eps).pow(self.pow).log10(), )
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
return loss
class GANLoss(nn.Layer):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
Example:
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
>>> import paddle
>>> x = AudioSignal("https://paddlespeech.cdn.bcebos.com/PaddleAudio/en.wav", 2_05)
>>> y = x * 0.01
>>> class My_discriminator0:
>>> def __call__(self, x):
>>> return x.sum()
>>> loss = GANLoss(My_discriminator0())
>>> [loss(x, y)[0].numpy(), loss(x, y)[1].numpy()]
[-0.102722, -0.001027]
>>> class My_discriminator1:
>>> def __call__(self, x):
>>> return x.sum()
>>> loss = GANLoss(My_discriminator1())
>>> [loss.generator_loss(x, y)[0].numpy(), loss.generator_loss(x, y)[1].numpy()]
[1.00019, 0]
>>> loss.discriminator_loss(x, y)
1.000200
"""
def __init__(self, discriminator):
"""
Args:
discriminator : paddle.nn.layer
Discriminator model
"""
super().__init__()
self.discriminator = discriminator
def forward(self,
fake: Union[AudioSignal, paddle.Tensor],
real: Union[AudioSignal, paddle.Tensor]):
if isinstance(fake, AudioSignal):
d_fake = self.discriminator(fake.audio_data)
else:
d_fake = self.discriminator(fake)
if isinstance(real, AudioSignal):
d_real = self.discriminator(real.audio_data)
else:
d_real = self.discriminator(real)
return d_fake, d_real
def discriminator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)
loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += paddle.mean(x_fake[-1]**2)
loss_d += paddle.mean((1 - x_real[-1])**2)
return loss_d
def generator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)
loss_g = 0
for x_fake in d_fake:
loss_g += paddle.mean((1 - x_fake[-1])**2)
loss_feature = 0
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j]())
return loss_g, loss_feature
class SISDRLoss(nn.Layer):
"""
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
of estimated and reference audio signals or aligned features.
Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/distance.py
Example:
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
>>> import paddle
>>> x = AudioSignal("https://paddlespeech.cdn.bcebos.com/PaddleAudio/en.wav", 2_05)
>>> y = x * 0.01
>>> sisdr = SISDRLoss()
>>> sisdr(x, y).numpy()
-145.377640
"""
def __init__(
self,
scaling: bool=True,
reduction: str="mean",
zero_mean: bool=True,
clip_min: Optional[int]=None,
weight: float=1.0, ):
"""
Args:
scaling : bool, optional
Whether to use scale-invariant (True) or
signal-to-noise ratio (False), by default True
reduction : str, optional
How to reduce across the batch (either 'mean',
'sum', or none).], by default ' mean'
zero_mean : bool, optional
Zero mean the references and estimates before
computing the loss, by default True
clip_min : int, optional
The minimum possible loss value. Helps network
to not focus on making already good examples better, by default None
weight : float, optional
Weight of this loss, defaults to 1.0.
"""
self.scaling = scaling
self.reduction = reduction
self.zero_mean = zero_mean
self.clip_min = clip_min
self.weight = weight
super().__init__()
def forward(self,
x: Union[AudioSignal, paddle.Tensor],
y: Union[AudioSignal, paddle.Tensor]):
eps = 1e-8
# B, C, T
if isinstance(x, AudioSignal):
references = x.audio_data
estimates = y.audio_data
else:
references = x
estimates = y
nb = references.shape[0]
references = references.reshape([nb, 1, -1]).transpose([0, 2, 1])
estimates = estimates.reshape([nb, 1, -1]).transpose([0, 2, 1])
# samples now on axis 1
if self.zero_mean:
mean_reference = references.mean(axis=1, keepdim=True)
mean_estimate = estimates.mean(axis=1, keepdim=True)
else:
mean_reference = 0
mean_estimate = 0
_references = references - mean_reference
_estimates = estimates - mean_estimate
references_projection = (_references**2).sum(axis=-2) + eps
references_on_estimates = (_estimates * _references).sum(axis=-2) + eps
scale = (
(references_on_estimates / references_projection).unsqueeze(axis=1)
if self.scaling else 1)
e_true = scale * _references
e_res = _estimates - e_true
signal = (e_true**2).sum(axis=1)
noise = (e_res**2).sum(axis=1)
sdr = -10 * paddle.log10(signal / noise + eps)
if self.clip_min != None:
sdr = paddle.clip(sdr, min=self.clip_min)
if self.reduction == "mean":
sdr = sdr.mean()
elif self.reduction == "sum":
sdr = sdr.sum()
return sdr