[Hackathon 7th No.56] 在 PaddleSpeech 中复现 DAC 训练需要用到的 loss

pull/3954/head
suzakuwcx 9 months ago
parent 8ee3a7ee40
commit 2c04e0ac92
No known key found for this signature in database
GPG Key ID: FA07FC9584DD32FE

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple
from typing import Tuple, Callable, List, Union
import librosa
import numpy as np
@ -28,6 +28,8 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import (
DurationPredictorLoss, # noqa: H301
)
from paddleaudio.audiotools import AudioSignal, STFTParams
# Losses for WaveRNN
def log_sum_exp(x):
@ -984,6 +986,108 @@ class MelSpectrogramLoss(nn.Layer):
return mel_loss
class MultiMelSpectrogramLoss(nn.Layer):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.
Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [150, 80],
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
"""
def __init__(
self,
n_mels: List[int] = [150, 80],
window_lengths: List[int] = [2048, 512],
loss_fn: Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
mel_fmin: List[float] = [0.0, 0.0],
mel_fmax: List[float] = [None, None],
window_type: str = None,
fs: int = 44100,
):
super().__init__()
self.mel_loss_fns = [
MelSpectrogramLoss(
fs=fs,
fft_size=w,
hop_size=w // 4,
num_mels=n_mel,
fmin=fmin,
fmax=fmax,
eps=clamp_eps,
)
for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, mel_fmax)
]
self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.pow = pow
def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi mel loss between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal or Tensor
Estimate signal
y : AudioSignal or Tensor
Reference signal
Returns
-------
paddle.Tensor
Mel loss.
"""
loss = 0.0
for i, mel_loss_fn in enumerate(self.mel_loss_fns):
if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor):
loss += self.log_weight * mel_loss_fn(x, y)
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
s = mel_loss_fn.mel_spectrogram.stft_params
x_mels = x.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length'])
y_mels = y.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length'])
loss += self.log_weight * self.loss_fn(
paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10()
)
else:
raise ValueError('\'x\' amd \'y\' should be the same type')
return loss
class FeatureMatchLoss(nn.Layer):
"""Feature matching loss module."""
@ -1326,3 +1430,118 @@ class ForwardSumLoss(nn.Layer):
bb_prior[bidx, :T, :N] = prob
return bb_prior
class MultiScaleSTFTLoss(nn.Layer):
"""Multi resolution STFT loss module."""
def __init__(
self,
window_lengths: List[int] = [2048, 512],
loss_fn: Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
window_type: str = 'hann',
):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.loss_fn = loss_fn
self.log_weight = log_weight
self.mag_weight = mag_weight
self.clamp_eps = clamp_eps
self.weight = weight
self.pow = pow
def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi-scale STFT between an estimate and a reference
signal.
Parameters
----------
x : AudioSignal or Tensor
Estimate signal
y : AudioSignal or Tensor
Reference signal
Returns
-------
paddle.Tensor
Multi-scale STFT loss.
"""
loss = 0.0
for s in self.stft_params:
if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor):
x_mag = stft(x.reshape([-1, x.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type)
y_mag = stft(y.reshape([-1, y.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type)
x_mag = x_mag.transpose([0, 2, 1])
y_mag = y_mag.transpose([0, 2, 1])
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
x.stft(s.window_length, s.hop_length, s.window_type)
y.stft(s.window_length, s.hop_length, s.window_type)
x_mag = x.magnitude
y_mag = y.magnitude
else:
raise ValueError('\'x\' amd \'y\' should be the same type')
loss += self.log_weight * self.loss_fn(
paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(),
)
loss += self.mag_weight * self.loss_fn(x_mag, y_mag)
return loss
class GANLoss(nn.Layer):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""
def __init__(self, discriminator):
super().__init__()
self.discriminator = discriminator
def forward(self, fake, real):
d_fake = self.discriminator(fake.audio_data)
d_real = self.discriminator(real.audio_data)
return d_fake, d_real
def discriminator_loss(self, fake, real):
d_fake, d_real = self.forward(fake.clone().detach(), real)
loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += paddle.mean(x_fake[-1] ** 2)
loss_d += paddle.mean((1 - x_real[-1]) ** 2)
return loss_d
def generator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)
loss_g = 0
for x_fake in d_fake:
loss_g += paddle.mean((1 - x_fake[-1]) ** 2)
loss_feature = 0
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += paddle.nn.functional.l1_loss(d_fake[i][j], d_real[i][j].detach())
return loss_g, loss_feature

@ -0,0 +1,22 @@
import torch
from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss, MultiMelSpectrogramLoss
from paddleaudio.audiotools.core.audio_signal import AudioSignal
def test_dac_losses():
for i in range(10):
loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt')
recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav')
signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav')
loss_fn_1 = MultiScaleSTFTLoss()
loss_fn_2 = MultiMelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mag_weight=0.0, pow=1.0, mel_fmin=[0, 0, 0, 0, 0, 0, 0], mel_fmax=[None, None, None, None, None, None, None])
#
# Test AudioSignal
#
assert abs(loss_fn_1(recons, signal).item() - loss_origin['stft/loss'].item()) < 1e-5
assert abs(loss_fn_2(recons, signal).item() - loss_origin['mel/loss'].item()) < 1e-5
#
# Test Tensor
#
assert abs(loss_fn_1(recons.audio_data, signal.audio_data).item() - loss_origin['stft/loss'].item()) < 1e-3
assert abs(loss_fn_2(recons.audio_data, signal.audio_data).item() - loss_origin['mel/loss'].item()) < 1e-3
Loading…
Cancel
Save