diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index 581e17adf..029ad1be1 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -12,13 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Tuple, Callable, List, Union +from typing import Callable +from typing import List +from typing import Tuple +from typing import Union import librosa import numpy as np import paddle from paddle import nn from paddle.nn import functional as F +from paddleaudio.audiotools import AudioSignal +from paddleaudio.audiotools import STFTParams from scipy import signal from scipy.stats import betabinom from typeguard import check_argument_types @@ -28,8 +33,6 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import ( DurationPredictorLoss, # noqa: H301 ) -from paddleaudio.audiotools import AudioSignal, STFTParams - # Losses for WaveRNN def log_sum_exp(x): @@ -1015,21 +1018,20 @@ class MultiMelSpectrogramLoss(nn.Layer): """ def __init__( - self, - n_mels: List[int] = [150, 80], - window_lengths: List[int] = [2048, 512], - loss_fn: Callable = nn.L1Loss(), - clamp_eps: float = 1e-5, - mag_weight: float = 1.0, - log_weight: float = 1.0, - pow: float = 2.0, - weight: float = 1.0, - match_stride: bool = False, - mel_fmin: List[float] = [0.0, 0.0], - mel_fmax: List[float] = [None, None], - window_type: str = None, - fs: int = 44100, - ): + self, + n_mels: List[int]=[150, 80], + window_lengths: List[int]=[2048, 512], + loss_fn: Callable=nn.L1Loss(), + clamp_eps: float=1e-5, + mag_weight: float=1.0, + log_weight: float=1.0, + pow: float=2.0, + weight: float=1.0, + match_stride: bool=False, + mel_fmin: List[float]=[0.0, 0.0], + mel_fmax: List[float]=[None, None], + window_type: str=None, + fs: int=44100, ): super().__init__() self.mel_loss_fns = [ @@ -1040,11 +1042,11 @@ class MultiMelSpectrogramLoss(nn.Layer): num_mels=n_mel, fmin=fmin, fmax=fmax, - eps=clamp_eps, - ) - for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, mel_fmax) + eps=clamp_eps, ) + for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, + mel_fmax) ] - + self.n_mels = n_mels self.loss_fn = loss_fn self.clamp_eps = clamp_eps @@ -1055,7 +1057,9 @@ class MultiMelSpectrogramLoss(nn.Layer): self.mel_fmax = mel_fmax self.pow = pow - def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): + def forward(self, + x: Union[AudioSignal, paddle.Tensor], + y: Union[AudioSignal, paddle.Tensor]): """Computes multi mel loss between an estimate and a reference signal. @@ -1077,12 +1081,21 @@ class MultiMelSpectrogramLoss(nn.Layer): loss += self.log_weight * mel_loss_fn(x, y) elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): s = mel_loss_fn.mel_spectrogram.stft_params - x_mels = x.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length']) - y_mels = y.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length']) + x_mels = x.mel_spectrogram( + self.n_mels[i], + mel_fmin=self.mel_fmin[i], + mel_fmax=self.mel_fmax[i], + window_length=s['n_fft'], + hop_length=s['hop_length']) + y_mels = y.mel_spectrogram( + self.n_mels[i], + mel_fmin=self.mel_fmin[i], + mel_fmax=self.mel_fmax[i], + window_length=s['n_fft'], + hop_length=s['hop_length']) loss += self.log_weight * self.loss_fn( paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(), - paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10() - ) + paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10()) else: raise ValueError('\'x\' amd \'y\' should be the same type') return loss @@ -1436,26 +1449,23 @@ class MultiScaleSTFTLoss(nn.Layer): """Multi resolution STFT loss module.""" def __init__( - self, - window_lengths: List[int] = [2048, 512], - loss_fn: Callable = nn.L1Loss(), - clamp_eps: float = 1e-5, - mag_weight: float = 1.0, - log_weight: float = 1.0, - pow: float = 2.0, - weight: float = 1.0, - match_stride: bool = False, - window_type: str = 'hann', - ): + self, + window_lengths: List[int]=[2048, 512], + loss_fn: Callable=nn.L1Loss(), + clamp_eps: float=1e-5, + mag_weight: float=1.0, + log_weight: float=1.0, + pow: float=2.0, + weight: float=1.0, + match_stride: bool=False, + window_type: str='hann', ): super().__init__() self.stft_params = [ STFTParams( window_length=w, hop_length=w // 4, match_stride=match_stride, - window_type=window_type, - ) - for w in window_lengths + window_type=window_type, ) for w in window_lengths ] self.loss_fn = loss_fn self.log_weight = log_weight @@ -1464,8 +1474,9 @@ class MultiScaleSTFTLoss(nn.Layer): self.weight = weight self.pow = pow - - def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): + def forward(self, + x: Union[AudioSignal, paddle.Tensor], + y: Union[AudioSignal, paddle.Tensor]): """Computes multi-scale STFT between an estimate and a reference signal. @@ -1482,11 +1493,21 @@ class MultiScaleSTFTLoss(nn.Layer): Multi-scale STFT loss. """ loss = 0.0 - + for s in self.stft_params: if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): - x_mag = stft(x.reshape([-1, x.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type) - y_mag = stft(y.reshape([-1, y.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type) + x_mag = stft( + x.reshape([-1, x.shape[-1]]), + fft_size=s.window_length, + hop_length=s.hop_length, + win_length=s.window_length, + window=s.window_type) + y_mag = stft( + y.reshape([-1, y.shape[-1]]), + fft_size=s.window_length, + hop_length=s.hop_length, + win_length=s.window_length, + window=s.window_type) x_mag = x_mag.transpose([0, 2, 1]) y_mag = y_mag.transpose([0, 2, 1]) elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): @@ -1499,8 +1520,7 @@ class MultiScaleSTFTLoss(nn.Layer): loss += self.log_weight * self.loss_fn( paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(), - paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), - ) + paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), ) loss += self.mag_weight * self.loss_fn(x_mag, y_mag) return loss @@ -1528,8 +1548,8 @@ class GANLoss(nn.Layer): loss_d = 0 for x_fake, x_real in zip(d_fake, d_real): - loss_d += paddle.mean(x_fake[-1] ** 2) - loss_d += paddle.mean((1 - x_real[-1]) ** 2) + loss_d += paddle.mean(x_fake[-1]**2) + loss_d += paddle.mean((1 - x_real[-1])**2) return loss_d def generator_loss(self, fake, real): @@ -1537,11 +1557,12 @@ class GANLoss(nn.Layer): loss_g = 0 for x_fake in d_fake: - loss_g += paddle.mean((1 - x_fake[-1]) ** 2) + loss_g += paddle.mean((1 - x_fake[-1])**2) loss_feature = 0 for i in range(len(d_fake)): for j in range(len(d_fake[i]) - 1): - loss_feature += paddle.nn.functional.l1_loss(d_fake[i][j], d_real[i][j].detach()) + loss_feature += paddle.nn.functional.l1_loss( + d_fake[i][j], d_real[i][j].detach()) return loss_g, loss_feature diff --git a/tests/unit/tts/test_losses.py b/tests/unit/tts/test_losses.py index 8b379d747..9480c0069 100644 --- a/tests/unit/tts/test_losses.py +++ b/tests/unit/tts/test_losses.py @@ -1,22 +1,39 @@ import torch -from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss, MultiMelSpectrogramLoss from paddleaudio.audiotools.core.audio_signal import AudioSignal +from paddlespeech.t2s.modules.losses import MultiMelSpectrogramLoss +from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss + + def test_dac_losses(): for i in range(10): loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt') recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav') signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav') loss_fn_1 = MultiScaleSTFTLoss() - loss_fn_2 = MultiMelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mag_weight=0.0, pow=1.0, mel_fmin=[0, 0, 0, 0, 0, 0, 0], mel_fmax=[None, None, None, None, None, None, None]) + loss_fn_2 = MultiMelSpectrogramLoss( + n_mels=[5, 10, 20, 40, 80, 160, 320], + window_lengths=[32, 64, 128, 256, 512, 1024, 2048], + mag_weight=0.0, + pow=1.0, + mel_fmin=[0, 0, 0, 0, 0, 0, 0], + mel_fmax=[None, None, None, None, None, None, None]) # # Test AudioSignal # - assert abs(loss_fn_1(recons, signal).item() - loss_origin['stft/loss'].item()) < 1e-5 - assert abs(loss_fn_2(recons, signal).item() - loss_origin['mel/loss'].item()) < 1e-5 + assert abs( + loss_fn_1(recons, signal).item() - loss_origin['stft/loss'] + .item()) < 1e-5 + assert abs( + loss_fn_2(recons, signal).item() - loss_origin['mel/loss'] + .item()) < 1e-5 # # Test Tensor # - assert abs(loss_fn_1(recons.audio_data, signal.audio_data).item() - loss_origin['stft/loss'].item()) < 1e-3 - assert abs(loss_fn_2(recons.audio_data, signal.audio_data).item() - loss_origin['mel/loss'].item()) < 1e-3 + assert abs( + loss_fn_1(recons.audio_data, signal.audio_data).item() - + loss_origin['stft/loss'].item()) < 1e-3 + assert abs( + loss_fn_2(recons.audio_data, signal.audio_data).item() - + loss_origin['mel/loss'].item()) < 1e-3