[Hackathon 7th No.56] 在 PaddleSpeech 中复现 DAC 训练需要用到的 loss

pull/3954/head
suzakuwcx 9 months ago
parent 8ee3a7ee40
commit 2c04e0ac92
No known key found for this signature in database
GPG Key ID: FA07FC9584DD32FE

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Tuple from typing import Tuple, Callable, List, Union
import librosa import librosa
import numpy as np import numpy as np
@ -28,6 +28,8 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import (
DurationPredictorLoss, # noqa: H301 DurationPredictorLoss, # noqa: H301
) )
from paddleaudio.audiotools import AudioSignal, STFTParams
# Losses for WaveRNN # Losses for WaveRNN
def log_sum_exp(x): def log_sum_exp(x):
@ -984,6 +986,108 @@ class MelSpectrogramLoss(nn.Layer):
return mel_loss return mel_loss
class MultiMelSpectrogramLoss(nn.Layer):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.
Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [150, 80],
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
"""
def __init__(
self,
n_mels: List[int] = [150, 80],
window_lengths: List[int] = [2048, 512],
loss_fn: Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
mel_fmin: List[float] = [0.0, 0.0],
mel_fmax: List[float] = [None, None],
window_type: str = None,
fs: int = 44100,
):
super().__init__()
self.mel_loss_fns = [
MelSpectrogramLoss(
fs=fs,
fft_size=w,
hop_size=w // 4,
num_mels=n_mel,
fmin=fmin,
fmax=fmax,
eps=clamp_eps,
)
for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, mel_fmax)
]
self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.pow = pow
def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi mel loss between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal or Tensor
Estimate signal
y : AudioSignal or Tensor
Reference signal
Returns
-------
paddle.Tensor
Mel loss.
"""
loss = 0.0
for i, mel_loss_fn in enumerate(self.mel_loss_fns):
if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor):
loss += self.log_weight * mel_loss_fn(x, y)
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
s = mel_loss_fn.mel_spectrogram.stft_params
x_mels = x.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length'])
y_mels = y.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length'])
loss += self.log_weight * self.loss_fn(
paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10()
)
else:
raise ValueError('\'x\' amd \'y\' should be the same type')
return loss
class FeatureMatchLoss(nn.Layer): class FeatureMatchLoss(nn.Layer):
"""Feature matching loss module.""" """Feature matching loss module."""
@ -1326,3 +1430,118 @@ class ForwardSumLoss(nn.Layer):
bb_prior[bidx, :T, :N] = prob bb_prior[bidx, :T, :N] = prob
return bb_prior return bb_prior
class MultiScaleSTFTLoss(nn.Layer):
"""Multi resolution STFT loss module."""
def __init__(
self,
window_lengths: List[int] = [2048, 512],
loss_fn: Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
window_type: str = 'hann',
):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.loss_fn = loss_fn
self.log_weight = log_weight
self.mag_weight = mag_weight
self.clamp_eps = clamp_eps
self.weight = weight
self.pow = pow
def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi-scale STFT between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal or Tensor
Estimate signal
y : AudioSignal or Tensor
Reference signal
Returns
-------
paddle.Tensor
Multi-scale STFT loss.
"""
loss = 0.0
for s in self.stft_params:
if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor):
x_mag = stft(x.reshape([-1, x.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type)
y_mag = stft(y.reshape([-1, y.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type)
x_mag = x_mag.transpose([0, 2, 1])
y_mag = y_mag.transpose([0, 2, 1])
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
x.stft(s.window_length, s.hop_length, s.window_type)
y.stft(s.window_length, s.hop_length, s.window_type)
x_mag = x.magnitude
y_mag = y.magnitude
else:
raise ValueError('\'x\' amd \'y\' should be the same type')
loss += self.log_weight * self.loss_fn(
paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(),
)
loss += self.mag_weight * self.loss_fn(x_mag, y_mag)
return loss
class GANLoss(nn.Layer):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""
def __init__(self, discriminator):
super().__init__()
self.discriminator = discriminator
def forward(self, fake, real):
d_fake = self.discriminator(fake.audio_data)
d_real = self.discriminator(real.audio_data)
return d_fake, d_real
def discriminator_loss(self, fake, real):
d_fake, d_real = self.forward(fake.clone().detach(), real)
loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += paddle.mean(x_fake[-1] ** 2)
loss_d += paddle.mean((1 - x_real[-1]) ** 2)
return loss_d
def generator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)
loss_g = 0
for x_fake in d_fake:
loss_g += paddle.mean((1 - x_fake[-1]) ** 2)
loss_feature = 0
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += paddle.nn.functional.l1_loss(d_fake[i][j], d_real[i][j].detach())
return loss_g, loss_feature

@ -0,0 +1,22 @@
import torch
from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss, MultiMelSpectrogramLoss
from paddleaudio.audiotools.core.audio_signal import AudioSignal
def test_dac_losses():
for i in range(10):
loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt')
recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav')
signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav')
loss_fn_1 = MultiScaleSTFTLoss()
loss_fn_2 = MultiMelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mag_weight=0.0, pow=1.0, mel_fmin=[0, 0, 0, 0, 0, 0, 0], mel_fmax=[None, None, None, None, None, None, None])
#
# Test AudioSignal
#
assert abs(loss_fn_1(recons, signal).item() - loss_origin['stft/loss'].item()) < 1e-5
assert abs(loss_fn_2(recons, signal).item() - loss_origin['mel/loss'].item()) < 1e-5
#
# Test Tensor
#
assert abs(loss_fn_1(recons.audio_data, signal.audio_data).item() - loss_origin['stft/loss'].item()) < 1e-3
assert abs(loss_fn_2(recons.audio_data, signal.audio_data).item() - loss_origin['mel/loss'].item()) < 1e-3
Loading…
Cancel
Save