Using pre-commit to format code

pull/3954/head
suzakuwcx 9 months ago
parent 2c04e0ac92
commit b741545f5e
No known key found for this signature in database
GPG Key ID: FA07FC9584DD32FE

@ -12,13 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Tuple, Callable, List, Union from typing import Callable
from typing import List
from typing import Tuple
from typing import Union
import librosa import librosa
import numpy as np import numpy as np
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddleaudio.audiotools import AudioSignal
from paddleaudio.audiotools import STFTParams
from scipy import signal from scipy import signal
from scipy.stats import betabinom from scipy.stats import betabinom
from typeguard import check_argument_types from typeguard import check_argument_types
@ -28,8 +33,6 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import (
DurationPredictorLoss, # noqa: H301 DurationPredictorLoss, # noqa: H301
) )
from paddleaudio.audiotools import AudioSignal, STFTParams
# Losses for WaveRNN # Losses for WaveRNN
def log_sum_exp(x): def log_sum_exp(x):
@ -1015,21 +1018,20 @@ class MultiMelSpectrogramLoss(nn.Layer):
""" """
def __init__( def __init__(
self, self,
n_mels: List[int] = [150, 80], n_mels: List[int]=[150, 80],
window_lengths: List[int] = [2048, 512], window_lengths: List[int]=[2048, 512],
loss_fn: Callable = nn.L1Loss(), loss_fn: Callable=nn.L1Loss(),
clamp_eps: float = 1e-5, clamp_eps: float=1e-5,
mag_weight: float = 1.0, mag_weight: float=1.0,
log_weight: float = 1.0, log_weight: float=1.0,
pow: float = 2.0, pow: float=2.0,
weight: float = 1.0, weight: float=1.0,
match_stride: bool = False, match_stride: bool=False,
mel_fmin: List[float] = [0.0, 0.0], mel_fmin: List[float]=[0.0, 0.0],
mel_fmax: List[float] = [None, None], mel_fmax: List[float]=[None, None],
window_type: str = None, window_type: str=None,
fs: int = 44100, fs: int=44100, ):
):
super().__init__() super().__init__()
self.mel_loss_fns = [ self.mel_loss_fns = [
@ -1040,11 +1042,11 @@ class MultiMelSpectrogramLoss(nn.Layer):
num_mels=n_mel, num_mels=n_mel,
fmin=fmin, fmin=fmin,
fmax=fmax, fmax=fmax,
eps=clamp_eps, eps=clamp_eps, )
) for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin,
for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, mel_fmax) mel_fmax)
] ]
self.n_mels = n_mels self.n_mels = n_mels
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.clamp_eps = clamp_eps self.clamp_eps = clamp_eps
@ -1055,7 +1057,9 @@ class MultiMelSpectrogramLoss(nn.Layer):
self.mel_fmax = mel_fmax self.mel_fmax = mel_fmax
self.pow = pow self.pow = pow
def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): def forward(self,
x: Union[AudioSignal, paddle.Tensor],
y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi mel loss between an estimate and a reference """Computes multi mel loss between an estimate and a reference
signal. signal.
@ -1077,12 +1081,21 @@ class MultiMelSpectrogramLoss(nn.Layer):
loss += self.log_weight * mel_loss_fn(x, y) loss += self.log_weight * mel_loss_fn(x, y)
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
s = mel_loss_fn.mel_spectrogram.stft_params s = mel_loss_fn.mel_spectrogram.stft_params
x_mels = x.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length']) x_mels = x.mel_spectrogram(
y_mels = y.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length']) self.n_mels[i],
mel_fmin=self.mel_fmin[i],
mel_fmax=self.mel_fmax[i],
window_length=s['n_fft'],
hop_length=s['hop_length'])
y_mels = y.mel_spectrogram(
self.n_mels[i],
mel_fmin=self.mel_fmin[i],
mel_fmax=self.mel_fmax[i],
window_length=s['n_fft'],
hop_length=s['hop_length'])
loss += self.log_weight * self.loss_fn( loss += self.log_weight * self.loss_fn(
paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(), paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10() paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10())
)
else: else:
raise ValueError('\'x\' amd \'y\' should be the same type') raise ValueError('\'x\' amd \'y\' should be the same type')
return loss return loss
@ -1436,26 +1449,23 @@ class MultiScaleSTFTLoss(nn.Layer):
"""Multi resolution STFT loss module.""" """Multi resolution STFT loss module."""
def __init__( def __init__(
self, self,
window_lengths: List[int] = [2048, 512], window_lengths: List[int]=[2048, 512],
loss_fn: Callable = nn.L1Loss(), loss_fn: Callable=nn.L1Loss(),
clamp_eps: float = 1e-5, clamp_eps: float=1e-5,
mag_weight: float = 1.0, mag_weight: float=1.0,
log_weight: float = 1.0, log_weight: float=1.0,
pow: float = 2.0, pow: float=2.0,
weight: float = 1.0, weight: float=1.0,
match_stride: bool = False, match_stride: bool=False,
window_type: str = 'hann', window_type: str='hann', ):
):
super().__init__() super().__init__()
self.stft_params = [ self.stft_params = [
STFTParams( STFTParams(
window_length=w, window_length=w,
hop_length=w // 4, hop_length=w // 4,
match_stride=match_stride, match_stride=match_stride,
window_type=window_type, window_type=window_type, ) for w in window_lengths
)
for w in window_lengths
] ]
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.log_weight = log_weight self.log_weight = log_weight
@ -1464,8 +1474,9 @@ class MultiScaleSTFTLoss(nn.Layer):
self.weight = weight self.weight = weight
self.pow = pow self.pow = pow
def forward(self,
def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]): x: Union[AudioSignal, paddle.Tensor],
y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi-scale STFT between an estimate and a reference """Computes multi-scale STFT between an estimate and a reference
signal. signal.
@ -1482,11 +1493,21 @@ class MultiScaleSTFTLoss(nn.Layer):
Multi-scale STFT loss. Multi-scale STFT loss.
""" """
loss = 0.0 loss = 0.0
for s in self.stft_params: for s in self.stft_params:
if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor): if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor):
x_mag = stft(x.reshape([-1, x.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type) x_mag = stft(
y_mag = stft(y.reshape([-1, y.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type) x.reshape([-1, x.shape[-1]]),
fft_size=s.window_length,
hop_length=s.hop_length,
win_length=s.window_length,
window=s.window_type)
y_mag = stft(
y.reshape([-1, y.shape[-1]]),
fft_size=s.window_length,
hop_length=s.hop_length,
win_length=s.window_length,
window=s.window_type)
x_mag = x_mag.transpose([0, 2, 1]) x_mag = x_mag.transpose([0, 2, 1])
y_mag = y_mag.transpose([0, 2, 1]) y_mag = y_mag.transpose([0, 2, 1])
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal): elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
@ -1499,8 +1520,7 @@ class MultiScaleSTFTLoss(nn.Layer):
loss += self.log_weight * self.loss_fn( loss += self.log_weight * self.loss_fn(
paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(), paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), )
)
loss += self.mag_weight * self.loss_fn(x_mag, y_mag) loss += self.mag_weight * self.loss_fn(x_mag, y_mag)
return loss return loss
@ -1528,8 +1548,8 @@ class GANLoss(nn.Layer):
loss_d = 0 loss_d = 0
for x_fake, x_real in zip(d_fake, d_real): for x_fake, x_real in zip(d_fake, d_real):
loss_d += paddle.mean(x_fake[-1] ** 2) loss_d += paddle.mean(x_fake[-1]**2)
loss_d += paddle.mean((1 - x_real[-1]) ** 2) loss_d += paddle.mean((1 - x_real[-1])**2)
return loss_d return loss_d
def generator_loss(self, fake, real): def generator_loss(self, fake, real):
@ -1537,11 +1557,12 @@ class GANLoss(nn.Layer):
loss_g = 0 loss_g = 0
for x_fake in d_fake: for x_fake in d_fake:
loss_g += paddle.mean((1 - x_fake[-1]) ** 2) loss_g += paddle.mean((1 - x_fake[-1])**2)
loss_feature = 0 loss_feature = 0
for i in range(len(d_fake)): for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1): for j in range(len(d_fake[i]) - 1):
loss_feature += paddle.nn.functional.l1_loss(d_fake[i][j], d_real[i][j].detach()) loss_feature += paddle.nn.functional.l1_loss(
d_fake[i][j], d_real[i][j].detach())
return loss_g, loss_feature return loss_g, loss_feature

@ -1,22 +1,39 @@
import torch import torch
from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss, MultiMelSpectrogramLoss
from paddleaudio.audiotools.core.audio_signal import AudioSignal from paddleaudio.audiotools.core.audio_signal import AudioSignal
from paddlespeech.t2s.modules.losses import MultiMelSpectrogramLoss
from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss
def test_dac_losses(): def test_dac_losses():
for i in range(10): for i in range(10):
loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt') loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt')
recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav') recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav')
signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav') signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav')
loss_fn_1 = MultiScaleSTFTLoss() loss_fn_1 = MultiScaleSTFTLoss()
loss_fn_2 = MultiMelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mag_weight=0.0, pow=1.0, mel_fmin=[0, 0, 0, 0, 0, 0, 0], mel_fmax=[None, None, None, None, None, None, None]) loss_fn_2 = MultiMelSpectrogramLoss(
n_mels=[5, 10, 20, 40, 80, 160, 320],
window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
mag_weight=0.0,
pow=1.0,
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
mel_fmax=[None, None, None, None, None, None, None])
# #
# Test AudioSignal # Test AudioSignal
# #
assert abs(loss_fn_1(recons, signal).item() - loss_origin['stft/loss'].item()) < 1e-5 assert abs(
assert abs(loss_fn_2(recons, signal).item() - loss_origin['mel/loss'].item()) < 1e-5 loss_fn_1(recons, signal).item() - loss_origin['stft/loss']
.item()) < 1e-5
assert abs(
loss_fn_2(recons, signal).item() - loss_origin['mel/loss']
.item()) < 1e-5
# #
# Test Tensor # Test Tensor
# #
assert abs(loss_fn_1(recons.audio_data, signal.audio_data).item() - loss_origin['stft/loss'].item()) < 1e-3 assert abs(
assert abs(loss_fn_2(recons.audio_data, signal.audio_data).item() - loss_origin['mel/loss'].item()) < 1e-3 loss_fn_1(recons.audio_data, signal.audio_data).item() -
loss_origin['stft/loss'].item()) < 1e-3
assert abs(
loss_fn_2(recons.audio_data, signal.audio_data).item() -
loss_origin['mel/loss'].item()) < 1e-3

Loading…
Cancel
Save