Using pre-commit to format code

pull/3954/head
suzakuwcx 9 months ago
parent 2c04e0ac92
commit b741545f5e
No known key found for this signature in database
GPG Key ID: FA07FC9584DD32FE

@ -12,13 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple, Callable, List, Union
from typing import Callable
from typing import List
from typing import Tuple
from typing import Union
import librosa
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddleaudio.audiotools import AudioSignal
from paddleaudio.audiotools import STFTParams
from scipy import signal
from scipy.stats import betabinom
from typeguard import check_argument_types
@ -28,8 +33,6 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import (
DurationPredictorLoss, # noqa: H301
)
from paddleaudio.audiotools import AudioSignal, STFTParams
# Losses for WaveRNN
def log_sum_exp(x):
@ -1015,21 +1018,20 @@ class MultiMelSpectrogramLoss(nn.Layer):
"""
def __init__(
self,
n_mels: List[int] = [150, 80],
window_lengths: List[int] = [2048, 512],
loss_fn: Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
mel_fmin: List[float] = [0.0, 0.0],
mel_fmax: List[float] = [None, None],
window_type: str = None,
fs: int = 44100,
):
self,
n_mels: List[int]=[150, 80],
window_lengths: List[int]=[2048, 512],
loss_fn: Callable=nn.L1Loss(),
clamp_eps: float=1e-5,
mag_weight: float=1.0,
log_weight: float=1.0,
pow: float=2.0,
weight: float=1.0,
match_stride: bool=False,
mel_fmin: List[float]=[0.0, 0.0],
mel_fmax: List[float]=[None, None],
window_type: str=None,
fs: int=44100, ):
super().__init__()
self.mel_loss_fns = [
@ -1040,11 +1042,11 @@ class MultiMelSpectrogramLoss(nn.Layer):
num_mels=n_mel,
fmin=fmin,
fmax=fmax,
eps=clamp_eps,
)
for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, mel_fmax)
eps=clamp_eps, )
for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin,
mel_fmax)
]
self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
@ -1055,7 +1057,9 @@ class MultiMelSpectrogramLoss(nn.Layer):
self.mel_fmax = mel_fmax
self.pow = pow
def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]):
def forward(self,
x: Union[AudioSignal, paddle.Tensor],
y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi mel loss between an estimate and a reference
signal.
@ -1077,12 +1081,21 @@ class MultiMelSpectrogramLoss(nn.Layer):
loss += self.log_weight * mel_loss_fn(x, y)
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
s = mel_loss_fn.mel_spectrogram.stft_params
x_mels = x.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length'])
y_mels = y.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length'])
x_mels = x.mel_spectrogram(
self.n_mels[i],
mel_fmin=self.mel_fmin[i],
mel_fmax=self.mel_fmax[i],
window_length=s['n_fft'],
hop_length=s['hop_length'])
y_mels = y.mel_spectrogram(
self.n_mels[i],
mel_fmin=self.mel_fmin[i],
mel_fmax=self.mel_fmax[i],
window_length=s['n_fft'],
hop_length=s['hop_length'])
loss += self.log_weight * self.loss_fn(
paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10()
)
paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10())
else:
raise ValueError('\'x\' amd \'y\' should be the same type')
return loss
@ -1436,26 +1449,23 @@ class MultiScaleSTFTLoss(nn.Layer):
"""Multi resolution STFT loss module."""
def __init__(
self,
window_lengths: List[int] = [2048, 512],
loss_fn: Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 1.0,
log_weight: float = 1.0,
pow: float = 2.0,
weight: float = 1.0,
match_stride: bool = False,
window_type: str = 'hann',
):
self,
window_lengths: List[int]=[2048, 512],
loss_fn: Callable=nn.L1Loss(),
clamp_eps: float=1e-5,
mag_weight: float=1.0,
log_weight: float=1.0,
pow: float=2.0,
weight: float=1.0,
match_stride: bool=False,
window_type: str='hann', ):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
window_type=window_type, ) for w in window_lengths
]
self.loss_fn = loss_fn
self.log_weight = log_weight
@ -1464,8 +1474,9 @@ class MultiScaleSTFTLoss(nn.Layer):
self.weight = weight
self.pow = pow
def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]):
def forward(self,
x: Union[AudioSignal, paddle.Tensor],
y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi-scale STFT between an estimate and a reference
signal.
@ -1482,11 +1493,21 @@ class MultiScaleSTFTLoss(nn.Layer):
Multi-scale STFT loss.
"""
loss = 0.0
for s in self.stft_params:
if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor):
x_mag = stft(x.reshape([-1, x.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type)
y_mag = stft(y.reshape([-1, y.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type)
x_mag = stft(
x.reshape([-1, x.shape[-1]]),
fft_size=s.window_length,
hop_length=s.hop_length,
win_length=s.window_length,
window=s.window_type)
y_mag = stft(
y.reshape([-1, y.shape[-1]]),
fft_size=s.window_length,
hop_length=s.hop_length,
win_length=s.window_length,
window=s.window_type)
x_mag = x_mag.transpose([0, 2, 1])
y_mag = y_mag.transpose([0, 2, 1])
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
@ -1499,8 +1520,7 @@ class MultiScaleSTFTLoss(nn.Layer):
loss += self.log_weight * self.loss_fn(
paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(),
)
paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), )
loss += self.mag_weight * self.loss_fn(x_mag, y_mag)
return loss
@ -1528,8 +1548,8 @@ class GANLoss(nn.Layer):
loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += paddle.mean(x_fake[-1] ** 2)
loss_d += paddle.mean((1 - x_real[-1]) ** 2)
loss_d += paddle.mean(x_fake[-1]**2)
loss_d += paddle.mean((1 - x_real[-1])**2)
return loss_d
def generator_loss(self, fake, real):
@ -1537,11 +1557,12 @@ class GANLoss(nn.Layer):
loss_g = 0
for x_fake in d_fake:
loss_g += paddle.mean((1 - x_fake[-1]) ** 2)
loss_g += paddle.mean((1 - x_fake[-1])**2)
loss_feature = 0
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += paddle.nn.functional.l1_loss(d_fake[i][j], d_real[i][j].detach())
loss_feature += paddle.nn.functional.l1_loss(
d_fake[i][j], d_real[i][j].detach())
return loss_g, loss_feature

@ -1,22 +1,39 @@
import torch
from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss, MultiMelSpectrogramLoss
from paddleaudio.audiotools.core.audio_signal import AudioSignal
from paddlespeech.t2s.modules.losses import MultiMelSpectrogramLoss
from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss
def test_dac_losses():
for i in range(10):
loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt')
recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav')
signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav')
loss_fn_1 = MultiScaleSTFTLoss()
loss_fn_2 = MultiMelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mag_weight=0.0, pow=1.0, mel_fmin=[0, 0, 0, 0, 0, 0, 0], mel_fmax=[None, None, None, None, None, None, None])
loss_fn_2 = MultiMelSpectrogramLoss(
n_mels=[5, 10, 20, 40, 80, 160, 320],
window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
mag_weight=0.0,
pow=1.0,
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
mel_fmax=[None, None, None, None, None, None, None])
#
# Test AudioSignal
#
assert abs(loss_fn_1(recons, signal).item() - loss_origin['stft/loss'].item()) < 1e-5
assert abs(loss_fn_2(recons, signal).item() - loss_origin['mel/loss'].item()) < 1e-5
assert abs(
loss_fn_1(recons, signal).item() - loss_origin['stft/loss']
.item()) < 1e-5
assert abs(
loss_fn_2(recons, signal).item() - loss_origin['mel/loss']
.item()) < 1e-5
#
# Test Tensor
#
assert abs(loss_fn_1(recons.audio_data, signal.audio_data).item() - loss_origin['stft/loss'].item()) < 1e-3
assert abs(loss_fn_2(recons.audio_data, signal.audio_data).item() - loss_origin['mel/loss'].item()) < 1e-3
assert abs(
loss_fn_1(recons.audio_data, signal.audio_data).item() -
loss_origin['stft/loss'].item()) < 1e-3
assert abs(
loss_fn_2(recons.audio_data, signal.audio_data).item() -
loss_origin['mel/loss'].item()) < 1e-3

Loading…
Cancel
Save