diff --git a/paddlespeech/__init__.py b/paddlespeech/__init__.py index 969d189f5..6c7e75c1f 100644 --- a/paddlespeech/__init__.py +++ b/paddlespeech/__init__.py @@ -13,7 +13,3 @@ # limitations under the License. import _locale _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) - -__version__ = '0.0.0' - -__commit__ = '9cf8c1985a98bb380c183116123672976bdfe5c9' diff --git a/paddlespeech/audiotools/core/__init__.py b/paddlespeech/audiotools/core/__init__.py index 609d6a34a..3443a7676 100644 --- a/paddlespeech/audiotools/core/__init__.py +++ b/paddlespeech/audiotools/core/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from . import util -from ._julius import fft_conv1d -from ._julius import FFTConv1D +from ...t2s.modules import fft_conv1d +from ...t2s.modules import FFTConv1D from ._julius import highpass_filter from ._julius import highpass_filters from ._julius import lowpass_filter diff --git a/paddlespeech/audiotools/core/_julius.py b/paddlespeech/audiotools/core/_julius.py index aef51f98f..113475cdd 100644 --- a/paddlespeech/audiotools/core/_julius.py +++ b/paddlespeech/audiotools/core/_julius.py @@ -20,8 +20,6 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddlespeech.t2s.modules import fft_conv1d -from paddlespeech.t2s.modules import FFTConv1D from paddlespeech.utils import satisfy_paddle_version __all__ = [ @@ -312,6 +310,7 @@ class LowPassFilters(nn.Layer): mode="replicate", data_format="NCL") if self.fft: + from paddlespeech.t2s.modules import fft_conv1d out = fft_conv1d(_input, self.filters, stride=self.stride) else: out = F.conv1d(_input, self.filters, stride=self.stride) diff --git a/paddlespeech/audiotools/core/util.py b/paddlespeech/audiotools/core/util.py index 6da927a6f..676d57704 100644 --- a/paddlespeech/audiotools/core/util.py +++ b/paddlespeech/audiotools/core/util.py @@ -32,7 +32,6 @@ import soundfile from flatten_dict import flatten from flatten_dict import unflatten -from .audio_signal import AudioSignal from paddlespeech.utils import satisfy_paddle_version from paddlespeech.vector.training.seeding import seed_everything @@ -232,8 +231,7 @@ def ensure_tensor( def _get_value(other): # - from . import AudioSignal - + from .audio_signal import AudioSignal if isinstance(other, AudioSignal): return other.audio_data return other @@ -784,6 +782,8 @@ def collate(list_of_dicts: list, n_splits: int=None): Dictionary containing batched data. """ + from .audio_signal import AudioSignal + batches = [] list_len = len(list_of_dicts) @@ -873,7 +873,7 @@ def generate_chord_dataset( """ import librosa - from . import AudioSignal + from .audio_signal import AudioSignal from ..data.preprocess import create_csv min_midi = librosa.note_to_midi(min_note) diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index f819352d6..a1a65a9dc 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +from typing import Callable +from typing import List +from typing import Optional from typing import Tuple +from typing import Union import librosa import numpy as np @@ -23,6 +27,8 @@ from scipy import signal from scipy.stats import betabinom from typeguard import typechecked +from paddlespeech.audiotools.core.audio_signal import AudioSignal +from paddlespeech.audiotools.core.audio_signal import STFTParams from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from paddlespeech.t2s.modules.predictor.duration_predictor import ( DurationPredictorLoss, # noqa: H301 @@ -1326,3 +1332,276 @@ class ForwardSumLoss(nn.Layer): bb_prior[bidx, :T, :N] = prob return bb_prior + + +class MultiScaleSTFTLoss(nn.Layer): + """Computes the multi-scale STFT loss from [1]. + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int]=[2048, 512], + loss_fn: Callable=nn.L1Loss(), + clamp_eps: float=1e-5, + mag_weight: float=1.0, + log_weight: float=1.0, + pow: float=2.0, + weight: float=1.0, + match_stride: bool=False, + window_type: Optional[str]=None, ): + """ + Args: + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + window_type : str, optional + Type of window to use, by default None. + """ + super().__init__() + + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, ) for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Args: + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns: + paddle.Tensor + Multi-scale STFT loss. + + Example: + >>> from paddlespeech.audiotools.core.audio_signal import AudioSignal + >>> import paddle + + >>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05) + >>> y = x * 0.01 + >>> loss = MultiScaleSTFTLoss() + >>> loss(x, y).numpy() + 7.562150 + """ + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clip(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clip(self.clamp_eps).pow(self.pow).log10(), ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class GANLoss(nn.Layer): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + + Example: + >>> from paddlespeech.audiotools.core.audio_signal import AudioSignal + >>> import paddle + + >>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05) + >>> y = x * 0.01 + >>> class My_discriminator0: + >>> def __call__(self, x): + >>> return x.sum() + >>> loss = GANLoss(My_discriminator0()) + >>> [loss(x, y)[0].numpy(), loss(x, y)[1].numpy()] + [-0.102722, -0.001027] + + >>> class My_discriminator1: + >>> def __call__(self, x): + >>> return x.sum() + >>> loss = GANLoss(My_discriminator1()) + >>> [loss.generator_loss(x, y)[0].numpy(), loss.generator_loss(x, y)[1].numpy()] + [1.00019, 0] + + >>> loss.discriminator_loss(x, y) + 1.000200 + """ + + def __init__(self, discriminator): + """ + Args: + discriminator : paddle.nn.layer + Discriminator model + """ + super().__init__() + self.discriminator = discriminator + + def forward(self, + fake: Union[AudioSignal, paddle.Tensor], + real: Union[AudioSignal, paddle.Tensor]): + if isinstance(fake, AudioSignal): + d_fake = self.discriminator(fake.audio_data) + else: + d_fake = self.discriminator(fake) + + if isinstance(real, AudioSignal): + d_real = self.discriminator(real.audio_data) + else: + d_real = self.discriminator(real) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += paddle.mean(x_fake[-1]**2) + loss_d += paddle.mean((1 - x_real[-1])**2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += paddle.mean((1 - x_fake[-1])**2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j]()) + return loss_g, loss_feature + + +class SISDRLoss(nn.Layer): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/distance.py + + Example: + >>> from paddlespeech.audiotools.core.audio_signal import AudioSignal + >>> import paddle + + >>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05) + >>> y = x * 0.01 + >>> sisdr = SISDRLoss() + >>> sisdr(x, y).numpy() + -145.377640 + """ + + def __init__( + self, + scaling: bool=True, + reduction: str="mean", + zero_mean: bool=True, + clip_min: Optional[int]=None, + weight: float=1.0, ): + """ + Args: + scaling : bool, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : bool, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + """ + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, + x: Union[AudioSignal, paddle.Tensor], + y: Union[AudioSignal, paddle.Tensor]): + eps = 1e-8 + # B, C, T + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape([nb, 1, -1]).transpose([0, 2, 1]) + estimates = estimates.reshape([nb, 1, -1]).transpose([0, 2, 1]) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(axis=1, keepdim=True) + mean_estimate = estimates.mean(axis=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(axis=-2) + eps + references_on_estimates = (_estimates * _references).sum(axis=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(axis=1) + if self.scaling else 1) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(axis=1) + noise = (e_res**2).sum(axis=1) + sdr = -10 * paddle.log10(signal / noise + eps) + + if self.clip_min != None: + sdr = paddle.clip(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr diff --git a/tests/unit/audiotools/core/test_util.py b/tests/unit/audiotools/core/test_util.py index 155686acd..16e5d5e92 100644 --- a/tests/unit/audiotools/core/test_util.py +++ b/tests/unit/audiotools/core/test_util.py @@ -13,7 +13,6 @@ import pytest from paddlespeech.audiotools import util from paddlespeech.audiotools.core.audio_signal import AudioSignal -from paddlespeech.vector.training.seeding import seed_everything def test_check_random_state(): @@ -36,12 +35,12 @@ def test_check_random_state(): def test_seed(): - seed_everything(0) + util.seed_everything(0) paddle_result_a = paddle.randn([1]) np_result_a = np.random.randn(1) py_result_a = random.random() - seed_everything(0) + util.seed_everything(0) paddle_result_b = paddle.randn([1]) np_result_b = np.random.randn(1) py_result_b = random.random() diff --git a/tests/unit/audiotools/test_audiotools.sh b/tests/unit/audiotools/test_audiotools.sh index 3a0161900..f69447d62 100644 --- a/tests/unit/audiotools/test_audiotools.sh +++ b/tests/unit/audiotools/test_audiotools.sh @@ -1,4 +1,3 @@ -python -m pip install -r ../../../paddlespeech/audiotools/requirements.txt wget https://paddlespeech.bj.bcebos.com/PaddleAudio/audio_tools/audio.tar.gz wget https://paddlespeech.bj.bcebos.com/PaddleAudio/audio_tools/regression.tar.gz tar -zxvf audio.tar.gz diff --git a/tests/unit/ci.sh b/tests/unit/ci.sh index 6beff0707..567af2210 100644 --- a/tests/unit/ci.sh +++ b/tests/unit/ci.sh @@ -1,6 +1,7 @@ function main(){ set -ex speech_ci_path=`pwd` + python -m pip install -r ../../paddlespeech/audiotools/requirements.txt echo "Start asr" cd ${speech_ci_path}/asr @@ -16,6 +17,7 @@ function main(){ python test_enfrontend.py python test_fftconv1d.py python test_mixfrontend.py + python test_losses.py echo "End TTS" echo "Start Vector" diff --git a/tests/unit/tts/test_losses.py b/tests/unit/tts/test_losses.py new file mode 100644 index 000000000..f99d15d1c --- /dev/null +++ b/tests/unit/tts/test_losses.py @@ -0,0 +1,61 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from paddlespeech.audiotools.core.audio_signal import AudioSignal +from paddlespeech.t2s.modules.losses import GANLoss +from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss +from paddlespeech.t2s.modules.losses import SISDRLoss + + +def get_input(): + x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", + 2_05) + y = x * 0.01 + return x, y + + +def test_multi_scale_stft_loss(): + x, y = get_input() + loss = MultiScaleSTFTLoss() + pd_loss = loss(x, y) + assert np.abs(pd_loss.numpy() - 7.562150) < 1e-06 + + +def test_sisdr_loss(): + x, y = get_input() + loss = SISDRLoss() + pd_loss = loss(x, y) + assert np.abs(pd_loss.numpy() - (-145.377640)) < 1e-06 + + +def test_gan_loss(): + class My_discriminator0: + def __call__(self, x): + return x.sum() + + class My_discriminator1: + def __call__(self, x): + return x * (-0.2) + + x, y = get_input() + loss = GANLoss(My_discriminator0()) + pd_loss0, pd_loss1 = loss(x, y) + assert np.abs(pd_loss0.numpy() - (-0.102722)) < 1e-06 + assert np.abs(pd_loss1.numpy() - (-0.001027)) < 1e-06 + loss = GANLoss(My_discriminator1()) + pd_loss0, _ = loss.generator_loss(x, y) + assert np.abs(pd_loss0.numpy() - 1.000199) < 1e-06 + pd_loss = loss.discriminator_loss(x, y) + assert np.abs(pd_loss.numpy() - 1.000200) < 1e-06