add readme && __all__

pull/3900/head
drryanhuang 9 months ago
parent f0b557648e
commit 643f1c6071

@ -1,23 +1,13 @@
# PaddleAudio
Audiotools is a comprehensive toolkit designed for audio processing and analysis, providing robust solutions for audio signal processing, data management, model training, and evaluation.
安装方式: pip install paddleaudio
### Directory Structure
目前支持的平台Linux, Mac, Windows
- **core directory**: Contains the core class AudioSignal, which is responsible for the fundamental representation and manipulation of audio signals.
## Environment
- **data directory**: Primarily dedicated to storing and processing datasets, including classes and functions for data preprocessing, ensuring efficient loading and transformation of audio data.
## Build wheel
cmd: python setup.py bdist_wheel
- **metrics directory**: Implements functions for various audio evaluation metrics, enabling precise assessment of the performance of audio models and processing algorithms.
Linux test build whl environment:
* os - Ubuntu 16.04.7 LTS
* gcc/g++ - 8.2.0
* cmake - 3.18.0 (need install)
- **ml directory**: Comprises classes and methods related to model training, supporting the construction, training, and optimization of machine learning models in the context of audio.
MACtest build whl environment
* os
* gcc/g++ 12.2.0
* cpu Intel Xeon E5 x86_64
Windows
not support paddleaudio C++ extension lib (sox io, kaldi native fbank)
This project aims to provide developers and researchers with an efficient and flexible framework to foster innovation and exploration across various domains of audio technology.

@ -7,9 +7,9 @@ from ._julius import lowpass_filter
from ._julius import LowPassFilter
from ._julius import LowPassFilters
from ._julius import pure_tone
from ._julius import resample_frac
from ._julius import split_bands
from ._julius import SplitBands
from .audio_signal import AudioSignal
from .audio_signal import STFTParams
from .loudness import Meter
from .resample import resample_frac

@ -7,6 +7,7 @@ This module implements efficient FFT based convolutions for such cases. A typica
application is for evaluating FIR filters with a long receptive field, typically
evaluated with a stride of 1.
"""
import inspect
import math
import typing
from typing import Optional
@ -16,7 +17,203 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .resample import sinc
__all__ = [
'fft_conv1d', 'FFTConv1d', 'highpass_filter', 'highpass_filters',
'lowpass_filter', 'LowPassFilter', 'LowPassFilters', 'pure_tone',
'resample_frac', 'split_bands', 'SplitBands'
]
def simple_repr(obj, attrs: Optional[Sequence[str]]=None, overrides: dict={}):
"""
Return a simple representation string for `obj`.
If `attrs` is not None, it should be a list of attributes to include.
"""
params = inspect.signature(obj.__class__).parameters
attrs_repr = []
if attrs is None:
attrs = list(params.keys())
for attr in attrs:
display = False
if attr in overrides:
value = overrides[attr]
elif hasattr(obj, attr):
value = getattr(obj, attr)
else:
continue
if attr in params:
param = params[attr]
if param.default is inspect._empty or value != param.default: # type: ignore
display = True
else:
display = True
if display:
attrs_repr.append(f"{attr}={value}")
return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
def sinc(x: paddle.Tensor):
"""
Implementation of sinc, i.e. sin(x) / x
__Warning__: the input is not multiplied by `pi`!
"""
return paddle.where(
x == 0,
paddle.to_tensor(1.0, dtype=x.dtype, place=x.place),
paddle.sin(x) / x, )
class ResampleFrac(paddle.nn.Layer):
"""
Resampling from the sample rate `old_sr` to `new_sr`.
"""
def __init__(self,
old_sr: int,
new_sr: int,
zeros: int=24,
rolloff: float=0.945):
"""
Args:
old_sr (int): sample rate of the input signal x.
new_sr (int): sample rate of the output.
zeros (int): number of zero crossing to keep in the sinc filter.
rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`,
to ensure sufficient margin due to the imperfection of the FIR filter used.
Lowering this value will reduce anti-aliasing, but will reduce some of the
highest frequencies.
Shape:
- Input: `[*, T]`
- Output: `[*, T']` with `T' = int(new_sr * T / old_sr)`
.. caution::
After dividing `old_sr` and `new_sr` by their GCD, both should be small
for this implementation to be fast.
>>> import paddle
>>> resample = ResampleFrac(4, 5)
>>> x = paddle.randn([1000])
>>> print(len(resample(x)))
1250
"""
super(ResampleFrac, self).__init__()
if not isinstance(old_sr, int) or not isinstance(new_sr, int):
raise ValueError("old_sr and new_sr should be integers")
gcd = math.gcd(old_sr, new_sr)
self.old_sr = old_sr // gcd
self.new_sr = new_sr // gcd
self.zeros = zeros
self.rolloff = rolloff
self._init_kernels()
def _init_kernels(self):
if self.old_sr == self.new_sr:
return
kernels = []
sr = min(self.new_sr, self.old_sr)
sr *= self.rolloff
self._width = math.ceil(self.zeros * self.old_sr / sr)
idx = paddle.arange(
-self._width, self._width + self.old_sr, dtype="float32")
for i in range(self.new_sr):
t = (-i / self.new_sr + idx / self.old_sr) * sr
t = paddle.clip(t, -self.zeros, self.zeros)
t *= math.pi
window = paddle.cos(t / self.zeros / 2)**2
kernel = sinc(t) * window
# Renormalize kernel to ensure a constant signal is preserved.
kernel = kernel / kernel.sum()
kernels.append(kernel)
_kernel = paddle.stack(kernels).reshape([self.new_sr, 1, -1])
self.kernel = self.create_parameter(
shape=_kernel.shape,
dtype=_kernel.dtype, )
self.kernel.set_value(_kernel)
def forward(
self,
x: paddle.Tensor,
output_length: Optional[int]=None,
full: bool=False, ):
"""
Resample x.
Args:
x (Tensor): signal to resample, time should be the last dimension
output_length (None or int): This can be set to the desired output length
(last dimension). Allowed values are between 0 and
ceil(length * new_sr / old_sr). When None (default) is specified, the
floored output length will be used. In order to select the largest possible
size, use the `full` argument.
full (bool): return the longest possible output from the input. This can be useful
if you chain resampling operations, and want to give the `output_length` only
for the last one, while passing `full=True` to all the other ones.
"""
if self.old_sr == self.new_sr:
return x
shape = x.shape
_dtype = x.dtype
length = x.shape[-1]
x = x.reshape([-1, length])
x = F.pad(
x.unsqueeze(1),
[self._width, self._width + self.old_sr],
mode="replicate",
data_format="NCL", ).astype(self.kernel.dtype)
ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL")
y = ys.transpose(
[0, 2, 1]).reshape(list(shape[:-1]) + [-1]).astype(_dtype)
float_output_length = paddle.to_tensor(
self.new_sr * length / self.old_sr, dtype="float32")
max_output_length = paddle.ceil(float_output_length).astype("int64")
default_output_length = paddle.floor(float_output_length).astype(
"int64")
if output_length is None:
applied_output_length = (max_output_length
if full else default_output_length)
elif output_length < 0 or output_length > max_output_length:
raise ValueError(
f"output_length must be between 0 and {max_output_length.numpy()}"
)
else:
applied_output_length = paddle.to_tensor(
output_length, dtype="int64")
if full:
raise ValueError(
"You cannot pass both full=True and output_length")
return y[..., :applied_output_length]
def __repr__(self):
return simple_repr(self)
def resample_frac(
x: paddle.Tensor,
old_sr: int,
new_sr: int,
zeros: int=24,
rolloff: float=0.945,
output_length: Optional[int]=None,
full: bool=False, ):
"""
Functional version of `ResampleFrac`, refer to its documentation for more information.
..warning::
If you call repeatidly this functions with the same sample rates, then the
resampling kernel will be recomputed everytime. For best performance, you should use
and cache an instance of `ResampleFrac`.
"""
return ResampleFrac(old_sr, new_sr, zeros, rolloff)(x, output_length, full)
def pad_to(tensor: paddle.Tensor,

@ -16,18 +16,20 @@ import paddle
import soundfile
from . import util
from ._julius import resample_frac
from .dsp import DSPMixin
from .effects import EffectMixin
from .effects import ImpulseResponseMixin
from .ffmpeg import FFMPEGMixin
from .loudness import LoudnessMixin
from .resample import resample_frac
# from .display import DisplayMixin
# from .playback import PlayMixin
# from .whisper import WhisperMixin
__all__ = ['STFTParams', 'AudioSignal']
def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> paddle.Tensor:
r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),

@ -207,7 +207,8 @@ class EffectMixin:
"""
peak = self.audio_data.abs().max(axis=-1, keepdim=True)
peak_gain = paddle.ones_like(peak)
peak_gain[peak > _max] = _max / peak[peak > _max]
# peak_gain[peak > _max] = _max / peak[peak > _max]
peak_gain = paddle.where(peak > _max, _max / peak, peak_gain)
self.audio_data = self.audio_data * peak_gain
return self
@ -476,70 +477,72 @@ class EffectMixin:
return self
# def quantization(
# self, quantization_channels: typing.Union[paddle.Tensor, np.ndarray, int]
# ):
# """Applies quantization to the input waveform.
def quantization(self,
quantization_channels: typing.Union[paddle.Tensor,
np.ndarray, int]):
"""Applies quantization to the input waveform.
# Parameters
# ----------
# quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int]
# Number of evenly spaced quantization channels to quantize
# to.
# Returns
# -------
# AudioSignal
# Quantized AudioSignal.
# """
# quantization_channels = util.ensure_tensor(quantization_channels, ndim=3)
# x = self.audio_data
# x = (x + 1) / 2
# x = x * quantization_channels
# x = x.floor()
# x = x / quantization_channels
# x = 2 * x - 1
Parameters
----------
quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int]
Number of evenly spaced quantization channels to quantize
to.
# residual = (self.audio_data - x).detach()
# self.audio_data = self.audio_data - residual
# return self
Returns
-------
AudioSignal
Quantized AudioSignal.
"""
quantization_channels = util.ensure_tensor(
quantization_channels, ndim=3)
x = self.audio_data
x = (x + 1) / 2
x = x * quantization_channels
x = x.floor()
x = x / quantization_channels
x = 2 * x - 1
residual = (self.audio_data - x).detach()
self.audio_data = self.audio_data - residual
return self
# def mulaw_quantization(
# self, quantization_channels: typing.Union[paddle.Tensor, np.ndarray, int]
# ):
# """Applies mu-law quantization to the input waveform.
def mulaw_quantization(self,
quantization_channels: typing.Union[
paddle.Tensor, np.ndarray, int]):
"""Applies mu-law quantization to the input waveform.
# Parameters
# ----------
# quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int]
# Number of mu-law spaced quantization channels to quantize
# to.
Parameters
----------
quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int]
Number of mu-law spaced quantization channels to quantize
to.
# Returns
# -------
# AudioSignal
# Quantized AudioSignal.
# """
# mu = quantization_channels - 1.0
# mu = util.ensure_tensor(mu, ndim=3)
Returns
-------
AudioSignal
Quantized AudioSignal.
"""
mu = quantization_channels - 1.0
mu = util.ensure_tensor(mu, ndim=3)
# x = self.audio_data
x = self.audio_data
# # quantize
# x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
# x = ((x + 1) / 2 * mu + 0.5).to(torch.int64)
# quantize
x = paddle.sign(x) * paddle.log1p(mu * paddle.abs(x)) / paddle.log1p(mu)
x = ((x + 1) / 2 * mu + 0.5).astype("int64")
# # unquantize
# x = (x / mu) * 2 - 1.0
# x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
# unquantize
x = (x / mu) * 2 - 1.0
x = paddle.sign(x) * (
paddle.exp(paddle.abs(x) * paddle.log1p(mu)) - 1.0) / mu
# residual = (self.audio_data - x).detach()
# self.audio_data = self.audio_data - residual
# return self
residual = (self.audio_data - x).detach()
self.audio_data = self.audio_data - residual
return self
# def __matmul__(self, other):
# return self.convolve(other)
def __matmul__(self, other):
return self.convolve(other)
import paddle
@ -591,13 +594,16 @@ class ImpulseResponseMixin:
# direct path and windowed residual.
window = paddle.zeros_like(self.audio_data)
window_idx = paddle.nonzero(early_idx)
for idx in range(self.batch_size):
window_idx = early_idx[idx, 0]
# window_idx = early_idx[idx, 0]
# ----- Just for this -----
# window[idx, ..., window_idx] = self.get_window("hann", window_idx.sum().item())
indices = paddle.nonzero(window_idx).reshape(
[-1]) # shape: [num_true], dtype: int64
# indices = paddle.nonzero(window_idx).reshape(
# [-1]) # shape: [num_true], dtype: int64
indices = window_idx[window_idx[:, 0] == idx][:, -1]
temp_window = self.get_window("hann", indices.shape[0])
window_slice = window[idx, 0]
@ -639,9 +645,9 @@ class ImpulseResponseMixin:
l_sq = late_field**2
a = (wd_sq * e_sq).sum(axis=-1)
b = (2 * (1 - wd) * wd * e_sq).sum(axis=-1)
c = (wd_sq_1 * e_sq).sum(axis=-1) - paddle.pow(
10 * paddle.ones_like(target_drr), target_drr / 10) * l_sq.sum(
axis=-1)
c = (wd_sq_1 * e_sq).sum(axis=-1) - paddle.pow(10 * paddle.ones_like(
target_drr, dtype="float32"), target_drr.cast("float32") /
10) * l_sq.sum(axis=-1)
expr = ((b**2) - 4 * a * c).sqrt()
alpha = paddle.maximum(

@ -1,231 +0,0 @@
import inspect
import math
from typing import Optional
from typing import Sequence
import paddle
import paddle.nn.functional as F
def simple_repr(obj, attrs: Optional[Sequence[str]]=None, overrides: dict={}):
"""
Return a simple representation string for `obj`.
If `attrs` is not None, it should be a list of attributes to include.
"""
params = inspect.signature(obj.__class__).parameters
attrs_repr = []
if attrs is None:
attrs = list(params.keys())
for attr in attrs:
display = False
if attr in overrides:
value = overrides[attr]
elif hasattr(obj, attr):
value = getattr(obj, attr)
else:
continue
if attr in params:
param = params[attr]
if param.default is inspect._empty or value != param.default: # type: ignore
display = True
else:
display = True
if display:
attrs_repr.append(f"{attr}={value}")
return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
def sinc(x: paddle.Tensor):
"""
Implementation of sinc, i.e. sin(x) / x
__Warning__: the input is not multiplied by `pi`!
"""
return paddle.where(
x == 0,
paddle.to_tensor(1.0, dtype=x.dtype, place=x.place),
paddle.sin(x) / x, )
class ResampleFrac(paddle.nn.Layer):
"""
Resampling from the sample rate `old_sr` to `new_sr`.
"""
def __init__(self,
old_sr: int,
new_sr: int,
zeros: int=24,
rolloff: float=0.945):
"""
Args:
old_sr (int): sample rate of the input signal x.
new_sr (int): sample rate of the output.
zeros (int): number of zero crossing to keep in the sinc filter.
rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`,
to ensure sufficient margin due to the imperfection of the FIR filter used.
Lowering this value will reduce anti-aliasing, but will reduce some of the
highest frequencies.
Shape:
- Input: `[*, T]`
- Output: `[*, T']` with `T' = int(new_sr * T / old_sr)`
.. caution::
After dividing `old_sr` and `new_sr` by their GCD, both should be small
for this implementation to be fast.
>>> import paddle
>>> resample = ResampleFrac(4, 5)
>>> x = paddle.randn([1000])
>>> print(len(resample(x)))
1250
"""
super(ResampleFrac, self).__init__()
if not isinstance(old_sr, int) or not isinstance(new_sr, int):
raise ValueError("old_sr and new_sr should be integers")
gcd = math.gcd(old_sr, new_sr)
self.old_sr = old_sr // gcd
self.new_sr = new_sr // gcd
self.zeros = zeros
self.rolloff = rolloff
self._init_kernels()
def _init_kernels(self):
if self.old_sr == self.new_sr:
return
kernels = []
sr = min(self.new_sr, self.old_sr)
# rolloff will perform antialiasing filtering by removing the highest frequencies.
# At first I thought I only needed this when downsampling, but when upsampling
# you will get edge artifacts without this, the edge is equivalent to zero padding,
# which will add high freq artifacts.
sr *= self.rolloff
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
# using the sinc interpolation formula:
# x(t) = sum_i x[i] sinc(pi * old_sr * (i / old_sr - t))
# We can then sample the function x(t) with a different sample rate:
# y[j] = x(j / new_sr)
# or,
# y[j] = sum_i x[i] sinc(pi * old_sr * (i / old_sr - j / new_sr))
# We see here that y[j] is the convolution of x[i] with a specific filter, for which
# we take an FIR approximation, stopping when we see at least `zeros` zeros crossing.
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_sr].
# Indeed:
# y[j + new_sr] = sum_i x[i] sinc(pi * old_sr * ((i / old_sr - (j + new_sr) / new_sr))
# = sum_i x[i] sinc(pi * old_sr * ((i - old_sr) / old_sr - j / new_sr))
# = sum_i x[i + old_sr] sinc(pi * old_sr * (i / old_sr - j / new_sr))
# so y[j+new_sr] uses the same filter as y[j], but on a shifted version of x by `old_sr`.
# This will explain the F.conv1d after, with a stride of old_sr.
self._width = math.ceil(self.zeros * self.old_sr / sr)
# If old_sr is still big after GCD reduction, most filters will be very unbalanced, i.e.,
# they will have a lot of almost zero values to the left or to the right...
# There is probably a way to evaluate those filters more efficiently, but this is kept for
# future work.
idx = paddle.arange(
-self._width, self._width + self.old_sr, dtype="float32")
for i in range(self.new_sr):
t = (-i / self.new_sr + idx / self.old_sr) * sr
t = paddle.clip(t, -self.zeros, self.zeros)
t *= math.pi
window = paddle.cos(t / self.zeros / 2)**2
kernel = sinc(t) * window
# Renormalize kernel to ensure a constant signal is preserved.
kernel = kernel / kernel.sum()
kernels.append(kernel)
_kernel = paddle.stack(kernels).reshape([self.new_sr, 1, -1])
self.kernel = self.create_parameter(
shape=_kernel.shape,
dtype=_kernel.dtype, )
self.kernel.set_value(_kernel)
def forward(
self,
x: paddle.Tensor,
output_length: Optional[int]=None,
full: bool=False, ):
"""
Resample x.
Args:
x (Tensor): signal to resample, time should be the last dimension
output_length (None or int): This can be set to the desired output length
(last dimension). Allowed values are between 0 and
ceil(length * new_sr / old_sr). When None (default) is specified, the
floored output length will be used. In order to select the largest possible
size, use the `full` argument.
full (bool): return the longest possible output from the input. This can be useful
if you chain resampling operations, and want to give the `output_length` only
for the last one, while passing `full=True` to all the other ones.
"""
if self.old_sr == self.new_sr:
return x
shape = x.shape
_dtype = x.dtype
length = x.shape[-1]
x = x.reshape([-1, length])
x = F.pad(
x.unsqueeze(1),
[self._width, self._width + self.old_sr],
mode="replicate",
data_format="NCL", ).astype(self.kernel.dtype)
ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL")
y = ys.transpose(
[0, 2, 1]).reshape(list(shape[:-1]) + [-1]).astype(_dtype)
float_output_length = paddle.to_tensor(
self.new_sr * length / self.old_sr, dtype="float32")
max_output_length = paddle.ceil(float_output_length).astype("int64")
default_output_length = paddle.floor(float_output_length).astype(
"int64")
if output_length is None:
applied_output_length = (max_output_length
if full else default_output_length)
elif output_length < 0 or output_length > max_output_length:
raise ValueError(
f"output_length must be between 0 and {max_output_length.numpy()}"
)
else:
applied_output_length = paddle.to_tensor(
output_length, dtype="int64")
if full:
raise ValueError(
"You cannot pass both full=True and output_length")
return y[..., :applied_output_length]
def __repr__(self):
return simple_repr(self)
def resample_frac(
x: paddle.Tensor,
old_sr: int,
new_sr: int,
zeros: int=24,
rolloff: float=0.945,
output_length: Optional[int]=None,
full: bool=False, ):
"""
Functional version of `ResampleFrac`, refer to its documentation for more information.
..warning::
If you call repeatidly this functions with the same sample rates, then the
resampling kernel will be recomputed everytime. For best performance, you should use
and cache an instance of `ResampleFrac`.
"""
return ResampleFrac(old_sr, new_sr, zeros, rolloff)(x, output_length, full)
if __name__ == "__main__":
resample = ResampleFrac(4, 5)
x = paddle.randn([1000])
print(len(resample(x)))

@ -0,0 +1,363 @@
import sys
import numpy as np
import paddle
import pytest
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools import AudioSignal
def test_normalize():
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=10)
signal = signal.normalize()
assert np.allclose(signal.loudness(), -24, atol=1e-1)
array = np.random.randn(1, 2, 32000)
array = array / np.abs(array).max()
signal = AudioSignal(array, sample_rate=16000)
for db_incr in np.arange(10, 75, 5):
db = -80 + db_incr
signal = signal.normalize(db)
loudness = signal.loudness()
assert np.allclose(loudness, db, atol=1) # TODO, atol=1e-1
batch_size = 16
db = -60 + paddle.linspace(10, 30, batch_size)
array = np.random.randn(batch_size, 2, 32000)
array = array / np.abs(array).max()
signal = AudioSignal(array, sample_rate=16000)
signal = signal.normalize(db)
assert np.allclose(signal.loudness(), db, 1e-1)
def test_volume_change():
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=10)
boost = 3
before_db = signal.loudness().clone()
signal = signal.volume_change(boost)
after_db = signal.loudness()
assert np.allclose(before_db + boost, after_db)
signal._loudness = None
after_db = signal.loudness()
assert np.allclose(before_db + boost, after_db, 1e-1)
def test_mix():
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=10)
audio_path = "tests/audiotools/audio/nz/f5_script2_ipad_balcony1_room_tone.wav"
nz = AudioSignal(audio_path, offset=10, duration=10)
spk.deepcopy().mix(nz, snr=-10)
snr = spk.loudness() - nz.loudness()
assert np.allclose(snr, -10, atol=1)
# Test in batch
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=10)
audio_path = "tests/audiotools/audio/nz/f5_script2_ipad_balcony1_room_tone.wav"
nz = AudioSignal(audio_path, offset=10, duration=10)
batch_size = 4
tgt_snr = paddle.linspace(-10, 10, batch_size)
spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
nz_batch = AudioSignal.batch([nz.deepcopy() for _ in range(batch_size)])
spk_batch.deepcopy().mix(nz_batch, snr=tgt_snr)
snr = spk_batch.loudness() - nz_batch.loudness()
assert np.allclose(snr, tgt_snr, atol=1)
# Test with "EQing" the other signal
db = 0 + 0 * paddle.rand([10])
spk_batch.deepcopy().mix(nz_batch, snr=tgt_snr, other_eq=db)
snr = spk_batch.loudness() - nz_batch.loudness()
assert np.allclose(snr, tgt_snr, atol=1)
def test_convolve():
np.random.seed(6) # Found a failing seed
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=10)
impulse = np.zeros((1, 16000), dtype="float32")
impulse[..., 0] = 1
ir = AudioSignal(impulse, 16000)
batch_size = 4
spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
ir_batch = AudioSignal.batch(
[
ir.deepcopy().zero_pad(np.random.randint(1000), 0)
for _ in range(batch_size)
],
pad_signals=True, )
convolved = spk_batch.deepcopy().convolve(ir_batch)
assert convolved == spk_batch
# Short duration
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=0.1)
impulse = np.zeros((1, 16000), dtype="float32")
impulse[..., 0] = 1
ir = AudioSignal(impulse, 16000)
batch_size = 4
spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
ir_batch = AudioSignal.batch(
[
ir.deepcopy().zero_pad(np.random.randint(1000), 0)
for _ in range(batch_size)
],
pad_signals=True, )
convolved = spk_batch.deepcopy().convolve(ir_batch)
assert convolved == spk_batch
def test_pipeline():
# An actual IR, no batching
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=5)
audio_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav"
ir = AudioSignal(audio_path)
spk.deepcopy().convolve(ir)
audio_path = "tests/audiotools/audio/nz/f5_script2_ipad_balcony1_room_tone.wav"
nz = AudioSignal(audio_path, offset=10, duration=5)
batch_size = 16
tgt_snr = paddle.linspace(20, 30, batch_size)
(spk @ ir).mix(nz, snr=tgt_snr)
# def test_codec():
# audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
# spk = AudioSignal(audio_path, offset=10, duration=10)
# with pytest.raises(ValueError):
# spk.apply_codec("unknown preset")
# out = spk.deepcopy().apply_codec("Ogg")
# out = spk.deepcopy().apply_codec("8-bit")
# def test_pitch_shift():
# audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
# spk = AudioSignal(audio_path, offset=10, duration=1)
# single = spk.deepcopy().pitch_shift(5)
# batch_size = 4
# spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
# batched = spk_batch.deepcopy().pitch_shift(5)
# assert np.allclose(batched[0].audio_data, single[0].audio_data)
# def test_time_stretch():
# audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
# spk = AudioSignal(audio_path, offset=10, duration=1)
# single = spk.deepcopy().time_stretch(0.8)
# batch_size = 4
# spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
# batched = spk_batch.deepcopy().time_stretch(0.8)
# assert np.allclose(batched[0].audio_data, single[0].audio_data)
@pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16])
def test_mel_filterbank(n_bands):
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=1)
fbank = spk.deepcopy().mel_filterbank(n_bands)
assert paddle.allclose(fbank.sum(-1), spk.audio_data, atol=1e-6)
# Check if it works in batches.
spk_batch = AudioSignal.batch([
AudioSignal.excerpt(
"tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2)
for _ in range(16)
])
fbank = spk_batch.deepcopy().mel_filterbank(n_bands)
summed = fbank.sum(-1)
assert paddle.allclose(summed, spk_batch.audio_data, atol=1e-6)
@pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16])
def test_equalizer(n_bands):
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=10)
db = -3 + 1 * paddle.rand([n_bands])
spk.deepcopy().equalizer(db)
db = -3 + 1 * np.random.rand(n_bands)
spk.deepcopy().equalizer(db)
audio_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav"
ir = AudioSignal(audio_path)
db = -3 + 1 * paddle.rand([n_bands])
spk.deepcopy().convolve(ir.equalizer(db))
spk_batch = AudioSignal.batch([
AudioSignal.excerpt(
"tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2)
for _ in range(16)
])
db = paddle.zeros([spk_batch.batch_size, n_bands])
output = spk_batch.deepcopy().equalizer(db)
assert output == spk_batch
def test_clip_distortion():
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=2)
clipped = spk.deepcopy().clip_distortion(0.05)
spk_batch = AudioSignal.batch([
AudioSignal.excerpt(
"tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2)
for _ in range(16)
])
percs = paddle.to_tensor(np.random.uniform(size=(16, ))).astype("float32")
clipped_batch = spk_batch.deepcopy().clip_distortion(percs)
assert clipped.audio_data.abs().max() < 1.0
assert clipped_batch.audio_data.abs().max() < 1.0
@pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128])
def test_quantization(quant_ch):
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=2)
quantized = spk.deepcopy().quantization(quant_ch)
# Need to round audio_data off because torch ops with straight
# through estimator are sometimes a bit off past 3 decimal places.
found_quant_ch = len(np.unique(np.around(quantized.audio_data, decimals=3)))
assert found_quant_ch <= quant_ch
spk_batch = AudioSignal.batch([
AudioSignal.excerpt(
"tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2)
for _ in range(16)
])
quant_ch = np.random.choice(
[2, 4, 8, 16, 32, 64, 128], size=(16, ), replace=True)
quantized = spk_batch.deepcopy().quantization(quant_ch)
for i, q_ch in enumerate(quant_ch):
found_quant_ch = len(
np.unique(np.around(quantized.audio_data[i], decimals=3)))
assert found_quant_ch <= q_ch
@pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128])
def test_mulaw_quantization(quant_ch):
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=2)
quantized = spk.deepcopy().mulaw_quantization(quant_ch)
# Need to round audio_data off because torch ops with straight
# through estimator are sometimes a bit off past 3 decimal places.
found_quant_ch = len(np.unique(np.around(quantized.audio_data, decimals=3)))
assert found_quant_ch <= quant_ch
spk_batch = AudioSignal.batch([
AudioSignal.excerpt(
"tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2)
for _ in range(16)
])
quant_ch = np.random.choice(
[2, 4, 8, 16, 32, 64, 128], size=(16, ), replace=True)
quantized = spk_batch.deepcopy().mulaw_quantization(quant_ch)
for i, q_ch in enumerate(quant_ch):
found_quant_ch = len(
np.unique(np.around(quantized.audio_data[i], decimals=3)))
assert found_quant_ch <= q_ch
def test_impulse_response_augmentation():
audio_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav"
batch_size = 16
ir = AudioSignal(audio_path)
ir_batch = AudioSignal.batch([ir for _ in range(batch_size)])
early_response, late_field, window = ir_batch.decompose_ir()
assert early_response.shape == late_field.shape
assert late_field.shape == window.shape
drr = ir_batch.measure_drr()
alpha = AudioSignal.solve_alpha(early_response, late_field, window, drr)
assert np.allclose(alpha, np.ones_like(alpha), 1e-5)
target_drr = 5
out = ir_batch.deepcopy().alter_drr(target_drr)
drr = out.measure_drr()
assert np.allclose(drr, np.ones_like(drr) * target_drr)
target_drr = np.random.rand(batch_size).astype("float32") * 50
altered_ir = ir_batch.deepcopy().alter_drr(target_drr)
drr = altered_ir.measure_drr()
assert np.allclose(drr.flatten(), target_drr.flatten())
def test_apply_ir():
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
ir_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav"
spk = AudioSignal(audio_path, offset=10, duration=2)
ir = AudioSignal(ir_path)
db = 0 + 0 * paddle.rand([10])
output = spk.deepcopy().apply_ir(ir, drr=10, ir_eq=db)
assert np.allclose(ir.measure_drr().flatten(), 10)
output = spk.deepcopy().apply_ir(
ir, drr=10, ir_eq=db, use_original_phase=True)
def test_ensure_max_of_audio():
spk = AudioSignal(paddle.randn([1, 1, 44100]), 44100)
max_vals = [1.0] + [np.random.rand() for _ in range(10)]
for val in max_vals:
after = spk.deepcopy().ensure_max_of_audio(val)
assert after.audio_data.abs().max() <= val + 1e-3
# Make sure it does nothing to a tiny signal
spk = AudioSignal(paddle.rand([1, 1, 44100]), 44100)
spk.audio_data = spk.audio_data * 0.5
after = spk.deepcopy().ensure_max_of_audio()
assert paddle.allclose(after.audio_data, spk.audio_data)
test_normalize()

@ -0,0 +1,168 @@
import sys
from typing import Callable
import numpy as np
import paddle
import pytest
sys.path.append("/home/aistudio/PaddleSpeech/audio")
from audiotools import AudioSignal
def test_audio_grad():
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
ir_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav"
def _test_audio_grad(attr: str, target=True, kwargs: dict={}):
signal = AudioSignal(audio_path)
signal.audio_data.stop_gradient = False
assert signal.audio_data.grad is None
# Avoid overwriting leaf tensor by cloning signal
attr = getattr(signal.clone(), attr)
result = attr(**kwargs) if isinstance(attr, Callable) else attr
try:
if isinstance(result, AudioSignal):
# If necessary, propagate spectrogram changes to waveform
if result.stft_data is not None:
result.istft()
# if result.audio_data.dtype.is_complex:
if paddle.is_complex(result.audio_data):
result.audio_data.real.sum().backward()
else:
result.audio_data.sum().backward()
else:
# if result.dtype.is_complex:
if paddle.is_complex(result):
result.real().sum().backward()
else:
result.sum().backward()
assert signal.audio_data.grad is not None or not target
except RuntimeError:
assert not target
for a in [
["mix", True, {
"other": AudioSignal(audio_path),
"snr": 0
}],
["convolve", True, {
"other": AudioSignal(ir_path)
}],
[
"apply_ir",
True,
{
"ir": AudioSignal(ir_path),
"drr": 0.1,
"ir_eq": paddle.randn([6])
},
],
["ensure_max_of_audio", True],
["normalize", True],
["volume_change", True, {
"db": 1
}],
# ["pitch_shift", False, {"n_semitones": 1}],
# ["time_stretch", False, {"factor": 2}],
# ["apply_codec", False],
["equalizer", True, {
"db": paddle.randn([6])
}],
["clip_distortion", True, {
"clip_percentile": 0.5
}],
["quantization", True, {
"quantization_channels": 8
}],
["mulaw_quantization", True, {
"quantization_channels": 8
}],
["resample", True, {
"sample_rate": 16000
}],
["low_pass", True, {
"cutoffs": 1000
}],
["high_pass", True, {
"cutoffs": 1000
}],
["to_mono", True],
["zero_pad", True, {
"before": 10,
"after": 10
}],
["magnitude", True],
["phase", True],
["log_magnitude", True],
["loudness", False],
["stft", True],
["clone", True],
["mel_spectrogram", True],
["zero_pad_to", True, {
"length": 100000
}],
["truncate_samples", True, {
"length_in_samples": 1000
}],
["corrupt_phase", True, {
"scale": 0.5
}],
["shift_phase", True, {
"shift": 1
}],
["mask_low_magnitudes", True, {
"db_cutoff": 0
}],
["mask_frequencies", True, {
"fmin_hz": 100,
"fmax_hz": 1000
}],
["mask_timesteps", True, {
"tmin_s": 0.1,
"tmax_s": 0.5
}],
["__add__", True, {
"other": AudioSignal(audio_path)
}],
["__iadd__", True, {
"other": AudioSignal(audio_path)
}],
["__radd__", True, {
"other": AudioSignal(audio_path)
}],
["__sub__", True, {
"other": AudioSignal(audio_path)
}],
["__isub__", True, {
"other": AudioSignal(audio_path)
}],
["__mul__", True, {
"other": AudioSignal(audio_path)
}],
["__imul__", True, {
"other": AudioSignal(audio_path)
}],
["__rmul__", True, {
"other": AudioSignal(audio_path)
}],
]:
_test_audio_grad(*a)
def test_batch_grad():
audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path)
signal.audio_data.stop_gradient = False
assert signal.audio_data.grad is None
batch_size = 16
batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)])
batch.audio_data.sum().backward()
assert signal.audio_data.grad is not None
Loading…
Cancel
Save