You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
173 lines
4.8 KiB
173 lines
4.8 KiB
# MIT License, Copyright (c) 2023-Present, Descript.
|
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_grad.py)
|
|
import sys
|
|
from typing import Callable
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import pytest
|
|
|
|
from paddlespeech.audiotools import AudioSignal
|
|
|
|
|
|
def test_audio_grad():
|
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
|
ir_path = "./audio/ir/h179_Bar_1txts.wav"
|
|
|
|
def _test_audio_grad(attr: str, target=True, kwargs: dict={}):
|
|
signal = AudioSignal(audio_path)
|
|
signal.audio_data.stop_gradient = False
|
|
|
|
assert signal.audio_data.grad is None
|
|
|
|
# Avoid overwriting leaf tensor by cloning signal
|
|
attr = getattr(signal.clone(), attr)
|
|
result = attr(**kwargs) if isinstance(attr, Callable) else attr
|
|
|
|
try:
|
|
if isinstance(result, AudioSignal):
|
|
# If necessary, propagate spectrogram changes to waveform
|
|
if result.stft_data is not None:
|
|
result.istft()
|
|
# if result.audio_data.dtype.is_complex:
|
|
if paddle.is_complex(result.audio_data):
|
|
result.audio_data.real.sum().backward()
|
|
else:
|
|
result.audio_data.sum().backward()
|
|
else:
|
|
# if result.dtype.is_complex:
|
|
if paddle.is_complex(result):
|
|
result.real().sum().backward()
|
|
else:
|
|
result.sum().backward()
|
|
|
|
assert signal.audio_data.grad is not None or not target
|
|
except RuntimeError:
|
|
assert not target
|
|
|
|
for a in [
|
|
["mix", True, {
|
|
"other": AudioSignal(audio_path),
|
|
"snr": 0
|
|
}],
|
|
["convolve", True, {
|
|
"other": AudioSignal(ir_path)
|
|
}],
|
|
[
|
|
"apply_ir",
|
|
True,
|
|
{
|
|
"ir": AudioSignal(ir_path),
|
|
"drr": 0.1,
|
|
"ir_eq": paddle.randn([6])
|
|
},
|
|
],
|
|
["ensure_max_of_audio", True],
|
|
["normalize", True],
|
|
["volume_change", True, {
|
|
"db": 1
|
|
}],
|
|
# ["pitch_shift", False, {"n_semitones": 1}],
|
|
# ["time_stretch", False, {"factor": 2}],
|
|
# ["apply_codec", False],
|
|
["equalizer", True, {
|
|
"db": paddle.randn([6])
|
|
}],
|
|
["clip_distortion", True, {
|
|
"clip_percentile": 0.5
|
|
}],
|
|
["quantization", True, {
|
|
"quantization_channels": 8
|
|
}],
|
|
["mulaw_quantization", True, {
|
|
"quantization_channels": 8
|
|
}],
|
|
["resample", True, {
|
|
"sample_rate": 16000
|
|
}],
|
|
["low_pass", True, {
|
|
"cutoffs": 1000
|
|
}],
|
|
["high_pass", True, {
|
|
"cutoffs": 1000
|
|
}],
|
|
["to_mono", True],
|
|
["zero_pad", True, {
|
|
"before": 10,
|
|
"after": 10
|
|
}],
|
|
["magnitude", True],
|
|
["phase", True],
|
|
["log_magnitude", True],
|
|
["loudness", False],
|
|
["stft", True],
|
|
["clone", True],
|
|
["mel_spectrogram", True],
|
|
["zero_pad_to", True, {
|
|
"length": 100000
|
|
}],
|
|
["truncate_samples", True, {
|
|
"length_in_samples": 1000
|
|
}],
|
|
["corrupt_phase", True, {
|
|
"scale": 0.5
|
|
}],
|
|
["shift_phase", True, {
|
|
"shift": 1
|
|
}],
|
|
["mask_low_magnitudes", True, {
|
|
"db_cutoff": 0
|
|
}],
|
|
["mask_frequencies", True, {
|
|
"fmin_hz": 100,
|
|
"fmax_hz": 1000
|
|
}],
|
|
["mask_timesteps", True, {
|
|
"tmin_s": 0.1,
|
|
"tmax_s": 0.5
|
|
}],
|
|
["__add__", True, {
|
|
"other": AudioSignal(audio_path)
|
|
}],
|
|
["__iadd__", True, {
|
|
"other": AudioSignal(audio_path)
|
|
}],
|
|
["__radd__", True, {
|
|
"other": AudioSignal(audio_path)
|
|
}],
|
|
["__sub__", True, {
|
|
"other": AudioSignal(audio_path)
|
|
}],
|
|
["__isub__", True, {
|
|
"other": AudioSignal(audio_path)
|
|
}],
|
|
["__mul__", True, {
|
|
"other": AudioSignal(audio_path)
|
|
}],
|
|
["__imul__", True, {
|
|
"other": AudioSignal(audio_path)
|
|
}],
|
|
["__rmul__", True, {
|
|
"other": AudioSignal(audio_path)
|
|
}],
|
|
]:
|
|
_test_audio_grad(*a)
|
|
|
|
|
|
def test_batch_grad():
|
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
|
|
|
signal = AudioSignal(audio_path)
|
|
signal.audio_data.stop_gradient = False
|
|
|
|
assert signal.audio_data.grad is None
|
|
|
|
batch_size = 16
|
|
batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)])
|
|
|
|
batch.audio_data.sum().backward()
|
|
|
|
assert signal.audio_data.grad is not None
|