fix codestyle

pull/3988/head
cchenhaifeng 7 months ago
parent 2862ae5bc4
commit 50e4f4ead2

@ -13,3 +13,4 @@
# limitations under the License. # limitations under the License.
import _locale import _locale
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8']) _locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])

@ -231,7 +231,6 @@ def ensure_tensor(
def _get_value(other): def _get_value(other):
# #
from .audio_signal import AudioSignal
if isinstance(other, AudioSignal): if isinstance(other, AudioSignal):
return other.audio_data return other.audio_data
return other return other
@ -800,7 +799,6 @@ def collate(list_of_dicts: list, n_splits: int=None):
batch = {} batch = {}
for k, v in dict_of_lists.items(): for k, v in dict_of_lists.items():
if isinstance(v, list): if isinstance(v, list):
from .audio_signal import AudioSignal
if all(isinstance(s, AudioSignal) for s in v): if all(isinstance(s, AudioSignal) for s in v):
batch[k] = AudioSignal.batch(v, pad_signals=True) batch[k] = AudioSignal.batch(v, pad_signals=True)
else: else:
@ -872,7 +870,6 @@ def generate_chord_dataset(
""" """
import librosa import librosa
from .audio_signal import AudioSignal
from ..data.preprocess import create_csv from ..data.preprocess import create_csv
min_midi = librosa.note_to_midi(min_note) min_midi = librosa.note_to_midi(min_note)

@ -1408,6 +1408,16 @@ class MultiScaleSTFTLoss(nn.Layer):
Returns: Returns:
paddle.Tensor paddle.Tensor
Multi-scale STFT loss. Multi-scale STFT loss.
Example:
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
>>> import paddle
>>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05)
>>> y = x * 0.01
>>> loss = MultiScaleSTFTLoss()
>>> loss(x, y).numpy()
7.562150
""" """
for s in self.stft_params: for s in self.stft_params:
x.stft(s.window_length, s.hop_length, s.window_type) x.stft(s.window_length, s.hop_length, s.window_type)
@ -1425,6 +1435,29 @@ class GANLoss(nn.Layer):
generated waveforms/spectrograms compared to ground truth generated waveforms/spectrograms compared to ground truth
waveforms/spectrograms. Computes the loss for both the waveforms/spectrograms. Computes the loss for both the
discriminator and the generator in separate functions. discriminator and the generator in separate functions.
Example:
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
>>> import paddle
>>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05)
>>> y = x * 0.01
>>> class My_discriminator0:
>>> def __call__(self, x):
>>> return x.sum()
>>> loss = GANLoss(My_discriminator0())
>>> [loss(x, y)[0].numpy(), loss(x, y)[1].numpy()]
[-0.102722, -0.001027]
>>> class My_discriminator1:
>>> def __call__(self, x):
>>> return x.sum()
>>> loss = GANLoss(My_discriminator1())
>>> [loss.generator_loss(x, y)[0].numpy(), loss.generator_loss(x, y)[1].numpy()]
[1.00019, 0]
>>> loss.discriminator_loss(x, y)
1.000200
""" """
def __init__(self, discriminator): def __init__(self, discriminator):
@ -1480,6 +1513,16 @@ class SISDRLoss(nn.Layer):
of estimated and reference audio signals or aligned features. of estimated and reference audio signals or aligned features.
Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/distance.py Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/distance.py
Example:
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
>>> import paddle
>>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05)
>>> y = x * 0.01
>>> sisdr = SISDRLoss()
>>> sisdr(x, y).numpy()
-145.377640
""" """
def __init__( def __init__(

Loading…
Cancel
Save