From 2c04e0ac9219ec84c197c4a76b8972c23f205a60 Mon Sep 17 00:00:00 2001 From: suzakuwcx Date: Tue, 17 Dec 2024 08:50:06 +0800 Subject: [PATCH] =?UTF-8?q?[Hackathon=207th=20No.56]=20=E5=9C=A8=20PaddleS?= =?UTF-8?q?peech=20=E4=B8=AD=E5=A4=8D=E7=8E=B0=20DAC=20=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E9=9C=80=E8=A6=81=E7=94=A8=E5=88=B0=E7=9A=84=20loss?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddlespeech/t2s/modules/losses.py | 221 ++++++++++++++++++++++++++++- tests/unit/tts/test_losses.py | 22 +++ 2 files changed, 242 insertions(+), 1 deletion(-) create mode 100644 tests/unit/tts/test_losses.py diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index e675dcab7..581e17adf 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple +from typing import Tuple, Callable, List, Union import librosa import numpy as np @@ -28,6 +28,8 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import ( DurationPredictorLoss, # noqa: H301 ) +from paddleaudio.audiotools import AudioSignal, STFTParams + # Losses for WaveRNN def log_sum_exp(x): @@ -984,6 +986,108 @@ class MelSpectrogramLoss(nn.Layer): return mel_loss +class MultiMelSpectrogramLoss(nn.Layer): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + fs: int = 44100, + ): + super().__init__() + + self.mel_loss_fns = [ + MelSpectrogramLoss( + fs=fs, + fft_size=w, + hop_size=w // 4, + num_mels=n_mel, + fmin=fmin, + fmax=fmax, + eps=clamp_eps, + ) + for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, mel_fmax) + ] + + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): + """Computes multi mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal or Tensor + Estimate signal + y : AudioSignal or Tensor + Reference signal + + Returns + ------- + paddle.Tensor + Mel loss. + """ + loss = 0.0 + for i, mel_loss_fn in enumerate(self.mel_loss_fns): + if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): + loss += self.log_weight * mel_loss_fn(x, y) + elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): + s = mel_loss_fn.mel_spectrogram.stft_params + x_mels = x.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length']) + y_mels = y.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length']) + loss += self.log_weight * self.loss_fn( + paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(), + paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10() + ) + else: + raise ValueError('\'x\' amd \'y\' should be the same type') + return loss + + class FeatureMatchLoss(nn.Layer): """Feature matching loss module.""" @@ -1326,3 +1430,118 @@ class ForwardSumLoss(nn.Layer): bb_prior[bidx, :T, :N] = prob return bb_prior + + +class MultiScaleSTFTLoss(nn.Layer): + """Multi resolution STFT loss module.""" + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = 'hann', + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + + def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal or Tensor + Estimate signal + y : AudioSignal or Tensor + Reference signal + + Returns + ------- + paddle.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + + for s in self.stft_params: + if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): + x_mag = stft(x.reshape([-1, x.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type) + y_mag = stft(y.reshape([-1, y.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type) + x_mag = x_mag.transpose([0, 2, 1]) + y_mag = y_mag.transpose([0, 2, 1]) + elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + x_mag = x.magnitude + y_mag = y.magnitude + else: + raise ValueError('\'x\' amd \'y\' should be the same type') + + loss += self.log_weight * self.loss_fn( + paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(), + paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), + ) + + loss += self.mag_weight * self.loss_fn(x_mag, y_mag) + return loss + + +class GANLoss(nn.Layer): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += paddle.mean(x_fake[-1] ** 2) + loss_d += paddle.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += paddle.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += paddle.nn.functional.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/tests/unit/tts/test_losses.py b/tests/unit/tts/test_losses.py new file mode 100644 index 000000000..8b379d747 --- /dev/null +++ b/tests/unit/tts/test_losses.py @@ -0,0 +1,22 @@ +import torch +from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss, MultiMelSpectrogramLoss +from paddleaudio.audiotools.core.audio_signal import AudioSignal + +def test_dac_losses(): + for i in range(10): + loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt') + recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav') + signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav') + loss_fn_1 = MultiScaleSTFTLoss() + loss_fn_2 = MultiMelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mag_weight=0.0, pow=1.0, mel_fmin=[0, 0, 0, 0, 0, 0, 0], mel_fmax=[None, None, None, None, None, None, None]) + # + # Test AudioSignal + # + assert abs(loss_fn_1(recons, signal).item() - loss_origin['stft/loss'].item()) < 1e-5 + assert abs(loss_fn_2(recons, signal).item() - loss_origin['mel/loss'].item()) < 1e-5 + + # + # Test Tensor + # + assert abs(loss_fn_1(recons.audio_data, signal.audio_data).item() - loss_origin['stft/loss'].item()) < 1e-3 + assert abs(loss_fn_2(recons.audio_data, signal.audio_data).item() - loss_origin['mel/loss'].item()) < 1e-3