Using pre-commit to format code

pull/3954/head
suzakuwcx 9 months ago
parent 2c04e0ac92
commit b741545f5e
No known key found for this signature in database
GPG Key ID: FA07FC9584DD32FE

@ -12,13 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple, Callable, List, Union
from typing import Callable
from typing import List
from typing import Tuple
from typing import Union
import librosa
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddleaudio.audiotools import AudioSignal
from paddleaudio.audiotools import STFTParams
from scipy import signal
from scipy.stats import betabinom
from typeguard import check_argument_types
@ -28,8 +33,6 @@ from paddlespeech.t2s.modules.predictor.duration_predictor import (
DurationPredictorLoss, # noqa: H301
)
from paddleaudio.audiotools import AudioSignal, STFTParams
# Losses for WaveRNN
def log_sum_exp(x):
@ -1028,8 +1031,7 @@ class MultiMelSpectrogramLoss(nn.Layer):
mel_fmin: List[float]=[0.0, 0.0],
mel_fmax: List[float]=[None, None],
window_type: str=None,
fs: int = 44100,
):
fs: int=44100, ):
super().__init__()
self.mel_loss_fns = [
@ -1040,9 +1042,9 @@ class MultiMelSpectrogramLoss(nn.Layer):
num_mels=n_mel,
fmin=fmin,
fmax=fmax,
eps=clamp_eps,
)
for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin, mel_fmax)
eps=clamp_eps, )
for n_mel, w, fmin, fmax in zip(n_mels, window_lengths, mel_fmin,
mel_fmax)
]
self.n_mels = n_mels
@ -1055,7 +1057,9 @@ class MultiMelSpectrogramLoss(nn.Layer):
self.mel_fmax = mel_fmax
self.pow = pow
def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]):
def forward(self,
x: Union[AudioSignal, paddle.Tensor],
y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi mel loss between an estimate and a reference
signal.
@ -1077,12 +1081,21 @@ class MultiMelSpectrogramLoss(nn.Layer):
loss += self.log_weight * mel_loss_fn(x, y)
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
s = mel_loss_fn.mel_spectrogram.stft_params
x_mels = x.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length'])
y_mels = y.mel_spectrogram(self.n_mels[i], mel_fmin=self.mel_fmin[i], mel_fmax=self.mel_fmax[i], window_length=s['n_fft'], hop_length=s['hop_length'])
x_mels = x.mel_spectrogram(
self.n_mels[i],
mel_fmin=self.mel_fmin[i],
mel_fmax=self.mel_fmax[i],
window_length=s['n_fft'],
hop_length=s['hop_length'])
y_mels = y.mel_spectrogram(
self.n_mels[i],
mel_fmin=self.mel_fmin[i],
mel_fmax=self.mel_fmax[i],
window_length=s['n_fft'],
hop_length=s['hop_length'])
loss += self.log_weight * self.loss_fn(
paddle.clip(x_mels, self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10()
)
paddle.clip(y_mels, self.clamp_eps).pow(self.pow).log10())
else:
raise ValueError('\'x\' amd \'y\' should be the same type')
return loss
@ -1445,17 +1458,14 @@ class MultiScaleSTFTLoss(nn.Layer):
pow: float=2.0,
weight: float=1.0,
match_stride: bool=False,
window_type: str = 'hann',
):
window_type: str='hann', ):
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
window_type=window_type, ) for w in window_lengths
]
self.loss_fn = loss_fn
self.log_weight = log_weight
@ -1464,8 +1474,9 @@ class MultiScaleSTFTLoss(nn.Layer):
self.weight = weight
self.pow = pow
def forward(self, x: Union[AudioSignal, paddle.Tensor], y: Union[AudioSignal, paddle.Tensor]):
def forward(self,
x: Union[AudioSignal, paddle.Tensor],
y: Union[AudioSignal, paddle.Tensor]):
"""Computes multi-scale STFT between an estimate and a reference
signal.
@ -1485,8 +1496,18 @@ class MultiScaleSTFTLoss(nn.Layer):
for s in self.stft_params:
if isinstance(x, paddle.Tensor) and isinstance(y, paddle.Tensor):
x_mag = stft(x.reshape([-1, x.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type)
y_mag = stft(y.reshape([-1, y.shape[-1]]), fft_size=s.window_length, hop_length=s.hop_length, win_length=s.window_length, window=s.window_type)
x_mag = stft(
x.reshape([-1, x.shape[-1]]),
fft_size=s.window_length,
hop_length=s.hop_length,
win_length=s.window_length,
window=s.window_type)
y_mag = stft(
y.reshape([-1, y.shape[-1]]),
fft_size=s.window_length,
hop_length=s.hop_length,
win_length=s.window_length,
window=s.window_type)
x_mag = x_mag.transpose([0, 2, 1])
y_mag = y_mag.transpose([0, 2, 1])
elif isinstance(x, AudioSignal) and isinstance(y, AudioSignal):
@ -1499,8 +1520,7 @@ class MultiScaleSTFTLoss(nn.Layer):
loss += self.log_weight * self.loss_fn(
paddle.clip(x_mag, min=self.clamp_eps).pow(self.pow).log10(),
paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(),
)
paddle.clip(y_mag, min=self.clamp_eps).pow(self.pow).log10(), )
loss += self.mag_weight * self.loss_fn(x_mag, y_mag)
return loss
@ -1543,5 +1563,6 @@ class GANLoss(nn.Layer):
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += paddle.nn.functional.l1_loss(d_fake[i][j], d_real[i][j].detach())
loss_feature += paddle.nn.functional.l1_loss(
d_fake[i][j], d_real[i][j].detach())
return loss_g, loss_feature

@ -1,22 +1,39 @@
import torch
from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss, MultiMelSpectrogramLoss
from paddleaudio.audiotools.core.audio_signal import AudioSignal
from paddlespeech.t2s.modules.losses import MultiMelSpectrogramLoss
from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss
def test_dac_losses():
for i in range(10):
loss_origin = torch.load(f'tests/unit/tts/data/{i}-loss.pt')
recons = AudioSignal(f'tests/unit/tts/data/{i}-recons.wav')
signal = AudioSignal(f'tests/unit/tts/data/{i}-signal.wav')
loss_fn_1 = MultiScaleSTFTLoss()
loss_fn_2 = MultiMelSpectrogramLoss(n_mels=[5, 10, 20, 40, 80, 160, 320], window_lengths=[32, 64, 128, 256, 512, 1024, 2048], mag_weight=0.0, pow=1.0, mel_fmin=[0, 0, 0, 0, 0, 0, 0], mel_fmax=[None, None, None, None, None, None, None])
loss_fn_2 = MultiMelSpectrogramLoss(
n_mels=[5, 10, 20, 40, 80, 160, 320],
window_lengths=[32, 64, 128, 256, 512, 1024, 2048],
mag_weight=0.0,
pow=1.0,
mel_fmin=[0, 0, 0, 0, 0, 0, 0],
mel_fmax=[None, None, None, None, None, None, None])
#
# Test AudioSignal
#
assert abs(loss_fn_1(recons, signal).item() - loss_origin['stft/loss'].item()) < 1e-5
assert abs(loss_fn_2(recons, signal).item() - loss_origin['mel/loss'].item()) < 1e-5
assert abs(
loss_fn_1(recons, signal).item() - loss_origin['stft/loss']
.item()) < 1e-5
assert abs(
loss_fn_2(recons, signal).item() - loss_origin['mel/loss']
.item()) < 1e-5
#
# Test Tensor
#
assert abs(loss_fn_1(recons.audio_data, signal.audio_data).item() - loss_origin['stft/loss'].item()) < 1e-3
assert abs(loss_fn_2(recons.audio_data, signal.audio_data).item() - loss_origin['mel/loss'].item()) < 1e-3
assert abs(
loss_fn_1(recons.audio_data, signal.audio_data).item() -
loss_origin['stft/loss'].item()) < 1e-3
assert abs(
loss_fn_2(recons.audio_data, signal.audio_data).item() -
loss_origin['mel/loss'].item()) < 1e-3

Loading…
Cancel
Save