add DAC loss

pull/3988/head
cchenhaifeng 7 months ago
parent 0479cce8ff
commit d92c45ec4f

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import util
from ._julius import fft_conv1d
from ._julius import FFTConv1D
from ._julius import highpass_filter
from ._julius import highpass_filters
from ._julius import lowpass_filter
@ -26,3 +24,5 @@ from ._julius import SplitBands
from .audio_signal import AudioSignal
from .audio_signal import STFTParams
from .loudness import Meter
from paddlespeech.t2s.modules import fft_conv1d
from paddlespeech.t2s.modules import FFTConv1D

@ -20,8 +20,6 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.t2s.modules import fft_conv1d
from paddlespeech.t2s.modules import FFTConv1D
from paddlespeech.utils import satisfy_paddle_version
__all__ = [
@ -312,6 +310,7 @@ class LowPassFilters(nn.Layer):
mode="replicate",
data_format="NCL")
if self.fft:
from paddlespeech.t2s.modules import fft_conv1d
out = fft_conv1d(_input, self.filters, stride=self.stride)
else:
out = F.conv1d(_input, self.filters, stride=self.stride)

@ -32,7 +32,6 @@ import soundfile
from flatten_dict import flatten
from flatten_dict import unflatten
from .audio_signal import AudioSignal
from paddlespeech.utils import satisfy_paddle_version
from paddlespeech.vector.training.seeding import seed_everything
@ -232,7 +231,7 @@ def ensure_tensor(
def _get_value(other):
#
from . import AudioSignal
from .audio_signal import AudioSignal
if isinstance(other, AudioSignal):
return other.audio_data
@ -873,7 +872,7 @@ def generate_chord_dataset(
"""
import librosa
from . import AudioSignal
from .audio_signal import AudioSignal
from ..data.preprocess import create_csv
min_midi = librosa.note_to_midi(min_note)

@ -12,7 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import librosa
import numpy as np
@ -23,6 +27,8 @@ from scipy import signal
from scipy.stats import betabinom
from typeguard import typechecked
from paddlespeech.audiotools.core.audio_signal import AudioSignal
from paddlespeech.audiotools.core.audio_signal import STFTParams
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.predictor.duration_predictor import (
DurationPredictorLoss, # noqa: H301
@ -1326,3 +1332,233 @@ class ForwardSumLoss(nn.Layer):
bb_prior[bidx, :T, :N] = prob
return bb_prior
class MultiScaleSTFTLoss(nn.Layer):
"""Computes the multi-scale STFT loss from [1].
References
----------
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
"DDSP: Differentiable Digital Signal Processing."
International Conference on Learning Representations. 2019.
Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/spectral.py
"""
def __init__(
self,
window_lengths: List[int]=[2048, 512],
loss_fn: Callable=nn.L1Loss(),
clamp_eps: float=1e-5,
mag_weight: float=1.0,
log_weight: float=1.0,
pow: float=2.0,
weight: float=1.0,
match_stride: bool=False,
window_type: Optional[str]=None, ):
"""
Args:
window_lengths : List[int], optional
Length of each window of each STFT, by default [2048, 512]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 1.0
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 2.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
window_type : str, optional
Type of window to use, by default None.
"""
super().__init__()
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type, ) for w in window_lengths
]
self.loss_fn = loss_fn
self.log_weight = log_weight
self.mag_weight = mag_weight
self.clamp_eps = clamp_eps
self.weight = weight
self.pow = pow
def forward(self, x: AudioSignal, y: AudioSignal):
"""Computes multi-scale STFT between an estimate and a reference
signal.
Args:
x : AudioSignal
Estimate signal
y : AudioSignal
Reference signal
Returns:
paddle.Tensor
Multi-scale STFT loss.
"""
for s in self.stft_params:
x.stft(s.window_length, s.hop_length, s.window_type)
y.stft(s.window_length, s.hop_length, s.window_type)
loss += self.log_weight * self.loss_fn(
x.magnitude.clip(self.clamp_eps).pow(self.pow).log10(),
y.magnitude.clip(self.clamp_eps).pow(self.pow).log10(), )
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
return loss
class GANLoss(nn.Layer):
"""
Computes a discriminator loss, given a discriminator on
generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions.
"""
def __init__(self, discriminator):
"""
Args:
discriminator : paddle.nn.layer
Discriminator model
"""
super().__init__()
self.discriminator = discriminator
def forward(self,
fake: Union[AudioSignal, paddle.Tensor],
real: Union[AudioSignal, paddle.Tensor]):
if isinstance(fake, AudioSignal):
d_fake = self.discriminator(fake.audio_data)
else:
d_fake = self.discriminator(fake)
if isinstance(real, AudioSignal):
d_real = self.discriminator(real.audio_data)
else:
d_real = self.discriminator(real)
return d_fake, d_real
def discriminator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)
loss_d = 0
for x_fake, x_real in zip(d_fake, d_real):
loss_d += paddle.mean(x_fake[-1]**2)
loss_d += paddle.mean((1 - x_real[-1])**2)
return loss_d
def generator_loss(self, fake, real):
d_fake, d_real = self.forward(fake, real)
loss_g = 0
for x_fake in d_fake:
loss_g += paddle.mean((1 - x_fake[-1])**2)
loss_feature = 0
for i in range(len(d_fake)):
for j in range(len(d_fake[i]) - 1):
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j]())
return loss_g, loss_feature
class SISDRLoss(nn.Layer):
"""
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
of estimated and reference audio signals or aligned features.
Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/distance.py
"""
def __init__(
self,
scaling: bool=True,
reduction: str="mean",
zero_mean: bool=True,
clip_min: Optional[int]=None,
weight: float=1.0, ):
"""
Args:
scaling : bool, optional
Whether to use scale-invariant (True) or
signal-to-noise ratio (False), by default True
reduction : str, optional
How to reduce across the batch (either 'mean',
'sum', or none).], by default ' mean'
zero_mean : bool, optional
Zero mean the references and estimates before
computing the loss, by default True
clip_min : int, optional
The minimum possible loss value. Helps network
to not focus on making already good examples better, by default None
weight : float, optional
Weight of this loss, defaults to 1.0.
"""
self.scaling = scaling
self.reduction = reduction
self.zero_mean = zero_mean
self.clip_min = clip_min
self.weight = weight
super().__init__()
def forward(self,
x: Union[AudioSignal, paddle.Tensor],
y: Union[AudioSignal, paddle.Tensor]):
eps = 1e-8
# B, C, T
if isinstance(x, AudioSignal):
references = x.audio_data
estimates = y.audio_data
else:
references = x
estimates = y
nb = references.shape[0]
references = references.reshape([nb, 1, -1]).transpose([0, 2, 1])
estimates = estimates.reshape([nb, 1, -1]).transpose([0, 2, 1])
# samples now on axis 1
if self.zero_mean:
mean_reference = references.mean(axis=1, keepdim=True)
mean_estimate = estimates.mean(axis=1, keepdim=True)
else:
mean_reference = 0
mean_estimate = 0
_references = references - mean_reference
_estimates = estimates - mean_estimate
references_projection = (_references**2).sum(axis=-2) + eps
references_on_estimates = (_estimates * _references).sum(axis=-2) + eps
scale = (
(references_on_estimates / references_projection).unsqueeze(axis=1)
if self.scaling else 1)
e_true = scale * _references
e_res = _estimates - e_true
signal = (e_true**2).sum(axis=1)
noise = (e_res**2).sum(axis=1)
sdr = -10 * paddle.log10(signal / noise + eps)
if self.clip_min is not None:
sdr = paddle.clip(sdr, min=self.clip_min)
if self.reduction == "mean":
sdr = sdr.mean()
elif self.reduction == "sum":
sdr = sdr.sum()
return sdr

@ -1,6 +1,7 @@
function main(){
set -ex
speech_ci_path=`pwd`
pip install ffmpeg flatten_dict ffmpy
echo "Start asr"
cd ${speech_ci_path}/asr
@ -16,6 +17,7 @@ function main(){
python test_enfrontend.py
python test_fftconv1d.py
python test_mixfrontend.py
python test_losses.py
echo "End TTS"
echo "Start Vector"

@ -0,0 +1,61 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddlespeech.audiotools.core.audio_signal import AudioSignal
from paddlespeech.t2s.modules.losses import GANLoss
from paddlespeech.t2s.modules.losses import MultiScaleSTFTLoss
from paddlespeech.t2s.modules.losses import SISDRLoss
def get_input():
x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav",
2_05)
y = x * 0.01
return x, y
def test_multi_scale_stft_loss():
x, y = get_input()
loss = MultiScaleSTFTLoss()
pd_loss = loss(x, y)
np.allclose(pd_loss.numpy(), 7.5622)
def test_sisdr_loss():
x, y = get_input()
loss = SISDRLoss()
pd_loss = loss(x, y)
np.allclose(pd_loss.numpy(), -145.3776)
def test_gan_loss():
class My_discriminator0:
def __call__(self, x):
return x.sum()
class My_discriminator1:
def __call__(self, x):
return x * (-0.2)
x, y = get_input()
loss = GANLoss(My_discriminator0())
pd_loss0, pd_loss1 = loss(x, y)
np.allclose(pd_loss0.numpy(), -0.1027)
np.allclose(pd_loss1.numpy(), -0.0010)
loss = GANLoss(My_discriminator1())
pd_loss0, _ = loss.generator_loss(x, y)
np.allclose(pd_loss0.numpy(), 1.0002)
pd_loss = loss.discriminator_loss(x, y)
np.allclose(pd_loss.numpy(), 1.0002)
Loading…
Cancel
Save