|
|
@ -12,7 +12,11 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
|
|
|
|
from typing import Callable
|
|
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from typing import Optional
|
|
|
|
from typing import Tuple
|
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
from typing import Union
|
|
|
|
|
|
|
|
|
|
|
|
import librosa
|
|
|
|
import librosa
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
@ -23,6 +27,8 @@ from scipy import signal
|
|
|
|
from scipy.stats import betabinom
|
|
|
|
from scipy.stats import betabinom
|
|
|
|
from typeguard import typechecked
|
|
|
|
from typeguard import typechecked
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from paddlespeech.audiotools.core.audio_signal import AudioSignal
|
|
|
|
|
|
|
|
from paddlespeech.audiotools.core.audio_signal import STFTParams
|
|
|
|
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
|
|
|
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
|
|
|
from paddlespeech.t2s.modules.predictor.duration_predictor import (
|
|
|
|
from paddlespeech.t2s.modules.predictor.duration_predictor import (
|
|
|
|
DurationPredictorLoss, # noqa: H301
|
|
|
|
DurationPredictorLoss, # noqa: H301
|
|
|
@ -1326,3 +1332,276 @@ class ForwardSumLoss(nn.Layer):
|
|
|
|
bb_prior[bidx, :T, :N] = prob
|
|
|
|
bb_prior[bidx, :T, :N] = prob
|
|
|
|
|
|
|
|
|
|
|
|
return bb_prior
|
|
|
|
return bb_prior
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiScaleSTFTLoss(nn.Layer):
|
|
|
|
|
|
|
|
"""Computes the multi-scale STFT loss from [1].
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
References
|
|
|
|
|
|
|
|
----------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
|
|
|
|
|
|
|
|
"DDSP: Differentiable Digital Signal Processing."
|
|
|
|
|
|
|
|
International Conference on Learning Representations. 2019.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/spectral.py
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
window_lengths: List[int]=[2048, 512],
|
|
|
|
|
|
|
|
loss_fn: Callable=nn.L1Loss(),
|
|
|
|
|
|
|
|
clamp_eps: float=1e-5,
|
|
|
|
|
|
|
|
mag_weight: float=1.0,
|
|
|
|
|
|
|
|
log_weight: float=1.0,
|
|
|
|
|
|
|
|
pow: float=2.0,
|
|
|
|
|
|
|
|
weight: float=1.0,
|
|
|
|
|
|
|
|
match_stride: bool=False,
|
|
|
|
|
|
|
|
window_type: Optional[str]=None, ):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
window_lengths : List[int], optional
|
|
|
|
|
|
|
|
Length of each window of each STFT, by default [2048, 512]
|
|
|
|
|
|
|
|
loss_fn : typing.Callable, optional
|
|
|
|
|
|
|
|
How to compare each loss, by default nn.L1Loss()
|
|
|
|
|
|
|
|
clamp_eps : float, optional
|
|
|
|
|
|
|
|
Clamp on the log magnitude, below, by default 1e-5
|
|
|
|
|
|
|
|
mag_weight : float, optional
|
|
|
|
|
|
|
|
Weight of raw magnitude portion of loss, by default 1.0
|
|
|
|
|
|
|
|
log_weight : float, optional
|
|
|
|
|
|
|
|
Weight of log magnitude portion of loss, by default 1.0
|
|
|
|
|
|
|
|
pow : float, optional
|
|
|
|
|
|
|
|
Power to raise magnitude to before taking log, by default 2.0
|
|
|
|
|
|
|
|
weight : float, optional
|
|
|
|
|
|
|
|
Weight of this loss, by default 1.0
|
|
|
|
|
|
|
|
match_stride : bool, optional
|
|
|
|
|
|
|
|
Whether to match the stride of convolutional layers, by default False
|
|
|
|
|
|
|
|
window_type : str, optional
|
|
|
|
|
|
|
|
Type of window to use, by default None.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.stft_params = [
|
|
|
|
|
|
|
|
STFTParams(
|
|
|
|
|
|
|
|
window_length=w,
|
|
|
|
|
|
|
|
hop_length=w // 4,
|
|
|
|
|
|
|
|
match_stride=match_stride,
|
|
|
|
|
|
|
|
window_type=window_type, ) for w in window_lengths
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
self.loss_fn = loss_fn
|
|
|
|
|
|
|
|
self.log_weight = log_weight
|
|
|
|
|
|
|
|
self.mag_weight = mag_weight
|
|
|
|
|
|
|
|
self.clamp_eps = clamp_eps
|
|
|
|
|
|
|
|
self.weight = weight
|
|
|
|
|
|
|
|
self.pow = pow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: AudioSignal, y: AudioSignal):
|
|
|
|
|
|
|
|
"""Computes multi-scale STFT between an estimate and a reference
|
|
|
|
|
|
|
|
signal.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
x : AudioSignal
|
|
|
|
|
|
|
|
Estimate signal
|
|
|
|
|
|
|
|
y : AudioSignal
|
|
|
|
|
|
|
|
Reference signal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
paddle.Tensor
|
|
|
|
|
|
|
|
Multi-scale STFT loss.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
|
|
|
|
|
|
|
|
>>> import paddle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05)
|
|
|
|
|
|
|
|
>>> y = x * 0.01
|
|
|
|
|
|
|
|
>>> loss = MultiScaleSTFTLoss()
|
|
|
|
|
|
|
|
>>> loss(x, y).numpy()
|
|
|
|
|
|
|
|
7.562150
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
for s in self.stft_params:
|
|
|
|
|
|
|
|
x.stft(s.window_length, s.hop_length, s.window_type)
|
|
|
|
|
|
|
|
y.stft(s.window_length, s.hop_length, s.window_type)
|
|
|
|
|
|
|
|
loss += self.log_weight * self.loss_fn(
|
|
|
|
|
|
|
|
x.magnitude.clip(self.clamp_eps).pow(self.pow).log10(),
|
|
|
|
|
|
|
|
y.magnitude.clip(self.clamp_eps).pow(self.pow).log10(), )
|
|
|
|
|
|
|
|
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GANLoss(nn.Layer):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Computes a discriminator loss, given a discriminator on
|
|
|
|
|
|
|
|
generated waveforms/spectrograms compared to ground truth
|
|
|
|
|
|
|
|
waveforms/spectrograms. Computes the loss for both the
|
|
|
|
|
|
|
|
discriminator and the generator in separate functions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
|
|
|
|
|
|
|
|
>>> import paddle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05)
|
|
|
|
|
|
|
|
>>> y = x * 0.01
|
|
|
|
|
|
|
|
>>> class My_discriminator0:
|
|
|
|
|
|
|
|
>>> def __call__(self, x):
|
|
|
|
|
|
|
|
>>> return x.sum()
|
|
|
|
|
|
|
|
>>> loss = GANLoss(My_discriminator0())
|
|
|
|
|
|
|
|
>>> [loss(x, y)[0].numpy(), loss(x, y)[1].numpy()]
|
|
|
|
|
|
|
|
[-0.102722, -0.001027]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>> class My_discriminator1:
|
|
|
|
|
|
|
|
>>> def __call__(self, x):
|
|
|
|
|
|
|
|
>>> return x.sum()
|
|
|
|
|
|
|
|
>>> loss = GANLoss(My_discriminator1())
|
|
|
|
|
|
|
|
>>> [loss.generator_loss(x, y)[0].numpy(), loss.generator_loss(x, y)[1].numpy()]
|
|
|
|
|
|
|
|
[1.00019, 0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>> loss.discriminator_loss(x, y)
|
|
|
|
|
|
|
|
1.000200
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, discriminator):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
discriminator : paddle.nn.layer
|
|
|
|
|
|
|
|
Discriminator model
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.discriminator = discriminator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
|
|
|
fake: Union[AudioSignal, paddle.Tensor],
|
|
|
|
|
|
|
|
real: Union[AudioSignal, paddle.Tensor]):
|
|
|
|
|
|
|
|
if isinstance(fake, AudioSignal):
|
|
|
|
|
|
|
|
d_fake = self.discriminator(fake.audio_data)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
d_fake = self.discriminator(fake)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(real, AudioSignal):
|
|
|
|
|
|
|
|
d_real = self.discriminator(real.audio_data)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
d_real = self.discriminator(real)
|
|
|
|
|
|
|
|
return d_fake, d_real
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def discriminator_loss(self, fake, real):
|
|
|
|
|
|
|
|
d_fake, d_real = self.forward(fake, real)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_d = 0
|
|
|
|
|
|
|
|
for x_fake, x_real in zip(d_fake, d_real):
|
|
|
|
|
|
|
|
loss_d += paddle.mean(x_fake[-1]**2)
|
|
|
|
|
|
|
|
loss_d += paddle.mean((1 - x_real[-1])**2)
|
|
|
|
|
|
|
|
return loss_d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generator_loss(self, fake, real):
|
|
|
|
|
|
|
|
d_fake, d_real = self.forward(fake, real)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_g = 0
|
|
|
|
|
|
|
|
for x_fake in d_fake:
|
|
|
|
|
|
|
|
loss_g += paddle.mean((1 - x_fake[-1])**2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_feature = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(d_fake)):
|
|
|
|
|
|
|
|
for j in range(len(d_fake[i]) - 1):
|
|
|
|
|
|
|
|
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j]())
|
|
|
|
|
|
|
|
return loss_g, loss_feature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SISDRLoss(nn.Layer):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
|
|
|
|
|
|
|
|
of estimated and reference audio signals or aligned features.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/distance.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
|
|
|
|
|
|
|
|
>>> import paddle
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
>>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05)
|
|
|
|
|
|
|
|
>>> y = x * 0.01
|
|
|
|
|
|
|
|
>>> sisdr = SISDRLoss()
|
|
|
|
|
|
|
|
>>> sisdr(x, y).numpy()
|
|
|
|
|
|
|
|
-145.377640
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
scaling: bool=True,
|
|
|
|
|
|
|
|
reduction: str="mean",
|
|
|
|
|
|
|
|
zero_mean: bool=True,
|
|
|
|
|
|
|
|
clip_min: Optional[int]=None,
|
|
|
|
|
|
|
|
weight: float=1.0, ):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
scaling : bool, optional
|
|
|
|
|
|
|
|
Whether to use scale-invariant (True) or
|
|
|
|
|
|
|
|
signal-to-noise ratio (False), by default True
|
|
|
|
|
|
|
|
reduction : str, optional
|
|
|
|
|
|
|
|
How to reduce across the batch (either 'mean',
|
|
|
|
|
|
|
|
'sum', or none).], by default ' mean'
|
|
|
|
|
|
|
|
zero_mean : bool, optional
|
|
|
|
|
|
|
|
Zero mean the references and estimates before
|
|
|
|
|
|
|
|
computing the loss, by default True
|
|
|
|
|
|
|
|
clip_min : int, optional
|
|
|
|
|
|
|
|
The minimum possible loss value. Helps network
|
|
|
|
|
|
|
|
to not focus on making already good examples better, by default None
|
|
|
|
|
|
|
|
weight : float, optional
|
|
|
|
|
|
|
|
Weight of this loss, defaults to 1.0.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.scaling = scaling
|
|
|
|
|
|
|
|
self.reduction = reduction
|
|
|
|
|
|
|
|
self.zero_mean = zero_mean
|
|
|
|
|
|
|
|
self.clip_min = clip_min
|
|
|
|
|
|
|
|
self.weight = weight
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self,
|
|
|
|
|
|
|
|
x: Union[AudioSignal, paddle.Tensor],
|
|
|
|
|
|
|
|
y: Union[AudioSignal, paddle.Tensor]):
|
|
|
|
|
|
|
|
eps = 1e-8
|
|
|
|
|
|
|
|
# B, C, T
|
|
|
|
|
|
|
|
if isinstance(x, AudioSignal):
|
|
|
|
|
|
|
|
references = x.audio_data
|
|
|
|
|
|
|
|
estimates = y.audio_data
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
references = x
|
|
|
|
|
|
|
|
estimates = y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nb = references.shape[0]
|
|
|
|
|
|
|
|
references = references.reshape([nb, 1, -1]).transpose([0, 2, 1])
|
|
|
|
|
|
|
|
estimates = estimates.reshape([nb, 1, -1]).transpose([0, 2, 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# samples now on axis 1
|
|
|
|
|
|
|
|
if self.zero_mean:
|
|
|
|
|
|
|
|
mean_reference = references.mean(axis=1, keepdim=True)
|
|
|
|
|
|
|
|
mean_estimate = estimates.mean(axis=1, keepdim=True)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
mean_reference = 0
|
|
|
|
|
|
|
|
mean_estimate = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_references = references - mean_reference
|
|
|
|
|
|
|
|
_estimates = estimates - mean_estimate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
references_projection = (_references**2).sum(axis=-2) + eps
|
|
|
|
|
|
|
|
references_on_estimates = (_estimates * _references).sum(axis=-2) + eps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scale = (
|
|
|
|
|
|
|
|
(references_on_estimates / references_projection).unsqueeze(axis=1)
|
|
|
|
|
|
|
|
if self.scaling else 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
e_true = scale * _references
|
|
|
|
|
|
|
|
e_res = _estimates - e_true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
signal = (e_true**2).sum(axis=1)
|
|
|
|
|
|
|
|
noise = (e_res**2).sum(axis=1)
|
|
|
|
|
|
|
|
sdr = -10 * paddle.log10(signal / noise + eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.clip_min != None:
|
|
|
|
|
|
|
|
sdr = paddle.clip(sdr, min=self.clip_min)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.reduction == "mean":
|
|
|
|
|
|
|
|
sdr = sdr.mean()
|
|
|
|
|
|
|
|
elif self.reduction == "sum":
|
|
|
|
|
|
|
|
sdr = sdr.sum()
|
|
|
|
|
|
|
|
return sdr
|
|
|
|