[Hackathon 7th No.55] Add `audiotools` to `PaddleSpeech` (#3900)

* add AudioSignal && util

* fix codestyle

* add basemodel && decorator

* add util && add quality

* add acc && data && transforms

* add utils

* fix dir

* add *.py; wo unitest

* add unitest

* fix codestyle

* fix cuda error

* add readme && __all__

* add 2 file test

* change download dir

* fix CI download path

* add tar -zxvf

* change requirements path

* add audiotools path

* fix place error

* fix paddle2.5 verion Q

* FFTConv1d -> FFTConv1D

* FFTConv1d -> FFTConv1D

* mv unfold

* add _unfold1d 2 loudness

* fix stupid device variable

* bias -> bias_attr

* () -> []

* fix .to()

* rm 

* fix exp

* deepcopy -> clone

* fix dim error

* fix slice && tensor.to

* fix paddle2.5 index bug

* git rm std

* rm comment && 

* rm some useless comment

* add __all__

* fix codestyle

* fix soundfile.info error

* fix sth

* add License

* fix cycle import

* Adapt to paddle3.0 && update readme

* fix License

* fix License

* rm duplicate requirements

* fix trasform problems

* rm disp

* Update test_transforms.py

* change path

* rm notebook && add audio path

* rm import

* add comment

* fix cycle import && rm TYPE_CHECKING

* rm IPython

* rm sth useless

* rm uesless deps

* Update requirements.txt
pull/3971/head
Ryan 8 months ago committed by GitHub
parent 553a9db374
commit cb15e382cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,68 @@
Audiotools is a comprehensive toolkit designed for audio processing and analysis, providing robust solutions for audio signal processing, data management, model training, and evaluation.
### Directory Structure
```
.
├── audiotools
│ ├── README.md
│ ├── __init__.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── _julius.py
│ │ ├── audio_signal.py
│ │ ├── display.py
│ │ ├── dsp.py
│ │ ├── effects.py
│ │ ├── ffmpeg.py
│ │ ├── loudness.py
│ │ └── util.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── datasets.py
│ │ ├── preprocess.py
│ │ └── transforms.py
│ ├── metrics
│ │ ├── __init__.py
│ │ └── quality.py
│ ├── ml
│ │ ├── __init__.py
│ │ ├── accelerator.py
│ │ ├── basemodel.py
│ │ └── decorators.py
│ ├── requirements.txt
│ └── post.py
├── tests
│ └── audiotools
│ ├── core
│ │ ├── test_audio_signal.py
│ │ ├── test_bands.py
│ │ ├── test_display.py
│ │ ├── test_dsp.py
│ │ ├── test_effects.py
│ │ ├── test_fftconv.py
│ │ ├── test_grad.py
│ │ ├── test_highpass.py
│ │ ├── test_loudness.py
│ │ ├── test_lowpass.py
│ │ └── test_util.py
│ ├── data
│ │ ├── test_datasets.py
│ │ ├── test_preprocess.py
│ │ └── test_transforms.py
│ ├── ml
│ │ ├── test_decorators.py
│ │ └── test_model.py
│ └── test_post.py
```
- **core**: Contains the core class AudioSignal, which is responsible for the fundamental representation and manipulation of audio signals.
- **data**: Primarily dedicated to storing and processing datasets, including classes and functions for data preprocessing, ensuring efficient loading and transformation of audio data.
- **metrics**: Implements functions for various audio evaluation metrics, enabling precise assessment of the performance of audio models and processing algorithms.
- **ml**: Comprises classes and methods related to model training, supporting the construction, training, and optimization of machine learning models in the context of audio.
This project aims to provide developers and researchers with an efficient and flexible framework to foster innovation and exploration across various domains of audio technology.

@ -0,0 +1,25 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import metrics
from . import ml
from . import post
from .core import AudioSignal
from .core import highpass_filter
from .core import highpass_filters
from .core import Meter
from .core import STFTParams
from .core import util
from .data import datasets
from .data import preprocess
from .data import transforms

@ -0,0 +1,28 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import util
from ._julius import fft_conv1d
from ._julius import FFTConv1D
from ._julius import highpass_filter
from ._julius import highpass_filters
from ._julius import lowpass_filter
from ._julius import LowPassFilter
from ._julius import LowPassFilters
from ._julius import pure_tone
from ._julius import resample_frac
from ._julius import split_bands
from ._julius import SplitBands
from .audio_signal import AudioSignal
from .audio_signal import STFTParams
from .loudness import Meter

@ -0,0 +1,666 @@
# MIT License, Copyright (c) 2020 Alexandre Défossez.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from julius(https://github.com/adefossez/julius/tree/main/julius)
"""
Implementation of a FFT based 1D convolution in PaddlePaddle.
While FFT is used in some cases for small kernel sizes, it is not the default for long ones, e.g. 512.
This module implements efficient FFT based convolutions for such cases. A typical
application is for evaluating FIR filters with a long receptive field, typically
evaluated with a stride of 1.
"""
import inspect
import math
import sys
import typing
from typing import Optional
from typing import Sequence
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddlespeech.t2s.modules import fft_conv1d
from paddlespeech.t2s.modules import FFTConv1D
from paddlespeech.utils import satisfy_paddle_version
__all__ = [
'highpass_filter', 'highpass_filters', 'lowpass_filter', 'LowPassFilter',
'LowPassFilters', 'pure_tone', 'resample_frac', 'split_bands', 'SplitBands'
]
def simple_repr(obj, attrs: Optional[Sequence[str]]=None, overrides: dict={}):
"""
Return a simple representation string for `obj`.
If `attrs` is not None, it should be a list of attributes to include.
"""
params = inspect.signature(obj.__class__).parameters
attrs_repr = []
if attrs is None:
attrs = list(params.keys())
for attr in attrs:
display = False
if attr in overrides:
value = overrides[attr]
elif hasattr(obj, attr):
value = getattr(obj, attr)
else:
continue
if attr in params:
param = params[attr]
if param.default is inspect._empty or value != param.default: # type: ignore
display = True
else:
display = True
if display:
attrs_repr.append(f"{attr}={value}")
return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
def sinc(x: paddle.Tensor):
"""
Implementation of sinc, i.e. sin(x) / x
__Warning__: the input is not multiplied by `pi`!
"""
if satisfy_paddle_version("3.0"):
return paddle.sinc(x)
return paddle.where(
x == 0,
paddle.to_tensor(1.0, dtype=x.dtype, place=x.place),
paddle.sin(x) / x, )
class ResampleFrac(paddle.nn.Layer):
"""
Resampling from the sample rate `old_sr` to `new_sr`.
"""
def __init__(self,
old_sr: int,
new_sr: int,
zeros: int=24,
rolloff: float=0.945):
"""
Args:
old_sr (int): sample rate of the input signal x.
new_sr (int): sample rate of the output.
zeros (int): number of zero crossing to keep in the sinc filter.
rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`,
to ensure sufficient margin due to the imperfection of the FIR filter used.
Lowering this value will reduce anti-aliasing, but will reduce some of the
highest frequencies.
Shape:
- Input: `[*, T]`
- Output: `[*, T']` with `T' = int(new_sr * T / old_sr)`
.. caution::
After dividing `old_sr` and `new_sr` by their GCD, both should be small
for this implementation to be fast.
>>> import paddle
>>> resample = ResampleFrac(4, 5)
>>> x = paddle.randn([1000])
>>> print(len(resample(x)))
1250
"""
super().__init__()
if not isinstance(old_sr, int) or not isinstance(new_sr, int):
raise ValueError("old_sr and new_sr should be integers")
gcd = math.gcd(old_sr, new_sr)
self.old_sr = old_sr // gcd
self.new_sr = new_sr // gcd
self.zeros = zeros
self.rolloff = rolloff
self._init_kernels()
def _init_kernels(self):
if self.old_sr == self.new_sr:
return
kernels = []
sr = min(self.new_sr, self.old_sr)
sr *= self.rolloff
self._width = math.ceil(self.zeros * self.old_sr / sr)
idx = paddle.arange(
-self._width, self._width + self.old_sr, dtype="float32")
for i in range(self.new_sr):
t = (-i / self.new_sr + idx / self.old_sr) * sr
t = paddle.clip(t, -self.zeros, self.zeros)
t *= math.pi
window = paddle.cos(t / self.zeros / 2)**2
kernel = sinc(t) * window
# Renormalize kernel to ensure a constant signal is preserved.
kernel = kernel / kernel.sum()
kernels.append(kernel)
_kernel = paddle.stack(kernels).reshape([self.new_sr, 1, -1])
self.kernel = self.create_parameter(
shape=_kernel.shape,
dtype=_kernel.dtype, )
self.kernel.set_value(_kernel)
def forward(
self,
x: paddle.Tensor,
output_length: Optional[int]=None,
full: bool=False, ):
"""
Resample x.
Args:
x (Tensor): signal to resample, time should be the last dimension
output_length (None or int): This can be set to the desired output length
(last dimension). Allowed values are between 0 and
ceil(length * new_sr / old_sr). When None (default) is specified, the
floored output length will be used. In order to select the largest possible
size, use the `full` argument.
full (bool): return the longest possible output from the input. This can be useful
if you chain resampling operations, and want to give the `output_length` only
for the last one, while passing `full=True` to all the other ones.
"""
if self.old_sr == self.new_sr:
return x
shape = x.shape
_dtype = x.dtype
length = x.shape[-1]
x = x.reshape([-1, length])
x = F.pad(
x.unsqueeze(1),
[self._width, self._width + self.old_sr],
mode="replicate",
data_format="NCL", ).astype(self.kernel.dtype)
ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL")
y = ys.transpose(
[0, 2, 1]).reshape(list(shape[:-1]) + [-1]).astype(_dtype)
float_output_length = paddle.to_tensor(
self.new_sr * length / self.old_sr, dtype="float32")
max_output_length = paddle.ceil(float_output_length).astype("int64")
default_output_length = paddle.floor(float_output_length).astype(
"int64")
if output_length is None:
applied_output_length = (max_output_length
if full else default_output_length)
elif output_length < 0 or output_length > max_output_length:
raise ValueError(
f"output_length must be between 0 and {max_output_length.numpy()}"
)
else:
applied_output_length = paddle.to_tensor(
output_length, dtype="int64")
if full:
raise ValueError(
"You cannot pass both full=True and output_length")
return y[..., :applied_output_length]
def __repr__(self):
return simple_repr(self)
def resample_frac(
x: paddle.Tensor,
old_sr: int,
new_sr: int,
zeros: int=24,
rolloff: float=0.945,
output_length: Optional[int]=None,
full: bool=False, ):
"""
Functional version of `ResampleFrac`, refer to its documentation for more information.
..warning::
If you call repeatidly this functions with the same sample rates, then the
resampling kernel will be recomputed everytime. For best performance, you should use
and cache an instance of `ResampleFrac`.
"""
return ResampleFrac(old_sr, new_sr, zeros, rolloff)(x, output_length, full)
def pad_to(tensor: paddle.Tensor,
target_length: int,
mode: str="constant",
value: float=0.0):
"""
Pad the given tensor to the given length, with 0s on the right.
"""
return F.pad(
tensor, (0, target_length - tensor.shape[-1]),
mode=mode,
value=value,
data_format="NCL")
def pure_tone(freq: float, sr: float=128, dur: float=4, device=None):
"""
Return a pure tone, i.e. cosine.
Args:
freq (float): frequency (in Hz)
sr (float): sample rate (in Hz)
dur (float): duration (in seconds)
"""
time = paddle.arange(int(sr * dur), dtype="float32") / sr
return paddle.cos(2 * math.pi * freq * time)
class LowPassFilters(nn.Layer):
"""
Bank of low pass filters.
"""
def __init__(self,
cutoffs: Sequence[float],
stride: int=1,
pad: bool=True,
zeros: float=8,
fft: Optional[bool]=None,
dtype="float32"):
super().__init__()
self.cutoffs = list(cutoffs)
if min(self.cutoffs) < 0:
raise ValueError("Minimum cutoff must be larger than zero.")
if max(self.cutoffs) > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.stride = stride
self.pad = pad
self.zeros = zeros
self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) /
2)
if fft is None:
fft = self.half_size > 32
self.fft = fft
# Create filters
window = paddle.audio.functional.get_window(
"hann", 2 * self.half_size + 1, fftbins=False, dtype=dtype)
time = paddle.arange(
-self.half_size, self.half_size + 1, dtype="float32")
filters = []
for cutoff in cutoffs:
if cutoff == 0:
filter_ = paddle.zeros_like(time)
else:
filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi *
time)
# Normalize filter
filter_ /= paddle.sum(filter_)
filters.append(filter_)
filters = paddle.stack(filters)[:, None]
self.filters = self.create_parameter(
shape=filters.shape,
default_initializer=nn.initializer.Constant(value=0.0),
dtype="float32",
is_bias=False,
attr=paddle.ParamAttr(trainable=False), )
self.filters.set_value(filters)
def forward(self, _input):
shape = list(_input.shape)
_input = _input.reshape([-1, 1, shape[-1]])
if self.pad:
_input = F.pad(
_input, (self.half_size, self.half_size),
mode="replicate",
data_format="NCL")
if self.fft:
out = fft_conv1d(_input, self.filters, stride=self.stride)
else:
out = F.conv1d(_input, self.filters, stride=self.stride)
shape.insert(0, len(self.cutoffs))
shape[-1] = out.shape[-1]
return out.transpose([1, 0, 2]).reshape(shape)
class LowPassFilter(nn.Layer):
"""
Same as `LowPassFilters` but applies a single low pass filter.
"""
def __init__(self,
cutoff: float,
stride: int=1,
pad: bool=True,
zeros: float=8,
fft: Optional[bool]=None):
super().__init__()
self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft)
@property
def cutoff(self):
return self._lowpasses.cutoffs[0]
@property
def stride(self):
return self._lowpasses.stride
@property
def pad(self):
return self._lowpasses.pad
@property
def zeros(self):
return self._lowpasses.zeros
@property
def fft(self):
return self._lowpasses.fft
def forward(self, _input):
return self._lowpasses(_input)[0]
def lowpass_filters(
_input: paddle.Tensor,
cutoffs: Sequence[float],
stride: int=1,
pad: bool=True,
zeros: float=8,
fft: Optional[bool]=None, ):
"""
Functional version of `LowPassFilters`, refer to this class for more information.
"""
return LowPassFilters(cutoffs, stride, pad, zeros, fft)(_input)
def lowpass_filter(_input: paddle.Tensor,
cutoff: float,
stride: int=1,
pad: bool=True,
zeros: float=8,
fft: Optional[bool]=None):
"""
Same as `lowpass_filters` but with a single cutoff frequency.
Output will not have a dimension inserted in the front.
"""
return lowpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0]
class HighPassFilters(paddle.nn.Layer):
"""
Bank of high pass filters. See `julius.lowpass.LowPassFilters` for more
details on the implementation.
Args:
cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where
f_s is the samplerate and `f` is the cutoff frequency.
The upper limit is 0.5, because a signal sampled at `f_s` contains only
frequencies under `f_s / 2`.
stride (int): how much to decimate the output. Probably not a good idea
to do so with a high pass filters though...
pad (bool): if True, appropriately pad the _input with zero over the edge. If `stride=1`,
the output will have the same length as the _input.
zeros (float): Number of zero crossings to keep.
Controls the receptive field of the Finite Impulse Response filter.
For filters with low cutoff frequency, e.g. 40Hz at 44.1kHz,
it is a bad idea to set this to a high value.
This is likely appropriate for most use. Lower values
will result in a faster filter, but with a slower attenuation around the
cutoff frequency.
fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions.
If False, uses PyTorch convolutions. If None, either one will be chosen automatically
depending on the effective filter size.
..warning::
All the filters will use the same filter size, aligned on the lowest
frequency provided. If you combine a lot of filters with very diverse frequencies, it might
be more efficient to split them over multiple modules with similar frequencies.
Shape:
- Input: `[*, T]`
- Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and
`F` is the numer of cutoff frequencies.
>>> highpass = HighPassFilters([1/4])
>>> x = paddle.randn([4, 12, 21, 1024])
>>> list(highpass(x).shape)
[1, 4, 12, 21, 1024]
"""
def __init__(self,
cutoffs: Sequence[float],
stride: int=1,
pad: bool=True,
zeros: float=8,
fft: Optional[bool]=None):
super().__init__()
self._lowpasses = LowPassFilters(cutoffs, stride, pad, zeros, fft)
@property
def cutoffs(self):
return self._lowpasses.cutoffs
@property
def stride(self):
return self._lowpasses.stride
@property
def pad(self):
return self._lowpasses.pad
@property
def zeros(self):
return self._lowpasses.zeros
@property
def fft(self):
return self._lowpasses.fft
def forward(self, _input):
lows = self._lowpasses(_input)
# We need to extract the right portion of the _input in case
# pad is False or stride > 1
if self.pad:
start, end = 0, _input.shape[-1]
else:
start = self._lowpasses.half_size
end = -start
_input = _input[..., start:end:self.stride]
highs = _input - lows
return highs
class HighPassFilter(paddle.nn.Layer):
"""
Same as `HighPassFilters` but applies a single high pass filter.
Shape:
- Input: `[*, T]`
- Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1.
>>> highpass = HighPassFilter(1/4, stride=1)
>>> x = paddle.randn([4, 124])
>>> list(highpass(x).shape)
[4, 124]
"""
def __init__(self,
cutoff: float,
stride: int=1,
pad: bool=True,
zeros: float=8,
fft: Optional[bool]=None):
super().__init__()
self._highpasses = HighPassFilters([cutoff], stride, pad, zeros, fft)
@property
def cutoff(self):
return self._highpasses.cutoffs[0]
@property
def stride(self):
return self._highpasses.stride
@property
def pad(self):
return self._highpasses.pad
@property
def zeros(self):
return self._highpasses.zeros
@property
def fft(self):
return self._highpasses.fft
def forward(self, _input):
return self._highpasses(_input)[0]
def highpass_filters(
_input: paddle.Tensor,
cutoffs: Sequence[float],
stride: int=1,
pad: bool=True,
zeros: float=8,
fft: Optional[bool]=None, ):
"""
Functional version of `HighPassFilters`, refer to this class for more information.
"""
return HighPassFilters(cutoffs, stride, pad, zeros, fft)(_input)
def highpass_filter(_input: paddle.Tensor,
cutoff: float,
stride: int=1,
pad: bool=True,
zeros: float=8,
fft: Optional[bool]=None):
"""
Functional version of `HighPassFilter`, refer to this class for more information.
Output will not have a dimension inserted in the front.
"""
return highpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0]
class SplitBands(paddle.nn.Layer):
"""
Decomposes a signal over the given frequency bands in the waveform domain using
a cascade of low pass filters as implemented by `julius.lowpass.LowPassFilters`.
You can either specify explicitly the frequency cutoffs, or just the number of bands,
in which case the frequency cutoffs will be spread out evenly in mel scale.
Args:
sample_rate (float): Sample rate of the input signal in Hz.
n_bands (int or None): number of bands, when not giving them explicitly with `cutoffs`.
In that case, the cutoff frequencies will be evenly spaced in mel-space.
cutoffs (list[float] or None): list of frequency cutoffs in Hz.
pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`,
the output will have the same length as the input.
zeros (float): Number of zero crossings to keep. See `LowPassFilters` for more informations.
fft (bool or None): See `LowPassFilters` for more info.
..note::
The sum of all the bands will always be the input signal.
..warning::
Unlike `julius.lowpass.LowPassFilters`, the cutoffs frequencies must be provided in Hz along
with the sample rate.
Shape:
- Input: `[*, T]`
- Output: `[B, *, T']`, with `T'=T` if `pad` is True.
If `n_bands` was provided, `B = n_bands` otherwise `B = len(cutoffs) + 1`
>>> bands = SplitBands(sample_rate=128, n_bands=10)
>>> x = paddle.randn(shape=[6, 4, 1024])
>>> list(bands(x).shape)
[10, 6, 4, 1024]
"""
def __init__(
self,
sample_rate: float,
n_bands: Optional[int]=None,
cutoffs: Optional[Sequence[float]]=None,
pad: bool=True,
zeros: float=8,
fft: Optional[bool]=None, ):
super().__init__()
if (cutoffs is None) + (n_bands is None) != 1:
raise ValueError(
"You must provide either n_bands, or cutoffs, but not both.")
self.sample_rate = sample_rate
self.n_bands = n_bands
self._cutoffs = list(cutoffs) if cutoffs is not None else None
self.pad = pad
self.zeros = zeros
self.fft = fft
if cutoffs is None:
if n_bands is None:
raise ValueError("You must provide one of n_bands or cutoffs.")
if not n_bands >= 1:
raise ValueError(
f"n_bands must be greater than one (got {n_bands})")
cutoffs = paddle.audio.functional.mel_frequencies(
n_bands + 1, 0, sample_rate / 2)[1:-1]
else:
if max(cutoffs) > 0.5 * sample_rate:
raise ValueError(
"A cutoff above sample_rate/2 does not make sense.")
if len(cutoffs) > 0:
self.lowpass = LowPassFilters(
[c / sample_rate for c in cutoffs],
pad=pad,
zeros=zeros,
fft=fft)
else:
self.lowpass = None # type: ignore
def forward(self, input):
if self.lowpass is None:
return input[None]
lows = self.lowpass(input)
low = lows[0]
bands = [low]
for low_and_band in lows[1:]:
# Get a bandpass filter by subtracting lowpasses
band = low_and_band - low
bands.append(band)
low = low_and_band
# Last band is whatever is left in the signal
bands.append(input - low)
return paddle.stack(bands)
@property
def cutoffs(self):
if self._cutoffs is not None:
return self._cutoffs
elif self.lowpass is not None:
return [c * self.sample_rate for c in self.lowpass.cutoffs]
else:
return []
def split_bands(
signal: paddle.Tensor,
sample_rate: float,
n_bands: Optional[int]=None,
cutoffs: Optional[Sequence[float]]=None,
pad: bool=True,
zeros: float=8,
fft: Optional[bool]=None, ):
"""
Functional version of `SplitBands`, refer to this class for more information.
>>> x = paddle.randn(shape=[6, 4, 1024])
>>> list(split_bands(x, sample_rate=64, cutoffs=[12, 24]).shape)
[3, 6, 4, 1024]
"""
return SplitBands(sample_rate, n_bands, cutoffs, pad, zeros, fft)(signal)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,195 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/display.py)
import inspect
import typing
from functools import wraps
from . import util
def format_figure(func):
"""Decorator for formatting figures produced by the code below.
See :py:func:`audiotools.core.util.format_figure` for more.
Parameters
----------
func : Callable
Plotting function that is decorated by this function.
"""
@wraps(func)
def wrapper(*args, **kwargs):
f_keys = inspect.signature(util.format_figure).parameters.keys()
f_kwargs = {}
for k, v in list(kwargs.items()):
if k in f_keys:
kwargs.pop(k)
f_kwargs[k] = v
func(*args, **kwargs)
util.format_figure(**f_kwargs)
return wrapper
class DisplayMixin:
@format_figure
def specshow(
self,
preemphasis: bool=False,
x_axis: str="time",
y_axis: str="linear",
n_mels: int=128,
**kwargs, ):
"""Displays a spectrogram, using ``librosa.display.specshow``.
Parameters
----------
preemphasis : bool, optional
Whether or not to apply preemphasis, which makes high
frequency detail easier to see, by default False
x_axis : str, optional
How to label the x axis, by default "time"
y_axis : str, optional
How to label the y axis, by default "linear"
n_mels : int, optional
If displaying a mel spectrogram with ``y_axis = "mel"``,
this controls the number of mels, by default 128.
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
"""
import librosa
import librosa.display
# Always re-compute the STFT data before showing it, in case
# it changed.
signal = self.clone()
signal.stft_data = None
if preemphasis:
signal.preemphasis()
ref = signal.magnitude.max()
log_mag = signal.log_magnitude(ref_value=ref)
if y_axis == "mel":
log_mag = 20 * signal.mel_spectrogram(n_mels).clip(1e-5).log10()
log_mag -= log_mag.max()
librosa.display.specshow(
log_mag.numpy()[0].mean(axis=0),
x_axis=x_axis,
y_axis=y_axis,
sr=signal.sample_rate,
**kwargs, )
@format_figure
def waveplot(self, x_axis: str="time", **kwargs):
"""Displays a waveform plot, using ``librosa.display.waveshow``.
Parameters
----------
x_axis : str, optional
How to label the x axis, by default "time"
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
"""
import librosa
import librosa.display
audio_data = self.audio_data[0].mean(axis=0)
audio_data = audio_data.cpu().numpy()
plot_fn = "waveshow" if hasattr(librosa.display,
"waveshow") else "waveplot"
wave_plot_fn = getattr(librosa.display, plot_fn)
wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)
@format_figure
def wavespec(self, x_axis: str="time", **kwargs):
"""Displays a waveform plot, using ``librosa.display.waveshow``.
Parameters
----------
x_axis : str, optional
How to label the x axis, by default "time"
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
"""
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
gs = GridSpec(6, 1)
plt.subplot(gs[0, :])
self.waveplot(x_axis=x_axis)
plt.subplot(gs[1:, :])
self.specshow(x_axis=x_axis, **kwargs)
def write_audio_to_tb(
self,
tag: str,
writer,
step: int=None,
plot_fn: typing.Union[typing.Callable, str]="specshow",
**kwargs, ):
"""Writes a signal and its spectrogram to Tensorboard. Will show up
under the Audio and Images tab in Tensorboard.
Parameters
----------
tag : str
Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
writer : SummaryWriter
A SummaryWriter object from PyTorch library.
step : int, optional
The step to write the signal to, by default None
plot_fn : typing.Union[typing.Callable, str], optional
How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
whatever ``plot_fn`` is set to.
"""
import matplotlib.pyplot as plt
audio_data = self.audio_data[0, 0].detach().cpu().numpy()
sample_rate = self.sample_rate
writer.add_audio(tag, audio_data, step, sample_rate)
if plot_fn is not None:
if isinstance(plot_fn, str):
plot_fn = getattr(self, plot_fn)
fig = plt.figure()
plt.clf()
plot_fn(**kwargs)
writer.add_figure(tag.replace("wav", "png"), fig, step)
def save_image(
self,
image_path: str,
plot_fn: typing.Union[typing.Callable, str]="specshow",
**kwargs, ):
"""Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
a specified file.
Parameters
----------
image_path : str
Where to save the file to.
plot_fn : typing.Union[typing.Callable, str], optional
How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
whatever ``plot_fn`` is set to.
"""
import matplotlib.pyplot as plt
if isinstance(plot_fn, str):
plot_fn = getattr(self, plot_fn)
plt.clf()
plot_fn(**kwargs)
plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
plt.close()

@ -0,0 +1,467 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/dsp.py)
import typing
import numpy as np
import paddle
from . import _julius
from . import util
def _unfold(x, kernel_sizes, strides):
# https://github.com/PaddlePaddle/Paddle/pull/70102
if 1 == kernel_sizes[0]:
x_zeros = paddle.zeros_like(x)
x = paddle.concat([x, x_zeros], axis=2)
kernel_sizes = [2, kernel_sizes[1]]
strides = list(strides)
unfolded = paddle.nn.functional.unfold(
x,
kernel_sizes=kernel_sizes,
strides=strides, )
if 2 == kernel_sizes[0]:
unfolded = unfolded[:, :kernel_sizes[1]]
return unfolded
def _fold(x, output_sizes, kernel_sizes, strides):
# https://github.com/PaddlePaddle/Paddle/pull/70102
if 1 == output_sizes[0] and 1 == kernel_sizes[0]:
x_zeros = paddle.zeros_like(x)
x = paddle.concat([x, x_zeros], axis=1)
output_sizes = (2, output_sizes[1])
kernel_sizes = (2, kernel_sizes[1])
fold = paddle.nn.functional.fold(
x,
output_sizes=output_sizes,
kernel_sizes=kernel_sizes,
strides=strides, )
if 2 == kernel_sizes[0]:
fold = fold[:, :, :1]
return fold
class DSPMixin:
_original_batch_size = None
_original_num_channels = None
_padded_signal_length = None
def _preprocess_signal_for_windowing(self, window_duration, hop_duration):
self._original_batch_size = self.batch_size
self._original_num_channels = self.num_channels
window_length = int(window_duration * self.sample_rate)
hop_length = int(hop_duration * self.sample_rate)
if window_length % hop_length != 0:
factor = window_length // hop_length
window_length = factor * hop_length
self.zero_pad(hop_length, hop_length)
self._padded_signal_length = self.signal_length
return window_length, hop_length
def windows(self,
window_duration: float,
hop_duration: float,
preprocess: bool=True):
"""Generator which yields windows of specified duration from signal with a specified
hop length.
Parameters
----------
window_duration : float
Duration of every window in seconds.
hop_duration : float
Hop between windows in seconds.
preprocess : bool, optional
Whether to preprocess the signal, so that the first sample is in
the middle of the first window, by default True
Yields
------
AudioSignal
Each window is returned as an AudioSignal.
"""
if preprocess:
window_length, hop_length = self._preprocess_signal_for_windowing(
window_duration, hop_duration)
self.audio_data = self.audio_data.reshape([-1, 1, self.signal_length])
for b in range(self.batch_size):
i = 0
start_idx = i * hop_length
while True:
start_idx = i * hop_length
i += 1
end_idx = start_idx + window_length
if end_idx > self.signal_length:
break
yield self[b, ..., start_idx:end_idx]
def collect_windows(self,
window_duration: float,
hop_duration: float,
preprocess: bool=True):
"""Reshapes signal into windows of specified duration from signal with a specified
hop length. Window are placed along the batch dimension. Use with
:py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the
original signal.
Parameters
----------
window_duration : float
Duration of every window in seconds.
hop_duration : float
Hop between windows in seconds.
preprocess : bool, optional
Whether to preprocess the signal, so that the first sample is in
the middle of the first window, by default True
Returns
-------
AudioSignal
AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)``
"""
if preprocess:
window_length, hop_length = self._preprocess_signal_for_windowing(
window_duration, hop_duration)
# self.audio_data: (nb, nch, nt).
# unfolded = paddle.nn.functional.unfold(
# self.audio_data.reshape([-1, 1, 1, self.signal_length]),
# kernel_sizes=(1, window_length),
# strides=(1, hop_length),
# )
unfolded = _unfold(
self.audio_data.reshape([-1, 1, 1, self.signal_length]),
kernel_sizes=(1, window_length),
strides=(1, hop_length), )
# unfolded: (nb * nch, window_length, num_windows).
# -> (nb * nch * num_windows, 1, window_length)
unfolded = unfolded.transpose([0, 2, 1]).reshape([-1, 1, window_length])
self.audio_data = unfolded
return self
def overlap_and_add(self, hop_duration: float):
"""Function which takes a list of windows and overlap adds them into a
signal the same length as ``audio_signal``.
Parameters
----------
hop_duration : float
How much to shift for each window
(overlap is window_duration - hop_duration) in seconds.
Returns
-------
AudioSignal
overlap-and-added signal.
"""
hop_length = int(hop_duration * self.sample_rate)
window_length = self.signal_length
nb, nch = self._original_batch_size, self._original_num_channels
unfolded = self.audio_data.reshape(
[nb * nch, -1, window_length]).transpose([0, 2, 1])
# folded = paddle.nn.functional.fold(
# unfolded,
# output_sizes=(1, self._padded_signal_length),
# kernel_sizes=(1, window_length),
# strides=(1, hop_length),
# )
folded = _fold(
unfolded,
output_sizes=(1, self._padded_signal_length),
kernel_sizes=(1, window_length),
strides=(1, hop_length), )
norm = paddle.ones_like(unfolded)
# norm = paddle.nn.functional.fold(
# norm,
# output_sizes=(1, self._padded_signal_length),
# kernel_sizes=(1, window_length),
# strides=(1, hop_length),
# )
norm = _fold(
norm,
output_sizes=(1, self._padded_signal_length),
kernel_sizes=(1, window_length),
strides=(1, hop_length), )
folded = folded / norm
folded = folded.reshape([nb, nch, -1])
self.audio_data = folded
self.trim(hop_length, hop_length)
return self
def low_pass(self,
cutoffs: typing.Union[paddle.Tensor, np.ndarray, float],
zeros: int=51):
"""Low-passes the signal in-place. Each item in the batch
can have a different low-pass cutoff, if the input
to this signal is an array or tensor. If a float, all
items are given the same low-pass filter.
Parameters
----------
cutoffs : typing.Union[paddle.Tensor, np.ndarray, float]
Cutoff in Hz of low-pass filter.
zeros : int, optional
Number of taps to use in low-pass filter, by default 51
Returns
-------
AudioSignal
Low-passed AudioSignal.
"""
cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
cutoffs = cutoffs / self.sample_rate
filtered = paddle.empty_like(self.audio_data)
for i, cutoff in enumerate(cutoffs):
lp_filter = _julius.LowPassFilter(cutoff.cpu(), zeros=zeros)
filtered[i] = lp_filter(self.audio_data[i])
self.audio_data = filtered
self.stft_data = None
return self
def high_pass(self,
cutoffs: typing.Union[paddle.Tensor, np.ndarray, float],
zeros: int=51):
"""High-passes the signal in-place. Each item in the batch
can have a different high-pass cutoff, if the input
to this signal is an array or tensor. If a float, all
items are given the same high-pass filter.
Parameters
----------
cutoffs : typing.Union[paddle.Tensor, np.ndarray, float]
Cutoff in Hz of high-pass filter.
zeros : int, optional
Number of taps to use in high-pass filter, by default 51
Returns
-------
AudioSignal
High-passed AudioSignal.
"""
cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
cutoffs = cutoffs / self.sample_rate
filtered = paddle.empty_like(self.audio_data)
for i, cutoff in enumerate(cutoffs):
hp_filter = _julius.HighPassFilter(cutoff.cpu(), zeros=zeros)
filtered[i] = hp_filter(self.audio_data[i])
self.audio_data = filtered
self.stft_data = None
return self
def mask_frequencies(
self,
fmin_hz: typing.Union[paddle.Tensor, np.ndarray, float],
fmax_hz: typing.Union[paddle.Tensor, np.ndarray, float],
val: float=0.0, ):
"""Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
with the value specified by ``val``. Useful for implementing SpecAug.
The min and max can be different for every item in the batch.
Parameters
----------
fmin_hz : typing.Union[paddle.Tensor, np.ndarray, float]
Lower end of band to mask out.
fmax_hz : typing.Union[paddle.Tensor, np.ndarray, float]
Upper end of band to mask out.
val : float, optional
Value to fill in, by default 0.0
Returns
-------
AudioSignal
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
masked audio data.
"""
# SpecAug
mag, phase = self.magnitude, self.phase
fmin_hz = util.ensure_tensor(
fmin_hz,
ndim=mag.ndim, )
fmax_hz = util.ensure_tensor(
fmax_hz,
ndim=mag.ndim, )
assert paddle.all(fmin_hz < fmax_hz)
# build mask
nbins = mag.shape[-2]
bins_hz = paddle.linspace(
0,
self.sample_rate / 2,
nbins, )
bins_hz = bins_hz[None, None, :, None].tile(
[self.batch_size, 1, 1, mag.shape[-1]])
fmin_hz, fmax_hz = fmin_hz.astype(bins_hz.dtype), fmax_hz.astype(
bins_hz.dtype)
mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
mag = paddle.where(mask, paddle.full_like(mag, val), mag)
phase = paddle.where(mask, paddle.full_like(phase, val), phase)
self.stft_data = mag * util.exp_compat(1j * phase)
return self
def mask_timesteps(
self,
tmin_s: typing.Union[paddle.Tensor, np.ndarray, float],
tmax_s: typing.Union[paddle.Tensor, np.ndarray, float],
val: float=0.0, ):
"""Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
with the value specified by ``val``. Useful for implementing SpecAug.
The min and max can be different for every item in the batch.
Parameters
----------
tmin_s : typing.Union[paddle.Tensor, np.ndarray, float]
Lower end of timesteps to mask out.
tmax_s : typing.Union[paddle.Tensor, np.ndarray, float]
Upper end of timesteps to mask out.
val : float, optional
Value to fill in, by default 0.0
Returns
-------
AudioSignal
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
masked audio data.
"""
# SpecAug
mag, phase = self.magnitude, self.phase
tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
assert paddle.all(tmin_s < tmax_s)
# build mask
nt = mag.shape[-1]
bins_t = paddle.linspace(
0,
self.signal_duration,
nt, )
bins_t = bins_t[None, None, None, :].tile(
[self.batch_size, 1, mag.shape[-2], 1])
mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
# mag = mag.masked_fill(mask, val)
# phase = phase.masked_fill(mask, val)
mag = paddle.where(mask, paddle.full_like(mag, val), mag)
phase = paddle.where(mask, paddle.full_like(phase, val), phase)
self.stft_data = mag * util.exp_compat(1j * phase)
return self
def mask_low_magnitudes(
self,
db_cutoff: typing.Union[paddle.Tensor, np.ndarray, float],
val: float=0.0):
"""Mask away magnitudes below a specified threshold, which
can be different for every item in the batch.
Parameters
----------
db_cutoff : typing.Union[paddle.Tensor, np.ndarray, float]
Decibel value for which things below it will be masked away.
val : float, optional
Value to fill in for masked portions, by default 0.0
Returns
-------
AudioSignal
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
masked audio data.
"""
mag = self.magnitude
log_mag = self.log_magnitude()
db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
db_cutoff = db_cutoff.astype(log_mag.dtype)
mask = log_mag < db_cutoff
# mag = mag.masked_fill(mask, val)
mag = paddle.where(mask, mag, val * paddle.ones_like(mag))
self.magnitude = mag
return self
def shift_phase(self,
shift: typing.Union[paddle.Tensor, np.ndarray, float]):
"""Shifts the phase by a constant value.
Parameters
----------
shift : typing.Union[paddle.Tensor, np.ndarray, float]
What to shift the phase by.
Returns
-------
AudioSignal
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
masked audio data.
"""
shift = util.ensure_tensor(shift, ndim=self.phase.ndim)
shift = shift.astype(self.phase.dtype)
self.phase = self.phase + shift
return self
def corrupt_phase(self,
scale: typing.Union[paddle.Tensor, np.ndarray, float]):
"""Corrupts the phase randomly by some scaled value.
Parameters
----------
scale : typing.Union[paddle.Tensor, np.ndarray, float]
Standard deviation of noise to add to the phase.
Returns
-------
AudioSignal
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
masked audio data.
"""
scale = util.ensure_tensor(scale, ndim=self.phase.ndim)
self.phase = self.phase + scale * paddle.randn(
shape=self.phase.shape, dtype=self.phase.dtype)
return self
def preemphasis(self, coef: float=0.85):
"""Applies pre-emphasis to audio signal.
Parameters
----------
coef : float, optional
How much pre-emphasis to apply, lower values do less. 0 does nothing.
by default 0.85
Returns
-------
AudioSignal
Pre-emphasized signal.
"""
kernel = paddle.to_tensor([1, -coef, 0]).reshape([1, 1, -1])
x = self.audio_data.reshape([-1, 1, self.signal_length])
x = paddle.nn.functional.conv1d(
x.astype(kernel.dtype), kernel, padding=1)
self.audio_data = x.reshape(self.audio_data.shape)
return self

@ -0,0 +1,539 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/effects.py)
import typing
import numpy as np
import paddle
from . import util
from ._julius import SplitBands
class EffectMixin:
GAIN_FACTOR = np.log(10) / 20
"""Gain factor for converting between amplitude and decibels."""
CODEC_PRESETS = {
"8-bit": {
"format": "wav",
"encoding": "ULAW",
"bits_per_sample": 8
},
"GSM-FR": {
"format": "gsm"
},
"MP3": {
"format": "mp3",
"compression": -9
},
"Vorbis": {
"format": "vorbis",
"compression": -1
},
"Ogg": {
"format": "ogg",
"compression": -1,
},
"Amr-nb": {
"format": "amr-nb"
},
}
"""Presets for applying codecs via torchaudio."""
def mix(
self,
other,
snr: typing.Union[paddle.Tensor, np.ndarray, float]=10,
other_eq: typing.Union[paddle.Tensor, np.ndarray]=None, ):
"""Mixes noise with signal at specified
signal-to-noise ratio. Optionally, the
other signal can be equalized in-place.
Parameters
----------
other : AudioSignal
AudioSignal object to mix with.
snr : typing.Union[paddle.Tensor, np.ndarray, float], optional
Signal to noise ratio, by default 10
other_eq : typing.Union[paddle.Tensor, np.ndarray], optional
EQ curve to apply to other signal, if any, by default None
Returns
-------
AudioSignal
In-place modification of AudioSignal.
"""
snr = util.ensure_tensor(snr)
pad_len = max(0, self.signal_length - other.signal_length)
other.zero_pad(0, pad_len)
other.truncate_samples(self.signal_length)
if other_eq is not None:
other = other.equalizer(other_eq)
tgt_loudness = self.loudness() - snr
other = other.normalize(tgt_loudness)
self.audio_data = self.audio_data + other.audio_data
return self
def convolve(self, other, start_at_max: bool=True):
"""Convolves self with other.
This function uses FFTs to do the convolution.
Parameters
----------
other : AudioSignal
Signal to convolve with.
start_at_max : bool, optional
Whether to start at the max value of other signal, to
avoid inducing delays, by default True
Returns
-------
AudioSignal
Convolved signal, in-place.
"""
from . import AudioSignal
pad_len = self.signal_length - other.signal_length
if pad_len > 0:
other.zero_pad(0, pad_len)
else:
other.truncate_samples(self.signal_length)
if start_at_max:
# Use roll to rotate over the max for every item
# so that the impulse responses don't induce any
# delay.
idx = paddle.argmax(paddle.abs(other.audio_data), axis=-1)
irs = paddle.zeros_like(other.audio_data)
for i in range(other.batch_size):
irs[i] = paddle.roll(
other.audio_data[i], shifts=-idx[i].item(), axis=-1)
other = AudioSignal(irs, other.sample_rate)
delta = paddle.zeros_like(other.audio_data)
delta[..., 0] = 1
length = self.signal_length
delta_fft = paddle.fft.rfft(delta, n=length)
other_fft = paddle.fft.rfft(other.audio_data, n=length)
self_fft = paddle.fft.rfft(self.audio_data, n=length)
convolved_fft = other_fft * self_fft
convolved_audio = paddle.fft.irfft(convolved_fft, n=length)
delta_convolved_fft = other_fft * delta_fft
delta_audio = paddle.fft.irfft(delta_convolved_fft, n=length)
# Use the delta to rescale the audio exactly as needed.
delta_max = paddle.max(paddle.abs(delta_audio), axis=-1, keepdim=True)
scale = 1 / paddle.clip(delta_max, min=1e-5)
convolved_audio = convolved_audio * scale
self.audio_data = convolved_audio
return self
def apply_ir(
self,
ir,
drr: typing.Union[paddle.Tensor, np.ndarray, float]=None,
ir_eq: typing.Union[paddle.Tensor, np.ndarray]=None,
use_original_phase: bool=False, ):
"""Applies an impulse response to the signal. If ` is`ir_eq``
is specified, the impulse response is equalized before
it is applied, using the given curve.
Parameters
----------
ir : AudioSignal
Impulse response to convolve with.
drr : typing.Union[paddle.Tensor, np.ndarray, float], optional
Direct-to-reverberant ratio that impulse response will be
altered to, if specified, by default None
ir_eq : typing.Union[paddle.Tensor, np.ndarray], optional
Equalization that will be applied to impulse response
if specified, by default None
use_original_phase : bool, optional
Whether to use the original phase, instead of the convolved
phase, by default False
Returns
-------
AudioSignal
Signal with impulse response applied to it
"""
if ir_eq is not None:
ir = ir.equalizer(ir_eq)
if drr is not None:
ir = ir.alter_drr(drr)
# Save the peak before
max_spk = self.audio_data.abs().max(axis=-1, keepdim=True)
# Augment the impulse response to simulate microphone effects
# and with varying direct-to-reverberant ratio.
phase = self.phase
self.convolve(ir)
# Use the input phase
if use_original_phase:
self.stft()
self.stft_data = self.magnitude * util.exp_compat(1j * phase)
self.istft()
# Rescale to the input's amplitude
max_transformed = self.audio_data.abs().max(axis=-1, keepdim=True)
scale_factor = max_spk.clip(1e-8) / max_transformed.clip(1e-8)
self = self * scale_factor
return self
def ensure_max_of_audio(self, _max: float=1.0):
"""Ensures that ``abs(audio_data) <= max``.
Parameters
----------
max : float, optional
Max absolute value of signal, by default 1.0
Returns
-------
AudioSignal
Signal with values scaled between -max and max.
"""
peak = self.audio_data.abs().max(axis=-1, keepdim=True)
peak_gain = paddle.ones_like(peak)
# peak_gain[peak > _max] = _max / peak[peak > _max]
peak_gain = paddle.where(peak > _max, _max / peak, peak_gain)
self.audio_data = self.audio_data * peak_gain
return self
def normalize(self,
db: typing.Union[paddle.Tensor, np.ndarray, float]=-24.0):
"""Normalizes the signal's volume to the specified db, in LUFS.
This is GPU-compatible, making for very fast loudness normalization.
Parameters
----------
db : typing.Union[paddle.Tensor, np.ndarray, float], optional
Loudness to normalize to, by default -24.0
Returns
-------
AudioSignal
Normalized audio signal.
"""
db = util.ensure_tensor(db)
ref_db = self.loudness()
gain = db.astype(ref_db.dtype) - ref_db
gain = util.exp_compat(gain * self.GAIN_FACTOR)
self.audio_data = self.audio_data * gain[:, None, None]
return self
def volume_change(self, db: typing.Union[paddle.Tensor, np.ndarray, float]):
"""Change volume of signal by some amount, in dB.
Parameters
----------
db : typing.Union[paddle.Tensor, np.ndarray, float]
Amount to change volume by.
Returns
-------
AudioSignal
Signal at new volume.
"""
db = util.ensure_tensor(db, ndim=1)
gain = util.exp_compat(db * self.GAIN_FACTOR)
self.audio_data = self.audio_data * gain[:, None, None]
return self
def mel_filterbank(self, n_bands: int):
"""Breaks signal into mel bands.
Parameters
----------
n_bands : int
Number of mel bands to use.
Returns
-------
paddle.Tensor
Mel-filtered bands, with last axis being the band index.
"""
filterbank = SplitBands(self.sample_rate, n_bands)
filtered = filterbank(self.audio_data)
return filtered.transpose([1, 2, 3, 0])
def equalizer(self, db: typing.Union[paddle.Tensor, np.ndarray]):
"""Applies a mel-spaced equalizer to the audio signal.
Parameters
----------
db : typing.Union[paddle.Tensor, np.ndarray]
EQ curve to apply.
Returns
-------
AudioSignal
AudioSignal with equalization applied.
"""
db = util.ensure_tensor(db)
n_bands = db.shape[-1]
fbank = self.mel_filterbank(n_bands)
# If there's a batch dimension, make sure it's the same.
if db.ndim == 2:
if db.shape[0] != 1:
assert db.shape[0] == fbank.shape[0]
else:
db = db.unsqueeze(0)
weights = (10**db).astype("float32")
fbank = fbank * weights[:, None, None, :]
eq_audio_data = fbank.sum(-1)
self.audio_data = eq_audio_data
return self
def clip_distortion(
self,
clip_percentile: typing.Union[paddle.Tensor, np.ndarray, float]):
"""Clips the signal at a given percentile. The higher it is,
the lower the threshold for clipping.
Parameters
----------
clip_percentile : typing.Union[paddle.Tensor, np.ndarray, float]
Values are between 0.0 to 1.0. Typical values are 0.1 or below.
Returns
-------
AudioSignal
Audio signal with clipped audio data.
"""
clip_percentile = util.ensure_tensor(clip_percentile, ndim=1)
clip_percentile = clip_percentile.cpu().numpy()
min_thresh = paddle.quantile(
self.audio_data, (clip_percentile / 2).tolist(), axis=-1)[None]
max_thresh = paddle.quantile(
self.audio_data, (1 - clip_percentile / 2).tolist(), axis=-1)[None]
nc = self.audio_data.shape[1]
min_thresh = min_thresh[:, :nc, :]
max_thresh = max_thresh[:, :nc, :]
self.audio_data = self.audio_data.clip(min_thresh, max_thresh)
return self
def quantization(self,
quantization_channels: typing.Union[paddle.Tensor,
np.ndarray, int]):
"""Applies quantization to the input waveform.
Parameters
----------
quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int]
Number of evenly spaced quantization channels to quantize
to.
Returns
-------
AudioSignal
Quantized AudioSignal.
"""
quantization_channels = util.ensure_tensor(
quantization_channels, ndim=3)
x = self.audio_data
quantization_channels = quantization_channels.astype(x.dtype)
x = (x + 1) / 2
x = x * quantization_channels
x = x.floor()
x = x / quantization_channels
x = 2 * x - 1
residual = (self.audio_data - x).detach()
self.audio_data = self.audio_data - residual
return self
def mulaw_quantization(self,
quantization_channels: typing.Union[
paddle.Tensor, np.ndarray, int]):
"""Applies mu-law quantization to the input waveform.
Parameters
----------
quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int]
Number of mu-law spaced quantization channels to quantize
to.
Returns
-------
AudioSignal
Quantized AudioSignal.
"""
mu = quantization_channels - 1.0
mu = util.ensure_tensor(mu, ndim=3)
x = self.audio_data
# quantize
x = paddle.sign(x) * paddle.log1p(mu * paddle.abs(x)) / paddle.log1p(mu)
x = ((x + 1) / 2 * mu + 0.5).astype("int64")
# unquantize
x = (x.astype(mu.dtype) / mu) * 2 - 1.0
x = paddle.sign(x) * (
util.exp_compat(paddle.abs(x) * paddle.log1p(mu)) - 1.0) / mu
residual = (self.audio_data - x).detach()
self.audio_data = self.audio_data - residual
return self
def __matmul__(self, other):
return self.convolve(other)
class ImpulseResponseMixin:
"""These functions are generally only used with AudioSignals that are derived
from impulse responses, not other sources like music or speech. These methods
are used to replicate the data augmentation described in [1].
1. Bryan, Nicholas J. "Impulse response data augmentation and deep
neural networks for blind room acoustic parameter estimation."
ICASSP 2020-2020 IEEE International Conference on Acoustics,
Speech and Signal Processing (ICASSP). IEEE, 2020.
"""
def decompose_ir(self):
"""Decomposes an impulse response into early and late
field responses.
"""
# Equations 1 and 2
# -----------------
# Breaking up into early
# response + late field response.
td = paddle.argmax(self.audio_data, axis=-1, keepdim=True)
t0 = int(self.sample_rate * 0.0025)
idx = paddle.arange(self.audio_data.shape[-1])[None, None, :]
idx = idx.expand([self.batch_size, -1, -1])
early_idx = (idx >= td - t0) * (idx <= td + t0)
early_response = paddle.zeros_like(self.audio_data)
# early_response[early_idx] = self.audio_data[early_idx]
early_response = paddle.where(early_idx, self.audio_data,
early_response)
late_idx = ~early_idx
late_field = paddle.zeros_like(self.audio_data)
# late_field[late_idx] = self.audio_data[late_idx]
late_field = paddle.where(late_idx, self.audio_data, late_field)
# Equation 4
# ----------
# Decompose early response into windowed
# direct path and windowed residual.
window = paddle.zeros_like(self.audio_data)
window_idx = paddle.nonzero(early_idx)
for idx in range(self.batch_size):
# window_idx = early_idx[idx, 0]
# ----- Just for this -----
# window[idx, ..., window_idx] = self.get_window("hann", window_idx.sum().item())
# indices = paddle.nonzero(window_idx).reshape(
# [-1]) # shape: [num_true], dtype: int64
indices = window_idx[window_idx[:, 0] == idx][:, -1]
temp_window = self.get_window("hann", indices.shape[0])
window_slice = window[idx, 0]
updated_window_slice = paddle.scatter(
window_slice, index=indices, updates=temp_window)
window[idx, 0] = updated_window_slice
# ----- Just for that -----
return early_response, late_field, window
def measure_drr(self):
"""Measures the direct-to-reverberant ratio of the impulse
response.
Returns
-------
float
Direct-to-reverberant ratio
"""
early_response, late_field, _ = self.decompose_ir()
num = (early_response**2).sum(axis=-1)
den = (late_field**2).sum(axis=-1)
drr = 10 * paddle.log10(num / den)
return drr
@staticmethod
def solve_alpha(early_response, late_field, wd, target_drr):
"""Used to solve for the alpha value, which is used
to alter the drr.
"""
# Equation 5
# ----------
# Apply the good ol' quadratic formula.
wd_sq = wd**2
wd_sq_1 = (1 - wd)**2
e_sq = early_response**2
l_sq = late_field**2
a = (wd_sq * e_sq).sum(axis=-1)
b = (2 * (1 - wd) * wd * e_sq).sum(axis=-1)
c = (wd_sq_1 * e_sq).sum(axis=-1) - paddle.pow(10 * paddle.ones_like(
target_drr, dtype="float32"), target_drr.cast("float32") /
10) * l_sq.sum(axis=-1)
expr = ((b**2) - 4 * a * c).sqrt()
alpha = paddle.maximum(
(-b - expr) / (2 * a),
(-b + expr) / (2 * a), )
return alpha
def alter_drr(self, drr: typing.Union[paddle.Tensor, np.ndarray, float]):
"""Alters the direct-to-reverberant ratio of the impulse response.
Parameters
----------
drr : typing.Union[paddle.Tensor, np.ndarray, float]
Direct-to-reverberant ratio that impulse response will be
altered to, if specified, by default None
Returns
-------
AudioSignal
Altered impulse response.
"""
drr = util.ensure_tensor(
drr, 2, self.batch_size
) # Assuming util.ensure_tensor is adapted or equivalent exists
early_response, late_field, window = self.decompose_ir()
alpha = self.solve_alpha(early_response, late_field, window, drr)
min_alpha = late_field.abs().max(axis=-1)[0] / early_response.abs().max(
axis=-1)[0]
alpha = paddle.maximum(alpha, min_alpha)[..., None]
aug_ir_data = alpha * window * early_response + (
(1 - window) * early_response) + late_field
self.audio_data = aug_ir_data
self.ensure_max_of_audio(
) # Assuming ensure_max_of_audio is a method defined elsewhere
return self

@ -0,0 +1,119 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/ffmpeg.py)
import json
import shlex
import subprocess
import tempfile
from pathlib import Path
from typing import Tuple
import ffmpy
import numpy as np
import paddle
def r128stats(filepath: str, quiet: bool):
"""Takes a path to an audio file, returns a dict with the loudness
stats computed by the ffmpeg ebur128 filter.
Parameters
----------
filepath : str
Path to compute loudness stats on.
quiet : bool
Whether to show FFMPEG output during computation.
Returns
-------
dict
Dictionary containing loudness stats.
"""
ffargs = [
"ffmpeg",
"-nostats",
"-i",
filepath,
"-filter_complex",
"ebur128",
"-f",
"null",
"-",
]
if quiet:
ffargs += ["-hide_banner"]
proc = subprocess.Popen(
ffargs, stderr=subprocess.PIPE, universal_newlines=True)
stats = proc.communicate()[1]
summary_index = stats.rfind("Summary:")
summary_list = stats[summary_index:].split()
i_lufs = float(summary_list[summary_list.index("I:") + 1])
i_thresh = float(summary_list[summary_list.index("I:") + 4])
lra = float(summary_list[summary_list.index("LRA:") + 1])
lra_thresh = float(summary_list[summary_list.index("LRA:") + 4])
lra_low = float(summary_list[summary_list.index("low:") + 1])
lra_high = float(summary_list[summary_list.index("high:") + 1])
stats_dict = {
"I": i_lufs,
"I Threshold": i_thresh,
"LRA": lra,
"LRA Threshold": lra_thresh,
"LRA Low": lra_low,
"LRA High": lra_high,
}
return stats_dict
def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]:
"""Given a path to a file, returns the start time offset and codec of
the first audio stream.
"""
ff = ffmpy.FFprobe(
inputs={path: None},
global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet",
)
streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"]
seconds_offset = 0.0
codec = None
# Get the offset and codec of the first audio stream we find
# and return its start time, if it has one.
for stream in streams:
if stream["codec_type"] == "audio":
seconds_offset = stream.get("start_time", 0.0)
codec = stream.get("codec_name")
break
return float(seconds_offset), codec
class FFMPEGMixin:
_loudness = None
def ffmpeg_loudness(self, quiet: bool=True):
"""Computes loudness of audio file using FFMPEG.
Parameters
----------
quiet : bool, optional
Whether to show FFMPEG output during computation,
by default True
Returns
-------
paddle.Tensor
Loudness of every item in the batch, computed via
FFMPEG.
"""
loudness = []
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
for i in range(self.batch_size):
self[i].write(f.name)
loudness_stats = r128stats(f.name, quiet=quiet)
loudness.append(loudness_stats["I"])
self._loudness = paddle.to_tensor(np.array(loudness)).astype("float32")
return self.loudness()

@ -0,0 +1,387 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/loudness.py)
import copy
import math
import typing
import numpy as np
import paddle
import paddle.nn.functional as F
import scipy
from . import _julius
def _unfold1d(x, kernel_size, stride):
# https://github.com/PaddlePaddle/Paddle/pull/70102
"""1D only unfolding similar to the one from Paddlepaddle.
Given an _input tensor of size `[*, T]` this will return
a tensor `[*, F, K]` with `K` the kernel size, and `F` the number
of frames. The i-th frame is a view onto `i * stride: i * stride + kernel_size`.
This will automatically pad the _input to cover at least once all entries in `_input`.
Args:
_input (Tensor): tensor for which to return the frames.
kernel_size (int): size of each frame.
stride (int): stride between each frame.
Shape:
- Inputs: `_input` is `[*, T]`
- Output: `[*, F, kernel_size]` with `F = 1 + ceil((T - kernel_size) / stride)`
"""
if 3 != x.dim():
raise NotImplementedError
N, C, length = x.shape
x = x.reshape([N * C, 1, length])
n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1
tgt_length = (n_frames - 1) * stride + kernel_size
x = F.pad(x, (0, tgt_length - length), data_format="NCL")
x = x.unsqueeze(-1)
unfolded = paddle.nn.functional.unfold(
x,
kernel_sizes=[kernel_size, 1],
strides=[stride, 1], )
unfolded = unfolded.transpose([0, 2, 1])
unfolded = unfolded.reshape([N, C, *unfolded.shape[1:]])
return unfolded
class Meter(paddle.nn.Layer):
"""Tensorized version of pyloudnorm.Meter. Works with batched audio tensors.
Parameters
----------
rate : int
Sample rate of audio.
filter_class : str, optional
Class of weighting filter used.
K-weighting' (default), 'Fenton/Lee 1'
'Fenton/Lee 2', 'Dash et al.'
by default "K-weighting"
block_size : float, optional
Gating block size in seconds, by default 0.400
zeros : int, optional
Number of zeros to use in FIR approximation of
IIR filters, by default 512
use_fir : bool, optional
Whether to use FIR approximation or exact IIR formulation.
If computing on GPU, ``use_fir=True`` will be used, as its
much faster, by default False
"""
def __init__(
self,
rate: int,
filter_class: str="K-weighting",
block_size: float=0.400,
zeros: int=512,
use_fir: bool=False, ):
super().__init__()
self.rate = rate
self.filter_class = filter_class
self.block_size = block_size
self.use_fir = use_fir
G = paddle.to_tensor(
np.array([1.0, 1.0, 1.0, 1.41, 1.41]), stop_gradient=True)
self.register_buffer("G", G)
# Compute impulse responses so that filtering is fast via
# a convolution at runtime, on GPU, unlike lfilter.
impulse = np.zeros((zeros, ))
impulse[..., 0] = 1.0
firs = np.zeros((len(self._filters), 1, zeros))
# passband_gain = torch.zeros(len(self._filters))
passband_gain = paddle.zeros([len(self._filters)], dtype="float32")
for i, (_, filter_stage) in enumerate(self._filters.items()):
firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a,
impulse)
passband_gain[i] = filter_stage.passband_gain
firs = paddle.to_tensor(
firs[..., ::-1].copy(), dtype="float32", stop_gradient=True)
self.register_buffer("firs", firs)
self.register_buffer("passband_gain", passband_gain)
def apply_filter_gpu(self, data: paddle.Tensor):
"""Performs FIR approximation of loudness computation.
Parameters
----------
data : paddle.Tensor
Audio data of shape (nb, nch, nt).
Returns
-------
paddle.Tensor
Filtered audio data.
"""
# Data is of shape (nb, nch, nt)
# Reshape to (nb*nch, 1, nt)
nb, nt, nch = data.shape
data = data.transpose([0, 2, 1])
data = data.reshape([nb * nch, 1, nt])
# Apply padding
pad_length = self.firs.shape[-1]
# Apply filtering in sequence
for i in range(self.firs.shape[0]):
data = F.pad(data, (pad_length, pad_length), data_format="NCL")
data = _julius.fft_conv1d(data, self.firs[i, None, ...])
data = self.passband_gain[i] * data
data = data[..., 1:nt + 1]
data = data.transpose([0, 2, 1])
data = data[:, :nt, :]
return data
@staticmethod
def scipy_lfilter(waveform, a_coeffs, b_coeffs, clamp: bool=True):
# 使用 scipy.signal.lfilter 进行滤波(处理三维数据)
output = np.zeros_like(waveform)
for batch_idx in range(waveform.shape[0]):
for channel_idx in range(waveform.shape[2]):
output[batch_idx, :, channel_idx] = scipy.signal.lfilter(
b_coeffs, a_coeffs, waveform[batch_idx, :, channel_idx])
return output
def apply_filter_cpu(self, data: paddle.Tensor):
"""Performs IIR formulation of loudness computation.
Parameters
----------
data : paddle.Tensor
Audio data of shape (nb, nch, nt).
Returns
-------
paddle.Tensor
Filtered audio data.
"""
_data = data.cpu().numpy().copy()
for _, filter_stage in self._filters.items():
passband_gain = filter_stage.passband_gain
a_coeffs = filter_stage.a
b_coeffs = filter_stage.b
filtered = self.scipy_lfilter(_data, a_coeffs, b_coeffs)
_data[:] = passband_gain * filtered
data = paddle.to_tensor(_data)
return data
def apply_filter(self, data: paddle.Tensor):
"""Applies filter on either CPU or GPU, depending
on if the audio is on GPU or is on CPU, or if
``self.use_fir`` is True.
Parameters
----------
data : paddle.Tensor
Audio data of shape (nb, nch, nt).
Returns
-------
paddle.Tensor
Filtered audio data.
"""
# if data.place.is_gpu_place() or self.use_fir:
# data = self.apply_filter_gpu(data)
# else:
# data = self.apply_filter_cpu(data)
data = self.apply_filter_cpu(data)
return data
def forward(self, data: paddle.Tensor):
"""Computes integrated loudness of data.
Parameters
----------
data : paddle.Tensor
Audio data of shape (nb, nch, nt).
Returns
-------
paddle.Tensor
Filtered audio data.
"""
return self.integrated_loudness(data)
def _unfold(self, input_data):
T_g = self.block_size
overlap = 0.75 # overlap of 75% of the block duration
step = 1.0 - overlap # step size by percentage
kernel_size = int(T_g * self.rate)
stride = int(T_g * self.rate * step)
unfolded = _unfold1d(
input_data.transpose([0, 2, 1]), kernel_size, stride)
unfolded = unfolded.transpose([0, 1, 3, 2])
return unfolded
def integrated_loudness(self, data: paddle.Tensor):
"""Computes integrated loudness of data.
Parameters
----------
data : paddle.Tensor
Audio data of shape (nb, nch, nt).
Returns
-------
paddle.Tensor
Filtered audio data.
"""
if not paddle.is_tensor(data):
data = paddle.to_tensor(data, dtype="float32")
else:
data = data.astype("float32")
input_data = data.clone()
# Data always has a batch and channel dimension.
# Is of shape (nb, nt, nch)
if input_data.ndim < 2:
input_data = input_data.unsqueeze(-1)
if input_data.ndim < 3:
input_data = input_data.unsqueeze(0)
nb, nt, nch = input_data.shape
# Apply frequency weighting filters - account
# for the acoustic respose of the head and auditory system
input_data = self.apply_filter(input_data)
G = self.G # channel gains
T_g = self.block_size # 400 ms gating block standard
Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold
unfolded = self._unfold(input_data)
z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
l = -0.691 + 10.0 * paddle.log10(
(G[None, :nch, None] * z).sum(1, keepdim=True))
l = l.expand_as(z)
# find gating block indices above absolute threshold
z_avg_gated = z
z_avg_gated[l <= Gamma_a] = 0
masked = l > Gamma_a
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2).astype("float32")
# calculate the relative threshold value (see eq. 6)
Gamma_r = -0.691 + 10.0 * paddle.log10(
(z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
Gamma_r = Gamma_r[:, None, None]
Gamma_r = Gamma_r.expand([nb, nch, l.shape[-1]])
# find gating block indices above relative and absolute thresholds (end of eq. 7)
z_avg_gated = z
z_avg_gated[l <= Gamma_a] = 0
z_avg_gated[l <= Gamma_r] = 0
masked = (l > Gamma_a) * (l > Gamma_r)
z_avg_gated = z_avg_gated.sum(2) / (masked.sum(2) + 10e-6)
# TODO Currently, paddle has a segmentation fault bug in this section of the code
# z_avg_gated = paddle.nan_to_num(z_avg_gated)
# z_avg_gated = paddle.where(
# paddle.isnan(z_avg_gated),
# paddle.zeros_like(z_avg_gated), z_avg_gated)
z_avg_gated[z_avg_gated == float("inf")] = float(
np.finfo(np.float32).max)
z_avg_gated[z_avg_gated == -float("inf")] = float(
np.finfo(np.float32).min)
LUFS = -0.691 + 10.0 * paddle.log10(
(G[None, :nch] * z_avg_gated).sum(1))
return LUFS.astype("float32")
@property
def filter_class(self):
return self._filter_class
@filter_class.setter
def filter_class(self, value):
from pyloudnorm import Meter
meter = Meter(self.rate)
meter.filter_class = value
self._filter_class = value
self._filters = meter._filters
class LoudnessMixin:
_loudness = None
MIN_LOUDNESS = -70
"""Minimum loudness possible."""
def loudness(self,
filter_class: str="K-weighting",
block_size: float=0.400,
**kwargs):
"""Calculates loudness using an implementation of ITU-R BS.1770-4.
Allows control over gating block size and frequency weighting filters for
additional control. Measure the integrated gated loudness of a signal.
API is derived from PyLoudnorm, but this implementation is ported to PyTorch
and is tensorized across batches. When on GPU, an FIR approximation of the IIR
filters is used to compute loudness for speed.
Uses the weighting filters and block size defined by the meter
the integrated loudness is measured based upon the gating algorithm
defined in the ITU-R BS.1770-4 specification.
Parameters
----------
filter_class : str, optional
Class of weighting filter used.
K-weighting' (default), 'Fenton/Lee 1'
'Fenton/Lee 2', 'Dash et al.'
by default "K-weighting"
block_size : float, optional
Gating block size in seconds, by default 0.400
kwargs : dict, optional
Keyword arguments to :py:func:`audiotools.core.loudness.Meter`.
Returns
-------
paddle.Tensor
Loudness of audio data.
"""
if self._loudness is not None:
return self._loudness # .to(self.device)
original_length = self.signal_length
if self.signal_duration < 0.5:
pad_len = int((0.5 - self.signal_duration) * self.sample_rate)
self.zero_pad(0, pad_len)
# create BS.1770 meter
meter = Meter(
self.sample_rate,
filter_class=filter_class,
block_size=block_size,
**kwargs)
# meter = meter.to(self.device)
# measure loudness
loudness = meter.integrated_loudness(
self.audio_data.transpose([0, 2, 1]))
self.truncate_samples(original_length)
min_loudness = paddle.ones_like(loudness) * self.MIN_LOUDNESS
self._loudness = paddle.maximum(loudness, min_loudness)
return self._loudness # .to(self.device)

@ -0,0 +1,921 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/util.py)
import collections
import csv
import glob
import math
import numbers
import os
import random
import typing
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union
import ffmpeg
import librosa
import numpy as np
import paddle
import soundfile
from flatten_dict import flatten
from flatten_dict import unflatten
from .audio_signal import AudioSignal
from paddlespeech.utils import satisfy_paddle_version
from paddlespeech.vector.training.seeding import seed_everything
__all__ = [
"exp_compat",
"bool_index_compat",
"bool_setitem_compat",
"Info",
"info",
"ensure_tensor",
"random_state",
"find_audio",
"read_sources",
"choose_from_list_of_lists",
"chdir",
"move_to_device",
"prepare_batch",
"sample_from_dist",
"format_figure",
"default_collate",
"collate",
"hz_to_bin",
"generate_chord_dataset",
]
def exp_compat(x):
"""
Compute the exponential of the input tensor `x`.
This function is designed to handle compatibility issues with PaddlePaddle versions below 2.6,
which do not support the `exp` operation for complex tensors. In such cases, the computation
is offloaded to NumPy.
Args:
x (paddle.Tensor): The input tensor for which to compute the exponential.
Returns:
paddle.Tensor: The result of the exponential operation, as a PaddlePaddle tensor.
Notes:
- If the PaddlePaddle version is 2.6 or above, the function uses `paddle.exp` directly.
- For versions below 2.6, the tensor is first converted to a NumPy array, the exponential
is computed using `np.exp`, and the result is then converted back to a PaddlePaddle tensor.
"""
if satisfy_paddle_version("2.6"):
return paddle.exp(x)
else:
x_np = x.cpu().numpy()
return paddle.to_tensor(np.exp(x_np))
def bool_index_compat(x, mask):
"""
Perform boolean indexing on the input tensor `x` using the provided `mask`.
This function ensures compatibility with PaddlePaddle versions below 2.6, where boolean indexing
may not be fully supported. For older versions, the operation is performed using NumPy.
Args:
x (paddle.Tensor): The input tensor to be indexed.
mask (paddle.Tensor or int): The boolean mask or integer index used for indexing.
Returns:
paddle.Tensor: The result of the boolean indexing operation, as a PaddlePaddle tensor.
Notes:
- If the PaddlePaddle version is 2.6 or above, or if `mask` is an integer, the function uses
Paddle's native indexing directly.
- For versions below 2.6, the tensor and mask are converted to NumPy arrays, the indexing
operation is performed using NumPy, and the result is converted back to a PaddlePaddle tensor.
"""
if satisfy_paddle_version("2.6") or isinstance(mask, (int, list, slice)):
return x[mask]
else:
x_np = x.cpu().numpy()[mask.cpu().numpy()]
return paddle.to_tensor(x_np)
def bool_setitem_compat(x, mask, y):
"""
Perform boolean assignment on the input tensor `x` using the provided `mask` and values `y`.
This function ensures compatibility with PaddlePaddle versions below 2.6, where boolean assignment
may not be fully supported. For older versions, the operation is performed using NumPy.
Args:
x (paddle.Tensor): The input tensor to be modified.
mask (paddle.Tensor): The boolean mask used for assignment.
y (paddle.Tensor): The values to assign to the selected elements of `x`.
Returns:
paddle.Tensor: The modified tensor after the assignment operation.
Notes:
- If the PaddlePaddle version is 2.6 or above, the function uses Paddle's native assignment directly.
- For versions below 2.6, the tensor, mask, and values are converted to NumPy arrays, the assignment
operation is performed using NumPy, and the result is converted back to a PaddlePaddle tensor.
"""
if satisfy_paddle_version("2.6"):
x[mask] = y
return x
else:
x_np = x.cpu().numpy()
x_np[mask.cpu().numpy()] = y.cpu().numpy()
return paddle.to_tensor(x_np)
@dataclass
class Info:
sample_rate: float
num_frames: int
@property
def duration(self) -> float:
return self.num_frames / self.sample_rate
def info_ffmpeg(audio_path: str):
"""
Parameters
----------
audio_path : str
Path to audio file.
"""
probe = ffmpeg.probe(audio_path)
audio_streams = [
stream for stream in probe['streams'] if stream['codec_type'] == 'audio'
]
if not audio_streams:
raise ValueError("No audio stream found in the file.")
audio_stream = audio_streams[0]
sample_rate = int(audio_stream['sample_rate'])
duration = float(audio_stream['duration'])
num_frames = int(duration * sample_rate)
info = Info(sample_rate=sample_rate, num_frames=num_frames)
return info
def info(audio_path: str):
"""
Parameters
----------
audio_path : str
Path to audio file.
"""
try:
info = soundfile.info(str(audio_path))
info = Info(sample_rate=info.samplerate, num_frames=info.frames)
except:
info = info_ffmpeg(str(audio_path))
return info
def ensure_tensor(
x: typing.Union[np.ndarray, paddle.Tensor, float, int],
ndim: int=None,
batch_size: int=None, ):
"""Ensures that the input ``x`` is a tensor of specified
dimensions and batch size.
Parameters
----------
x : typing.Union[np.ndarray, paddle.Tensor, float, int]
Data that will become a tensor on its way out.
ndim : int, optional
How many dimensions should be in the output, by default None
batch_size : int, optional
The batch size of the output, by default None
Returns
-------
paddle.Tensor
Modified version of ``x`` as a tensor.
"""
if not paddle.is_tensor(x):
x = paddle.to_tensor(x)
if ndim is not None:
assert x.ndim <= ndim
while x.ndim < ndim:
x = x.unsqueeze(-1)
if batch_size is not None:
if x.shape[0] != batch_size:
shape = list(x.shape)
shape[0] = batch_size
x = paddle.expand(x, shape)
return x
def _get_value(other):
#
from . import AudioSignal
if isinstance(other, AudioSignal):
return other.audio_data
return other
def random_state(seed: typing.Union[int, np.random.RandomState]):
"""
Turn seed into a np.random.RandomState instance.
Parameters
----------
seed : typing.Union[int, np.random.RandomState] or None
If seed is None, return the RandomState singleton used by np.random.
If seed is an int, return a new RandomState instance seeded with seed.
If seed is already a RandomState instance, return it.
Otherwise raise ValueError.
Returns
-------
np.random.RandomState
Random state object.
Raises
------
ValueError
If seed is not valid, an error is thrown.
"""
if seed is None or seed is np.random:
return np.random.mtrand._rand
elif isinstance(seed, (numbers.Integral, np.integer, int)):
return np.random.RandomState(seed)
elif isinstance(seed, np.random.RandomState):
return seed
else:
raise ValueError("%r cannot be used to seed a numpy.random.RandomState"
" instance" % seed)
@contextmanager
def _close_temp_files(tmpfiles: list):
"""Utility function for creating a context and closing all temporary files
once the context is exited. For correct functionality, all temporary file
handles created inside the context must be appended to the ```tmpfiles```
list.
This function is taken wholesale from Scaper.
Parameters
----------
tmpfiles : list
List of temporary file handles
"""
def _close():
for t in tmpfiles:
try:
t.close()
os.unlink(t.name)
except:
pass
try:
yield
except:
_close()
raise
_close()
AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3"]
def find_audio(folder: str, ext: List[str]=AUDIO_EXTENSIONS):
"""Finds all audio files in a directory recursively.
Returns a list.
Parameters
----------
folder : str
Folder to look for audio files in, recursively.
ext : List[str], optional
Extensions to look for without the ., by default
``['.wav', '.flac', '.mp3', '.mp4']``.
"""
folder = Path(folder)
# Take care of case where user has passed in an audio file directly
# into one of the calling functions.
if str(folder).endswith(tuple(ext)):
# if, however, there's a glob in the path, we need to
# return the glob, not the file.
if "*" in str(folder):
return glob.glob(str(folder), recursive=("**" in str(folder)))
else:
return [folder]
files = []
for x in ext:
files += folder.glob(f"**/*{x}")
return files
def read_sources(
sources: List[str],
remove_empty: bool=True,
relative_path: str="",
ext: List[str]=AUDIO_EXTENSIONS, ):
"""Reads audio sources that can either be folders
full of audio files, or CSV files that contain paths
to audio files. CSV files that adhere to the expected
format can be generated by
:py:func:`audiotools.data.preprocess.create_csv`.
Parameters
----------
sources : List[str]
List of audio sources to be converted into a
list of lists of audio files.
remove_empty : bool, optional
Whether or not to remove rows with an empty "path"
from each CSV file, by default True.
Returns
-------
list
List of lists of rows of CSV files.
"""
files = []
relative_path = Path(relative_path)
for source in sources:
source = str(source)
_files = []
if source.endswith(".csv"):
with open(source, "r") as f:
reader = csv.DictReader(f)
for x in reader:
if remove_empty and x["path"] == "":
continue
if x["path"] != "":
x["path"] = str(relative_path / x["path"])
_files.append(x)
else:
for x in find_audio(source, ext=ext):
x = str(relative_path / x)
_files.append({"path": x})
files.append(sorted(_files, key=lambda x: x["path"]))
return files
def choose_from_list_of_lists(state: np.random.RandomState,
list_of_lists: list,
p: float=None):
"""Choose a single item from a list of lists.
Parameters
----------
state : np.random.RandomState
Random state to use when choosing an item.
list_of_lists : list
A list of lists from which items will be drawn.
p : float, optional
Probabilities of each list, by default None
Returns
-------
typing.Any
An item from the list of lists.
"""
source_idx = state.choice(list(range(len(list_of_lists))), p=p)
item_idx = state.randint(len(list_of_lists[source_idx]))
return list_of_lists[source_idx][item_idx], source_idx, item_idx
@contextmanager
def chdir(newdir: typing.Union[Path, str]):
"""
Context manager for switching directories to run a
function. Useful for when you want to use relative
paths to different runs.
Parameters
----------
newdir : typing.Union[Path, str]
Directory to switch to.
"""
curdir = os.getcwd()
try:
os.chdir(newdir)
yield
finally:
os.chdir(curdir)
def move_to_device(data, device):
if device is None or device == "":
return data
elif device == 'cpu':
return paddle.to_tensor(data, place=paddle.CPUPlace())
elif device in ('gpu', 'cuda'):
return paddle.to_tensor(data, place=paddle.CUDAPlace())
else:
device = device.replace("cuda", "gpu") if "cuda" in device else device
return data.to(device)
def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor],
device: str="cpu"):
"""Moves items in a batch (typically generated by a DataLoader as a list
or a dict) to the specified device. This works even if dictionaries
are nested.
Parameters
----------
batch : typing.Union[dict, list, paddle.Tensor]
Batch, typically generated by a dataloader, that will be moved to
the device.
device : str, optional
Device to move batch to, by default "cpu"
Returns
-------
typing.Union[dict, list, paddle.Tensor]
Batch with all values moved to the specified device.
"""
device = device.replace("cuda", "gpu")
if isinstance(batch, dict):
batch = flatten(batch)
for key, val in batch.items():
try:
# batch[key] = val.to(device)
batch[key] = move_to_device(val, device)
except:
pass
batch = unflatten(batch)
elif paddle.is_tensor(batch):
# batch = batch.to(device)
batch = move_to_device(batch, device)
elif isinstance(batch, list):
for i in range(len(batch)):
try:
batch[i] = batch[i].to(device)
except:
pass
return batch
def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState=None):
"""Samples from a distribution defined by a tuple. The first
item in the tuple is the distribution type, and the rest of the
items are arguments to that distribution. The distribution function
is gotten from the ``np.random.RandomState`` object.
Parameters
----------
dist_tuple : tuple
Distribution tuple
state : np.random.RandomState, optional
Random state, or seed to use, by default None
Returns
-------
typing.Union[float, int, str]
Draw from the distribution.
Examples
--------
Sample from a uniform distribution:
>>> dist_tuple = ("uniform", 0, 1)
>>> sample_from_dist(dist_tuple)
Sample from a constant distribution:
>>> dist_tuple = ("const", 0)
>>> sample_from_dist(dist_tuple)
Sample from a normal distribution:
>>> dist_tuple = ("normal", 0, 0.5)
>>> sample_from_dist(dist_tuple)
"""
if dist_tuple[0] == "const":
return dist_tuple[1]
state = random_state(state)
dist_fn = getattr(state, dist_tuple[0])
return dist_fn(*dist_tuple[1:])
BASE_SIZE = 864
DEFAULT_FIG_SIZE = (9, 3)
def format_figure(
fig_size: tuple=None,
title: str=None,
fig=None,
format_axes: bool=True,
format: bool=True,
font_color: str="white", ):
"""Prettifies the spectrogram and waveform plots. A title
can be inset into the top right corner, and the axes can be
inset into the figure, allowing the data to take up the entire
image. Used in
- :py:func:`audiotools.core.display.DisplayMixin.specshow`
- :py:func:`audiotools.core.display.DisplayMixin.waveplot`
- :py:func:`audiotools.core.display.DisplayMixin.wavespec`
Parameters
----------
fig_size : tuple, optional
Size of figure, by default (9, 3)
title : str, optional
Title to inset in top right, by default None
fig : matplotlib.figure.Figure, optional
Figure object, if None ``plt.gcf()`` will be used, by default None
format_axes : bool, optional
Format the axes to be inside the figure, by default True
format : bool, optional
This formatting can be skipped entirely by passing ``format=False``
to any of the plotting functions that use this formater, by default True
font_color : str, optional
Color of font of axes, by default "white"
"""
import matplotlib
import matplotlib.pyplot as plt
if fig_size is None:
fig_size = DEFAULT_FIG_SIZE
if not format:
return
if fig is None:
fig = plt.gcf()
fig.set_size_inches(*fig_size)
axs = fig.axes
pixels = (fig.get_size_inches() * fig.dpi)[0]
font_scale = pixels / BASE_SIZE
if format_axes:
axs = fig.axes
for ax in axs:
ymin, _ = ax.get_ylim()
xmin, _ = ax.get_xlim()
ticks = ax.get_yticks()
for t in ticks[2:-1]:
t = axs[0].annotate(
f"{(t / 1000):2.1f}k",
xy=(xmin, t),
xycoords="data",
xytext=(5, -5),
textcoords="offset points",
ha="left",
va="top",
color=font_color,
fontsize=12 * font_scale,
alpha=0.75, )
ticks = ax.get_xticks()[2:]
for t in ticks[:-1]:
t = axs[0].annotate(
f"{t:2.1f}s",
xy=(t, ymin),
xycoords="data",
xytext=(5, 5),
textcoords="offset points",
ha="center",
va="bottom",
color=font_color,
fontsize=12 * font_scale,
alpha=0.75, )
ax.margins(0, 0)
ax.set_axis_off()
ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(
top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
if title is not None:
t = axs[0].annotate(
title,
xy=(1, 1),
xycoords="axes fraction",
fontsize=20 * font_scale,
xytext=(-5, -5),
textcoords="offset points",
ha="right",
va="top",
color="white", )
t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black"))
_default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
def collate_tensor_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[type, Tuple[type, ...]],
Callable]]=None, ):
out = paddle.stack(batch, axis=0)
return out
def collate_float_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]],
Callable]]=None, ):
return paddle.to_tensor(batch, dtype=paddle.float64)
def collate_int_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]],
Callable]]=None, ):
return paddle.to_tensor(batch)
def collate_str_fn(
batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]],
Callable]]=None, ):
return batch
default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {
paddle.Tensor: collate_tensor_fn
}
default_collate_fn_map[float] = collate_float_fn
default_collate_fn_map[int] = collate_int_fn
default_collate_fn_map[str] = collate_str_fn
default_collate_fn_map[bytes] = collate_str_fn
def default_collate(batch,
*,
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]],
Callable]]=None):
r"""
General collate function that handles collection type of element within each batch.
The function also opens function registry to deal with specific element types. `default_collate_fn_map`
provides default collate functions for tensors, numpy arrays, numbers and strings.
Args:
batch: a single batch to be collated
collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.
If the element type isn't present in this dictionary,
this function will go through each key of the dictionary in the insertion order to
invoke the corresponding collate function if the element type is a subclass of the key.
Note:
Each collate function requires a positional argument for batch and a keyword argument
for the dictionary of collate functions as `collate_fn_map`.
"""
elem = batch[0]
elem_type = type(elem)
if collate_fn_map is not None:
if elem_type in collate_fn_map:
return collate_fn_map[elem_type](
batch, collate_fn_map=collate_fn_map)
for collate_type in collate_fn_map:
if isinstance(elem, collate_type):
return collate_fn_map[collate_type](
batch, collate_fn_map=collate_fn_map)
if isinstance(elem, collections.abc.Mapping):
try:
return elem_type({
key: default_collate(
[d[key] for d in batch], collate_fn_map=collate_fn_map)
for key in elem
})
except TypeError:
# The mapping type may not support `__init__(iterable)`.
return {
key: default_collate(
[d[key] for d in batch], collate_fn_map=collate_fn_map)
for key in elem
}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
return elem_type(*(default_collate(
samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError(
"each element in list of batch should be of equal size")
transposed = list(zip(
*batch)) # It may be accessed twice, so we use a list.
if isinstance(elem, tuple):
return [
default_collate(samples, collate_fn_map=collate_fn_map)
for samples in transposed
] # Backwards compatibility.
else:
try:
return elem_type([
default_collate(samples, collate_fn_map=collate_fn_map)
for samples in transposed
])
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return [
default_collate(samples, collate_fn_map=collate_fn_map)
for samples in transposed
]
raise TypeError(_default_collate_err_msg_format.format(elem_type))
def collate(list_of_dicts: list, n_splits: int=None):
"""Collates a list of dictionaries (e.g. as returned by a
dataloader) into a dictionary with batched values. This routine
uses the default torch collate function for everything
except AudioSignal objects, which are handled by the
:py:func:`audiotools.core.audio_signal.AudioSignal.batch`
function.
This function takes n_splits to enable splitting a batch
into multiple sub-batches for the purposes of gradient accumulation,
etc.
Parameters
----------
list_of_dicts : list
List of dictionaries to be collated.
n_splits : int
Number of splits to make when creating the batches (split into
sub-batches). Useful for things like gradient accumulation.
Returns
-------
dict
Dictionary containing batched data.
"""
batches = []
list_len = len(list_of_dicts)
return_list = False if n_splits is None else True
n_splits = 1 if n_splits is None else n_splits
n_items = int(math.ceil(list_len / n_splits))
for i in range(0, list_len, n_items):
# Flatten the dictionaries to avoid recursion.
list_of_dicts_ = [flatten(d) for d in list_of_dicts[i:i + n_items]]
dict_of_lists = {
k: [dic[k] for dic in list_of_dicts_]
for k in list_of_dicts_[0]
}
batch = {}
for k, v in dict_of_lists.items():
if isinstance(v, list):
if all(isinstance(s, AudioSignal) for s in v):
batch[k] = AudioSignal.batch(v, pad_signals=True)
else:
batch[k] = default_collate(
v, collate_fn_map=default_collate_fn_map)
batches.append(unflatten(batch))
batches = batches[0] if not return_list else batches
return batches
def hz_to_bin(hz: paddle.Tensor, n_fft: int, sample_rate: int):
"""Closest frequency bin given a frequency, number
of bins, and a sampling rate.
Parameters
----------
hz : paddle.Tensor
Tensor of frequencies in Hz.
n_fft : int
Number of FFT bins.
sample_rate : int
Sample rate of audio.
Returns
-------
paddle.Tensor
Closest bins to the data.
"""
shape = hz.shape
hz = hz.reshape([-1])
freqs = paddle.linspace(0, sample_rate / 2, 2 + n_fft // 2)
hz = paddle.clip(hz, max=sample_rate / 2).astype(freqs.dtype)
closest = (hz[None, :] - freqs[:, None]).abs()
closest_bins = closest.argmin(axis=0)
return closest_bins.reshape(shape)
def generate_chord_dataset(
max_voices: int=8,
sample_rate: int=44100,
num_items: int=5,
duration: float=1.0,
min_note: str="C2",
max_note: str="C6",
output_dir: Path="chords", ):
"""
Generates a toy multitrack dataset of chords, synthesized from sine waves.
Parameters
----------
max_voices : int, optional
Maximum number of voices in a chord, by default 8
sample_rate : int, optional
Sample rate of audio, by default 44100
num_items : int, optional
Number of items to generate, by default 5
duration : float, optional
Duration of each item, by default 1.0
min_note : str, optional
Minimum note in the dataset, by default "C2"
max_note : str, optional
Maximum note in the dataset, by default "C6"
output_dir : Path, optional
Directory to save the dataset, by default "chords"
"""
import librosa
from . import AudioSignal
from ..data.preprocess import create_csv
min_midi = librosa.note_to_midi(min_note)
max_midi = librosa.note_to_midi(max_note)
tracks = []
for idx in range(num_items):
track = {}
# figure out how many voices to put in this track
num_voices = random.randint(1, max_voices)
for voice_idx in range(num_voices):
# choose some random params
midinote = random.randint(min_midi, max_midi)
dur = random.uniform(0.85 * duration, duration)
sig = AudioSignal.wave(
frequency=librosa.midi_to_hz(midinote),
duration=dur,
sample_rate=sample_rate,
shape="sine", )
track[f"voice_{voice_idx}"] = sig
tracks.append(track)
# save the tracks to disk
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
for idx, track in enumerate(tracks):
track_dir = output_dir / f"track_{idx}"
track_dir.mkdir(exist_ok=True)
for voice_name, sig in track.items():
sig.write(track_dir / f"{voice_name}.wav")
all_voices = list(set([k for track in tracks for k in track.keys()]))
voice_lists = {voice: [] for voice in all_voices}
for track in tracks:
for voice_name in all_voices:
if voice_name in track:
voice_lists[voice_name].append(track[voice_name].path_to_file)
else:
voice_lists[voice_name].append("")
for voice_name, paths in voice_lists.items():
create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True)
return output_dir

@ -0,0 +1,16 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import datasets
from . import preprocess
from . import transforms

@ -0,0 +1,548 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/data/datasets.py)
from pathlib import Path
from typing import Callable
from typing import Dict
from typing import List
from typing import Union
import numpy as np
import paddle
from paddle.io import DistributedBatchSampler
from paddle.io import SequenceSampler
from ..core import AudioSignal
from ..core import util
__all__ = [
"AudioLoader", "AudioDataset", "ConcatDataset",
"ResumableDistributedSampler", "ResumableSequentialSampler"
]
class AudioLoader:
"""Loads audio endlessly from a list of audio sources
containing paths to audio files. Audio sources can be
folders full of audio files (which are found via file
extension) or by providing a CSV file which contains paths
to audio files.
Parameters
----------
sources : List[str], optional
Sources containing folders, or CSVs with
paths to audio files, by default None
weights : List[float], optional
Weights to sample audio files from each source, by default None
relative_path : str, optional
Path audio should be loaded relative to, by default ""
transform : Callable, optional
Transform to instantiate alongside audio sample,
by default None
ext : List[str]
List of extensions to find audio within each source by. Can
also be a file name (e.g. "vocals.wav"). by default
``['.wav', '.flac', '.mp3', '.mp4']``.
shuffle: bool
Whether to shuffle the files within the dataloader. Defaults to True.
shuffle_state: int
State to use to seed the shuffle of the files.
"""
def __init__(
self,
sources: List[str]=None,
weights: List[float]=None,
transform: Callable=None,
relative_path: str="",
ext: List[str]=util.AUDIO_EXTENSIONS,
shuffle: bool=True,
shuffle_state: int=0, ):
self.audio_lists = util.read_sources(
sources, relative_path=relative_path, ext=ext)
self.audio_indices = [(src_idx, item_idx)
for src_idx, src in enumerate(self.audio_lists)
for item_idx in range(len(src))]
if shuffle:
state = util.random_state(shuffle_state)
state.shuffle(self.audio_indices)
self.sources = sources
self.weights = weights
self.transform = transform
def __call__(
self,
state,
sample_rate: int,
duration: float,
loudness_cutoff: float=-40,
num_channels: int=1,
offset: float=None,
source_idx: int=None,
item_idx: int=None,
global_idx: int=None, ):
if source_idx is not None and item_idx is not None:
try:
audio_info = self.audio_lists[source_idx][item_idx]
except:
audio_info = {"path": "none"}
elif global_idx is not None:
source_idx, item_idx = self.audio_indices[global_idx %
len(self.audio_indices)]
audio_info = self.audio_lists[source_idx][item_idx]
else:
audio_info, source_idx, item_idx = util.choose_from_list_of_lists(
state, self.audio_lists, p=self.weights)
path = audio_info["path"]
signal = AudioSignal.zeros(duration, sample_rate, num_channels)
if path != "none":
if offset is None:
signal = AudioSignal.salient_excerpt(
path,
duration=duration,
state=state,
loudness_cutoff=loudness_cutoff, )
else:
signal = AudioSignal(
path,
offset=offset,
duration=duration, )
if num_channels == 1:
signal = signal.to_mono()
signal = signal.resample(sample_rate)
if signal.duration < duration:
signal = signal.zero_pad_to(int(duration * sample_rate))
for k, v in audio_info.items():
signal.metadata[k] = v
item = {
"signal": signal,
"source_idx": source_idx,
"item_idx": item_idx,
"source": str(self.sources[source_idx]),
"path": str(path),
}
if self.transform is not None:
item["transform_args"] = self.transform.instantiate(
state, signal=signal)
return item
def default_matcher(x, y):
return Path(x).parent == Path(y).parent
def align_lists(lists, matcher: Callable=default_matcher):
longest_list = lists[np.argmax([len(l) for l in lists])]
for i, x in enumerate(longest_list):
for l in lists:
if i >= len(l):
l.append({"path": "none"})
elif not matcher(l[i]["path"], x["path"]):
l.insert(i, {"path": "none"})
return lists
class AudioDataset:
"""Loads audio from multiple loaders (with associated transforms)
for a specified number of samples. Excerpts are drawn randomly
of the specified duration, above a specified loudness threshold
and are resampled on the fly to the desired sample rate
(if it is different from the audio source sample rate).
This takes either a single AudioLoader object,
a dictionary of AudioLoader objects, or a dictionary of AudioLoader
objects. Each AudioLoader is called by the dataset, and the
result is placed in the output dictionary. A transform can also be
specified for the entire dataset, rather than for each specific
loader. This transform can be applied to the output of all the
loaders if desired.
AudioLoader objects can be specified as aligned, which means the
loaders correspond to multitrack audio (e.g. a vocals, bass,
drums, and other loader for multitrack music mixtures).
Parameters
----------
loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]]
AudioLoaders to sample audio from.
sample_rate : int
Desired sample rate.
n_examples : int, optional
Number of examples (length of dataset), by default 1000
duration : float, optional
Duration of audio samples, by default 0.5
loudness_cutoff : float, optional
Loudness cutoff threshold for audio samples, by default -40
num_channels : int, optional
Number of channels in output audio, by default 1
transform : Callable, optional
Transform to instantiate alongside each dataset item, by default None
aligned : bool, optional
Whether the loaders should be sampled in an aligned manner (e.g. same
offset, duration, and matched file name), by default False
shuffle_loaders : bool, optional
Whether to shuffle the loaders before sampling from them, by default False
matcher : Callable
How to match files from adjacent audio lists (e.g. for a multitrack audio loader),
by default uses the parent directory of each file.
without_replacement : bool
Whether to choose files with or without replacement, by default True.
Examples
--------
>>> from audio.audiotools.data.datasets import AudioLoader
>>> from audio.audiotools.data.datasets import AudioDataset
>>> from audio.audiotools import transforms as tfm
>>> import numpy as np
>>>
>>> loaders = [
>>> AudioLoader(
>>> sources=[f"tests/audiotools/audio/spk"],
>>> transform=tfm.Equalizer(),
>>> ext=["wav"],
>>> )
>>> for i in range(5)
>>> ]
>>>
>>> dataset = AudioDataset(
>>> loaders = loaders,
>>> sample_rate = 44100,
>>> duration = 1.0,
>>> transform = tfm.RescaleAudio(),
>>> )
>>>
>>> item = dataset[np.random.randint(len(dataset))]
>>>
>>> for i in range(len(loaders)):
>>> item[i]["signal"] = loaders[i].transform(
>>> item[i]["signal"], **item[i]["transform_args"]
>>> )
>>> item[i]["signal"].widget(i)
>>>
>>> mix = sum([item[i]["signal"] for i in range(len(loaders))])
>>> mix = dataset.transform(mix, **item["transform_args"])
>>> mix.widget("mix")
Below is an example of how one could load MUSDB multitrack data:
>>> from audio import audiotools as at
>>> from pathlib import Path
>>> from audio.audiotools import transforms as tfm
>>> import numpy as np
>>> import torch
>>>
>>> def build_dataset(
>>> sample_rate: int = 44100,
>>> duration: float = 5.0,
>>> musdb_path: str = "~/.data/musdb/",
>>> ):
>>> musdb_path = Path(musdb_path).expanduser()
>>> loaders = {
>>> src: at.datasets.AudioLoader(
>>> sources=[musdb_path],
>>> transform=tfm.Compose(
>>> tfm.VolumeNorm(("uniform", -20, -10)),
>>> tfm.Silence(prob=0.1),
>>> ),
>>> ext=[f"{src}.wav"],
>>> )
>>> for src in ["vocals", "bass", "drums", "other"]
>>> }
>>>
>>> dataset = at.datasets.AudioDataset(
>>> loaders=loaders,
>>> sample_rate=sample_rate,
>>> duration=duration,
>>> num_channels=1,
>>> aligned=True,
>>> transform=tfm.RescaleAudio(),
>>> shuffle_loaders=True,
>>> )
>>> return dataset, list(loaders.keys())
>>>
>>> train_data, sources = build_dataset()
>>> dataloader = torch.utils.data.DataLoader(
>>> train_data,
>>> batch_size=16,
>>> num_workers=0,
>>> collate_fn=train_data.collate,
>>> )
>>> batch = next(iter(dataloader))
>>>
>>> for k in sources:
>>> src = batch[k]
>>> src["transformed"] = train_data.loaders[k].transform(
>>> src["signal"].clone(), **src["transform_args"]
>>> )
>>>
>>> mixture = sum(batch[k]["transformed"] for k in sources)
>>> mixture = train_data.transform(mixture, **batch["transform_args"])
>>>
>>> # Say a model takes the mix and gives back (n_batch, n_src, n_time).
>>> # Construct the targets:
>>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1)
Similarly, here's example code for loading Slakh data:
>>> from audio import audiotools as at
>>> from pathlib import Path
>>> from audio.audiotools import transforms as tfm
>>> import numpy as np
>>> import torch
>>> import glob
>>>
>>> def build_dataset(
>>> sample_rate: int = 16000,
>>> duration: float = 10.0,
>>> slakh_path: str = "~/.data/slakh/",
>>> ):
>>> slakh_path = Path(slakh_path).expanduser()
>>>
>>> # Find the max number of sources in Slakh
>>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)]
>>> n_sources = len(list(set(src_names)))
>>>
>>> loaders = {
>>> f"S{i:02d}": at.datasets.AudioLoader(
>>> sources=[slakh_path],
>>> transform=tfm.Compose(
>>> tfm.VolumeNorm(("uniform", -20, -10)),
>>> tfm.Silence(prob=0.1),
>>> ),
>>> ext=[f"S{i:02d}.wav"],
>>> )
>>> for i in range(n_sources)
>>> }
>>> dataset = at.datasets.AudioDataset(
>>> loaders=loaders,
>>> sample_rate=sample_rate,
>>> duration=duration,
>>> num_channels=1,
>>> aligned=True,
>>> transform=tfm.RescaleAudio(),
>>> shuffle_loaders=False,
>>> )
>>>
>>> return dataset, list(loaders.keys())
>>>
>>> train_data, sources = build_dataset()
>>> dataloader = torch.utils.data.DataLoader(
>>> train_data,
>>> batch_size=16,
>>> num_workers=0,
>>> collate_fn=train_data.collate,
>>> )
>>> batch = next(iter(dataloader))
>>>
>>> for k in sources:
>>> src = batch[k]
>>> src["transformed"] = train_data.loaders[k].transform(
>>> src["signal"].clone(), **src["transform_args"]
>>> )
>>>
>>> mixture = sum(batch[k]["transformed"] for k in sources)
>>> mixture = train_data.transform(mixture, **batch["transform_args"])
"""
def __init__(
self,
loaders: Union[AudioLoader, List[AudioLoader], Dict[str,
AudioLoader]],
sample_rate: int,
n_examples: int=1000,
duration: float=0.5,
offset: float=None,
loudness_cutoff: float=-40,
num_channels: int=1,
transform: Callable=None,
aligned: bool=False,
shuffle_loaders: bool=False,
matcher: Callable=default_matcher,
without_replacement: bool=True, ):
# Internally we convert loaders to a dictionary
if isinstance(loaders, list):
loaders = {i: l for i, l in enumerate(loaders)}
elif isinstance(loaders, AudioLoader):
loaders = {0: loaders}
self.loaders = loaders
self.loudness_cutoff = loudness_cutoff
self.num_channels = num_channels
self.length = n_examples
self.transform = transform
self.sample_rate = sample_rate
self.duration = duration
self.offset = offset
self.aligned = aligned
self.shuffle_loaders = shuffle_loaders
self.without_replacement = without_replacement
if aligned:
loaders_list = list(loaders.values())
for i in range(len(loaders_list[0].audio_lists)):
input_lists = [l.audio_lists[i] for l in loaders_list]
# Alignment happens in-place
align_lists(input_lists, matcher)
def __getitem__(self, idx):
state = util.random_state(idx)
offset = None if self.offset is None else self.offset
item = {}
keys = list(self.loaders.keys())
if self.shuffle_loaders:
state.shuffle(keys)
loader_kwargs = {
"state": state,
"sample_rate": self.sample_rate,
"duration": self.duration,
"loudness_cutoff": self.loudness_cutoff,
"num_channels": self.num_channels,
"global_idx": idx if self.without_replacement else None,
}
# Draw item from first loader
loader = self.loaders[keys[0]]
item[keys[0]] = loader(**loader_kwargs)
for key in keys[1:]:
loader = self.loaders[key]
if self.aligned:
# Path mapper takes the current loader + everything
# returned by the first loader.
offset = item[keys[0]]["signal"].metadata["offset"]
loader_kwargs.update({
"offset": offset,
"source_idx": item[keys[0]]["source_idx"],
"item_idx": item[keys[0]]["item_idx"],
})
item[key] = loader(**loader_kwargs)
# Sort dictionary back into original order
keys = list(self.loaders.keys())
item = {k: item[k] for k in keys}
item["idx"] = idx
if self.transform is not None:
item["transform_args"] = self.transform.instantiate(
state=state, signal=item[keys[0]]["signal"])
# If there's only one loader, pop it up
# to the main dictionary, instead of keeping it
# nested.
if len(keys) == 1:
item.update(item.pop(keys[0]))
return item
def __len__(self):
return self.length
@staticmethod
def collate(list_of_dicts: Union[list, dict], n_splits: int=None):
"""Collates items drawn from this dataset. Uses
:py:func:`audiotools.core.util.collate`.
Parameters
----------
list_of_dicts : typing.Union[list, dict]
Data drawn from each item.
n_splits : int
Number of splits to make when creating the batches (split into
sub-batches). Useful for things like gradient accumulation.
Returns
-------
dict
Dictionary of batched data.
"""
return util.collate(list_of_dicts, n_splits=n_splits)
class ConcatDataset(AudioDataset):
#
def __init__(self, datasets: list):
self.datasets = datasets
def __len__(self):
return sum([len(d) for d in self.datasets])
def __getitem__(self, idx):
dataset = self.datasets[idx % len(self.datasets)]
return dataset[idx // len(self.datasets)]
class ResumableDistributedSampler(DistributedBatchSampler):
"""Distributed sampler that can be resumed from a given start index."""
def __init__(self,
dataset,
batch_size,
start_idx: int=None,
num_replicas=None,
rank=None,
shuffle=False,
drop_last=False):
super().__init__(
dataset=dataset,
batch_size=batch_size,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
drop_last=drop_last, )
# Start index, allows to resume an experiment at the index it was
if start_idx is not None:
self.start_idx = start_idx // self.num_replicas
else:
self.start_idx = 0
# 重新计算样本总数,因为 DistributedBatchSampler 的 __len__ 方法是基于 shuffle 后的样本总数计算的
self.total_size = len(self.dataset) if not shuffle else len(
self.indices)
def __iter__(self):
# 由于 Paddle 的 DistributedBatchSampler 直接返回 batch我们需要将其展开为单个索引
indices_iter = iter(super().__iter__())
# 跳过前面的 start_idx 个 batch
for _ in range(self.start_idx):
next(indices_iter)
current_idx = 0
while True:
batch_indices = next(indices_iter, None)
if batch_indices is None:
break
for idx in batch_indices:
if current_idx >= self.start_idx * self.batch_size: # 调整判断条件,确保从 start_idx 开始
yield idx
current_idx += 1
self.start_idx = 0 # set the index back to 0 so for the next epoch
class ResumableSequentialSampler(SequenceSampler):
"""Sequential sampler that can be resumed from a given start index."""
def __init__(self, dataset, start_idx: int=None, **kwargs):
super().__init__(dataset, **kwargs)
# Start index, allows to resume an experiment at the index it was
self.start_idx = start_idx if start_idx is not None else 0
def __iter__(self):
for i, idx in enumerate(super().__iter__()):
if i >= self.start_idx:
yield idx
self.start_idx = 0 # set the index back to 0 so for the next epoch

@ -0,0 +1,87 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/data/preprocess.py)
import csv
import os
from pathlib import Path
from tqdm import tqdm
from ..core import AudioSignal
def create_csv(audio_files: list,
output_csv: Path,
loudness: bool=False,
data_path: str=None):
"""Converts a folder of audio files to a CSV file. If ``loudness = True``,
the output of this function will create a CSV file that looks something
like:
.. csv-table::
:header: path,loudness
daps/produced/f1_script1_produced.wav,-16.299999237060547
daps/produced/f1_script2_produced.wav,-16.600000381469727
daps/produced/f1_script3_produced.wav,-17.299999237060547
daps/produced/f1_script4_produced.wav,-16.100000381469727
daps/produced/f1_script5_produced.wav,-16.700000762939453
daps/produced/f3_script1_produced.wav,-16.5
.. note::
The paths above are written relative to the ``data_path`` argument
which defaults to the environment variable ``PATH_TO_DATA`` if
it isn't passed to this function, and defaults to the empty string
if that environment variable is not set.
You can produce a CSV file from a directory of audio files via:
>>> from audio import audiotools
>>> directory = ...
>>> audio_files = audiotools.util.find_audio(directory)
>>> output_path = "train.csv"
>>> audiotools.data.preprocess.create_csv(
>>> audio_files, output_csv, loudness=True
>>> )
Note that you can create empty rows in the CSV file by passing an empty
string or None in the ``audio_files`` list. This is useful if you want to
sync multiple CSV files in a multitrack setting. The loudness of these
empty rows will be set to -inf.
Parameters
----------
audio_files : list
List of audio files.
output_csv : Path
Output CSV, with each row containing the relative path of every file
to ``data_path``, if specified (defaults to None).
loudness : bool
Compute loudness of entire file and store alongside path.
"""
info = []
pbar = tqdm(audio_files)
for af in pbar:
af = Path(af)
pbar.set_description(f"Processing {af.name}")
_info = {}
if af.name == "":
_info["path"] = ""
if loudness:
_info["loudness"] = -float("inf")
else:
_info["path"] = af.relative_to(
data_path) if data_path is not None else af
if loudness:
_info["loudness"] = AudioSignal(af).ffmpeg_loudness().item()
info.append(_info)
with open(output_csv, "w") as f:
writer = csv.DictWriter(f, fieldnames=list(info[0].keys()))
writer.writeheader()
for item in info:
writer.writerow(item)

File diff suppressed because it is too large Load Diff

@ -0,0 +1,17 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Functions for comparing AudioSignal objects to one another.
"""
from . import quality

@ -0,0 +1,74 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/quality.py)
import os
import numpy as np
import paddle
from ..core import AudioSignal
def visqol(
estimates: AudioSignal,
references: AudioSignal,
mode: str="audio", ):
"""ViSQOL score.
Parameters
----------
estimates : AudioSignal
Degraded AudioSignal
references : AudioSignal
Reference AudioSignal
mode : str, optional
'audio' or 'speech', by default 'audio'
Returns
-------
Tensor[float]
ViSQOL score (MOS-LQO)
"""
try:
from pyvisqol import visqol_lib_py
from pyvisqol.pb2 import visqol_config_pb2
from pyvisqol.pb2 import similarity_result_pb2
except ImportError:
from visqol import visqol_lib_py
from visqol.pb2 import visqol_config_pb2
from visqol.pb2 import similarity_result_pb2
config = visqol_config_pb2.VisqolConfig()
if mode == "audio":
target_sr = 48000
config.options.use_speech_scoring = False
svr_model_path = "libsvm_nu_svr_model.txt"
elif mode == "speech":
target_sr = 16000
config.options.use_speech_scoring = True
svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
else:
raise ValueError(f"Unrecognized mode: {mode}")
config.audio.sample_rate = target_sr
config.options.svr_model_path = os.path.join(
os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path)
api = visqol_lib_py.VisqolApi()
api.Create(config)
estimates = estimates.clone().to_mono().resample(target_sr)
references = references.clone().to_mono().resample(target_sr)
visqols = []
for i in range(estimates.batch_size):
_visqol = api.Measure(
references.audio_data[i, 0].detach().cpu().numpy().astype(float),
estimates.audio_data[i, 0].detach().cpu().numpy().astype(float), )
visqols.append(_visqol.moslqo)
return paddle.to_tensor(np.array(visqols))
if __name__ == "__main__":
signal = AudioSignal(paddle.randn([44100]), 44100)
print(visqol(signal, signal))

@ -0,0 +1,16 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import decorators
from .accelerator import Accelerator
from .basemodel import BaseModel

@ -0,0 +1,199 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/ml/accelerator.py)
import os
import typing
import paddle
import paddle.distributed as dist
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.io import SequenceSampler
class ResumableDistributedSampler(DistributedBatchSampler):
"""Distributed sampler that can be resumed from a given start index."""
def __init__(self, dataset, start_idx: int=None, **kwargs):
super().__init__(dataset, **kwargs)
# Start index, allows to resume an experiment at the index it was
self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0
def __iter__(self):
for i, idx in enumerate(super().__iter__()):
if i >= self.start_idx:
yield idx
self.start_idx = 0 # set the index back to 0 so for the next epoch
class ResumableSequentialSampler(SequenceSampler):
"""Sequential sampler that can be resumed from a given start index."""
def __init__(self, dataset, start_idx: int=None, **kwargs):
super().__init__(dataset, **kwargs)
# Start index, allows to resume an experiment at the index it was
self.start_idx = start_idx if start_idx is not None else 0
def __iter__(self):
for i, idx in enumerate(super().__iter__()):
if i >= self.start_idx:
yield idx
self.start_idx = 0 # set the index back to 0 so for the next epoch
class Accelerator:
"""This class is used to prepare models and dataloaders for
usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
prepare the respective objects. In the case of models, they are moved to
the appropriate GPU. In the case of
dataloaders, a sampler is created and the dataloader is initialized with
that sampler.
If the world size is 1, prepare_model and prepare_dataloader are
no-ops. If the environment variable ``PADDLE_TRAINER_ID`` is not set, then the
script was launched without ``paddle.distributed.launch``, and ``DataParallel``
will be used instead of ``DistributedDataParallel`` (not recommended), if
the world size (number of GPUs) is greater than 1.
Parameters
----------
amp : bool, optional
Whether or not to enable automatic mixed precision, by default False
(Note: This is a placeholder as PaddlePaddle doesn't have native support for AMP as of now)
"""
def __init__(self, amp: bool=False):
trainer_id = os.getenv("PADDLE_TRAINER_ID", None)
self.world_size = paddle.distributed.get_world_size()
self.use_ddp = self.world_size > 1 and trainer_id is not None
self.use_dp = self.world_size > 1 and trainer_id is None
self.device = "cpu" if self.world_size == 0 else "cuda"
if self.use_ddp:
trainer_id = int(trainer_id)
dist.init_parallel_env()
self.local_rank = 0 if trainer_id is None else int(trainer_id)
self.amp = amp
class DummyScaler:
def __init__(self):
pass
def step(self, optimizer):
optimizer.step()
def scale(self, loss):
return loss
def unscale_(self, optimizer):
return optimizer
def update(self):
pass
self.scaler = paddle.amp.GradScaler() if self.amp else DummyScaler()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def prepare_model(self, model: paddle.nn.Layer, **kwargs):
"""Prepares model for DDP or DP. The model is moved to
the device of the correct rank.
Parameters
----------
model : paddle.nn.Layer
Model that is converted for DDP or DP.
Returns
-------
paddle.nn.Layer
Wrapped model, or original model if DDP and DP are turned off.
"""
if self.use_ddp:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = paddle.DataParallel(model, **kwargs)
elif self.use_dp:
model = paddle.DataParallel(model, **kwargs)
return model
def autocast(self, *args, **kwargs):
return paddle.amp.auto_cast(self.amp, *args, **kwargs)
def backward(self, loss: paddle.Tensor):
"""Backwards pass.
Parameters
----------
loss : paddle.Tensor
Loss value.
"""
scaled = self.scaler.scale(loss) # scale the loss
scaled.backward()
def step(self, optimizer: paddle.optimizer.Optimizer):
"""Steps the optimizer.
Parameters
----------
optimizer : paddle.optimizer.Optimizer
Optimizer to step forward.
"""
self.scaler.step(optimizer)
def update(self):
# https://www.paddlepaddle.org.cn/documentation/docs/zh/2.6/api/paddle/amp/GradScaler_cn.html#step-optimizer
self.scaler.update()
def prepare_dataloader(self,
dataset: typing.Iterable,
start_idx: int=None,
**kwargs):
"""Wraps a dataset with a DataLoader, using the correct sampler if DDP is
enabled.
Parameters
----------
dataset : typing.Iterable
Dataset to build Dataloader around.
start_idx : int, optional
Start index of sampler, useful if resuming from some epoch,
by default None
Returns
-------
DataLoader
Wrapped DataLoader.
"""
if self.use_ddp:
sampler = ResumableDistributedSampler(
dataset,
start_idx,
batch_size=kwargs.get("batch_size", 1),
shuffle=kwargs.get("shuffle", True),
drop_last=kwargs.get("drop_last", False),
num_replicas=self.world_size,
rank=self.local_rank, )
if "num_workers" in kwargs:
kwargs["num_workers"] = max(kwargs["num_workers"] //
self.world_size, 1)
else:
sampler = ResumableSequentialSampler(dataset, start_idx)
dataloader = DataLoader(
dataset,
batch_sampler=sampler if self.use_ddp else None,
sampler=sampler if not self.use_ddp else None,
**kwargs, )
return dataloader
@staticmethod
def unwrap(model):
return model

@ -0,0 +1,272 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/ml/layers/base.py)
import inspect
import shutil
import tempfile
import typing
from pathlib import Path
import paddle
from paddle import nn
class BaseModel(nn.Layer):
"""This is a class that adds useful save/load functionality to a
``paddle.nn.Layer`` object. ``BaseModel`` objects can be saved
as ``package`` easily, making them super easy to port between
machines without requiring a ton of dependencies. Files can also be
saved as just weights, in the standard way.
>>> class Model(ml.BaseModel):
>>> def __init__(self, arg1: float = 1.0):
>>> super().__init__()
>>> self.arg1 = arg1
>>> self.linear = nn.Linear(1, 1)
>>>
>>> def forward(self, x):
>>> return self.linear(x)
>>>
>>> model1 = Model()
>>>
>>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
>>> model1.save(
>>> f.name,
>>> )
>>> model2 = Model.load(f.name)
>>> out2 = seed_and_run(model2, x)
>>> assert paddle.allclose(out1, out2)
>>>
>>> model1.save(f.name, package=True)
>>> model2 = Model.load(f.name)
>>> model2.save(f.name, package=False)
>>> model3 = Model.load(f.name)
>>> out3 = seed_and_run(model3, x)
>>>
>>> with tempfile.TemporaryDirectory() as d:
>>> model1.save_to_folder(d, {"data": 1.0})
>>> Model.load_from_folder(d)
"""
def save(
self,
path: str,
metadata: dict=None,
package: bool=False,
intern: list=[],
extern: list=[],
mock: list=[], ):
"""Saves the model, either as a package, or just as
weights, alongside some specified metadata.
Parameters
----------
path : str
Path to save model to.
metadata : dict, optional
Any metadata to save alongside the model,
by default None
package : bool, optional
Whether to use ``package`` to save the model in
a format that is portable, by default True
intern : list, optional
List of additional libraries that are internal
to the model, used with package, by default []
extern : list, optional
List of additional libraries that are external to
the model, used with package, by default []
mock : list, optional
List of libraries to mock, used with package,
by default []
Returns
-------
str
Path to saved model.
"""
sig = inspect.signature(self.__class__)
args = {}
for key, val in sig.parameters.items():
arg_val = val.default
if arg_val is not inspect.Parameter.empty:
args[key] = arg_val
# Look up attibutes in self, and if any of them are in args,
# overwrite them in args.
for attribute in dir(self):
if attribute in args:
args[attribute] = getattr(self, attribute)
metadata = {} if metadata is None else metadata
metadata["kwargs"] = args
if not hasattr(self, "metadata"):
self.metadata = {}
self.metadata.update(metadata)
if not package:
state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
paddle.save(state_dict, str(path))
else:
self._save_package(path, intern=intern, extern=extern, mock=mock)
return path
@property
def device(self):
"""Gets the device the model is on by looking at the device of
the first parameter. May not be valid if model is split across
multiple devices.
"""
return list(self.parameters())[0].place
@classmethod
def load(
cls,
location: str,
*args,
package_name: str=None,
strict: bool=False,
**kwargs, ):
"""Load model from a path. Tries first to load as a package, and if
that fails, tries to load as weights. The arguments to the class are
specified inside the model weights file.
Parameters
----------
location : str
Path to file.
package_name : str, optional
Name of package, by default ``cls.__name__``.
strict : bool, optional
Ignore unmatched keys, by default False
kwargs : dict
Additional keyword arguments to the model instantiation, if
not loading from package.
Returns
-------
BaseModel
A model that inherits from BaseModel.
"""
try:
model = cls._load_package(location, package_name=package_name)
except:
model_dict = paddle.load(location)
metadata = model_dict["metadata"]
metadata["kwargs"].update(kwargs)
sig = inspect.signature(cls)
class_keys = list(sig.parameters.keys())
for k in list(metadata["kwargs"].keys()):
if k not in class_keys:
metadata["kwargs"].pop(k)
model = cls(*args, **metadata["kwargs"])
model.set_state_dict(model_dict["state_dict"])
model.metadata = metadata
return model
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
raise NotImplementedError("Currently Paddle does not support packaging")
@classmethod
def _load_package(cls, path, package_name=None):
raise NotImplementedError("Currently Paddle does not support packaging")
def save_to_folder(
self,
folder: typing.Union[str, Path],
extra_data: dict=None,
package: bool=False, ):
"""Dumps a model into a folder, as both a package
and as weights, as well as anything specified in
``extra_data``. ``extra_data`` is a dictionary of other
pickleable files, with the keys being the paths
to save them in. The model is saved under a subfolder
specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
if the model name was ``Generator``).
>>> with tempfile.TemporaryDirectory() as d:
>>> extra_data = {
>>> "optimizer.pth": optimizer.state_dict()
>>> }
>>> model.save_to_folder(d, extra_data)
>>> Model.load_from_folder(d)
Parameters
----------
folder : typing.Union[str, Path]
_description_
extra_data : dict, optional
_description_, by default None
Returns
-------
str
Path to folder
"""
extra_data = {} if extra_data is None else extra_data
model_name = type(self).__name__.lower()
target_base = Path(f"{folder}/{model_name}/")
target_base.mkdir(exist_ok=True, parents=True)
if package:
package_path = target_base / f"package.pth"
self.save(package_path)
weights_path = target_base / f"weights.pth"
self.save(weights_path, package=False)
for path, obj in extra_data.items():
paddle.save(obj, str(target_base / path))
return target_base
@classmethod
def load_from_folder(
cls,
folder: typing.Union[str, Path],
package: bool=False,
strict: bool=False,
**kwargs, ):
"""Loads the model from a folder generated by
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
Like that function, this one looks for a subfolder that has
the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
model name was ``Generator``).
Parameters
----------
folder : typing.Union[str, Path]
_description_
package : bool, optional
Whether to use ``package`` to load the model,
loading the model from ``package.pth``.
strict : bool, optional
Ignore unmatched keys, by default False
Returns
-------
tuple
tuple of model and extra data as saved by
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
"""
folder = Path(folder) / cls.__name__.lower()
model_pth = "package.pth" if package else "weights.pth"
model_pth = folder / model_pth
model = cls.load(str(model_pth))
extra_data = {}
excluded = ["package.pth", "weights.pth"]
files = [
x for x in folder.glob("*")
if x.is_file() and x.name not in excluded
]
for f in files:
extra_data[f.name] = paddle.load(str(f), **kwargs)
return model, extra_data

@ -0,0 +1,446 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/ml/decorators.py)
import math
import os
import time
from collections import defaultdict
from functools import wraps
import paddle
import paddle.distributed as dist
from rich import box
from rich.console import Console
from rich.console import Group
from rich.live import Live
from rich.markdown import Markdown
from rich.padding import Padding
from rich.panel import Panel
from rich.progress import BarColumn
from rich.progress import Progress
from rich.progress import SpinnerColumn
from rich.progress import TimeElapsedColumn
from rich.progress import TimeRemainingColumn
from rich.rule import Rule
from rich.table import Table
from visualdl import LogWriter
# This is here so that the history can be pickled.
def default_list():
return []
class Mean:
"""Keeps track of the running mean, along with the latest
value.
"""
def __init__(self):
self.reset()
def __call__(self):
mean = self.total / max(self.count, 1)
return mean
def reset(self):
self.count = 0
self.total = 0
def update(self, val):
if math.isfinite(val):
self.count += 1
self.total += val
def when(condition):
"""Runs a function only when the condition is met. The condition is
a function that is run.
Parameters
----------
condition : Callable
Function to run to check whether or not to run the decorated
function.
Example
-------
Checkpoint only runs every 100 iterations, and only if the
local rank is 0.
>>> i = 0
>>> rank = 0
>>>
>>> @when(lambda: i % 100 == 0 and rank == 0)
>>> def checkpoint():
>>> print("Saving to /runs/exp1")
>>>
>>> for i in range(1000):
>>> checkpoint()
"""
def decorator(fn):
@wraps(fn)
def decorated(*args, **kwargs):
if condition():
return fn(*args, **kwargs)
return decorated
return decorator
def timer(prefix: str="time"):
"""Adds execution time to the output dictionary of the decorated
function. The function decorated by this must output a dictionary.
The key added will follow the form "[prefix]/[name_of_function]"
Parameters
----------
prefix : str, optional
The key added will follow the form "[prefix]/[name_of_function]",
by default "time".
"""
def decorator(fn):
@wraps(fn)
def decorated(*args, **kwargs):
s = time.perf_counter()
output = fn(*args, **kwargs)
assert isinstance(output, dict)
e = time.perf_counter()
output[f"{prefix}/{fn.__name__}"] = e - s
return output
return decorated
return decorator
class Tracker:
"""
A tracker class that helps to monitor the progress of training and logging the metrics.
Attributes
----------
metrics : dict
A dictionary containing the metrics for each label.
history : dict
A dictionary containing the history of metrics for each label.
writer : LogWriter
A LogWriter object for logging the metrics.
rank : int
The rank of the current process.
step : int
The current step of the training.
tasks : dict
A dictionary containing the progress bars and tables for each label.
pbar : Progress
A progress bar object for displaying the progress.
consoles : list
A list of console objects for logging.
live : Live
A Live object for updating the display live.
Methods
-------
print(msg: str)
Prints the given message to all consoles.
update(label: str, fn_name: str)
Updates the progress bar and table for the given label.
done(label: str, title: str)
Resets the progress bar and table for the given label and prints the final result.
track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
A decorator for tracking the progress and metrics of a function.
log(label: str, value_type: str = "value", history: bool = True)
A decorator for logging the metrics of a function.
is_best(label: str, key: str) -> bool
Checks if the latest value of the given key in the label is the best so far.
state_dict() -> dict
Returns a dictionary containing the state of the tracker.
load_state_dict(state_dict: dict) -> Tracker
Loads the state of the tracker from the given state dictionary.
"""
def __init__(
self,
writer: LogWriter=None,
log_file: str=None,
rank: int=0,
console_width: int=100,
step: int=0, ):
"""
Initializes the Tracker object.
Parameters
----------
writer : LogWriter, optional
A LogWriter object for logging the metrics, by default None.
log_file : str, optional
The path to the log file, by default None.
rank : int, optional
The rank of the current process, by default 0.
console_width : int, optional
The width of the console, by default 100.
step : int, optional
The current step of the training, by default 0.
"""
self.metrics = {}
self.history = {}
self.writer = writer
self.rank = rank
self.step = step
# Create progress bars etc.
self.tasks = {}
self.pbar = Progress(
SpinnerColumn(),
"[progress.description]{task.description}",
"{task.completed}/{task.total}",
BarColumn(),
TimeElapsedColumn(),
"/",
TimeRemainingColumn(), )
self.consoles = [Console(width=console_width)]
self.live = Live(console=self.consoles[0], refresh_per_second=10)
if log_file is not None:
self.consoles.append(
Console(width=console_width, file=open(log_file, "a")))
def print(self, msg):
"""
Prints the given message to all consoles.
Parameters
----------
msg : str
The message to be printed.
"""
if self.rank == 0:
for c in self.consoles:
c.log(msg)
def update(self, label, fn_name):
"""
Updates the progress bar and table for the given label.
Parameters
----------
label : str
The label of the progress bar and table to be updated.
fn_name : str
The name of the function associated with the label.
"""
if self.rank == 0:
self.pbar.advance(self.tasks[label]["pbar"])
# Create table
table = Table(title=label, expand=True, box=box.MINIMAL)
table.add_column("key", style="cyan")
table.add_column("value", style="bright_blue")
table.add_column("mean", style="bright_green")
keys = self.metrics[label]["value"].keys()
for k in keys:
value = self.metrics[label]["value"][k]
mean = self.metrics[label]["mean"][k]()
table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
self.tasks[label]["table"] = table
tables = [t["table"] for t in self.tasks.values()]
group = Group(*tables, self.pbar)
self.live.update(
Group(
Padding("", (0, 0)),
Rule(f"[italic]{fn_name}()", style="white"),
Padding("", (0, 0)),
Panel.fit(
group,
padding=(0, 5),
title="[b]Progress",
border_style="blue", ), ))
def done(self, label: str, title: str):
"""
Resets the progress bar and table for the given label and prints the final result.
Parameters
----------
label : str
The label of the progress bar and table to be reset.
title : str
The title to be displayed when printing the final result.
"""
for label in self.metrics:
for v in self.metrics[label]["mean"].values():
v.reset()
if self.rank == 0:
self.pbar.reset(self.tasks[label]["pbar"])
tables = [t["table"] for t in self.tasks.values()]
group = Group(Markdown(f"# {title}"), *tables, self.pbar)
self.print(group)
def track(
self,
label: str,
length: int,
completed: int=0,
op: dist.ReduceOp=dist.ReduceOp.AVG,
ddp_active: bool="LOCAL_RANK" in os.environ, ):
"""
A decorator for tracking the progress and metrics of a function.
Parameters
----------
label : str
The label to be associated with the progress and metrics.
length : int
The total number of iterations to be completed.
completed : int, optional
The number of iterations already completed, by default 0.
op : dist.ReduceOp, optional
The reduce operation to be used, by default dist.ReduceOp.AVG.
ddp_active : bool, optional
Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
"""
self.tasks[label] = {
"pbar":
self.pbar.add_task(
f"[white]Iteration ({label})",
total=length,
completed=completed),
"table":
Table(),
}
self.metrics[label] = {
"value": defaultdict(),
"mean": defaultdict(lambda: Mean()),
}
def decorator(fn):
@wraps(fn)
def decorated(*args, **kwargs):
output = fn(*args, **kwargs)
if not isinstance(output, dict):
self.update(label, fn.__name__)
return output
# Collect across all DDP processes
scalar_keys = []
for k, v in output.items():
if isinstance(v, (int, float)):
v = paddle.to_tensor([v])
if not paddle.is_tensor(v):
continue
if ddp_active and v.is_cuda:
dist.all_reduce(v, op=op)
output[k] = v.detach()
if paddle.numel(v) == 1:
scalar_keys.append(k)
output[k] = v.item()
# Save the outputs to tracker
for k, v in output.items():
if k not in scalar_keys:
continue
self.metrics[label]["value"][k] = v
# Update the running mean
self.metrics[label]["mean"][k].update(v)
self.update(label, fn.__name__)
return output
return decorated
return decorator
def log(self, label: str, value_type: str="value", history: bool=True):
"""
A decorator for logging the metrics of a function.
Parameters
----------
label : str
The label to be associated with the logging.
value_type : str, optional
The type of value to be logged, by default "value".
history : bool, optional
Whether to save the history of the metrics, by default True.
"""
assert value_type in ["mean", "value"]
if history:
if label not in self.history:
self.history[label] = defaultdict(default_list)
def decorator(fn):
@wraps(fn)
def decorated(*args, **kwargs):
output = fn(*args, **kwargs)
if self.rank == 0:
nonlocal value_type, label
metrics = self.metrics[label][value_type]
for k, v in metrics.items():
v = v() if isinstance(v, Mean) else v
if self.writer is not None:
self.writer.add_scalar(
tag=f"{k}/{label}", value=v, step=self.step)
if label in self.history:
self.history[label][k].append(v)
if label in self.history:
self.history[label]["step"].append(self.step)
return output
return decorated
return decorator
def is_best(self, label, key):
"""
Checks if the latest value of the given key in the label is the best so far.
Parameters
----------
label : str
The label of the metrics to be checked.
key : str
The key of the metric to be checked.
Returns
-------
bool
True if the latest value is the best so far, otherwise False.
"""
return self.history[label][key][-1] == min(self.history[label][key])
def state_dict(self):
"""
Returns a dictionary containing the state of the tracker.
Returns
-------
dict
A dictionary containing the history and step of the tracker.
"""
return {"history": self.history, "step": self.step}
def load_state_dict(self, state_dict):
"""
Loads the state of the tracker from the given state dictionary.
Parameters
----------
state_dict : dict
A dictionary containing the history and step of the tracker.
Returns
-------
Tracker
The tracker object with the loaded state.
"""
self.history = state_dict["history"]
self.step = state_dict["step"]
return self

@ -0,0 +1,88 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/post.py)
import typing
import paddle
from audio.audiotools.core import AudioSignal
def audio_table(
audio_dict: dict,
first_column: str=None,
format_fn: typing.Callable=None,
**kwargs, ):
"""Embeds an audio table into HTML, or as the output cell
in a notebook.
Parameters
----------
audio_dict : dict
Dictionary of data to embed.
first_column : str, optional
The label for the first column of the table, by default None
format_fn : typing.Callable, optional
How to format the data, by default None
Returns
-------
str
Table as a string
Examples
--------
>>> audio_dict = {}
>>> for i in range(signal_batch.batch_size):
>>> audio_dict[i] = {
>>> "input": signal_batch[i],
>>> "output": output_batch[i]
>>> }
>>> audiotools.post.audio_zip(audio_dict)
"""
output = []
columns = None
def _default_format_fn(label, x, **kwargs):
if paddle.is_tensor(x):
x = x.tolist()
if x is None:
return "."
elif isinstance(x, AudioSignal):
return x.embed(display=False, return_html=True, **kwargs)
else:
return str(x)
if format_fn is None:
format_fn = _default_format_fn
if first_column is None:
first_column = "."
for k, v in audio_dict.items():
if not isinstance(v, dict):
v = {"Audio": v}
v_keys = list(v.keys())
if columns is None:
columns = [first_column] + v_keys
output.append(" | ".join(columns))
layout = "|---" + len(v_keys) * "|:-:"
output.append(layout)
formatted_audio = []
for col in columns[1:]:
formatted_audio.append(format_fn(col, v[col], **kwargs))
row = f"| {k} | "
row += " | ".join(formatted_audio)
output.append(row)
output = "\n" + "\n".join(output)
return output

@ -0,0 +1,6 @@
ffmpeg-python
ffmpy
flatten_dict
pyloudnorm
pytest
rich

@ -0,0 +1,615 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_audio_signal.py)
import pathlib
import sys
import tempfile
import librosa
import numpy as np
import paddle
import pytest
import rich
from audio import audiotools
from audio.audiotools import AudioSignal
from audio.audiotools import util
def test_io():
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(pathlib.Path(audio_path))
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
signal.write(f.name)
signal_from_file = AudioSignal(f.name)
mp3_signal = AudioSignal(audio_path.replace("wav", "mp3"))
print(mp3_signal)
assert signal == signal_from_file
print(signal)
print(signal.markdown())
mp3_signal = AudioSignal.excerpt(
audio_path.replace("wav", "mp3"), offset=5, duration=5)
assert mp3_signal.signal_duration == 5.0
assert mp3_signal.duration == 5.0
assert mp3_signal.length == mp3_signal.signal_length
rich.print(signal)
array = np.random.randn(2, 16000)
signal = AudioSignal(array, sample_rate=16000)
assert np.allclose(signal.numpy(), array)
signal = AudioSignal(array, 44100)
assert signal.sample_rate == 44100
signal.shape
with pytest.raises(ValueError):
signal = AudioSignal(5, sample_rate=16000)
signal = AudioSignal(audio_path, offset=10, duration=10)
assert np.allclose(signal.signal_duration, 10.0)
assert np.allclose(signal.duration, 10.0)
signal = AudioSignal.excerpt(audio_path, offset=5, duration=5)
assert signal.signal_duration == 5.0
assert signal.duration == 5.0
assert "offset" in signal.metadata
assert "duration" in signal.metadata
signal = AudioSignal(paddle.randn([1000]), 44100)
assert signal.audio_data.ndim == 3
assert paddle.all(signal.samples == signal.audio_data)
audio_path = "./audio/spk/f10_script4_produced.wav"
assert AudioSignal(audio_path).hash() == AudioSignal(audio_path).hash()
assert AudioSignal(audio_path).hash() != AudioSignal(audio_path).normalize(
-20).hash()
with pytest.raises(RuntimeError):
AudioSignal(audio_path, offset=100000, duration=3)
def test_copy_and_clone():
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path)
signal.stft()
signal.loudness()
copied = signal.copy()
deep_copied = signal.deepcopy()
cloned = signal.clone()
for a in ["audio_data", "stft_data", "_loudness"]:
a1 = getattr(signal, a)
a2 = getattr(cloned, a)
a3 = getattr(copied, a)
a4 = getattr(deep_copied, a)
assert id(a1) != id(a2)
assert id(a1) == id(a3)
assert id(a1) != id(a4)
assert np.allclose(a1, a2)
assert np.allclose(a1, a3)
assert np.allclose(a1, a4)
for a in ["path_to_file", "metadata"]:
a1 = getattr(signal, a)
a2 = getattr(cloned, a)
a3 = getattr(copied, a)
a4 = getattr(deep_copied, a)
assert id(a1) == id(a2) if isinstance(a1, str) else id(a1) != id(a2)
assert id(a1) == id(a3)
assert id(a1) == id(a4) if isinstance(a1, str) else id(a1) != id(a2)
# for clone, id should differ if path is list, and should differ always for metadata
# if path is string, id should remain same...
assert signal.original_signal_length == copied.original_signal_length
assert signal.original_signal_length == deep_copied.original_signal_length
assert signal.original_signal_length == cloned.original_signal_length
signal = signal.detach()
@pytest.mark.parametrize("loudness_cutoff", [-np.inf, -160, -80, -40, -20])
def test_salient_excerpt(loudness_cutoff):
MAP = {-np.inf: 0.0, -160: 0.0, -80: 0.001, -40: 0.01, -20: 0.1}
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
sr = 44100
signal = AudioSignal(paddle.zeros([sr * 60]), sr)
signal[..., sr * 20:sr * 21] = MAP[loudness_cutoff] * paddle.randn(
[44100])
signal.write(f.name)
signal = AudioSignal.salient_excerpt(
f.name, loudness_cutoff=loudness_cutoff, duration=1, num_tries=None)
assert "offset" in signal.metadata
assert "duration" in signal.metadata
assert signal.loudness() >= loudness_cutoff
signal = AudioSignal.salient_excerpt(
f.name, loudness_cutoff=np.inf, duration=1, num_tries=10)
signal = AudioSignal.salient_excerpt(
f.name,
loudness_cutoff=None,
duration=1, )
def test_arithmetic():
def _make_signals():
array = np.random.randn(2, 16000)
sig1 = AudioSignal(array, sample_rate=16000)
array = np.random.randn(2, 16000)
sig2 = AudioSignal(array, sample_rate=16000)
return sig1, sig2
# Addition (with a copy)
sig1, sig2 = _make_signals()
sig3 = sig1 + sig2
assert paddle.allclose(sig3.audio_data, sig1.audio_data + sig2.audio_data)
# Addition (rmul)
sig1, _ = _make_signals()
sig3 = 5.0 + sig1
assert paddle.allclose(sig3.audio_data, sig1.audio_data + 5.0)
# In place addition
sig3, sig2 = _make_signals()
sig1 = sig3.deepcopy()
sig3 += sig2
assert paddle.allclose(sig3.audio_data, sig1.audio_data + sig2.audio_data)
# Subtraction (with a copy)
sig1, sig2 = _make_signals()
sig3 = sig1 - sig2
assert paddle.allclose(sig3.audio_data, sig1.audio_data - sig2.audio_data)
# In place subtraction
sig3, sig2 = _make_signals()
sig1 = sig3.deepcopy()
sig3 -= sig2
assert paddle.allclose(sig3.audio_data, sig1.audio_data - sig2.audio_data)
# Multiplication (element-wise)
sig1, sig2 = _make_signals()
sig3 = sig1 * sig2
assert paddle.allclose(sig3.audio_data, sig1.audio_data * sig2.audio_data)
# Multiplication (gain)
sig1, _ = _make_signals()
sig3 = sig1 * 5.0
assert paddle.allclose(sig3.audio_data, sig1.audio_data * 5.0)
# Multiplication (rmul)
sig1, _ = _make_signals()
sig3 = 5.0 * sig1
assert paddle.allclose(sig3.audio_data, sig1.audio_data * 5.0)
# Multiplication (in-place)
sig3, sig2 = _make_signals()
sig1 = sig3.deepcopy()
sig3 *= sig2
assert paddle.allclose(sig3.audio_data, sig1.audio_data * sig2.audio_data)
def test_equality():
array = np.random.randn(2, 16000)
sig1 = AudioSignal(array, sample_rate=16000)
sig2 = AudioSignal(array, sample_rate=16000)
assert sig1 == sig2
array = np.random.randn(2, 16000)
sig3 = AudioSignal(array, sample_rate=16000)
assert sig1 != sig3
assert not np.allclose(sig1.numpy(), sig3.numpy())
def test_indexing():
array = np.random.randn(4, 2, 16000)
sig1 = AudioSignal(array, sample_rate=16000)
assert np.allclose(sig1[0].audio_data, array[0])
assert np.allclose(sig1[0, :, 8000].audio_data, array[0, :, 8000])
# Test with the associated STFT data.
array = np.random.randn(4, 2, 16000)
sig1 = AudioSignal(array, sample_rate=16000)
sig1.loudness()
sig1.stft()
indexed = sig1[0]
assert np.allclose(indexed.audio_data, array[0])
assert np.allclose(indexed.stft_data, sig1.stft_data[0])
assert np.allclose(indexed._loudness, sig1._loudness[0])
indexed = sig1[0:2]
assert np.allclose(indexed.audio_data, array[0:2])
assert np.allclose(indexed.stft_data, sig1.stft_data[0:2])
assert np.allclose(indexed._loudness, sig1._loudness[0:2])
# Test using a boolean tensor to index batch
mask = paddle.to_tensor([True, False, True, False])
indexed = sig1[mask]
assert np.allclose(indexed.audio_data, sig1.audio_data[mask])
# assert np.allclose(indexed.stft_data, sig1.stft_data[mask])
assert np.allclose(indexed.stft_data,
util.bool_index_compat(sig1.stft_data, mask))
assert np.allclose(indexed._loudness, sig1._loudness[mask])
# Set parts of signal using tensor
other_array = paddle.to_tensor(np.random.randn(4, 2, 16000))
sig1 = AudioSignal(array, sample_rate=16000)
sig1[0, :, 6000:8000] = other_array[0, :, 6000:8000]
assert np.allclose(sig1[0, :, 6000:8000].audio_data,
other_array[0, :, 6000:8000])
# Set parts of signal using AudioSignal
sig2 = AudioSignal(other_array, sample_rate=16000)
sig1 = AudioSignal(array, sample_rate=16000)
sig1[0, :, 6000:8000] = sig2[0, :, 6000:8000]
assert np.allclose(sig1[0, :, 6000:8000].audio_data,
sig2[0, :, 6000:8000].audio_data)
# Check that loudnesses and stft_data get set as well, if only the batch
# dim is indexed.
sig2 = AudioSignal(other_array, sample_rate=16000)
sig2.stft()
sig2.loudness()
sig1 = AudioSignal(array, sample_rate=16000)
sig1.stft()
sig1.loudness()
# Test using a boolean tensor to index batch
mask = paddle.to_tensor([True, False, True, False])
sig1[mask] = sig2[mask]
for k in ["stft_data", "audio_data", "_loudness"]:
a1 = getattr(sig1, k)
a2 = getattr(sig2, k)
# assert np.allclose(a1[mask], a2[mask])
assert np.allclose(
util.bool_index_compat(a1, mask), util.bool_index_compat(a2, mask))
def test_zeros():
x = AudioSignal.zeros(0.5, 44100)
assert x.signal_duration == 0.5
assert x.duration == 0.5
assert x.sample_rate == 44100
@pytest.mark.parametrize("shape",
["sine", "square", "sawtooth", "triangle", "beep"])
def test_waves(shape: str):
# error case
if shape == "beep":
with pytest.raises(ValueError):
AudioSignal.wave(440, 0.5, 44100, shape=shape)
return
x = AudioSignal.wave(440, 0.5, 44100, shape=shape)
assert x.duration == 0.5
assert x.sample_rate == 44100
# test the default shape arg
x = AudioSignal.wave(440, 0.5, 44100)
assert x.duration == 0.5
assert x.sample_rate == 44100
def test_zero_pad():
array = np.random.randn(4, 2, 16000)
sig1 = AudioSignal(array, sample_rate=16000)
sig1.zero_pad(100, 100)
zeros = paddle.zeros([4, 2, 100], dtype="float64")
assert paddle.allclose(sig1.audio_data[..., :100], zeros)
assert paddle.allclose(sig1.audio_data[..., -100:], zeros)
def test_zero_pad_to():
array = np.random.randn(4, 2, 16000)
sig1 = AudioSignal(array, sample_rate=16000)
sig1.zero_pad_to(16100)
zeros = paddle.zeros([4, 2, 100], dtype="float64")
assert paddle.allclose(sig1.audio_data[..., -100:], zeros)
assert sig1.signal_length == 16100
sig1 = AudioSignal(array, sample_rate=16000)
sig1.zero_pad_to(15000)
assert sig1.signal_length == 16000
sig1 = AudioSignal(array, sample_rate=16000)
sig1.zero_pad_to(16100, mode="before")
zeros = paddle.zeros([4, 2, 100], dtype="float64")
assert paddle.allclose(sig1.audio_data[..., :100], zeros)
assert sig1.signal_length == 16100
sig1 = AudioSignal(array, sample_rate=16000)
sig1.zero_pad_to(15000, mode="before")
assert sig1.signal_length == 16000
def test_truncate():
array = np.random.randn(4, 2, 16000)
sig1 = AudioSignal(array, sample_rate=16000)
sig1.truncate_samples(100)
assert sig1.signal_length == 100
assert np.allclose(sig1.audio_data, array[..., :100])
def test_trim():
array = np.random.randn(4, 2, 16000)
sig1 = AudioSignal(array, sample_rate=16000)
sig1.trim(100, 100)
assert sig1.signal_length == 16000 - 200
assert np.allclose(sig1.audio_data, array[..., 100:-100])
array = np.random.randn(4, 2, 16000)
sig1 = AudioSignal(array, sample_rate=16000)
sig1.trim(0, 0)
assert np.allclose(sig1.audio_data, array)
def test_to_from_ops():
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path)
signal.stft()
signal.loudness()
signal = signal.to("cpu")
assert str(signal.audio_data.place) == "Place(cpu)"
assert isinstance(signal.numpy(), np.ndarray)
signal.cpu()
# signal.cuda()
signal.float()
def test_device():
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path)
signal.to("cpu")
assert str(signal.device) == "Place(cpu)"
@pytest.mark.parametrize("window_length", [2048, 512])
@pytest.mark.parametrize("hop_length", [512, 128])
@pytest.mark.parametrize("window_type", ["sqrt_hann", "hann", None])
def test_stft(window_length, hop_length, window_type):
if hop_length >= window_length:
hop_length = window_length // 2
audio_path = "./audio/spk/f10_script4_produced.wav"
stft_params = audiotools.STFTParams(
window_length=window_length,
hop_length=hop_length,
window_type=window_type)
for _stft_params in [None, stft_params]:
signal = AudioSignal(audio_path, duration=10, stft_params=_stft_params)
with pytest.raises(RuntimeError):
signal.istft()
stft_data = signal.stft()
# assert paddle.allclose(signal.stft_data, stft_data)
assert np.allclose(signal.stft_data.cpu().numpy(),
stft_data.cpu().numpy())
copied_signal = signal.deepcopy()
copied_signal.stft()
copied_signal = copied_signal.istft()
assert copied_signal == signal
mag = signal.magnitude
phase = signal.phase
recon_stft = mag * util.exp_compat(1j * phase)
# assert paddle.allclose(recon_stft, signal.stft_data)
assert np.allclose(recon_stft.cpu().numpy(),
signal.stft_data.cpu().numpy())
signal.stft_data = None
mag = signal.magnitude
signal.stft_data = None
phase = signal.phase
recon_stft = mag * util.exp_compat(1j * phase)
# assert paddle.allclose(recon_stft, signal.stft_data)
assert np.allclose(recon_stft.cpu().numpy(),
signal.stft_data.cpu().numpy())
# Test with match_stride=True, ignoring the beginning and end.
s = signal.stft_params
if s.hop_length == s.window_length // 4:
og_signal = signal.clone()
stft_data = signal.stft(match_stride=True)
recon_data = signal.istft(match_stride=True)
discard = window_length * 2
right_pad, _ = signal.compute_stft_padding(
s.window_length, s.hop_length, match_stride=True)
length = signal.signal_length + right_pad
assert stft_data.shape[-1] == length // s.hop_length
assert paddle.allclose(
recon_data.audio_data[..., discard:-discard],
og_signal.audio_data[..., discard:-discard],
atol=1e-6, )
def test_log_magnitude():
audio_path = "./audio/spk/f10_script4_produced.wav"
for _ in range(10):
signal = AudioSignal.excerpt(audio_path, duration=5.0)
magnitude = signal.magnitude.numpy()[0, 0]
librosa_log_mag = librosa.amplitude_to_db(magnitude)
log_mag = signal.log_magnitude().numpy()[0, 0]
# print(abs((log_mag - librosa_log_mag)).max())
assert np.allclose(log_mag, librosa_log_mag, atol=10e-7)
@pytest.mark.parametrize("n_mels", [40, 80, 128])
@pytest.mark.parametrize("window_length", [2048, 512])
@pytest.mark.parametrize("hop_length", [512, 128])
@pytest.mark.parametrize("window_type", ["sqrt_hann", "hann", None])
def test_mel_spectrogram(n_mels, window_length, hop_length, window_type):
if hop_length >= window_length:
hop_length = window_length // 2
audio_path = "./audio/spk/f10_script4_produced.wav"
stft_params = audiotools.STFTParams(
window_length=window_length,
hop_length=hop_length,
window_type=window_type)
for _stft_params in [None, stft_params]:
signal = AudioSignal(audio_path, duration=10, stft_params=_stft_params)
mel_spec = signal.mel_spectrogram(n_mels=n_mels)
assert mel_spec.shape[2] == n_mels
@pytest.mark.parametrize("n_mfcc", [20, 40])
@pytest.mark.parametrize("n_mels", [40, 80, 128])
@pytest.mark.parametrize("window_length", [2048, 512])
@pytest.mark.parametrize("hop_length", [512, 128])
def test_mfcc(n_mfcc, n_mels, window_length, hop_length):
if hop_length >= window_length:
hop_length = window_length // 2
audio_path = "./audio/spk/f10_script4_produced.wav"
stft_params = audiotools.STFTParams(
window_length=window_length, hop_length=hop_length)
for _stft_params in [None, stft_params]:
signal = AudioSignal(audio_path, duration=10, stft_params=_stft_params)
mfcc = signal.mfcc(n_mfcc=n_mfcc, n_mels=n_mels)
assert mfcc.shape[2] == n_mfcc
def test_to_mono():
array = np.random.randn(4, 2, 16000)
sr = 16000
signal = AudioSignal(array, sample_rate=sr)
assert signal.num_channels == 2
signal = signal.to_mono()
assert signal.num_channels == 1
def test_float():
array = np.random.randn(4, 1, 16000).astype("float64")
sr = 1600
signal = AudioSignal(array, sample_rate=sr)
signal = signal.float()
assert signal.audio_data.dtype == paddle.float32
@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100, 48000])
def test_resample(sample_rate):
array = np.random.randn(4, 2, 16000)
sr = 16000
signal = AudioSignal(array, sample_rate=sr)
signal = signal.resample(sample_rate)
assert signal.sample_rate == sample_rate
assert signal.signal_length == sample_rate
def test_batching():
signals = []
batch_size = 16
# All same length, same sample rate.
for _ in range(batch_size):
array = np.random.randn(2, 16000)
signal = AudioSignal(array, sample_rate=16000)
signals.append(signal)
batched_signal = AudioSignal.batch(signals)
assert batched_signal.batch_size == batch_size
signals = []
# All different lengths, same sample rate, pad signals
for _ in range(batch_size):
L = np.random.randint(8000, 32000)
array = np.random.randn(2, L)
signal = AudioSignal(array, sample_rate=16000)
signals.append(signal)
with pytest.raises(RuntimeError):
batched_signal = AudioSignal.batch(signals)
signal_lengths = [x.signal_length for x in signals]
max_length = max(signal_lengths)
batched_signal = AudioSignal.batch(signals, pad_signals=True)
assert batched_signal.signal_length == max_length
assert batched_signal.batch_size == batch_size
signals = []
# All different lengths, same sample rate, truncate signals
for _ in range(batch_size):
L = np.random.randint(8000, 32000)
array = np.random.randn(2, L)
signal = AudioSignal(array, sample_rate=16000)
signals.append(signal)
with pytest.raises(RuntimeError):
batched_signal = AudioSignal.batch(signals)
signal_lengths = [x.signal_length for x in signals]
min_length = min(signal_lengths)
batched_signal = AudioSignal.batch(signals, truncate_signals=True)
assert batched_signal.signal_length == min_length
assert batched_signal.batch_size == batch_size
signals = []
# All different lengths, different sample rate, pad signals
for _ in range(batch_size):
L = np.random.randint(8000, 32000)
sr = np.random.choice([8000, 16000, 32000])
array = np.random.randn(2, L)
signal = AudioSignal(array, sample_rate=int(sr))
signals.append(signal)
with pytest.raises(RuntimeError):
batched_signal = AudioSignal.batch(signals)
signal_lengths = [x.signal_length for x in signals]
max_length = max(signal_lengths)
for i, x in enumerate(signals):
x.path_to_file = i
batched_signal = AudioSignal.batch(signals, resample=True, pad_signals=True)
assert batched_signal.signal_length == max_length
assert batched_signal.batch_size == batch_size
assert batched_signal.path_to_file == list(range(len(signals)))
assert batched_signal.path_to_input_file == batched_signal.path_to_file

@ -0,0 +1,54 @@
# MIT License, Copyright (c) 2020 Alexandre Défossez.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_bands.py)
import random
import sys
import unittest
import paddle
from audio.audiotools.core import pure_tone
from audio.audiotools.core import split_bands
from audio.audiotools.core import SplitBands
def delta(a, b, ref, fraction=0.9):
length = a.shape[-1]
compare_length = int(length * fraction)
offset = (length - compare_length) // 2
a = a[..., offset:offset + length]
b = b[..., offset:offset + length]
return 100 * paddle.abs(a - b).mean() / ref.std()
TOLERANCE = 0.5 # Tolerance to errors as percentage of the std of the input signal
class _BaseTest(unittest.TestCase):
def assertSimilar(self, a, b, ref, msg=None, tol=TOLERANCE):
self.assertLessEqual(delta(a, b, ref), tol, msg)
class TestLowPassFilters(_BaseTest):
def setUp(self):
paddle.seed(1234)
random.seed(1234)
def test_keep_or_kill(self):
sr = 256
low = pure_tone(10, sr)
mid = pure_tone(40, sr)
high = pure_tone(100, sr)
x = low + mid + high
decomp = split_bands(x, sr, cutoffs=[20, 70])
self.assertEqual(len(decomp), 3)
for est, gt, name in zip(decomp, [low, mid, high],
["low", "mid", "high"]):
self.assertSimilar(est, gt, gt, name)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,51 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_display.py)
import sys
from pathlib import Path
import numpy as np
from visualdl import LogWriter
from audio.audiotools import AudioSignal
def test_specshow():
array = np.zeros((1, 16000))
AudioSignal(array, sample_rate=16000).specshow()
AudioSignal(array, sample_rate=16000).specshow(preemphasis=True)
AudioSignal(
array, sample_rate=16000).specshow(
title="test", preemphasis=True)
AudioSignal(
array, sample_rate=16000).specshow(
format=False, preemphasis=True)
AudioSignal(
array, sample_rate=16000).specshow(
format=False, preemphasis=False, y_axis="mel")
def test_waveplot():
array = np.zeros((1, 16000))
AudioSignal(array, sample_rate=16000).waveplot()
def test_wavespec():
array = np.zeros((1, 16000))
AudioSignal(array, sample_rate=16000).wavespec()
def test_write_audio_to_tb():
signal = AudioSignal("./audio/spk/f10_script4_produced.mp3", duration=5)
Path("./scratch").mkdir(parents=True, exist_ok=True)
writer = LogWriter("./scratch/")
signal.write_audio_to_tb("tag", writer)
def test_save_image():
signal = AudioSignal(
"./audio/spk/f10_script4_produced.wav", duration=10, offset=10)
Path("./scratch").mkdir(parents=True, exist_ok=True)
signal.save_image("./scratch/image.png")

@ -0,0 +1,181 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_dsp.py)
import sys
import numpy as np
import paddle
import pytest
from audio.audiotools import AudioSignal
from audio.audiotools.core.util import sample_from_dist
@pytest.mark.parametrize("window_duration", [0.1, 0.25, 0.5, 1.0])
@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100])
@pytest.mark.parametrize("duration", [0.5, 1.0, 2.0, 10.0])
def test_overlap_add(duration, sample_rate, window_duration):
np.random.seed(0)
if duration > window_duration:
spk_signal = AudioSignal.batch([
AudioSignal.excerpt(
"./audio/spk/f10_script4_produced.wav", duration=duration)
for _ in range(16)
])
spk_signal.resample(sample_rate)
noise = paddle.randn([16, 1, int(duration * sample_rate)])
nz_signal = AudioSignal(noise, sample_rate=sample_rate)
def _test(signal):
hop_duration = window_duration / 2
windowed_signal = signal.clone().collect_windows(window_duration,
hop_duration)
recombined = windowed_signal.overlap_and_add(hop_duration)
assert recombined == signal
assert np.allclose(recombined.audio_data, signal.audio_data, 1e-3)
_test(nz_signal)
_test(spk_signal)
@pytest.mark.parametrize("window_duration", [0.1, 0.25, 0.5, 1.0])
@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100])
@pytest.mark.parametrize("duration", [0.5, 1.0, 2.0, 10.0])
def test_inplace_overlap_add(duration, sample_rate, window_duration):
np.random.seed(0)
if duration > window_duration:
spk_signal = AudioSignal.batch([
AudioSignal.excerpt(
"./audio/spk/f10_script4_produced.wav", duration=duration)
for _ in range(16)
])
spk_signal.resample(sample_rate)
noise = paddle.randn([16, 1, int(duration * sample_rate)])
nz_signal = AudioSignal(noise, sample_rate=sample_rate)
def _test(signal):
hop_duration = window_duration / 2
windowed_signal = signal.clone().collect_windows(window_duration,
hop_duration)
# Compare in-place with unfold results
for i, window in enumerate(
signal.clone().windows(window_duration, hop_duration)):
assert np.allclose(window.audio_data,
windowed_signal.audio_data[i])
_test(nz_signal)
_test(spk_signal)
def test_low_pass():
sample_rate = 44100
f = 440
t = paddle.arange(0, 1, 1 / sample_rate)
sine_wave = paddle.sin(2 * np.pi * f * t)
window = AudioSignal.get_window("hann", sine_wave.shape[-1])
sine_wave = sine_wave * window
signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate)
out = signal.clone().low_pass(220)
assert out.audio_data.abs().max() < 1e-4
out = signal.clone().low_pass(880)
assert (out - signal).audio_data.abs().max() < 1e-3
batch = AudioSignal.batch([signal.clone(), signal.clone(), signal.clone()])
cutoffs = [220, 880, 220]
out = batch.clone().low_pass(cutoffs)
assert out.audio_data[0].abs().max() < 1e-4
assert out.audio_data[2].abs().max() < 1e-4
assert (out - batch).audio_data[1].abs().max() < 1e-3
def test_high_pass():
sample_rate = 44100
f = 440
t = paddle.arange(0, 1, 1 / sample_rate)
sine_wave = paddle.sin(2 * np.pi * f * t)
window = AudioSignal.get_window("hann", sine_wave.shape[-1])
sine_wave = sine_wave * window
signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate)
out = signal.clone().high_pass(220)
assert (signal - out).audio_data.abs().max() < 1e-4
def test_mask_frequencies():
sample_rate = 44100
fs = paddle.to_tensor([500.0, 2000.0, 8000.0, 32000.0])[None]
t = paddle.arange(0, 1, 1 / sample_rate)[:, None]
sine_wave = paddle.sin(2 * np.pi * t @ fs).sum(axis=-1)
sine_wave = AudioSignal(sine_wave, sample_rate)
masked_sine_wave = sine_wave.mask_frequencies(fmin_hz=1500, fmax_hz=10000)
fs2 = paddle.to_tensor([500.0, 32000.0])[None]
sine_wave2 = paddle.sin(2 * np.pi * t @ fs).sum(axis=-1)
sine_wave2 = AudioSignal(sine_wave2, sample_rate)
assert paddle.allclose(masked_sine_wave.audio_data, sine_wave2.audio_data)
def test_mask_timesteps():
sample_rate = 44100
f = 440
t = paddle.linspace(0, 1, sample_rate)
sine_wave = paddle.sin(2 * np.pi * f * t)
sine_wave = AudioSignal(sine_wave, sample_rate)
masked_sine_wave = sine_wave.mask_timesteps(tmin_s=0.25, tmax_s=0.75)
masked_sine_wave.istft()
mask = ((0.3 < t) & (t < 0.7))[None, None]
assert paddle.allclose(
masked_sine_wave.audio_data[mask],
paddle.zeros_like(masked_sine_wave.audio_data[mask]), )
def test_shift_phase():
sample_rate = 44100
f = 440
t = paddle.linspace(0, 1, sample_rate)
sine_wave = paddle.sin(2 * np.pi * f * t)
sine_wave = AudioSignal(sine_wave, sample_rate)
sine_wave2 = sine_wave.clone()
shifted_sine_wave = sine_wave.shift_phase(np.pi)
shifted_sine_wave.istft()
sine_wave2.phase = sine_wave2.phase + np.pi
sine_wave2.istft()
assert paddle.allclose(shifted_sine_wave.audio_data, sine_wave2.audio_data)
def test_corrupt_phase():
sample_rate = 44100
f = 440
t = paddle.linspace(0, 1, sample_rate)
sine_wave = paddle.sin(2 * np.pi * f * t)
sine_wave = AudioSignal(sine_wave, sample_rate)
sine_wave2 = sine_wave.clone()
shifted_sine_wave = sine_wave.corrupt_phase(scale=np.pi)
shifted_sine_wave.istft()
assert (sine_wave2.phase - shifted_sine_wave.phase).abs().mean() > 0.0
assert ((sine_wave2.phase - shifted_sine_wave.phase).std() / np.pi) < 1.0
def test_preemphasis():
x = AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=5)
import matplotlib.pyplot as plt
x.specshow(preemphasis=False)
x.specshow(preemphasis=True)
x.preemphasis()

@ -0,0 +1,321 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_effects.py)
import sys
import numpy as np
import paddle
import pytest
from audio.audiotools import AudioSignal
def test_normalize():
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=10)
signal = signal.normalize()
assert np.allclose(signal.loudness(), -24, atol=1e-1)
array = np.random.randn(1, 2, 32000)
array = array / np.abs(array).max()
signal = AudioSignal(array, sample_rate=16000)
for db_incr in np.arange(10, 75, 5):
db = -80 + db_incr
signal = signal.normalize(db)
loudness = signal.loudness()
assert np.allclose(loudness, db, atol=1) # TODO, atol=1e-1
batch_size = 16
db = -60 + paddle.linspace(10, 30, batch_size)
array = np.random.randn(batch_size, 2, 32000)
array = array / np.abs(array).max()
signal = AudioSignal(array, sample_rate=16000)
signal = signal.normalize(db)
assert np.allclose(signal.loudness(), db, 1e-1)
def test_volume_change():
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=10)
boost = 3
before_db = signal.loudness().clone()
signal = signal.volume_change(boost)
after_db = signal.loudness()
assert np.allclose(before_db + boost, after_db)
signal._loudness = None
after_db = signal.loudness()
assert np.allclose(before_db + boost, after_db, 1e-1)
def test_mix():
audio_path = "./audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=10)
audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav"
nz = AudioSignal(audio_path, offset=10, duration=10)
spk.deepcopy().mix(nz, snr=-10)
snr = spk.loudness() - nz.loudness()
assert np.allclose(snr, -10, atol=1)
# Test in batch
audio_path = "./audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=10)
audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav"
nz = AudioSignal(audio_path, offset=10, duration=10)
batch_size = 4
tgt_snr = paddle.linspace(-10, 10, batch_size)
spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
nz_batch = AudioSignal.batch([nz.deepcopy() for _ in range(batch_size)])
spk_batch.deepcopy().mix(nz_batch, snr=tgt_snr)
snr = spk_batch.loudness() - nz_batch.loudness()
assert np.allclose(snr, tgt_snr, atol=1)
# Test with "EQing" the other signal
db = 0 + 0 * paddle.rand([10])
spk_batch.deepcopy().mix(nz_batch, snr=tgt_snr, other_eq=db)
snr = spk_batch.loudness() - nz_batch.loudness()
assert np.allclose(snr, tgt_snr, atol=1)
def test_convolve():
np.random.seed(6) # Found a failing seed
audio_path = "./audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=10)
impulse = np.zeros((1, 16000), dtype="float32")
impulse[..., 0] = 1
ir = AudioSignal(impulse, 16000)
batch_size = 4
spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
ir_batch = AudioSignal.batch(
[
ir.deepcopy().zero_pad(np.random.randint(1000), 0)
for _ in range(batch_size)
],
pad_signals=True, )
convolved = spk_batch.deepcopy().convolve(ir_batch)
assert convolved == spk_batch
# Short duration
audio_path = "./audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=0.1)
impulse = np.zeros((1, 16000), dtype="float32")
impulse[..., 0] = 1
ir = AudioSignal(impulse, 16000)
batch_size = 4
spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
ir_batch = AudioSignal.batch(
[
ir.deepcopy().zero_pad(np.random.randint(1000), 0)
for _ in range(batch_size)
],
pad_signals=True, )
convolved = spk_batch.deepcopy().convolve(ir_batch)
assert convolved == spk_batch
def test_pipeline():
# An actual IR, no batching
audio_path = "./audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=5)
audio_path = "./audio/ir/h179_Bar_1txts.wav"
ir = AudioSignal(audio_path)
spk.deepcopy().convolve(ir)
audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav"
nz = AudioSignal(audio_path, offset=10, duration=5)
batch_size = 16
tgt_snr = paddle.linspace(20, 30, batch_size)
(spk @ ir).mix(nz, snr=tgt_snr)
@pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16])
def test_mel_filterbank(n_bands):
audio_path = "./audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=1)
fbank = spk.deepcopy().mel_filterbank(n_bands)
assert paddle.allclose(fbank.sum(-1), spk.audio_data, atol=1e-6)
# Check if it works in batches.
spk_batch = AudioSignal.batch([
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2)
for _ in range(16)
])
fbank = spk_batch.deepcopy().mel_filterbank(n_bands)
summed = fbank.sum(-1)
assert paddle.allclose(summed, spk_batch.audio_data, atol=1e-6)
@pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16])
def test_equalizer(n_bands):
audio_path = "./audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=10)
db = -3 + 1 * paddle.rand([n_bands])
spk.deepcopy().equalizer(db)
db = -3 + 1 * np.random.rand(n_bands)
spk.deepcopy().equalizer(db)
audio_path = "./audio/ir/h179_Bar_1txts.wav"
ir = AudioSignal(audio_path)
db = -3 + 1 * paddle.rand([n_bands])
spk.deepcopy().convolve(ir.equalizer(db))
spk_batch = AudioSignal.batch([
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2)
for _ in range(16)
])
db = paddle.zeros([spk_batch.batch_size, n_bands])
output = spk_batch.deepcopy().equalizer(db)
assert output == spk_batch
def test_clip_distortion():
audio_path = "./audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=2)
clipped = spk.deepcopy().clip_distortion(0.05)
spk_batch = AudioSignal.batch([
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2)
for _ in range(16)
])
percs = paddle.to_tensor(np.random.uniform(size=(16, ))).astype("float32")
clipped_batch = spk_batch.deepcopy().clip_distortion(percs)
assert clipped.audio_data.abs().max() < 1.0
assert clipped_batch.audio_data.abs().max() < 1.0
@pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128])
def test_quantization(quant_ch):
audio_path = "./audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=2)
quantized = spk.deepcopy().quantization(quant_ch)
# Need to round audio_data off because torch ops with straight
# through estimator are sometimes a bit off past 3 decimal places.
found_quant_ch = len(np.unique(np.around(quantized.audio_data, decimals=3)))
assert found_quant_ch <= quant_ch
spk_batch = AudioSignal.batch([
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2)
for _ in range(16)
])
quant_ch = np.random.choice(
[2, 4, 8, 16, 32, 64, 128], size=(16, ), replace=True)
quantized = spk_batch.deepcopy().quantization(quant_ch)
for i, q_ch in enumerate(quant_ch):
found_quant_ch = len(
np.unique(np.around(quantized.audio_data[i], decimals=3)))
assert found_quant_ch <= q_ch
@pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128])
def test_mulaw_quantization(quant_ch):
audio_path = "./audio/spk/f10_script4_produced.wav"
spk = AudioSignal(audio_path, offset=10, duration=2)
quantized = spk.deepcopy().mulaw_quantization(quant_ch)
# Need to round audio_data off because torch ops with straight
# through estimator are sometimes a bit off past 3 decimal places.
found_quant_ch = len(np.unique(np.around(quantized.audio_data, decimals=3)))
assert found_quant_ch <= quant_ch
spk_batch = AudioSignal.batch([
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2)
for _ in range(16)
])
quant_ch = np.random.choice(
[2, 4, 8, 16, 32, 64, 128], size=(16, ), replace=True)
quantized = spk_batch.deepcopy().mulaw_quantization(quant_ch)
for i, q_ch in enumerate(quant_ch):
found_quant_ch = len(
np.unique(np.around(quantized.audio_data[i], decimals=3)))
assert found_quant_ch <= q_ch
def test_impulse_response_augmentation():
audio_path = "./audio/ir/h179_Bar_1txts.wav"
batch_size = 16
ir = AudioSignal(audio_path)
ir_batch = AudioSignal.batch([ir for _ in range(batch_size)])
early_response, late_field, window = ir_batch.decompose_ir()
assert early_response.shape == late_field.shape
assert late_field.shape == window.shape
drr = ir_batch.measure_drr()
alpha = AudioSignal.solve_alpha(early_response, late_field, window, drr)
assert np.allclose(alpha, np.ones_like(alpha), 1e-5)
target_drr = 5
out = ir_batch.deepcopy().alter_drr(target_drr)
drr = out.measure_drr()
assert np.allclose(drr, np.ones_like(drr) * target_drr)
target_drr = np.random.rand(batch_size).astype("float32") * 50
altered_ir = ir_batch.deepcopy().alter_drr(target_drr)
drr = altered_ir.measure_drr()
assert np.allclose(drr.flatten(), target_drr.flatten())
def test_apply_ir():
audio_path = "./audio/spk/f10_script4_produced.wav"
ir_path = "./audio/ir/h179_Bar_1txts.wav"
spk = AudioSignal(audio_path, offset=10, duration=2)
ir = AudioSignal(ir_path)
db = 0 + 0 * paddle.rand([10])
output = spk.deepcopy().apply_ir(ir, drr=10, ir_eq=db)
assert np.allclose(ir.measure_drr().flatten(), 10)
output = spk.deepcopy().apply_ir(
ir, drr=10, ir_eq=db, use_original_phase=True)
def test_ensure_max_of_audio():
spk = AudioSignal(paddle.randn([1, 1, 44100]), 44100)
max_vals = [1.0] + [np.random.rand() for _ in range(10)]
for val in max_vals:
after = spk.deepcopy().ensure_max_of_audio(val)
assert after.audio_data.abs().max() <= val + 1e-3
# Make sure it does nothing to a tiny signal
spk = AudioSignal(paddle.rand([1, 1, 44100]), 44100)
spk.audio_data = spk.audio_data * 0.5
after = spk.deepcopy().ensure_max_of_audio()
assert paddle.allclose(after.audio_data, spk.audio_data)

@ -0,0 +1,85 @@
# MIT License, Copyright (c) 2020 Alexandre Défossez.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_fftconv.py)
import random
import sys
import unittest
import paddle
import paddle.nn.functional as F
from audio.audiotools.core import fft_conv1d
from audio.audiotools.core import FFTConv1D
TOLERANCE = 1e-4 # as relative delta in percentage
class _BaseTest(unittest.TestCase):
def setUp(self):
paddle.seed(1234)
random.seed(1234)
def assertSimilar(self, a, b, msg=None, tol=TOLERANCE):
delta = 100 * paddle.norm(a - b, p=2) / paddle.norm(b, p=2)
self.assertLessEqual(delta.numpy(), tol, msg)
def compare_paddle(self, *args, msg=None, tol=TOLERANCE, **kwargs):
y_ref = F.conv1d(*args, **kwargs)
y = fft_conv1d(*args, **kwargs)
self.assertEqual(list(y.shape), list(y_ref.shape), msg)
self.assertSimilar(y, y_ref, msg, tol)
class TestFFTConv1d(_BaseTest):
def test_same_as_paddle(self):
for _ in range(5):
kernel_size = random.randrange(4, 128)
batch_size = random.randrange(1, 6)
length = random.randrange(kernel_size, 1024)
chin = random.randrange(1, 12)
chout = random.randrange(1, 12)
bias = random.random() < 0.5
if random.random() < 0.5:
padding = 0
else:
padding = random.randrange(kernel_size // 2, 2 * kernel_size)
x = paddle.randn([batch_size, chin, length])
w = paddle.randn([chout, chin, kernel_size])
keys = ["length", "kernel_size", "chin", "chout", "bias"]
loc = locals()
state = {key: loc[key] for key in keys}
if bias:
bias = paddle.randn([chout])
else:
bias = None
for stride in [1, 2, 5]:
state["stride"] = stride
self.compare_paddle(
x, w, bias, stride, padding, msg=repr(state))
def test_small_input(self):
x = paddle.randn([1, 5, 19])
w = paddle.randn([10, 5, 32])
with self.assertRaises(RuntimeError):
fft_conv1d(x, w)
x = paddle.randn([1, 5, 19])
w = paddle.randn([10, 5, 19])
self.assertEqual(list(fft_conv1d(x, w).shape), [1, 10, 1])
def test_module(self):
x = paddle.randn([16, 4, 1024])
mod = FFTConv1D(4, 5, 8, bias_attr=True)
mod(x)
mod = FFTConv1D(4, 5, 8, bias_attr=False)
mod(x)
def test_dynamic_graph(self):
x = paddle.randn([16, 4, 1024])
mod = FFTConv1D(4, 5, 8, bias_attr=True)
self.assertEqual(list(mod(x).shape), [16, 5, 1024 - 8 + 1])
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,172 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_grad.py)
import sys
from typing import Callable
import numpy as np
import paddle
import pytest
from audio.audiotools import AudioSignal
def test_audio_grad():
audio_path = "./audio/spk/f10_script4_produced.wav"
ir_path = "./audio/ir/h179_Bar_1txts.wav"
def _test_audio_grad(attr: str, target=True, kwargs: dict={}):
signal = AudioSignal(audio_path)
signal.audio_data.stop_gradient = False
assert signal.audio_data.grad is None
# Avoid overwriting leaf tensor by cloning signal
attr = getattr(signal.clone(), attr)
result = attr(**kwargs) if isinstance(attr, Callable) else attr
try:
if isinstance(result, AudioSignal):
# If necessary, propagate spectrogram changes to waveform
if result.stft_data is not None:
result.istft()
# if result.audio_data.dtype.is_complex:
if paddle.is_complex(result.audio_data):
result.audio_data.real.sum().backward()
else:
result.audio_data.sum().backward()
else:
# if result.dtype.is_complex:
if paddle.is_complex(result):
result.real().sum().backward()
else:
result.sum().backward()
assert signal.audio_data.grad is not None or not target
except RuntimeError:
assert not target
for a in [
["mix", True, {
"other": AudioSignal(audio_path),
"snr": 0
}],
["convolve", True, {
"other": AudioSignal(ir_path)
}],
[
"apply_ir",
True,
{
"ir": AudioSignal(ir_path),
"drr": 0.1,
"ir_eq": paddle.randn([6])
},
],
["ensure_max_of_audio", True],
["normalize", True],
["volume_change", True, {
"db": 1
}],
# ["pitch_shift", False, {"n_semitones": 1}],
# ["time_stretch", False, {"factor": 2}],
# ["apply_codec", False],
["equalizer", True, {
"db": paddle.randn([6])
}],
["clip_distortion", True, {
"clip_percentile": 0.5
}],
["quantization", True, {
"quantization_channels": 8
}],
["mulaw_quantization", True, {
"quantization_channels": 8
}],
["resample", True, {
"sample_rate": 16000
}],
["low_pass", True, {
"cutoffs": 1000
}],
["high_pass", True, {
"cutoffs": 1000
}],
["to_mono", True],
["zero_pad", True, {
"before": 10,
"after": 10
}],
["magnitude", True],
["phase", True],
["log_magnitude", True],
["loudness", False],
["stft", True],
["clone", True],
["mel_spectrogram", True],
["zero_pad_to", True, {
"length": 100000
}],
["truncate_samples", True, {
"length_in_samples": 1000
}],
["corrupt_phase", True, {
"scale": 0.5
}],
["shift_phase", True, {
"shift": 1
}],
["mask_low_magnitudes", True, {
"db_cutoff": 0
}],
["mask_frequencies", True, {
"fmin_hz": 100,
"fmax_hz": 1000
}],
["mask_timesteps", True, {
"tmin_s": 0.1,
"tmax_s": 0.5
}],
["__add__", True, {
"other": AudioSignal(audio_path)
}],
["__iadd__", True, {
"other": AudioSignal(audio_path)
}],
["__radd__", True, {
"other": AudioSignal(audio_path)
}],
["__sub__", True, {
"other": AudioSignal(audio_path)
}],
["__isub__", True, {
"other": AudioSignal(audio_path)
}],
["__mul__", True, {
"other": AudioSignal(audio_path)
}],
["__imul__", True, {
"other": AudioSignal(audio_path)
}],
["__rmul__", True, {
"other": AudioSignal(audio_path)
}],
]:
_test_audio_grad(*a)
def test_batch_grad():
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path)
signal.audio_data.stop_gradient = False
assert signal.audio_data.grad is None
batch_size = 16
batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)])
batch.audio_data.sum().backward()
assert signal.audio_data.grad is not None

@ -0,0 +1,104 @@
# MIT License, Copyright (c) 2020 Alexandre Défossez.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_filters.py)
import math
import random
import sys
import unittest
import paddle
from audio.audiotools.core import highpass_filter
from audio.audiotools.core import highpass_filters
def pure_tone(freq: float, sr: float=128, dur: float=4, device=None):
"""
Return a pure tone, i.e. cosine.
Args:
freq (float): frequency (in Hz)
sr (float): sample rate (in Hz)
dur (float): duration (in seconds)
"""
time = paddle.arange(int(sr * dur), dtype="float32") / sr
return paddle.cos(2 * math.pi * freq * time)
def delta(a, b, ref, fraction=0.9):
length = a.shape[-1]
compare_length = int(length * fraction)
offset = (length - compare_length) // 2
a = a[..., offset:offset + length]
b = b[..., offset:offset + length]
# 计算绝对差值均值然后除以ref的标准差乘以100
return 100 * paddle.mean(paddle.abs(a - b)) / paddle.std(ref)
TOLERANCE = 1 # Tolerance to errors as percentage of the std of the input signal
class _BaseTest(unittest.TestCase):
def assertSimilar(self, a, b, ref, msg=None, tol=TOLERANCE):
self.assertLessEqual(delta(a, b, ref), tol, msg)
class TestHighPassFilters(_BaseTest):
def setUp(self):
paddle.seed(1234)
random.seed(1234)
def test_keep_or_kill(self):
for _ in range(10):
freq = random.uniform(0.01, 0.4)
sr = 1024
tone = pure_tone(freq * sr, sr=sr, dur=10)
# For this test we accept 5% tolerance in amplitude, or -26dB in power.
tol = 5
zeros = 16
# If cutoff frequency is under freq, output should be input
y_pass = highpass_filter(tone, 0.9 * freq, zeros=zeros)
self.assertSimilar(
y_pass, tone, tone, f"freq={freq}, pass", tol=tol)
# If cutoff frequency is over freq, output should be zero
y_killed = highpass_filter(tone, 1.1 * freq, zeros=zeros)
self.assertSimilar(
y_killed, 0 * tone, tone, f"freq={freq}, kill", tol=tol)
def test_fft_nofft(self):
for _ in range(10):
x = paddle.randn([1024])
freq = random.uniform(0.01, 0.5)
y_fft = highpass_filter(x, freq, fft=True)
y_ref = highpass_filter(x, freq, fft=False)
self.assertSimilar(y_fft, y_ref, x, f"freq={freq}", tol=0.01)
def test_constant(self):
x = paddle.ones([2048])
for zeros in [4, 10]:
for freq in [0.01, 0.1]:
y_high = highpass_filter(x, freq, zeros=zeros)
self.assertLessEqual(y_high.abs().mean(), 1e-6, (zeros, freq))
def test_stride(self):
x = paddle.randn([1024])
y = highpass_filters(x, [0.1, 0.2], stride=1)[:, ::3]
y2 = highpass_filters(x, [0.1, 0.2], stride=3)
self.assertEqual(y.shape, y2.shape)
self.assertSimilar(y, y2, x)
y = highpass_filters(x, [0.1, 0.2], stride=1, pad=False)[:, ::3]
y2 = highpass_filters(x, [0.1, 0.2], stride=3, pad=False)
self.assertEqual(y.shape, y2.shape)
self.assertSimilar(y, y2, x)
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,274 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_loudness.py)
import sys
import numpy as np
import pyloudnorm
import soundfile as sf
from audio.audiotools import AudioSignal
from audio.audiotools import datasets
from audio.audiotools import Meter
from audio.audiotools import transforms
ATOL = 1e-1
def test_loudness_against_pyln():
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=5, duration=10)
signal_loudness = signal.loudness()
meter = pyloudnorm.Meter(
signal.sample_rate, filter_class="K-weighting", block_size=0.4)
py_loudness = meter.integrated_loudness(signal.numpy()[0].T)
assert np.allclose(signal_loudness, py_loudness)
def test_loudness_short():
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=0.25)
signal_loudness = signal.loudness()
def test_batch_loudness():
np.random.seed(0)
array = np.random.randn(16, 2, 16000)
array /= np.abs(array).max()
gains = np.random.rand(array.shape[0])[:, None, None]
array = array * gains
meter = pyloudnorm.Meter(16000)
py_loudness = [
meter.integrated_loudness(array[i].T) for i in range(array.shape[0])
]
meter = Meter(16000)
meter.filter_class
at_loudness_iso = [
meter.integrated_loudness(array[i].T).item()
for i in range(array.shape[0])
]
assert np.allclose(py_loudness, at_loudness_iso, atol=1e-1)
signal = AudioSignal(array, sample_rate=16000)
at_loudness_batch = signal.loudness()
assert np.allclose(py_loudness, at_loudness_batch, atol=1e-1)
# Tests below are copied from pyloudnorm
def test_integrated_loudness():
data, rate = sf.read("./audio/loudness/sine_1000.wav")
meter = Meter(rate)
loudness = meter(data)
targetLoudness = -3.0523438444331137
assert np.allclose(loudness, targetLoudness)
def test_rel_gate_test():
data, rate = sf.read("./audio/loudness/1770-2_Comp_RelGateTest.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -10.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_abs_gate_test():
data, rate = sf.read("./audio/loudness/1770-2_Comp_AbsGateTest.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -69.5
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_24LKFS_25Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_25Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -24.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_24LKFS_100Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_100Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -24.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_24LKFS_500Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_500Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -24.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_24LKFS_1000Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_1000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -24.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_24LKFS_2000Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_2000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -24.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_24LKFS_10000Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_10000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -24.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_23LKFS_25Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_25Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -23.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_23LKFS_100Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_100Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -23.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_23LKFS_500Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_500Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -23.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_23LKFS_1000Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_1000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -23.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_23LKFS_2000Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_2000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -23.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_23LKFS_10000Hz_2ch():
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_10000Hz_2ch.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -23.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_18LKFS_frequency_sweep():
data, rate = sf.read(
"./audio/loudness/1770-2_Comp_18LKFS_FrequencySweep.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -18.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_conf_stereo_vinL_R_23LKFS():
data, rate = sf.read(
"./audio/loudness/1770-2_Conf_Stereo_VinL+R-23LKFS.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -23.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_conf_monovoice_music_24LKFS():
data, rate = sf.read(
"./audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -24.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def conf_monovoice_music_24LKFS():
data, rate = sf.read(
"./audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -24.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_conf_monovoice_music_23LKFS():
data, rate = sf.read(
"./audio/loudness/1770-2_Conf_Mono_Voice+Music-23LKFS.wav")
meter = Meter(rate)
loudness = meter.integrated_loudness(data)
targetLoudness = -23.0
assert np.allclose(loudness, targetLoudness, atol=ATOL)
def test_fir_accuracy():
transform = transforms.Compose(
transforms.ClippingDistortion(prob=0.5),
transforms.LowPass(prob=0.5),
transforms.HighPass(prob=0.5),
transforms.Equalizer(prob=0.5),
prob=0.5, )
loader = datasets.AudioLoader(sources=["./audio/spk.csv"])
dataset = datasets.AudioDataset(
loader,
44100,
10,
5.0,
transform=transform, )
for i in range(20):
item = dataset[i]
kwargs = item["transform_args"]
signal = item["signal"]
signal = transform(signal, **kwargs)
signal._loudness = None
iir_db = signal.clone().loudness()
fir_db = signal.clone().loudness(use_fir=True)
assert np.allclose(iir_db, fir_db, atol=1e-2)

@ -0,0 +1,109 @@
# MIT License, Copyright (c) 2020 Alexandre Défossez.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_lowpass.py)
import math
import random
import sys
import unittest
import numpy as np
import paddle
from audio.audiotools.core import lowpass_filter
from audio.audiotools.core import LowPassFilter
from audio.audiotools.core import LowPassFilters
from audio.audiotools.core import resample_frac
def pure_tone(freq: float, sr: float=128, dur: float=4, device=None):
"""
Return a pure tone, i.e. cosine.
Args:
freq (float): frequency (in Hz)
sr (float): sample rate (in Hz)
dur (float): duration (in seconds)
"""
time = paddle.arange(int(sr * dur), dtype="float32") / sr
return paddle.cos(2 * math.pi * freq * time)
def delta(a, b, ref, fraction=0.9):
length = a.shape[-1]
compare_length = int(length * fraction)
offset = (length - compare_length) // 2
a = a[..., offset:offset + length]
b = b[..., offset:offset + length]
# 计算绝对差值均值然后除以ref的标准差乘以100
return 100 * paddle.mean(paddle.abs(a - b)) / paddle.std(ref)
TOLERANCE = 1 # Tolerance to errors as percentage of the std of the input signal
class _BaseTest(unittest.TestCase):
def assertSimilar(self, a, b, ref, msg=None, tol=TOLERANCE):
self.assertLessEqual(delta(a, b, ref), tol, msg)
class TestLowPassFilters(_BaseTest):
def setUp(self):
paddle.seed(1234)
random.seed(1234)
def test_keep_or_kill(self):
for _ in range(10):
freq = random.uniform(0.01, 0.4)
sr = 1024
tone = pure_tone(freq * sr, sr=sr, dur=10)
# For this test we accept 5% tolerance in amplitude, or -26dB in power.
tol = 5
zeros = 16
# If cutoff frequency is under freq, output should be zero
y_killed = lowpass_filter(tone, 0.9 * freq, zeros=zeros)
self.assertSimilar(
y_killed, 0 * y_killed, tone, f"freq={freq}, kill", tol=tol)
# If cutoff frequency is under freq, output should be input
y_pass = lowpass_filter(tone, 1.1 * freq, zeros=zeros)
self.assertSimilar(
y_pass, tone, tone, f"freq={freq}, pass", tol=tol)
def test_same_as_downsample(self):
for _ in range(10):
x = paddle.randn([2 * 3 * 4 * 100])
x = paddle.ones_like(x)
np.random.seed(1234)
x = paddle.to_tensor(
np.random.randn(2 * 3 * 4 * 100), dtype="float32")
rolloff = 0.945
for old_sr in [2, 3, 4]:
y_resampled = resample_frac(
x, old_sr, 1, rolloff=rolloff, zeros=16)
y_lowpass = lowpass_filter(
x, rolloff / old_sr / 2, stride=old_sr, zeros=16)
self.assertSimilar(y_resampled, y_lowpass, x,
f"old_sr={old_sr}")
def test_fft_nofft(self):
for _ in range(10):
x = paddle.randn([1024])
freq = random.uniform(0.01, 0.5)
y_fft = lowpass_filter(x, freq, fft=True)
y_ref = lowpass_filter(x, freq, fft=False)
self.assertSimilar(y_fft, y_ref, x, f"freq={freq}", tol=0.01)
def test_constant(self):
x = paddle.ones([2048])
for zeros in [4, 10]:
for freq in [0.01, 0.1]:
y_low = lowpass_filter(x, freq, zeros=zeros)
self.assertLessEqual((y_low - 1).abs().mean(), 1e-6,
(zeros, freq))
if __name__ == "__main__":
unittest.main()

@ -0,0 +1,157 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_util.py)
import os
import random
import sys
import tempfile
import numpy as np
import paddle
import pytest
from audio.audiotools import util
from audio.audiotools.core.audio_signal import AudioSignal
from paddlespeech.vector.training.seeding import seed_everything
def test_check_random_state():
# seed is None
rng_type = type(np.random.RandomState(10))
rng = util.random_state(None)
assert type(rng) == rng_type
# seed is int
rng = util.random_state(10)
assert type(rng) == rng_type
# seed is RandomState
rng_test = np.random.RandomState(10)
rng = util.random_state(rng_test)
assert type(rng) == rng_type
# seed is none of the above : error
pytest.raises(ValueError, util.random_state, "random")
def test_seed():
seed_everything(0)
paddle_result_a = paddle.randn([1])
np_result_a = np.random.randn(1)
py_result_a = random.random()
seed_everything(0)
paddle_result_b = paddle.randn([1])
np_result_b = np.random.randn(1)
py_result_b = random.random()
assert paddle_result_a == paddle_result_b
assert np_result_a == np_result_b
assert py_result_a == py_result_b
def test_hz_to_bin():
hz = paddle.to_tensor(np.array([100, 200, 300]), dtype="float32")
sr = 1000
n_fft = 2048
bins = util.hz_to_bin(hz, n_fft, sr)
assert (((bins / n_fft) * sr) - hz).abs().max() < 1
def test_find_audio():
wav_files = util.find_audio("tests/", ["wav"])
for a in wav_files:
assert "wav" in str(a)
audio_files = util.find_audio("tests/", ["flac"])
assert not audio_files
# Make sure it works with single audio files
audio_files = util.find_audio("./audio/spk//f10_script4_produced.wav")
# Make sure it works with globs
audio_files = util.find_audio("tests/**/*.wav")
assert len(audio_files) == len(wav_files)
def test_chdir():
with tempfile.TemporaryDirectory(suffix="tmp") as d:
with util.chdir(d):
assert os.path.samefile(d, os.path.realpath("."))
def test_prepare_batch():
batch = {"tensor": paddle.randn([1]), "non_tensor": np.random.randn(1)}
util.prepare_batch(batch)
batch = paddle.randn([1])
util.prepare_batch(batch)
batch = [paddle.randn([1]), np.random.randn(1)]
util.prepare_batch(batch)
def test_sample_dist():
state = util.random_state(0)
v1 = state.uniform(0.0, 1.0)
v2 = util.sample_from_dist(("uniform", 0.0, 1.0), 0)
assert v1 == v2
assert util.sample_from_dist(("const", 1.0)) == 1.0
dist_tuple = ("choice", [8, 16, 32])
assert util.sample_from_dist(dist_tuple) in [8, 16, 32]
def test_collate():
batch_size = 16
def _one_item():
return {
"signal": AudioSignal(paddle.randn([1, 1, 44100]), 44100),
"tensor": paddle.randn([1]),
"string": "Testing",
"dict": {
"nested_signal":
AudioSignal(paddle.randn([1, 1, 44100]), 44100),
},
}
items = [_one_item() for _ in range(batch_size)]
collated = util.collate(items)
assert collated["signal"].batch_size == batch_size
assert collated["tensor"].shape[0] == batch_size
assert len(collated["string"]) == batch_size
assert collated["dict"]["nested_signal"].batch_size == batch_size
# test collate with splitting (evenly)
batch_size = 16
n_splits = 4
items = [_one_item() for _ in range(batch_size)]
collated = util.collate(items, n_splits=n_splits)
for x in collated:
assert x["signal"].batch_size == batch_size // n_splits
assert x["tensor"].shape[0] == batch_size // n_splits
assert len(x["string"]) == batch_size // n_splits
assert x["dict"]["nested_signal"].batch_size == batch_size // n_splits
# test collate with splitting (unevenly)
batch_size = 15
n_splits = 4
items = [_one_item() for _ in range(batch_size)]
collated = util.collate(items, n_splits=n_splits)
tlen = [4, 4, 4, 3]
for x, t in zip(collated, tlen):
assert x["signal"].batch_size == t
assert x["tensor"].shape[0] == t
assert len(x["string"]) == t
assert x["dict"]["nested_signal"].batch_size == t

@ -0,0 +1,208 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/data/test_datasets.py)
import sys
import tempfile
from pathlib import Path
import numpy as np
import paddle
import pytest
from audio import audiotools
from audio.audiotools.data import transforms as tfm
def test_align_lists():
input_lists = [
["a/1.wav", "b/1.wav", "c/1.wav", "d/1.wav"],
["a/2.wav", "c/2.wav"],
["c/3.wav"],
]
target_lists = [
["a/1.wav", "b/1.wav", "c/1.wav", "d/1.wav"],
["a/2.wav", "none", "c/2.wav", "none"],
["none", "none", "c/3.wav", "none"],
]
def _preprocess(lists):
output = []
for x in lists:
output.append([])
for y in x:
output[-1].append({"path": y})
return output
input_lists = _preprocess(input_lists)
target_lists = _preprocess(target_lists)
aligned_lists = audiotools.datasets.align_lists(input_lists)
assert target_lists == aligned_lists
def test_audio_dataset():
transform = tfm.Compose(
[
tfm.VolumeNorm(),
tfm.Silence(prob=0.5),
], )
loader = audiotools.data.datasets.AudioLoader(
sources=["./audio/spk.csv"],
transform=transform, )
dataset = audiotools.data.datasets.AudioDataset(
loader,
44100,
n_examples=100,
transform=transform, )
dataloader = paddle.io.DataLoader(
dataset,
batch_size=16,
num_workers=0,
collate_fn=dataset.collate, )
for batch in dataloader:
kwargs = batch["transform_args"]
signal = batch["signal"]
original = signal.clone()
signal = dataset.transform(signal, **kwargs)
original = dataset.transform(original, **kwargs)
mask = kwargs["Compose"]["1.Silence"]["mask"]
zeros_ = paddle.zeros_like(signal[mask].audio_data)
original_ = original[~mask].audio_data
assert paddle.allclose(signal[mask].audio_data, zeros_)
assert paddle.allclose(signal[~mask].audio_data, original_)
def test_aligned_audio_dataset():
with tempfile.TemporaryDirectory() as d:
dataset_dir = Path(d)
audiotools.util.generate_chord_dataset(
max_voices=8, num_items=3, output_dir=dataset_dir)
loaders = [
audiotools.data.datasets.AudioLoader([dataset_dir / f"track_{i}"])
for i in range(3)
]
dataset = audiotools.data.datasets.AudioDataset(
loaders, 44100, n_examples=1000, aligned=True, shuffle_loaders=True)
dataloader = paddle.io.DataLoader(
dataset,
batch_size=16,
num_workers=0,
collate_fn=dataset.collate, )
# Make sure the voice tracks are aligned.
for batch in dataloader:
paths = []
for i in range(len(loaders)):
_paths = [p.split("/")[-1] for p in batch[i]["path"]]
paths.append(_paths)
paths = np.array(paths)
for i in range(paths.shape[1]):
col = paths[:, i]
col = col[col != "none"]
assert np.all(col == col[0])
def test_loader_without_replacement():
with tempfile.TemporaryDirectory() as d:
dataset_dir = Path(d)
num_items = 100
audiotools.util.generate_chord_dataset(
max_voices=1,
num_items=num_items,
output_dir=dataset_dir,
duration=0.01, )
loader = audiotools.data.datasets.AudioLoader(
[dataset_dir], shuffle=False)
dataset = audiotools.data.datasets.AudioDataset(loader, 44100)
for idx in range(num_items):
item = dataset[idx]
assert item["item_idx"] == idx
def test_loader_with_replacement():
with tempfile.TemporaryDirectory() as d:
dataset_dir = Path(d)
num_items = 100
audiotools.util.generate_chord_dataset(
max_voices=1,
num_items=num_items,
output_dir=dataset_dir,
duration=0.01, )
loader = audiotools.data.datasets.AudioLoader([dataset_dir])
dataset = audiotools.data.datasets.AudioDataset(
loader, 44100, without_replacement=False)
for idx in range(num_items):
item = dataset[idx]
def test_loader_out_of_range():
with tempfile.TemporaryDirectory() as d:
dataset_dir = Path(d)
num_items = 100
audiotools.util.generate_chord_dataset(
max_voices=1,
num_items=num_items,
output_dir=dataset_dir,
duration=0.01, )
loader = audiotools.data.datasets.AudioLoader([dataset_dir])
item = loader(
sample_rate=44100,
duration=0.01,
state=audiotools.util.random_state(0),
source_idx=0,
item_idx=101, )
assert item["path"] == "none"
def test_dataset_pipeline():
transform = tfm.Compose([
tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]),
tfm.BackgroundNoise(sources=["./audio/noises.csv"]),
])
loader = audiotools.data.datasets.AudioLoader(sources=["./audio/spk.csv"])
dataset = audiotools.data.datasets.AudioDataset(
loader,
44100,
n_examples=10,
transform=transform, )
dataloader = paddle.io.DataLoader(
dataset, num_workers=0, batch_size=1, collate_fn=dataset.collate)
for batch in dataloader:
batch = audiotools.core.util.prepare_batch(batch, device="cpu")
kwargs = batch["transform_args"]
signal = batch["signal"]
batch = dataset.transform(signal, **kwargs)
class NumberDataset:
def __init__(self):
pass
def __len__(self):
return 10
def __getitem__(self, idx):
return {"idx": idx}
def test_concat_dataset():
d1 = NumberDataset()
d2 = NumberDataset()
d3 = NumberDataset()
d = audiotools.datasets.ConcatDataset([d1, d2, d3])
x = d.collate([d[i] for i in range(len(d))])["idx"].tolist()
t = []
for i in range(10):
t += [i, i, i]
assert x == t

@ -0,0 +1,33 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/data/test_preprocess.py)
import sys
import tempfile
from pathlib import Path
import paddle
from audio.audiotools.core.util import find_audio
from audio.audiotools.core.util import read_sources
from audio.audiotools.data import preprocess
def test_create_csv():
with tempfile.NamedTemporaryFile(suffix=".csv") as f:
preprocess.create_csv(
find_audio("././audio/spk", ext=["wav"]), f.name, loudness=True)
def test_create_csv_with_empty_rows():
audio_files = find_audio("././audio/spk", ext=["wav"])
audio_files.insert(0, "")
audio_files.insert(2, "")
with tempfile.NamedTemporaryFile(suffix=".csv") as f:
preprocess.create_csv(audio_files, f.name, loudness=True)
audio_files = read_sources([f.name], remove_empty=True)
assert len(audio_files[0]) == 1
audio_files = read_sources([f.name], remove_empty=False)
assert len(audio_files[0]) == 3

@ -0,0 +1,453 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/data/test_transforms.py)
import inspect
import sys
import warnings
from pathlib import Path
import numpy as np
import paddle
import pytest
from audio import audiotools
from audio.audiotools import AudioSignal
from audio.audiotools import util
from audio.audiotools.data import transforms as tfm
from audio.audiotools.data.datasets import AudioDataset
from paddlespeech.vector.training.seeding import seed_everything
non_deterministic_transforms = ["TimeNoise", "FrequencyNoise"]
transforms_to_test = []
for x in dir(tfm):
if hasattr(getattr(tfm, x), "transform"):
if x not in [
"Compose",
"Choose",
"Repeat",
"RepeatUpTo",
# The above 4 transforms are currently excluded from testing at 1e-4 precision due to potential accuracy issues
"BackgroundNoise",
"Equalizer",
"FrequencyNoise",
"RoomImpulseResponse"
]:
transforms_to_test.append(x)
def _compare_transform(transform_name, signal):
regression_data = Path(f"regression/transforms/{transform_name}.wav")
regression_data.parent.mkdir(exist_ok=True, parents=True)
if regression_data.exists():
regression_signal = AudioSignal(regression_data)
assert paddle.allclose(
signal.audio_data, regression_signal.audio_data, atol=1e-4)
else:
signal.write(regression_data)
@pytest.mark.parametrize("transform_name", transforms_to_test)
def test_transform(transform_name):
seed = 0
seed_everything(seed)
transform_cls = getattr(tfm, transform_name)
kwargs = {}
if transform_name == "BackgroundNoise":
kwargs["sources"] = ["./audio/noises.csv"]
if transform_name == "RoomImpulseResponse":
kwargs["sources"] = ["./audio/irs.csv"]
if transform_name == "CrossTalk":
kwargs["sources"] = ["./audio/spk.csv"]
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
signal.metadata["loudness"] = AudioSignal(
audio_path).ffmpeg_loudness().item()
transform = transform_cls(prob=1.0, **kwargs)
kwargs = transform.instantiate(seed, signal)
for k in kwargs[transform_name]:
assert k in transform.keys
output = transform(signal, **kwargs)
assert isinstance(output, AudioSignal)
_compare_transform(transform_name, output)
if transform_name in non_deterministic_transforms:
return
# Test that if you make a batch of signals and call it,
# the first item in the batch is still the same as above.
batch_size = 4
signal = AudioSignal(audio_path, offset=10, duration=2)
signal_batch = AudioSignal.batch(
[signal.clone() for _ in range(batch_size)])
signal_batch.metadata["loudness"] = AudioSignal(
audio_path).ffmpeg_loudness().item()
states = [seed + idx for idx in list(range(batch_size))]
kwargs = transform.batch_instantiate(states, signal_batch)
batch_output = transform(signal_batch, **kwargs)
assert batch_output[0] == output
## Test that you can apply transform with the same args twice.
signal = AudioSignal(audio_path, offset=10, duration=2)
signal.metadata["loudness"] = AudioSignal(
audio_path).ffmpeg_loudness().item()
kwargs = transform.instantiate(seed, signal)
output_a = transform(signal.clone(), **kwargs)
output_b = transform(signal.clone(), **kwargs)
assert output_a == output_b
def test_compose_basic():
seed = 0
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
transform = tfm.Compose(
[
tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]),
tfm.BackgroundNoise(sources=["./audio/noises.csv"]),
], )
kwargs = transform.instantiate(seed, signal)
output = transform(signal, **kwargs)
# Due to precision issues with RoomImpulseResponse and BackgroundNoise used in the Compose test,
# we only perform logical testing for Compose and skip precision testing of the final output
# _compare_transform("Compose", output)
assert isinstance(transform[0], tfm.RoomImpulseResponse)
assert isinstance(transform[1], tfm.BackgroundNoise)
assert len(transform) == 2
# Make sure __iter__ works
for _tfm in transform:
pass
class MulTransform(tfm.BaseTransform):
def __init__(self, num, name=None):
self.num = num
super().__init__(name=name, keys=["num"])
def _transform(self, signal, num):
if not num.dim():
num = num.unsqueeze(axis=0)
signal.audio_data = signal.audio_data * num[:, None, None]
return signal
def _instantiate(self, state):
return {"num": self.num}
def test_compose_with_duplicate_transforms():
muls = [0.5, 0.25, 0.125]
transform = tfm.Compose([MulTransform(x) for x in muls])
full_mul = np.prod(muls)
kwargs = transform.instantiate(0)
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
output = transform(signal.clone(), **kwargs)
expected_output = signal.audio_data * full_mul
assert paddle.allclose(output.audio_data, expected_output)
def test_nested_compose():
muls = [0.5, 0.25, 0.125]
transform = tfm.Compose([
MulTransform(muls[0]),
tfm.Compose(
[MulTransform(muls[1]), tfm.Compose([MulTransform(muls[2])])]),
])
full_mul = np.prod(muls)
kwargs = transform.instantiate(0)
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
output = transform(signal.clone(), **kwargs)
expected_output = signal.audio_data * full_mul
assert paddle.allclose(output.audio_data, expected_output)
def test_compose_filtering():
muls = [0.5, 0.25, 0.125]
transform = tfm.Compose([MulTransform(x, name=str(x)) for x in muls])
kwargs = transform.instantiate(0)
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
for s in range(len(muls)):
for _ in range(10):
_muls = np.random.choice(muls, size=s, replace=False).tolist()
full_mul = np.prod(_muls)
with transform.filter(*[str(x) for x in _muls]):
output = transform(signal.clone(), **kwargs)
expected_output = signal.audio_data * full_mul
assert paddle.allclose(output.audio_data, expected_output)
def test_sequential_compose():
muls = [0.5, 0.25, 0.125]
transform = tfm.Compose([
tfm.Compose([MulTransform(muls[0])]),
tfm.Compose([MulTransform(muls[1]), MulTransform(muls[2])]),
])
full_mul = np.prod(muls)
kwargs = transform.instantiate(0)
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
output = transform(signal.clone(), **kwargs)
expected_output = signal.audio_data * full_mul
assert paddle.allclose(output.audio_data, expected_output)
def test_choose_basic():
seed = 0
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
transform = tfm.Choose([
tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]),
tfm.BackgroundNoise(sources=["./audio/noises.csv"]),
])
kwargs = transform.instantiate(seed, signal)
output = transform(signal.clone(), **kwargs)
# Due to precision issues with RoomImpulseResponse and BackgroundNoise used in the Choose test,
# we only perform logical testing for Choose and skip precision testing of the final output
# _compare_transform("Choose", output)
transform = tfm.Choose([
MulTransform(0.0),
MulTransform(2.0),
])
targets = [signal.clone() * 0.0, signal.clone() * 2.0]
for seed in range(10):
kwargs = transform.instantiate(seed, signal)
output = transform(signal.clone(), **kwargs)
assert any([output == target for target in targets])
# Test that if you make a batch of signals and call it,
# the first item in the batch is still the same as above.
batch_size = 4
signal = AudioSignal(audio_path, offset=10, duration=2)
signal_batch = AudioSignal.batch(
[signal.clone() for _ in range(batch_size)])
states = [seed + idx for idx in list(range(batch_size))]
kwargs = transform.batch_instantiate(states, signal_batch)
batch_output = transform(signal_batch, **kwargs)
for nb in range(batch_size):
assert batch_output[nb] in targets
def test_choose_weighted():
seed = 0
audio_path = "./audio/spk/f10_script4_produced.wav"
transform = tfm.Choose(
[
MulTransform(0.0),
MulTransform(2.0),
],
weights=[0.0, 1.0], )
# Test that if you make a batch of signals and call it,
# the first item in the batch is still the same as above.
batch_size = 4
signal = AudioSignal(audio_path, offset=10, duration=2)
signal_batch = AudioSignal.batch(
[signal.clone() for _ in range(batch_size)])
targets = [signal.clone() * 0.0, signal.clone() * 2.0]
states = [seed + idx for idx in list(range(batch_size))]
kwargs = transform.batch_instantiate(states, signal_batch)
batch_output = transform(signal_batch, **kwargs)
for nb in range(batch_size):
assert batch_output[nb] == targets[1]
def test_choose_with_compose():
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
transform = tfm.Choose([
tfm.Compose([MulTransform(0.0)]),
tfm.Compose([MulTransform(2.0)]),
])
targets = [signal.clone() * 0.0, signal.clone() * 2.0]
for seed in range(10):
kwargs = transform.instantiate(seed, signal)
output = transform(signal, **kwargs)
assert output in targets
def test_repeat():
seed = 0
audio_path = "./audio/spk/f10_script4_produced.wav"
signal = AudioSignal(audio_path, offset=10, duration=2)
kwargs = {}
kwargs["transform"] = tfm.Compose(
tfm.FrequencyMask(),
tfm.TimeMask(), )
kwargs["n_repeat"] = 5
transform = tfm.Repeat(**kwargs)
kwargs = transform.instantiate(seed, signal)
output = transform(signal.clone(), **kwargs)
_compare_transform("Repeat", output)
kwargs = {}
kwargs["transform"] = tfm.Compose(
tfm.FrequencyMask(),
tfm.TimeMask(), )
kwargs["max_repeat"] = 10
transform = tfm.RepeatUpTo(**kwargs)
kwargs = transform.instantiate(seed, signal)
output = transform(signal.clone(), **kwargs)
_compare_transform("RepeatUpTo", output)
# Make sure repeat does what it says
transform = tfm.Repeat(MulTransform(0.5), n_repeat=3)
kwargs = transform.instantiate(seed, signal)
signal = AudioSignal(paddle.randn([1, 1, 100]).clip(1e-5), 44100)
output = transform(signal.clone(), **kwargs)
scale = (output.audio_data / signal.audio_data).mean()
assert scale == (0.5**3)
class DummyData(paddle.io.Dataset):
def __init__(self, audio_path):
super().__init__()
self.audio_path = audio_path
self.length = 100
self.transform = tfm.Silence(prob=0.5)
def __getitem__(self, idx):
state = util.random_state(idx)
signal = AudioSignal.salient_excerpt(
self.audio_path, state=state, duration=1.0).resample(44100)
item = self.transform.instantiate(state, signal=signal)
item["signal"] = signal
return item
def __len__(self):
return self.length
def test_masking():
dataset = DummyData("./audio/spk/f10_script4_produced.wav")
dataloader = paddle.io.DataLoader(
dataset,
batch_size=16,
num_workers=0,
collate_fn=util.collate, )
for batch in dataloader:
signal = batch.pop("signal")
original = signal.clone()
signal = dataset.transform(signal, **batch)
original = dataset.transform(original, **batch)
mask = batch["Silence"]["mask"]
zeros_ = paddle.zeros_like(signal[mask].audio_data)
original_ = original[~mask].audio_data
assert paddle.allclose(signal[mask].audio_data, zeros_)
assert paddle.allclose(original[~mask].audio_data, original_)
def test_nested_masking():
transform = tfm.Compose(
[
tfm.VolumeNorm(prob=0.5),
tfm.Silence(prob=0.9),
],
prob=0.9, )
loader = audiotools.data.datasets.AudioLoader(sources=["./audio/spk.csv"])
dataset = audiotools.data.datasets.AudioDataset(
loader,
44100,
n_examples=100,
transform=transform, )
dataloader = paddle.io.DataLoader(
dataset, num_workers=0, batch_size=10, collate_fn=dataset.collate)
for batch in dataloader:
batch = util.prepare_batch(batch, device="cpu")
signal = batch["signal"]
kwargs = batch["transform_args"]
with paddle.no_grad():
output = dataset.transform(signal, **kwargs)
def test_smoothing_edge_case():
transform = tfm.Smoothing()
zeros = paddle.zeros([1, 1, 44100])
signal = AudioSignal(zeros, 44100)
kwargs = transform.instantiate(0, signal)
output = transform(signal, **kwargs)
assert paddle.allclose(output.audio_data, zeros)
def test_global_volume_norm():
signal = AudioSignal.wave(440, 1, 44100, 1)
# signal with -inf loudness should be unchanged
signal.metadata["loudness"] = float("-inf")
transform = tfm.GlobalVolumeNorm(db=("const", -100))
kwargs = transform.instantiate(0, signal)
output = transform(signal.clone(), **kwargs)
assert paddle.allclose(output.samples, signal.samples)
# signal without a loudness key should be unchanged
signal.metadata.pop("loudness")
kwargs = transform.instantiate(0, signal)
output = transform(signal.clone(), **kwargs)
assert paddle.allclose(output.samples, signal.samples)
# signal with the actual loudness should be normalized
signal.metadata["loudness"] = signal.ffmpeg_loudness()
kwargs = transform.instantiate(0, signal)
output = transform(signal.clone(), **kwargs)
assert not paddle.allclose(output.samples, signal.samples)

@ -0,0 +1,110 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/ml/test_decorators.py)
import sys
import time
import paddle
from visualdl import LogWriter
from audio.audiotools import util
from audio.audiotools.ml.decorators import timer
from audio.audiotools.ml.decorators import Tracker
from audio.audiotools.ml.decorators import when
def test_all_decorators():
rank = 0
max_iters = 100
writer = LogWriter("/tmp/logs")
tracker = Tracker(writer, log_file="/tmp/log.txt")
train_data = range(100)
val_data = range(100)
@tracker.log("train", "value", history=False)
@tracker.track("train", max_iters, tracker.step)
@timer()
def train_loop():
i = tracker.step
time.sleep(0.01)
return {
"loss":
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
"mel":
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
"stft":
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
"waveform":
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
"not_scalar":
paddle.arange(start=0, end=10, step=1, dtype="int64"),
}
@tracker.track("val", len(val_data))
@timer()
def val_loop():
i = tracker.step
time.sleep(0.01)
return {
"loss":
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
"mel":
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
"stft":
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
"waveform":
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
"not_scalar":
paddle.arange(10, dtype="int64"),
"string":
"string",
}
@when(lambda: tracker.step % 1000 == 0 and rank == 0)
@paddle.no_grad()
def save_samples():
tracker.print("Saving samples to TensorBoard.")
@when(lambda: tracker.step % 100 == 0 and rank == 0)
def checkpoint():
save_samples()
if tracker.is_best("val", "mel"):
tracker.print("Best model so far.")
tracker.print("Saving to /runs/exp1")
tracker.done("val", f"Iteration {tracker.step}")
@when(lambda: tracker.step % 100 == 0)
@tracker.log("val", "mean")
@paddle.no_grad()
def validate():
for _ in range(len(val_data)):
output = val_loop()
return output
with tracker.live:
for tracker.step in range(max_iters):
validate()
checkpoint()
train_loop()
state_dict = tracker.state_dict()
tracker.load_state_dict(state_dict)
# If train loop returned not a dict
@tracker.track("train", max_iters, tracker.step)
def train_loop_2():
i = tracker.step
time.sleep(0.01)
with tracker.live:
for tracker.step in range(max_iters):
validate()
checkpoint()
train_loop_2()
if __name__ == "__main__":
test_all_decorators()

@ -0,0 +1,89 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/ml/test_model.py)
import sys
import tempfile
import paddle
from paddle import nn
from audio.audiotools import ml
from audio.audiotools import util
from paddlespeech.vector.training.seeding import seed_everything
SEED = 0
def seed_and_run(model, *args, **kwargs):
seed_everything(SEED)
return model(*args, **kwargs)
class Model(ml.BaseModel):
def __init__(self, arg1: float=1.0):
super().__init__()
self.arg1 = arg1
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
class OtherModel(ml.BaseModel):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
def test_base_model():
# Save and load
# ml.BaseModel.EXTERN += ["test_model"]
x = paddle.randn([10, 1])
model1 = Model()
# assert str(model1.device) == 'Place(cpu)'
out1 = seed_and_run(model1, x)
with tempfile.NamedTemporaryFile(suffix=".pdparams") as f:
model1.save(
f.name, )
model2 = Model.load(f.name)
out2 = seed_and_run(model2, x)
assert paddle.allclose(out1, out2)
# test re-export
model2.save(f.name)
model3 = Model.load(f.name)
out3 = seed_and_run(model3, x)
assert paddle.allclose(out1, out3)
# make sure legacy/save load works
model1.save(f.name, package=False)
model2 = Model.load(f.name)
out2 = seed_and_run(model2, x)
assert paddle.allclose(out1, out2)
# make sure new way -> legacy save -> legacy load works
model1.save(f.name, package=False)
model2 = Model.load(f.name)
model2.save(f.name, package=False)
model3 = Model.load(f.name)
out3 = seed_and_run(model3, x)
# save/load without package, but with model2 being a model
# without an argument of arg1 to its instantiation.
model1.save(f.name, package=False)
model2 = OtherModel.load(f.name)
out2 = seed_and_run(model2, x)
assert paddle.allclose(out1, out2)
assert paddle.allclose(out1, out3)
with tempfile.TemporaryDirectory() as d:
model1.save_to_folder(d, {"data": 1.0})
Model.load_from_folder(d)

@ -0,0 +1,7 @@
python -m pip install -r ../../audiotools/requirements.txt
export PYTHONPATH=$PYTHONPATH:$(realpath ../../..) # this is root path of `PaddleSpeech`
wget https://paddlespeech.bj.bcebos.com/PaddleAudio/audio_tools/audio.tar.gz
wget https://paddlespeech.bj.bcebos.com/PaddleAudio/audio_tools/regression.tar.gz
tar -zxvf audio.tar.gz
tar -zxvf regression.tar.gz
python -m pytest

@ -0,0 +1,30 @@
# MIT License, Copyright (c) 2023-Present, Descript.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/test_post.py)
import sys
from pathlib import Path
from audio.audiotools import AudioSignal
from audio.audiotools import post
from audio.audiotools import transforms
def test_audio_table():
tfm = transforms.LowPass()
audio_dict = {}
audio_dict["inputs"] = [
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=5)
for _ in range(3)
]
audio_dict["outputs"] = []
for i in range(3):
x = audio_dict["inputs"][i]
kwargs = tfm.instantiate()
output = tfm(x.clone(), **kwargs)
audio_dict["outputs"].append(output)
post.audio_table(audio_dict)

@ -32,6 +32,13 @@ function main(){
cd ${speech_ci_path}/server/offline
bash test_server_client.sh
echo "End server"
echo "Start testing audiotools"
cd ${speech_ci_path}/../../audio/tests/audiotools
bash test_audiotools.sh
echo "End testing audiotools"
}
main

Loading…
Cancel
Save