[Hackathon 7th No.55] Add `audiotools` to `PaddleSpeech` (#3900)
* add AudioSignal && util * fix codestyle * add basemodel && decorator * add util && add quality * add acc && data && transforms * add utils * fix dir * add *.py; wo unitest * add unitest * fix codestyle * fix cuda error * add readme && __all__ * add 2 file test * change download dir * fix CI download path * add tar -zxvf * change requirements path * add audiotools path * fix place error * fix paddle2.5 verion Q * FFTConv1d -> FFTConv1D * FFTConv1d -> FFTConv1D * mv unfold * add _unfold1d 2 loudness * fix stupid device variable * bias -> bias_attr * () -> [] * fix .to() * rm ✅ * fix exp * deepcopy -> clone * fix dim error * fix slice && tensor.to * fix paddle2.5 index bug * git rm std * rm comment && ✅ * rm some useless comment * add __all__ * fix codestyle * fix soundfile.info error * fix sth * add License * fix cycle import * Adapt to paddle3.0 && update readme * fix License * fix License * rm duplicate requirements * fix trasform problems * rm disp * Update test_transforms.py * change path * rm notebook && add audio path * rm import * add comment * fix cycle import && rm TYPE_CHECKING * rm IPython * rm sth useless * rm uesless deps * Update requirements.txtpull/3971/head
parent
553a9db374
commit
cb15e382cb
@ -0,0 +1,68 @@
|
|||||||
|
Audiotools is a comprehensive toolkit designed for audio processing and analysis, providing robust solutions for audio signal processing, data management, model training, and evaluation.
|
||||||
|
|
||||||
|
### Directory Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
.
|
||||||
|
├── audiotools
|
||||||
|
│ ├── README.md
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── core
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── _julius.py
|
||||||
|
│ │ ├── audio_signal.py
|
||||||
|
│ │ ├── display.py
|
||||||
|
│ │ ├── dsp.py
|
||||||
|
│ │ ├── effects.py
|
||||||
|
│ │ ├── ffmpeg.py
|
||||||
|
│ │ ├── loudness.py
|
||||||
|
│ │ └── util.py
|
||||||
|
│ ├── data
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── datasets.py
|
||||||
|
│ │ ├── preprocess.py
|
||||||
|
│ │ └── transforms.py
|
||||||
|
│ ├── metrics
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ └── quality.py
|
||||||
|
│ ├── ml
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── accelerator.py
|
||||||
|
│ │ ├── basemodel.py
|
||||||
|
│ │ └── decorators.py
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── post.py
|
||||||
|
├── tests
|
||||||
|
│ └── audiotools
|
||||||
|
│ ├── core
|
||||||
|
│ │ ├── test_audio_signal.py
|
||||||
|
│ │ ├── test_bands.py
|
||||||
|
│ │ ├── test_display.py
|
||||||
|
│ │ ├── test_dsp.py
|
||||||
|
│ │ ├── test_effects.py
|
||||||
|
│ │ ├── test_fftconv.py
|
||||||
|
│ │ ├── test_grad.py
|
||||||
|
│ │ ├── test_highpass.py
|
||||||
|
│ │ ├── test_loudness.py
|
||||||
|
│ │ ├── test_lowpass.py
|
||||||
|
│ │ └── test_util.py
|
||||||
|
│ ├── data
|
||||||
|
│ │ ├── test_datasets.py
|
||||||
|
│ │ ├── test_preprocess.py
|
||||||
|
│ │ └── test_transforms.py
|
||||||
|
│ ├── ml
|
||||||
|
│ │ ├── test_decorators.py
|
||||||
|
│ │ └── test_model.py
|
||||||
|
│ └── test_post.py
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
- **core**: Contains the core class AudioSignal, which is responsible for the fundamental representation and manipulation of audio signals.
|
||||||
|
|
||||||
|
- **data**: Primarily dedicated to storing and processing datasets, including classes and functions for data preprocessing, ensuring efficient loading and transformation of audio data.
|
||||||
|
|
||||||
|
- **metrics**: Implements functions for various audio evaluation metrics, enabling precise assessment of the performance of audio models and processing algorithms.
|
||||||
|
|
||||||
|
- **ml**: Comprises classes and methods related to model training, supporting the construction, training, and optimization of machine learning models in the context of audio.
|
||||||
|
|
||||||
|
This project aims to provide developers and researchers with an efficient and flexible framework to foster innovation and exploration across various domains of audio technology.
|
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from . import metrics
|
||||||
|
from . import ml
|
||||||
|
from . import post
|
||||||
|
from .core import AudioSignal
|
||||||
|
from .core import highpass_filter
|
||||||
|
from .core import highpass_filters
|
||||||
|
from .core import Meter
|
||||||
|
from .core import STFTParams
|
||||||
|
from .core import util
|
||||||
|
from .data import datasets
|
||||||
|
from .data import preprocess
|
||||||
|
from .data import transforms
|
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from . import util
|
||||||
|
from ._julius import fft_conv1d
|
||||||
|
from ._julius import FFTConv1D
|
||||||
|
from ._julius import highpass_filter
|
||||||
|
from ._julius import highpass_filters
|
||||||
|
from ._julius import lowpass_filter
|
||||||
|
from ._julius import LowPassFilter
|
||||||
|
from ._julius import LowPassFilters
|
||||||
|
from ._julius import pure_tone
|
||||||
|
from ._julius import resample_frac
|
||||||
|
from ._julius import split_bands
|
||||||
|
from ._julius import SplitBands
|
||||||
|
from .audio_signal import AudioSignal
|
||||||
|
from .audio_signal import STFTParams
|
||||||
|
from .loudness import Meter
|
@ -0,0 +1,666 @@
|
|||||||
|
# MIT License, Copyright (c) 2020 Alexandre Défossez.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from julius(https://github.com/adefossez/julius/tree/main/julius)
|
||||||
|
"""
|
||||||
|
Implementation of a FFT based 1D convolution in PaddlePaddle.
|
||||||
|
While FFT is used in some cases for small kernel sizes, it is not the default for long ones, e.g. 512.
|
||||||
|
This module implements efficient FFT based convolutions for such cases. A typical
|
||||||
|
application is for evaluating FIR filters with a long receptive field, typically
|
||||||
|
evaluated with a stride of 1.
|
||||||
|
"""
|
||||||
|
import inspect
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import typing
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
from paddlespeech.t2s.modules import fft_conv1d
|
||||||
|
from paddlespeech.t2s.modules import FFTConv1D
|
||||||
|
from paddlespeech.utils import satisfy_paddle_version
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'highpass_filter', 'highpass_filters', 'lowpass_filter', 'LowPassFilter',
|
||||||
|
'LowPassFilters', 'pure_tone', 'resample_frac', 'split_bands', 'SplitBands'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def simple_repr(obj, attrs: Optional[Sequence[str]]=None, overrides: dict={}):
|
||||||
|
"""
|
||||||
|
Return a simple representation string for `obj`.
|
||||||
|
If `attrs` is not None, it should be a list of attributes to include.
|
||||||
|
"""
|
||||||
|
params = inspect.signature(obj.__class__).parameters
|
||||||
|
attrs_repr = []
|
||||||
|
if attrs is None:
|
||||||
|
attrs = list(params.keys())
|
||||||
|
for attr in attrs:
|
||||||
|
display = False
|
||||||
|
if attr in overrides:
|
||||||
|
value = overrides[attr]
|
||||||
|
elif hasattr(obj, attr):
|
||||||
|
value = getattr(obj, attr)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if attr in params:
|
||||||
|
param = params[attr]
|
||||||
|
if param.default is inspect._empty or value != param.default: # type: ignore
|
||||||
|
display = True
|
||||||
|
else:
|
||||||
|
display = True
|
||||||
|
|
||||||
|
if display:
|
||||||
|
attrs_repr.append(f"{attr}={value}")
|
||||||
|
return f"{obj.__class__.__name__}({','.join(attrs_repr)})"
|
||||||
|
|
||||||
|
|
||||||
|
def sinc(x: paddle.Tensor):
|
||||||
|
"""
|
||||||
|
Implementation of sinc, i.e. sin(x) / x
|
||||||
|
|
||||||
|
__Warning__: the input is not multiplied by `pi`!
|
||||||
|
"""
|
||||||
|
if satisfy_paddle_version("3.0"):
|
||||||
|
return paddle.sinc(x)
|
||||||
|
|
||||||
|
return paddle.where(
|
||||||
|
x == 0,
|
||||||
|
paddle.to_tensor(1.0, dtype=x.dtype, place=x.place),
|
||||||
|
paddle.sin(x) / x, )
|
||||||
|
|
||||||
|
|
||||||
|
class ResampleFrac(paddle.nn.Layer):
|
||||||
|
"""
|
||||||
|
Resampling from the sample rate `old_sr` to `new_sr`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
old_sr: int,
|
||||||
|
new_sr: int,
|
||||||
|
zeros: int=24,
|
||||||
|
rolloff: float=0.945):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
old_sr (int): sample rate of the input signal x.
|
||||||
|
new_sr (int): sample rate of the output.
|
||||||
|
zeros (int): number of zero crossing to keep in the sinc filter.
|
||||||
|
rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`,
|
||||||
|
to ensure sufficient margin due to the imperfection of the FIR filter used.
|
||||||
|
Lowering this value will reduce anti-aliasing, but will reduce some of the
|
||||||
|
highest frequencies.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
|
||||||
|
- Input: `[*, T]`
|
||||||
|
- Output: `[*, T']` with `T' = int(new_sr * T / old_sr)`
|
||||||
|
|
||||||
|
|
||||||
|
.. caution::
|
||||||
|
After dividing `old_sr` and `new_sr` by their GCD, both should be small
|
||||||
|
for this implementation to be fast.
|
||||||
|
|
||||||
|
>>> import paddle
|
||||||
|
>>> resample = ResampleFrac(4, 5)
|
||||||
|
>>> x = paddle.randn([1000])
|
||||||
|
>>> print(len(resample(x)))
|
||||||
|
1250
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(old_sr, int) or not isinstance(new_sr, int):
|
||||||
|
raise ValueError("old_sr and new_sr should be integers")
|
||||||
|
gcd = math.gcd(old_sr, new_sr)
|
||||||
|
self.old_sr = old_sr // gcd
|
||||||
|
self.new_sr = new_sr // gcd
|
||||||
|
self.zeros = zeros
|
||||||
|
self.rolloff = rolloff
|
||||||
|
|
||||||
|
self._init_kernels()
|
||||||
|
|
||||||
|
def _init_kernels(self):
|
||||||
|
if self.old_sr == self.new_sr:
|
||||||
|
return
|
||||||
|
|
||||||
|
kernels = []
|
||||||
|
sr = min(self.new_sr, self.old_sr)
|
||||||
|
sr *= self.rolloff
|
||||||
|
|
||||||
|
self._width = math.ceil(self.zeros * self.old_sr / sr)
|
||||||
|
idx = paddle.arange(
|
||||||
|
-self._width, self._width + self.old_sr, dtype="float32")
|
||||||
|
for i in range(self.new_sr):
|
||||||
|
t = (-i / self.new_sr + idx / self.old_sr) * sr
|
||||||
|
t = paddle.clip(t, -self.zeros, self.zeros)
|
||||||
|
t *= math.pi
|
||||||
|
window = paddle.cos(t / self.zeros / 2)**2
|
||||||
|
kernel = sinc(t) * window
|
||||||
|
# Renormalize kernel to ensure a constant signal is preserved.
|
||||||
|
kernel = kernel / kernel.sum()
|
||||||
|
kernels.append(kernel)
|
||||||
|
|
||||||
|
_kernel = paddle.stack(kernels).reshape([self.new_sr, 1, -1])
|
||||||
|
self.kernel = self.create_parameter(
|
||||||
|
shape=_kernel.shape,
|
||||||
|
dtype=_kernel.dtype, )
|
||||||
|
self.kernel.set_value(_kernel)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: paddle.Tensor,
|
||||||
|
output_length: Optional[int]=None,
|
||||||
|
full: bool=False, ):
|
||||||
|
"""
|
||||||
|
Resample x.
|
||||||
|
Args:
|
||||||
|
x (Tensor): signal to resample, time should be the last dimension
|
||||||
|
output_length (None or int): This can be set to the desired output length
|
||||||
|
(last dimension). Allowed values are between 0 and
|
||||||
|
ceil(length * new_sr / old_sr). When None (default) is specified, the
|
||||||
|
floored output length will be used. In order to select the largest possible
|
||||||
|
size, use the `full` argument.
|
||||||
|
full (bool): return the longest possible output from the input. This can be useful
|
||||||
|
if you chain resampling operations, and want to give the `output_length` only
|
||||||
|
for the last one, while passing `full=True` to all the other ones.
|
||||||
|
"""
|
||||||
|
if self.old_sr == self.new_sr:
|
||||||
|
return x
|
||||||
|
shape = x.shape
|
||||||
|
_dtype = x.dtype
|
||||||
|
length = x.shape[-1]
|
||||||
|
x = x.reshape([-1, length])
|
||||||
|
x = F.pad(
|
||||||
|
x.unsqueeze(1),
|
||||||
|
[self._width, self._width + self.old_sr],
|
||||||
|
mode="replicate",
|
||||||
|
data_format="NCL", ).astype(self.kernel.dtype)
|
||||||
|
ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL")
|
||||||
|
y = ys.transpose(
|
||||||
|
[0, 2, 1]).reshape(list(shape[:-1]) + [-1]).astype(_dtype)
|
||||||
|
|
||||||
|
float_output_length = paddle.to_tensor(
|
||||||
|
self.new_sr * length / self.old_sr, dtype="float32")
|
||||||
|
max_output_length = paddle.ceil(float_output_length).astype("int64")
|
||||||
|
default_output_length = paddle.floor(float_output_length).astype(
|
||||||
|
"int64")
|
||||||
|
|
||||||
|
if output_length is None:
|
||||||
|
applied_output_length = (max_output_length
|
||||||
|
if full else default_output_length)
|
||||||
|
elif output_length < 0 or output_length > max_output_length:
|
||||||
|
raise ValueError(
|
||||||
|
f"output_length must be between 0 and {max_output_length.numpy()}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
applied_output_length = paddle.to_tensor(
|
||||||
|
output_length, dtype="int64")
|
||||||
|
if full:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot pass both full=True and output_length")
|
||||||
|
return y[..., :applied_output_length]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return simple_repr(self)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_frac(
|
||||||
|
x: paddle.Tensor,
|
||||||
|
old_sr: int,
|
||||||
|
new_sr: int,
|
||||||
|
zeros: int=24,
|
||||||
|
rolloff: float=0.945,
|
||||||
|
output_length: Optional[int]=None,
|
||||||
|
full: bool=False, ):
|
||||||
|
"""
|
||||||
|
Functional version of `ResampleFrac`, refer to its documentation for more information.
|
||||||
|
|
||||||
|
..warning::
|
||||||
|
If you call repeatidly this functions with the same sample rates, then the
|
||||||
|
resampling kernel will be recomputed everytime. For best performance, you should use
|
||||||
|
and cache an instance of `ResampleFrac`.
|
||||||
|
"""
|
||||||
|
return ResampleFrac(old_sr, new_sr, zeros, rolloff)(x, output_length, full)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_to(tensor: paddle.Tensor,
|
||||||
|
target_length: int,
|
||||||
|
mode: str="constant",
|
||||||
|
value: float=0.0):
|
||||||
|
"""
|
||||||
|
Pad the given tensor to the given length, with 0s on the right.
|
||||||
|
"""
|
||||||
|
return F.pad(
|
||||||
|
tensor, (0, target_length - tensor.shape[-1]),
|
||||||
|
mode=mode,
|
||||||
|
value=value,
|
||||||
|
data_format="NCL")
|
||||||
|
|
||||||
|
|
||||||
|
def pure_tone(freq: float, sr: float=128, dur: float=4, device=None):
|
||||||
|
"""
|
||||||
|
Return a pure tone, i.e. cosine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
freq (float): frequency (in Hz)
|
||||||
|
sr (float): sample rate (in Hz)
|
||||||
|
dur (float): duration (in seconds)
|
||||||
|
"""
|
||||||
|
time = paddle.arange(int(sr * dur), dtype="float32") / sr
|
||||||
|
return paddle.cos(2 * math.pi * freq * time)
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilters(nn.Layer):
|
||||||
|
"""
|
||||||
|
Bank of low pass filters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cutoffs: Sequence[float],
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None,
|
||||||
|
dtype="float32"):
|
||||||
|
super().__init__()
|
||||||
|
self.cutoffs = list(cutoffs)
|
||||||
|
if min(self.cutoffs) < 0:
|
||||||
|
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||||
|
if max(self.cutoffs) > 0.5:
|
||||||
|
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||||
|
self.stride = stride
|
||||||
|
self.pad = pad
|
||||||
|
self.zeros = zeros
|
||||||
|
self.half_size = int(zeros / min([c for c in self.cutoffs if c > 0]) /
|
||||||
|
2)
|
||||||
|
if fft is None:
|
||||||
|
fft = self.half_size > 32
|
||||||
|
self.fft = fft
|
||||||
|
|
||||||
|
# Create filters
|
||||||
|
window = paddle.audio.functional.get_window(
|
||||||
|
"hann", 2 * self.half_size + 1, fftbins=False, dtype=dtype)
|
||||||
|
time = paddle.arange(
|
||||||
|
-self.half_size, self.half_size + 1, dtype="float32")
|
||||||
|
filters = []
|
||||||
|
for cutoff in cutoffs:
|
||||||
|
if cutoff == 0:
|
||||||
|
filter_ = paddle.zeros_like(time)
|
||||||
|
else:
|
||||||
|
filter_ = 2 * cutoff * window * sinc(2 * cutoff * math.pi *
|
||||||
|
time)
|
||||||
|
# Normalize filter
|
||||||
|
filter_ /= paddle.sum(filter_)
|
||||||
|
filters.append(filter_)
|
||||||
|
filters = paddle.stack(filters)[:, None]
|
||||||
|
self.filters = self.create_parameter(
|
||||||
|
shape=filters.shape,
|
||||||
|
default_initializer=nn.initializer.Constant(value=0.0),
|
||||||
|
dtype="float32",
|
||||||
|
is_bias=False,
|
||||||
|
attr=paddle.ParamAttr(trainable=False), )
|
||||||
|
self.filters.set_value(filters)
|
||||||
|
|
||||||
|
def forward(self, _input):
|
||||||
|
shape = list(_input.shape)
|
||||||
|
_input = _input.reshape([-1, 1, shape[-1]])
|
||||||
|
if self.pad:
|
||||||
|
_input = F.pad(
|
||||||
|
_input, (self.half_size, self.half_size),
|
||||||
|
mode="replicate",
|
||||||
|
data_format="NCL")
|
||||||
|
if self.fft:
|
||||||
|
out = fft_conv1d(_input, self.filters, stride=self.stride)
|
||||||
|
else:
|
||||||
|
out = F.conv1d(_input, self.filters, stride=self.stride)
|
||||||
|
|
||||||
|
shape.insert(0, len(self.cutoffs))
|
||||||
|
shape[-1] = out.shape[-1]
|
||||||
|
return out.transpose([1, 0, 2]).reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilter(nn.Layer):
|
||||||
|
"""
|
||||||
|
Same as `LowPassFilters` but applies a single low pass filter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cutoff: float,
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None):
|
||||||
|
super().__init__()
|
||||||
|
self._lowpasses = LowPassFilters([cutoff], stride, pad, zeros, fft)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cutoff(self):
|
||||||
|
return self._lowpasses.cutoffs[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
return self._lowpasses.stride
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad(self):
|
||||||
|
return self._lowpasses.pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def zeros(self):
|
||||||
|
return self._lowpasses.zeros
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fft(self):
|
||||||
|
return self._lowpasses.fft
|
||||||
|
|
||||||
|
def forward(self, _input):
|
||||||
|
return self._lowpasses(_input)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def lowpass_filters(
|
||||||
|
_input: paddle.Tensor,
|
||||||
|
cutoffs: Sequence[float],
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None, ):
|
||||||
|
"""
|
||||||
|
Functional version of `LowPassFilters`, refer to this class for more information.
|
||||||
|
"""
|
||||||
|
return LowPassFilters(cutoffs, stride, pad, zeros, fft)(_input)
|
||||||
|
|
||||||
|
|
||||||
|
def lowpass_filter(_input: paddle.Tensor,
|
||||||
|
cutoff: float,
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None):
|
||||||
|
"""
|
||||||
|
Same as `lowpass_filters` but with a single cutoff frequency.
|
||||||
|
Output will not have a dimension inserted in the front.
|
||||||
|
"""
|
||||||
|
return lowpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class HighPassFilters(paddle.nn.Layer):
|
||||||
|
"""
|
||||||
|
Bank of high pass filters. See `julius.lowpass.LowPassFilters` for more
|
||||||
|
details on the implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cutoffs (list[float]): list of cutoff frequencies, in [0, 0.5] expressed as `f/f_s` where
|
||||||
|
f_s is the samplerate and `f` is the cutoff frequency.
|
||||||
|
The upper limit is 0.5, because a signal sampled at `f_s` contains only
|
||||||
|
frequencies under `f_s / 2`.
|
||||||
|
stride (int): how much to decimate the output. Probably not a good idea
|
||||||
|
to do so with a high pass filters though...
|
||||||
|
pad (bool): if True, appropriately pad the _input with zero over the edge. If `stride=1`,
|
||||||
|
the output will have the same length as the _input.
|
||||||
|
zeros (float): Number of zero crossings to keep.
|
||||||
|
Controls the receptive field of the Finite Impulse Response filter.
|
||||||
|
For filters with low cutoff frequency, e.g. 40Hz at 44.1kHz,
|
||||||
|
it is a bad idea to set this to a high value.
|
||||||
|
This is likely appropriate for most use. Lower values
|
||||||
|
will result in a faster filter, but with a slower attenuation around the
|
||||||
|
cutoff frequency.
|
||||||
|
fft (bool or None): if True, uses `julius.fftconv` rather than PyTorch convolutions.
|
||||||
|
If False, uses PyTorch convolutions. If None, either one will be chosen automatically
|
||||||
|
depending on the effective filter size.
|
||||||
|
|
||||||
|
|
||||||
|
..warning::
|
||||||
|
All the filters will use the same filter size, aligned on the lowest
|
||||||
|
frequency provided. If you combine a lot of filters with very diverse frequencies, it might
|
||||||
|
be more efficient to split them over multiple modules with similar frequencies.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
|
||||||
|
- Input: `[*, T]`
|
||||||
|
- Output: `[F, *, T']`, with `T'=T` if `pad` is True and `stride` is 1, and
|
||||||
|
`F` is the numer of cutoff frequencies.
|
||||||
|
|
||||||
|
>>> highpass = HighPassFilters([1/4])
|
||||||
|
>>> x = paddle.randn([4, 12, 21, 1024])
|
||||||
|
>>> list(highpass(x).shape)
|
||||||
|
[1, 4, 12, 21, 1024]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cutoffs: Sequence[float],
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None):
|
||||||
|
super().__init__()
|
||||||
|
self._lowpasses = LowPassFilters(cutoffs, stride, pad, zeros, fft)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cutoffs(self):
|
||||||
|
return self._lowpasses.cutoffs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
return self._lowpasses.stride
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad(self):
|
||||||
|
return self._lowpasses.pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def zeros(self):
|
||||||
|
return self._lowpasses.zeros
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fft(self):
|
||||||
|
return self._lowpasses.fft
|
||||||
|
|
||||||
|
def forward(self, _input):
|
||||||
|
lows = self._lowpasses(_input)
|
||||||
|
|
||||||
|
# We need to extract the right portion of the _input in case
|
||||||
|
# pad is False or stride > 1
|
||||||
|
if self.pad:
|
||||||
|
start, end = 0, _input.shape[-1]
|
||||||
|
else:
|
||||||
|
start = self._lowpasses.half_size
|
||||||
|
end = -start
|
||||||
|
_input = _input[..., start:end:self.stride]
|
||||||
|
highs = _input - lows
|
||||||
|
return highs
|
||||||
|
|
||||||
|
|
||||||
|
class HighPassFilter(paddle.nn.Layer):
|
||||||
|
"""
|
||||||
|
Same as `HighPassFilters` but applies a single high pass filter.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
|
||||||
|
- Input: `[*, T]`
|
||||||
|
- Output: `[*, T']`, with `T'=T` if `pad` is True and `stride` is 1.
|
||||||
|
|
||||||
|
>>> highpass = HighPassFilter(1/4, stride=1)
|
||||||
|
>>> x = paddle.randn([4, 124])
|
||||||
|
>>> list(highpass(x).shape)
|
||||||
|
[4, 124]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
cutoff: float,
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None):
|
||||||
|
super().__init__()
|
||||||
|
self._highpasses = HighPassFilters([cutoff], stride, pad, zeros, fft)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cutoff(self):
|
||||||
|
return self._highpasses.cutoffs[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stride(self):
|
||||||
|
return self._highpasses.stride
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pad(self):
|
||||||
|
return self._highpasses.pad
|
||||||
|
|
||||||
|
@property
|
||||||
|
def zeros(self):
|
||||||
|
return self._highpasses.zeros
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fft(self):
|
||||||
|
return self._highpasses.fft
|
||||||
|
|
||||||
|
def forward(self, _input):
|
||||||
|
return self._highpasses(_input)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def highpass_filters(
|
||||||
|
_input: paddle.Tensor,
|
||||||
|
cutoffs: Sequence[float],
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None, ):
|
||||||
|
"""
|
||||||
|
Functional version of `HighPassFilters`, refer to this class for more information.
|
||||||
|
"""
|
||||||
|
return HighPassFilters(cutoffs, stride, pad, zeros, fft)(_input)
|
||||||
|
|
||||||
|
|
||||||
|
def highpass_filter(_input: paddle.Tensor,
|
||||||
|
cutoff: float,
|
||||||
|
stride: int=1,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None):
|
||||||
|
"""
|
||||||
|
Functional version of `HighPassFilter`, refer to this class for more information.
|
||||||
|
Output will not have a dimension inserted in the front.
|
||||||
|
"""
|
||||||
|
return highpass_filters(_input, [cutoff], stride, pad, zeros, fft)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class SplitBands(paddle.nn.Layer):
|
||||||
|
"""
|
||||||
|
Decomposes a signal over the given frequency bands in the waveform domain using
|
||||||
|
a cascade of low pass filters as implemented by `julius.lowpass.LowPassFilters`.
|
||||||
|
You can either specify explicitly the frequency cutoffs, or just the number of bands,
|
||||||
|
in which case the frequency cutoffs will be spread out evenly in mel scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate (float): Sample rate of the input signal in Hz.
|
||||||
|
n_bands (int or None): number of bands, when not giving them explicitly with `cutoffs`.
|
||||||
|
In that case, the cutoff frequencies will be evenly spaced in mel-space.
|
||||||
|
cutoffs (list[float] or None): list of frequency cutoffs in Hz.
|
||||||
|
pad (bool): if True, appropriately pad the input with zero over the edge. If `stride=1`,
|
||||||
|
the output will have the same length as the input.
|
||||||
|
zeros (float): Number of zero crossings to keep. See `LowPassFilters` for more informations.
|
||||||
|
fft (bool or None): See `LowPassFilters` for more info.
|
||||||
|
|
||||||
|
..note::
|
||||||
|
The sum of all the bands will always be the input signal.
|
||||||
|
|
||||||
|
..warning::
|
||||||
|
Unlike `julius.lowpass.LowPassFilters`, the cutoffs frequencies must be provided in Hz along
|
||||||
|
with the sample rate.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
|
||||||
|
- Input: `[*, T]`
|
||||||
|
- Output: `[B, *, T']`, with `T'=T` if `pad` is True.
|
||||||
|
If `n_bands` was provided, `B = n_bands` otherwise `B = len(cutoffs) + 1`
|
||||||
|
|
||||||
|
>>> bands = SplitBands(sample_rate=128, n_bands=10)
|
||||||
|
>>> x = paddle.randn(shape=[6, 4, 1024])
|
||||||
|
>>> list(bands(x).shape)
|
||||||
|
[10, 6, 4, 1024]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: float,
|
||||||
|
n_bands: Optional[int]=None,
|
||||||
|
cutoffs: Optional[Sequence[float]]=None,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None, ):
|
||||||
|
super().__init__()
|
||||||
|
if (cutoffs is None) + (n_bands is None) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"You must provide either n_bands, or cutoffs, but not both.")
|
||||||
|
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.n_bands = n_bands
|
||||||
|
self._cutoffs = list(cutoffs) if cutoffs is not None else None
|
||||||
|
self.pad = pad
|
||||||
|
self.zeros = zeros
|
||||||
|
self.fft = fft
|
||||||
|
|
||||||
|
if cutoffs is None:
|
||||||
|
if n_bands is None:
|
||||||
|
raise ValueError("You must provide one of n_bands or cutoffs.")
|
||||||
|
if not n_bands >= 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"n_bands must be greater than one (got {n_bands})")
|
||||||
|
cutoffs = paddle.audio.functional.mel_frequencies(
|
||||||
|
n_bands + 1, 0, sample_rate / 2)[1:-1]
|
||||||
|
else:
|
||||||
|
if max(cutoffs) > 0.5 * sample_rate:
|
||||||
|
raise ValueError(
|
||||||
|
"A cutoff above sample_rate/2 does not make sense.")
|
||||||
|
if len(cutoffs) > 0:
|
||||||
|
self.lowpass = LowPassFilters(
|
||||||
|
[c / sample_rate for c in cutoffs],
|
||||||
|
pad=pad,
|
||||||
|
zeros=zeros,
|
||||||
|
fft=fft)
|
||||||
|
else:
|
||||||
|
self.lowpass = None # type: ignore
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
if self.lowpass is None:
|
||||||
|
return input[None]
|
||||||
|
lows = self.lowpass(input)
|
||||||
|
low = lows[0]
|
||||||
|
bands = [low]
|
||||||
|
for low_and_band in lows[1:]:
|
||||||
|
# Get a bandpass filter by subtracting lowpasses
|
||||||
|
band = low_and_band - low
|
||||||
|
bands.append(band)
|
||||||
|
low = low_and_band
|
||||||
|
# Last band is whatever is left in the signal
|
||||||
|
bands.append(input - low)
|
||||||
|
return paddle.stack(bands)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cutoffs(self):
|
||||||
|
if self._cutoffs is not None:
|
||||||
|
return self._cutoffs
|
||||||
|
elif self.lowpass is not None:
|
||||||
|
return [c * self.sample_rate for c in self.lowpass.cutoffs]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def split_bands(
|
||||||
|
signal: paddle.Tensor,
|
||||||
|
sample_rate: float,
|
||||||
|
n_bands: Optional[int]=None,
|
||||||
|
cutoffs: Optional[Sequence[float]]=None,
|
||||||
|
pad: bool=True,
|
||||||
|
zeros: float=8,
|
||||||
|
fft: Optional[bool]=None, ):
|
||||||
|
"""
|
||||||
|
Functional version of `SplitBands`, refer to this class for more information.
|
||||||
|
|
||||||
|
>>> x = paddle.randn(shape=[6, 4, 1024])
|
||||||
|
>>> list(split_bands(x, sample_rate=64, cutoffs=[12, 24]).shape)
|
||||||
|
[3, 6, 4, 1024]
|
||||||
|
"""
|
||||||
|
return SplitBands(sample_rate, n_bands, cutoffs, pad, zeros, fft)(signal)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,195 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/display.py)
|
||||||
|
import inspect
|
||||||
|
import typing
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from . import util
|
||||||
|
|
||||||
|
|
||||||
|
def format_figure(func):
|
||||||
|
"""Decorator for formatting figures produced by the code below.
|
||||||
|
See :py:func:`audiotools.core.util.format_figure` for more.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
func : Callable
|
||||||
|
Plotting function that is decorated by this function.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
f_keys = inspect.signature(util.format_figure).parameters.keys()
|
||||||
|
f_kwargs = {}
|
||||||
|
for k, v in list(kwargs.items()):
|
||||||
|
if k in f_keys:
|
||||||
|
kwargs.pop(k)
|
||||||
|
f_kwargs[k] = v
|
||||||
|
func(*args, **kwargs)
|
||||||
|
util.format_figure(**f_kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class DisplayMixin:
|
||||||
|
@format_figure
|
||||||
|
def specshow(
|
||||||
|
self,
|
||||||
|
preemphasis: bool=False,
|
||||||
|
x_axis: str="time",
|
||||||
|
y_axis: str="linear",
|
||||||
|
n_mels: int=128,
|
||||||
|
**kwargs, ):
|
||||||
|
"""Displays a spectrogram, using ``librosa.display.specshow``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
preemphasis : bool, optional
|
||||||
|
Whether or not to apply preemphasis, which makes high
|
||||||
|
frequency detail easier to see, by default False
|
||||||
|
x_axis : str, optional
|
||||||
|
How to label the x axis, by default "time"
|
||||||
|
y_axis : str, optional
|
||||||
|
How to label the y axis, by default "linear"
|
||||||
|
n_mels : int, optional
|
||||||
|
If displaying a mel spectrogram with ``y_axis = "mel"``,
|
||||||
|
this controls the number of mels, by default 128.
|
||||||
|
kwargs : dict, optional
|
||||||
|
Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
|
||||||
|
"""
|
||||||
|
import librosa
|
||||||
|
import librosa.display
|
||||||
|
|
||||||
|
# Always re-compute the STFT data before showing it, in case
|
||||||
|
# it changed.
|
||||||
|
signal = self.clone()
|
||||||
|
signal.stft_data = None
|
||||||
|
|
||||||
|
if preemphasis:
|
||||||
|
signal.preemphasis()
|
||||||
|
|
||||||
|
ref = signal.magnitude.max()
|
||||||
|
log_mag = signal.log_magnitude(ref_value=ref)
|
||||||
|
|
||||||
|
if y_axis == "mel":
|
||||||
|
log_mag = 20 * signal.mel_spectrogram(n_mels).clip(1e-5).log10()
|
||||||
|
log_mag -= log_mag.max()
|
||||||
|
|
||||||
|
librosa.display.specshow(
|
||||||
|
log_mag.numpy()[0].mean(axis=0),
|
||||||
|
x_axis=x_axis,
|
||||||
|
y_axis=y_axis,
|
||||||
|
sr=signal.sample_rate,
|
||||||
|
**kwargs, )
|
||||||
|
|
||||||
|
@format_figure
|
||||||
|
def waveplot(self, x_axis: str="time", **kwargs):
|
||||||
|
"""Displays a waveform plot, using ``librosa.display.waveshow``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x_axis : str, optional
|
||||||
|
How to label the x axis, by default "time"
|
||||||
|
kwargs : dict, optional
|
||||||
|
Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
|
||||||
|
"""
|
||||||
|
import librosa
|
||||||
|
import librosa.display
|
||||||
|
|
||||||
|
audio_data = self.audio_data[0].mean(axis=0)
|
||||||
|
audio_data = audio_data.cpu().numpy()
|
||||||
|
|
||||||
|
plot_fn = "waveshow" if hasattr(librosa.display,
|
||||||
|
"waveshow") else "waveplot"
|
||||||
|
wave_plot_fn = getattr(librosa.display, plot_fn)
|
||||||
|
wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)
|
||||||
|
|
||||||
|
@format_figure
|
||||||
|
def wavespec(self, x_axis: str="time", **kwargs):
|
||||||
|
"""Displays a waveform plot, using ``librosa.display.waveshow``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x_axis : str, optional
|
||||||
|
How to label the x axis, by default "time"
|
||||||
|
kwargs : dict, optional
|
||||||
|
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
|
||||||
|
"""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.gridspec import GridSpec
|
||||||
|
|
||||||
|
gs = GridSpec(6, 1)
|
||||||
|
plt.subplot(gs[0, :])
|
||||||
|
self.waveplot(x_axis=x_axis)
|
||||||
|
plt.subplot(gs[1:, :])
|
||||||
|
self.specshow(x_axis=x_axis, **kwargs)
|
||||||
|
|
||||||
|
def write_audio_to_tb(
|
||||||
|
self,
|
||||||
|
tag: str,
|
||||||
|
writer,
|
||||||
|
step: int=None,
|
||||||
|
plot_fn: typing.Union[typing.Callable, str]="specshow",
|
||||||
|
**kwargs, ):
|
||||||
|
"""Writes a signal and its spectrogram to Tensorboard. Will show up
|
||||||
|
under the Audio and Images tab in Tensorboard.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tag : str
|
||||||
|
Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
|
||||||
|
written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
|
||||||
|
writer : SummaryWriter
|
||||||
|
A SummaryWriter object from PyTorch library.
|
||||||
|
step : int, optional
|
||||||
|
The step to write the signal to, by default None
|
||||||
|
plot_fn : typing.Union[typing.Callable, str], optional
|
||||||
|
How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
|
||||||
|
kwargs : dict, optional
|
||||||
|
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
|
||||||
|
whatever ``plot_fn`` is set to.
|
||||||
|
"""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
audio_data = self.audio_data[0, 0].detach().cpu().numpy()
|
||||||
|
sample_rate = self.sample_rate
|
||||||
|
writer.add_audio(tag, audio_data, step, sample_rate)
|
||||||
|
|
||||||
|
if plot_fn is not None:
|
||||||
|
if isinstance(plot_fn, str):
|
||||||
|
plot_fn = getattr(self, plot_fn)
|
||||||
|
fig = plt.figure()
|
||||||
|
plt.clf()
|
||||||
|
plot_fn(**kwargs)
|
||||||
|
writer.add_figure(tag.replace("wav", "png"), fig, step)
|
||||||
|
|
||||||
|
def save_image(
|
||||||
|
self,
|
||||||
|
image_path: str,
|
||||||
|
plot_fn: typing.Union[typing.Callable, str]="specshow",
|
||||||
|
**kwargs, ):
|
||||||
|
"""Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
|
||||||
|
a specified file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
image_path : str
|
||||||
|
Where to save the file to.
|
||||||
|
plot_fn : typing.Union[typing.Callable, str], optional
|
||||||
|
How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
|
||||||
|
kwargs : dict, optional
|
||||||
|
Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
|
||||||
|
whatever ``plot_fn`` is set to.
|
||||||
|
"""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
if isinstance(plot_fn, str):
|
||||||
|
plot_fn = getattr(self, plot_fn)
|
||||||
|
|
||||||
|
plt.clf()
|
||||||
|
plot_fn(**kwargs)
|
||||||
|
plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
|
||||||
|
plt.close()
|
@ -0,0 +1,467 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/dsp.py)
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from . import _julius
|
||||||
|
from . import util
|
||||||
|
|
||||||
|
|
||||||
|
def _unfold(x, kernel_sizes, strides):
|
||||||
|
# https://github.com/PaddlePaddle/Paddle/pull/70102
|
||||||
|
|
||||||
|
if 1 == kernel_sizes[0]:
|
||||||
|
x_zeros = paddle.zeros_like(x)
|
||||||
|
x = paddle.concat([x, x_zeros], axis=2)
|
||||||
|
|
||||||
|
kernel_sizes = [2, kernel_sizes[1]]
|
||||||
|
strides = list(strides)
|
||||||
|
|
||||||
|
unfolded = paddle.nn.functional.unfold(
|
||||||
|
x,
|
||||||
|
kernel_sizes=kernel_sizes,
|
||||||
|
strides=strides, )
|
||||||
|
if 2 == kernel_sizes[0]:
|
||||||
|
unfolded = unfolded[:, :kernel_sizes[1]]
|
||||||
|
return unfolded
|
||||||
|
|
||||||
|
|
||||||
|
def _fold(x, output_sizes, kernel_sizes, strides):
|
||||||
|
# https://github.com/PaddlePaddle/Paddle/pull/70102
|
||||||
|
|
||||||
|
if 1 == output_sizes[0] and 1 == kernel_sizes[0]:
|
||||||
|
x_zeros = paddle.zeros_like(x)
|
||||||
|
x = paddle.concat([x, x_zeros], axis=1)
|
||||||
|
|
||||||
|
output_sizes = (2, output_sizes[1])
|
||||||
|
kernel_sizes = (2, kernel_sizes[1])
|
||||||
|
|
||||||
|
fold = paddle.nn.functional.fold(
|
||||||
|
x,
|
||||||
|
output_sizes=output_sizes,
|
||||||
|
kernel_sizes=kernel_sizes,
|
||||||
|
strides=strides, )
|
||||||
|
if 2 == kernel_sizes[0]:
|
||||||
|
fold = fold[:, :, :1]
|
||||||
|
return fold
|
||||||
|
|
||||||
|
|
||||||
|
class DSPMixin:
|
||||||
|
_original_batch_size = None
|
||||||
|
_original_num_channels = None
|
||||||
|
_padded_signal_length = None
|
||||||
|
|
||||||
|
def _preprocess_signal_for_windowing(self, window_duration, hop_duration):
|
||||||
|
self._original_batch_size = self.batch_size
|
||||||
|
self._original_num_channels = self.num_channels
|
||||||
|
|
||||||
|
window_length = int(window_duration * self.sample_rate)
|
||||||
|
hop_length = int(hop_duration * self.sample_rate)
|
||||||
|
|
||||||
|
if window_length % hop_length != 0:
|
||||||
|
factor = window_length // hop_length
|
||||||
|
window_length = factor * hop_length
|
||||||
|
|
||||||
|
self.zero_pad(hop_length, hop_length)
|
||||||
|
self._padded_signal_length = self.signal_length
|
||||||
|
|
||||||
|
return window_length, hop_length
|
||||||
|
|
||||||
|
def windows(self,
|
||||||
|
window_duration: float,
|
||||||
|
hop_duration: float,
|
||||||
|
preprocess: bool=True):
|
||||||
|
"""Generator which yields windows of specified duration from signal with a specified
|
||||||
|
hop length.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
window_duration : float
|
||||||
|
Duration of every window in seconds.
|
||||||
|
hop_duration : float
|
||||||
|
Hop between windows in seconds.
|
||||||
|
preprocess : bool, optional
|
||||||
|
Whether to preprocess the signal, so that the first sample is in
|
||||||
|
the middle of the first window, by default True
|
||||||
|
|
||||||
|
Yields
|
||||||
|
------
|
||||||
|
AudioSignal
|
||||||
|
Each window is returned as an AudioSignal.
|
||||||
|
"""
|
||||||
|
if preprocess:
|
||||||
|
window_length, hop_length = self._preprocess_signal_for_windowing(
|
||||||
|
window_duration, hop_duration)
|
||||||
|
|
||||||
|
self.audio_data = self.audio_data.reshape([-1, 1, self.signal_length])
|
||||||
|
|
||||||
|
for b in range(self.batch_size):
|
||||||
|
i = 0
|
||||||
|
start_idx = i * hop_length
|
||||||
|
while True:
|
||||||
|
start_idx = i * hop_length
|
||||||
|
i += 1
|
||||||
|
end_idx = start_idx + window_length
|
||||||
|
if end_idx > self.signal_length:
|
||||||
|
break
|
||||||
|
yield self[b, ..., start_idx:end_idx]
|
||||||
|
|
||||||
|
def collect_windows(self,
|
||||||
|
window_duration: float,
|
||||||
|
hop_duration: float,
|
||||||
|
preprocess: bool=True):
|
||||||
|
"""Reshapes signal into windows of specified duration from signal with a specified
|
||||||
|
hop length. Window are placed along the batch dimension. Use with
|
||||||
|
:py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the
|
||||||
|
original signal.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
window_duration : float
|
||||||
|
Duration of every window in seconds.
|
||||||
|
hop_duration : float
|
||||||
|
Hop between windows in seconds.
|
||||||
|
preprocess : bool, optional
|
||||||
|
Whether to preprocess the signal, so that the first sample is in
|
||||||
|
the middle of the first window, by default True
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)``
|
||||||
|
"""
|
||||||
|
if preprocess:
|
||||||
|
window_length, hop_length = self._preprocess_signal_for_windowing(
|
||||||
|
window_duration, hop_duration)
|
||||||
|
|
||||||
|
# self.audio_data: (nb, nch, nt).
|
||||||
|
# unfolded = paddle.nn.functional.unfold(
|
||||||
|
# self.audio_data.reshape([-1, 1, 1, self.signal_length]),
|
||||||
|
# kernel_sizes=(1, window_length),
|
||||||
|
# strides=(1, hop_length),
|
||||||
|
# )
|
||||||
|
unfolded = _unfold(
|
||||||
|
self.audio_data.reshape([-1, 1, 1, self.signal_length]),
|
||||||
|
kernel_sizes=(1, window_length),
|
||||||
|
strides=(1, hop_length), )
|
||||||
|
# unfolded: (nb * nch, window_length, num_windows).
|
||||||
|
# -> (nb * nch * num_windows, 1, window_length)
|
||||||
|
unfolded = unfolded.transpose([0, 2, 1]).reshape([-1, 1, window_length])
|
||||||
|
self.audio_data = unfolded
|
||||||
|
return self
|
||||||
|
|
||||||
|
def overlap_and_add(self, hop_duration: float):
|
||||||
|
"""Function which takes a list of windows and overlap adds them into a
|
||||||
|
signal the same length as ``audio_signal``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
hop_duration : float
|
||||||
|
How much to shift for each window
|
||||||
|
(overlap is window_duration - hop_duration) in seconds.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
overlap-and-added signal.
|
||||||
|
"""
|
||||||
|
hop_length = int(hop_duration * self.sample_rate)
|
||||||
|
window_length = self.signal_length
|
||||||
|
|
||||||
|
nb, nch = self._original_batch_size, self._original_num_channels
|
||||||
|
|
||||||
|
unfolded = self.audio_data.reshape(
|
||||||
|
[nb * nch, -1, window_length]).transpose([0, 2, 1])
|
||||||
|
# folded = paddle.nn.functional.fold(
|
||||||
|
# unfolded,
|
||||||
|
# output_sizes=(1, self._padded_signal_length),
|
||||||
|
# kernel_sizes=(1, window_length),
|
||||||
|
# strides=(1, hop_length),
|
||||||
|
# )
|
||||||
|
folded = _fold(
|
||||||
|
unfolded,
|
||||||
|
output_sizes=(1, self._padded_signal_length),
|
||||||
|
kernel_sizes=(1, window_length),
|
||||||
|
strides=(1, hop_length), )
|
||||||
|
|
||||||
|
norm = paddle.ones_like(unfolded)
|
||||||
|
# norm = paddle.nn.functional.fold(
|
||||||
|
# norm,
|
||||||
|
# output_sizes=(1, self._padded_signal_length),
|
||||||
|
# kernel_sizes=(1, window_length),
|
||||||
|
# strides=(1, hop_length),
|
||||||
|
# )
|
||||||
|
norm = _fold(
|
||||||
|
norm,
|
||||||
|
output_sizes=(1, self._padded_signal_length),
|
||||||
|
kernel_sizes=(1, window_length),
|
||||||
|
strides=(1, hop_length), )
|
||||||
|
|
||||||
|
folded = folded / norm
|
||||||
|
|
||||||
|
folded = folded.reshape([nb, nch, -1])
|
||||||
|
self.audio_data = folded
|
||||||
|
self.trim(hop_length, hop_length)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def low_pass(self,
|
||||||
|
cutoffs: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
zeros: int=51):
|
||||||
|
"""Low-passes the signal in-place. Each item in the batch
|
||||||
|
can have a different low-pass cutoff, if the input
|
||||||
|
to this signal is an array or tensor. If a float, all
|
||||||
|
items are given the same low-pass filter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cutoffs : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Cutoff in Hz of low-pass filter.
|
||||||
|
zeros : int, optional
|
||||||
|
Number of taps to use in low-pass filter, by default 51
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Low-passed AudioSignal.
|
||||||
|
"""
|
||||||
|
cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
|
||||||
|
cutoffs = cutoffs / self.sample_rate
|
||||||
|
filtered = paddle.empty_like(self.audio_data)
|
||||||
|
|
||||||
|
for i, cutoff in enumerate(cutoffs):
|
||||||
|
lp_filter = _julius.LowPassFilter(cutoff.cpu(), zeros=zeros)
|
||||||
|
filtered[i] = lp_filter(self.audio_data[i])
|
||||||
|
|
||||||
|
self.audio_data = filtered
|
||||||
|
self.stft_data = None
|
||||||
|
return self
|
||||||
|
|
||||||
|
def high_pass(self,
|
||||||
|
cutoffs: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
zeros: int=51):
|
||||||
|
"""High-passes the signal in-place. Each item in the batch
|
||||||
|
can have a different high-pass cutoff, if the input
|
||||||
|
to this signal is an array or tensor. If a float, all
|
||||||
|
items are given the same high-pass filter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
cutoffs : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Cutoff in Hz of high-pass filter.
|
||||||
|
zeros : int, optional
|
||||||
|
Number of taps to use in high-pass filter, by default 51
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
High-passed AudioSignal.
|
||||||
|
"""
|
||||||
|
cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
|
||||||
|
cutoffs = cutoffs / self.sample_rate
|
||||||
|
filtered = paddle.empty_like(self.audio_data)
|
||||||
|
|
||||||
|
for i, cutoff in enumerate(cutoffs):
|
||||||
|
hp_filter = _julius.HighPassFilter(cutoff.cpu(), zeros=zeros)
|
||||||
|
filtered[i] = hp_filter(self.audio_data[i])
|
||||||
|
|
||||||
|
self.audio_data = filtered
|
||||||
|
self.stft_data = None
|
||||||
|
return self
|
||||||
|
|
||||||
|
def mask_frequencies(
|
||||||
|
self,
|
||||||
|
fmin_hz: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
fmax_hz: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
val: float=0.0, ):
|
||||||
|
"""Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
|
||||||
|
with the value specified by ``val``. Useful for implementing SpecAug.
|
||||||
|
The min and max can be different for every item in the batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fmin_hz : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Lower end of band to mask out.
|
||||||
|
fmax_hz : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Upper end of band to mask out.
|
||||||
|
val : float, optional
|
||||||
|
Value to fill in, by default 0.0
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
||||||
|
masked audio data.
|
||||||
|
"""
|
||||||
|
# SpecAug
|
||||||
|
mag, phase = self.magnitude, self.phase
|
||||||
|
fmin_hz = util.ensure_tensor(
|
||||||
|
fmin_hz,
|
||||||
|
ndim=mag.ndim, )
|
||||||
|
fmax_hz = util.ensure_tensor(
|
||||||
|
fmax_hz,
|
||||||
|
ndim=mag.ndim, )
|
||||||
|
assert paddle.all(fmin_hz < fmax_hz)
|
||||||
|
|
||||||
|
# build mask
|
||||||
|
nbins = mag.shape[-2]
|
||||||
|
bins_hz = paddle.linspace(
|
||||||
|
0,
|
||||||
|
self.sample_rate / 2,
|
||||||
|
nbins, )
|
||||||
|
bins_hz = bins_hz[None, None, :, None].tile(
|
||||||
|
[self.batch_size, 1, 1, mag.shape[-1]])
|
||||||
|
|
||||||
|
fmin_hz, fmax_hz = fmin_hz.astype(bins_hz.dtype), fmax_hz.astype(
|
||||||
|
bins_hz.dtype)
|
||||||
|
mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
|
||||||
|
|
||||||
|
mag = paddle.where(mask, paddle.full_like(mag, val), mag)
|
||||||
|
phase = paddle.where(mask, paddle.full_like(phase, val), phase)
|
||||||
|
self.stft_data = mag * util.exp_compat(1j * phase)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def mask_timesteps(
|
||||||
|
self,
|
||||||
|
tmin_s: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
tmax_s: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
val: float=0.0, ):
|
||||||
|
"""Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
|
||||||
|
with the value specified by ``val``. Useful for implementing SpecAug.
|
||||||
|
The min and max can be different for every item in the batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tmin_s : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Lower end of timesteps to mask out.
|
||||||
|
tmax_s : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Upper end of timesteps to mask out.
|
||||||
|
val : float, optional
|
||||||
|
Value to fill in, by default 0.0
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
||||||
|
masked audio data.
|
||||||
|
"""
|
||||||
|
# SpecAug
|
||||||
|
mag, phase = self.magnitude, self.phase
|
||||||
|
tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
|
||||||
|
tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
|
||||||
|
|
||||||
|
assert paddle.all(tmin_s < tmax_s)
|
||||||
|
|
||||||
|
# build mask
|
||||||
|
nt = mag.shape[-1]
|
||||||
|
bins_t = paddle.linspace(
|
||||||
|
0,
|
||||||
|
self.signal_duration,
|
||||||
|
nt, )
|
||||||
|
bins_t = bins_t[None, None, None, :].tile(
|
||||||
|
[self.batch_size, 1, mag.shape[-2], 1])
|
||||||
|
mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
|
||||||
|
|
||||||
|
# mag = mag.masked_fill(mask, val)
|
||||||
|
# phase = phase.masked_fill(mask, val)
|
||||||
|
mag = paddle.where(mask, paddle.full_like(mag, val), mag)
|
||||||
|
phase = paddle.where(mask, paddle.full_like(phase, val), phase)
|
||||||
|
|
||||||
|
self.stft_data = mag * util.exp_compat(1j * phase)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def mask_low_magnitudes(
|
||||||
|
self,
|
||||||
|
db_cutoff: typing.Union[paddle.Tensor, np.ndarray, float],
|
||||||
|
val: float=0.0):
|
||||||
|
"""Mask away magnitudes below a specified threshold, which
|
||||||
|
can be different for every item in the batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
db_cutoff : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Decibel value for which things below it will be masked away.
|
||||||
|
val : float, optional
|
||||||
|
Value to fill in for masked portions, by default 0.0
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
||||||
|
masked audio data.
|
||||||
|
"""
|
||||||
|
mag = self.magnitude
|
||||||
|
log_mag = self.log_magnitude()
|
||||||
|
|
||||||
|
db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
|
||||||
|
db_cutoff = db_cutoff.astype(log_mag.dtype)
|
||||||
|
mask = log_mag < db_cutoff
|
||||||
|
# mag = mag.masked_fill(mask, val)
|
||||||
|
mag = paddle.where(mask, mag, val * paddle.ones_like(mag))
|
||||||
|
|
||||||
|
self.magnitude = mag
|
||||||
|
return self
|
||||||
|
|
||||||
|
def shift_phase(self,
|
||||||
|
shift: typing.Union[paddle.Tensor, np.ndarray, float]):
|
||||||
|
"""Shifts the phase by a constant value.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
shift : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
What to shift the phase by.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
||||||
|
masked audio data.
|
||||||
|
"""
|
||||||
|
shift = util.ensure_tensor(shift, ndim=self.phase.ndim)
|
||||||
|
shift = shift.astype(self.phase.dtype)
|
||||||
|
self.phase = self.phase + shift
|
||||||
|
return self
|
||||||
|
|
||||||
|
def corrupt_phase(self,
|
||||||
|
scale: typing.Union[paddle.Tensor, np.ndarray, float]):
|
||||||
|
"""Corrupts the phase randomly by some scaled value.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
scale : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Standard deviation of noise to add to the phase.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
|
||||||
|
masked audio data.
|
||||||
|
"""
|
||||||
|
scale = util.ensure_tensor(scale, ndim=self.phase.ndim)
|
||||||
|
self.phase = self.phase + scale * paddle.randn(
|
||||||
|
shape=self.phase.shape, dtype=self.phase.dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def preemphasis(self, coef: float=0.85):
|
||||||
|
"""Applies pre-emphasis to audio signal.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
coef : float, optional
|
||||||
|
How much pre-emphasis to apply, lower values do less. 0 does nothing.
|
||||||
|
by default 0.85
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Pre-emphasized signal.
|
||||||
|
"""
|
||||||
|
kernel = paddle.to_tensor([1, -coef, 0]).reshape([1, 1, -1])
|
||||||
|
x = self.audio_data.reshape([-1, 1, self.signal_length])
|
||||||
|
x = paddle.nn.functional.conv1d(
|
||||||
|
x.astype(kernel.dtype), kernel, padding=1)
|
||||||
|
self.audio_data = x.reshape(self.audio_data.shape)
|
||||||
|
return self
|
@ -0,0 +1,539 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/effects.py)
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from . import util
|
||||||
|
from ._julius import SplitBands
|
||||||
|
|
||||||
|
|
||||||
|
class EffectMixin:
|
||||||
|
GAIN_FACTOR = np.log(10) / 20
|
||||||
|
"""Gain factor for converting between amplitude and decibels."""
|
||||||
|
CODEC_PRESETS = {
|
||||||
|
"8-bit": {
|
||||||
|
"format": "wav",
|
||||||
|
"encoding": "ULAW",
|
||||||
|
"bits_per_sample": 8
|
||||||
|
},
|
||||||
|
"GSM-FR": {
|
||||||
|
"format": "gsm"
|
||||||
|
},
|
||||||
|
"MP3": {
|
||||||
|
"format": "mp3",
|
||||||
|
"compression": -9
|
||||||
|
},
|
||||||
|
"Vorbis": {
|
||||||
|
"format": "vorbis",
|
||||||
|
"compression": -1
|
||||||
|
},
|
||||||
|
"Ogg": {
|
||||||
|
"format": "ogg",
|
||||||
|
"compression": -1,
|
||||||
|
},
|
||||||
|
"Amr-nb": {
|
||||||
|
"format": "amr-nb"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
"""Presets for applying codecs via torchaudio."""
|
||||||
|
|
||||||
|
def mix(
|
||||||
|
self,
|
||||||
|
other,
|
||||||
|
snr: typing.Union[paddle.Tensor, np.ndarray, float]=10,
|
||||||
|
other_eq: typing.Union[paddle.Tensor, np.ndarray]=None, ):
|
||||||
|
"""Mixes noise with signal at specified
|
||||||
|
signal-to-noise ratio. Optionally, the
|
||||||
|
other signal can be equalized in-place.
|
||||||
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
other : AudioSignal
|
||||||
|
AudioSignal object to mix with.
|
||||||
|
snr : typing.Union[paddle.Tensor, np.ndarray, float], optional
|
||||||
|
Signal to noise ratio, by default 10
|
||||||
|
other_eq : typing.Union[paddle.Tensor, np.ndarray], optional
|
||||||
|
EQ curve to apply to other signal, if any, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
In-place modification of AudioSignal.
|
||||||
|
"""
|
||||||
|
snr = util.ensure_tensor(snr)
|
||||||
|
|
||||||
|
pad_len = max(0, self.signal_length - other.signal_length)
|
||||||
|
other.zero_pad(0, pad_len)
|
||||||
|
other.truncate_samples(self.signal_length)
|
||||||
|
if other_eq is not None:
|
||||||
|
other = other.equalizer(other_eq)
|
||||||
|
|
||||||
|
tgt_loudness = self.loudness() - snr
|
||||||
|
other = other.normalize(tgt_loudness)
|
||||||
|
|
||||||
|
self.audio_data = self.audio_data + other.audio_data
|
||||||
|
return self
|
||||||
|
|
||||||
|
def convolve(self, other, start_at_max: bool=True):
|
||||||
|
"""Convolves self with other.
|
||||||
|
This function uses FFTs to do the convolution.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
other : AudioSignal
|
||||||
|
Signal to convolve with.
|
||||||
|
start_at_max : bool, optional
|
||||||
|
Whether to start at the max value of other signal, to
|
||||||
|
avoid inducing delays, by default True
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Convolved signal, in-place.
|
||||||
|
"""
|
||||||
|
from . import AudioSignal
|
||||||
|
|
||||||
|
pad_len = self.signal_length - other.signal_length
|
||||||
|
|
||||||
|
if pad_len > 0:
|
||||||
|
other.zero_pad(0, pad_len)
|
||||||
|
else:
|
||||||
|
other.truncate_samples(self.signal_length)
|
||||||
|
|
||||||
|
if start_at_max:
|
||||||
|
# Use roll to rotate over the max for every item
|
||||||
|
# so that the impulse responses don't induce any
|
||||||
|
# delay.
|
||||||
|
idx = paddle.argmax(paddle.abs(other.audio_data), axis=-1)
|
||||||
|
irs = paddle.zeros_like(other.audio_data)
|
||||||
|
for i in range(other.batch_size):
|
||||||
|
irs[i] = paddle.roll(
|
||||||
|
other.audio_data[i], shifts=-idx[i].item(), axis=-1)
|
||||||
|
other = AudioSignal(irs, other.sample_rate)
|
||||||
|
|
||||||
|
delta = paddle.zeros_like(other.audio_data)
|
||||||
|
delta[..., 0] = 1
|
||||||
|
|
||||||
|
length = self.signal_length
|
||||||
|
delta_fft = paddle.fft.rfft(delta, n=length)
|
||||||
|
other_fft = paddle.fft.rfft(other.audio_data, n=length)
|
||||||
|
self_fft = paddle.fft.rfft(self.audio_data, n=length)
|
||||||
|
|
||||||
|
convolved_fft = other_fft * self_fft
|
||||||
|
convolved_audio = paddle.fft.irfft(convolved_fft, n=length)
|
||||||
|
|
||||||
|
delta_convolved_fft = other_fft * delta_fft
|
||||||
|
delta_audio = paddle.fft.irfft(delta_convolved_fft, n=length)
|
||||||
|
|
||||||
|
# Use the delta to rescale the audio exactly as needed.
|
||||||
|
delta_max = paddle.max(paddle.abs(delta_audio), axis=-1, keepdim=True)
|
||||||
|
scale = 1 / paddle.clip(delta_max, min=1e-5)
|
||||||
|
convolved_audio = convolved_audio * scale
|
||||||
|
|
||||||
|
self.audio_data = convolved_audio
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def apply_ir(
|
||||||
|
self,
|
||||||
|
ir,
|
||||||
|
drr: typing.Union[paddle.Tensor, np.ndarray, float]=None,
|
||||||
|
ir_eq: typing.Union[paddle.Tensor, np.ndarray]=None,
|
||||||
|
use_original_phase: bool=False, ):
|
||||||
|
"""Applies an impulse response to the signal. If ` is`ir_eq``
|
||||||
|
is specified, the impulse response is equalized before
|
||||||
|
it is applied, using the given curve.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
ir : AudioSignal
|
||||||
|
Impulse response to convolve with.
|
||||||
|
drr : typing.Union[paddle.Tensor, np.ndarray, float], optional
|
||||||
|
Direct-to-reverberant ratio that impulse response will be
|
||||||
|
altered to, if specified, by default None
|
||||||
|
ir_eq : typing.Union[paddle.Tensor, np.ndarray], optional
|
||||||
|
Equalization that will be applied to impulse response
|
||||||
|
if specified, by default None
|
||||||
|
use_original_phase : bool, optional
|
||||||
|
Whether to use the original phase, instead of the convolved
|
||||||
|
phase, by default False
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Signal with impulse response applied to it
|
||||||
|
"""
|
||||||
|
if ir_eq is not None:
|
||||||
|
ir = ir.equalizer(ir_eq)
|
||||||
|
if drr is not None:
|
||||||
|
ir = ir.alter_drr(drr)
|
||||||
|
|
||||||
|
# Save the peak before
|
||||||
|
max_spk = self.audio_data.abs().max(axis=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Augment the impulse response to simulate microphone effects
|
||||||
|
# and with varying direct-to-reverberant ratio.
|
||||||
|
phase = self.phase
|
||||||
|
self.convolve(ir)
|
||||||
|
|
||||||
|
# Use the input phase
|
||||||
|
if use_original_phase:
|
||||||
|
self.stft()
|
||||||
|
self.stft_data = self.magnitude * util.exp_compat(1j * phase)
|
||||||
|
self.istft()
|
||||||
|
|
||||||
|
# Rescale to the input's amplitude
|
||||||
|
max_transformed = self.audio_data.abs().max(axis=-1, keepdim=True)
|
||||||
|
scale_factor = max_spk.clip(1e-8) / max_transformed.clip(1e-8)
|
||||||
|
self = self * scale_factor
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def ensure_max_of_audio(self, _max: float=1.0):
|
||||||
|
"""Ensures that ``abs(audio_data) <= max``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max : float, optional
|
||||||
|
Max absolute value of signal, by default 1.0
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Signal with values scaled between -max and max.
|
||||||
|
"""
|
||||||
|
peak = self.audio_data.abs().max(axis=-1, keepdim=True)
|
||||||
|
peak_gain = paddle.ones_like(peak)
|
||||||
|
# peak_gain[peak > _max] = _max / peak[peak > _max]
|
||||||
|
peak_gain = paddle.where(peak > _max, _max / peak, peak_gain)
|
||||||
|
self.audio_data = self.audio_data * peak_gain
|
||||||
|
return self
|
||||||
|
|
||||||
|
def normalize(self,
|
||||||
|
db: typing.Union[paddle.Tensor, np.ndarray, float]=-24.0):
|
||||||
|
"""Normalizes the signal's volume to the specified db, in LUFS.
|
||||||
|
This is GPU-compatible, making for very fast loudness normalization.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
db : typing.Union[paddle.Tensor, np.ndarray, float], optional
|
||||||
|
Loudness to normalize to, by default -24.0
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Normalized audio signal.
|
||||||
|
"""
|
||||||
|
db = util.ensure_tensor(db)
|
||||||
|
ref_db = self.loudness()
|
||||||
|
gain = db.astype(ref_db.dtype) - ref_db
|
||||||
|
gain = util.exp_compat(gain * self.GAIN_FACTOR)
|
||||||
|
|
||||||
|
self.audio_data = self.audio_data * gain[:, None, None]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def volume_change(self, db: typing.Union[paddle.Tensor, np.ndarray, float]):
|
||||||
|
"""Change volume of signal by some amount, in dB.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
db : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Amount to change volume by.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Signal at new volume.
|
||||||
|
"""
|
||||||
|
db = util.ensure_tensor(db, ndim=1)
|
||||||
|
gain = util.exp_compat(db * self.GAIN_FACTOR)
|
||||||
|
self.audio_data = self.audio_data * gain[:, None, None]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def mel_filterbank(self, n_bands: int):
|
||||||
|
"""Breaks signal into mel bands.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_bands : int
|
||||||
|
Number of mel bands to use.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Mel-filtered bands, with last axis being the band index.
|
||||||
|
"""
|
||||||
|
filterbank = SplitBands(self.sample_rate, n_bands)
|
||||||
|
filtered = filterbank(self.audio_data)
|
||||||
|
return filtered.transpose([1, 2, 3, 0])
|
||||||
|
|
||||||
|
def equalizer(self, db: typing.Union[paddle.Tensor, np.ndarray]):
|
||||||
|
"""Applies a mel-spaced equalizer to the audio signal.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
db : typing.Union[paddle.Tensor, np.ndarray]
|
||||||
|
EQ curve to apply.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
AudioSignal with equalization applied.
|
||||||
|
"""
|
||||||
|
db = util.ensure_tensor(db)
|
||||||
|
n_bands = db.shape[-1]
|
||||||
|
fbank = self.mel_filterbank(n_bands)
|
||||||
|
|
||||||
|
# If there's a batch dimension, make sure it's the same.
|
||||||
|
if db.ndim == 2:
|
||||||
|
if db.shape[0] != 1:
|
||||||
|
assert db.shape[0] == fbank.shape[0]
|
||||||
|
else:
|
||||||
|
db = db.unsqueeze(0)
|
||||||
|
|
||||||
|
weights = (10**db).astype("float32")
|
||||||
|
fbank = fbank * weights[:, None, None, :]
|
||||||
|
eq_audio_data = fbank.sum(-1)
|
||||||
|
self.audio_data = eq_audio_data
|
||||||
|
return self
|
||||||
|
|
||||||
|
def clip_distortion(
|
||||||
|
self,
|
||||||
|
clip_percentile: typing.Union[paddle.Tensor, np.ndarray, float]):
|
||||||
|
"""Clips the signal at a given percentile. The higher it is,
|
||||||
|
the lower the threshold for clipping.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clip_percentile : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Values are between 0.0 to 1.0. Typical values are 0.1 or below.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Audio signal with clipped audio data.
|
||||||
|
"""
|
||||||
|
clip_percentile = util.ensure_tensor(clip_percentile, ndim=1)
|
||||||
|
clip_percentile = clip_percentile.cpu().numpy()
|
||||||
|
min_thresh = paddle.quantile(
|
||||||
|
self.audio_data, (clip_percentile / 2).tolist(), axis=-1)[None]
|
||||||
|
max_thresh = paddle.quantile(
|
||||||
|
self.audio_data, (1 - clip_percentile / 2).tolist(), axis=-1)[None]
|
||||||
|
|
||||||
|
nc = self.audio_data.shape[1]
|
||||||
|
min_thresh = min_thresh[:, :nc, :]
|
||||||
|
max_thresh = max_thresh[:, :nc, :]
|
||||||
|
|
||||||
|
self.audio_data = self.audio_data.clip(min_thresh, max_thresh)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def quantization(self,
|
||||||
|
quantization_channels: typing.Union[paddle.Tensor,
|
||||||
|
np.ndarray, int]):
|
||||||
|
"""Applies quantization to the input waveform.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int]
|
||||||
|
Number of evenly spaced quantization channels to quantize
|
||||||
|
to.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Quantized AudioSignal.
|
||||||
|
"""
|
||||||
|
quantization_channels = util.ensure_tensor(
|
||||||
|
quantization_channels, ndim=3)
|
||||||
|
|
||||||
|
x = self.audio_data
|
||||||
|
quantization_channels = quantization_channels.astype(x.dtype)
|
||||||
|
x = (x + 1) / 2
|
||||||
|
x = x * quantization_channels
|
||||||
|
x = x.floor()
|
||||||
|
x = x / quantization_channels
|
||||||
|
x = 2 * x - 1
|
||||||
|
|
||||||
|
residual = (self.audio_data - x).detach()
|
||||||
|
self.audio_data = self.audio_data - residual
|
||||||
|
return self
|
||||||
|
|
||||||
|
def mulaw_quantization(self,
|
||||||
|
quantization_channels: typing.Union[
|
||||||
|
paddle.Tensor, np.ndarray, int]):
|
||||||
|
"""Applies mu-law quantization to the input waveform.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int]
|
||||||
|
Number of mu-law spaced quantization channels to quantize
|
||||||
|
to.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Quantized AudioSignal.
|
||||||
|
"""
|
||||||
|
mu = quantization_channels - 1.0
|
||||||
|
mu = util.ensure_tensor(mu, ndim=3)
|
||||||
|
|
||||||
|
x = self.audio_data
|
||||||
|
|
||||||
|
# quantize
|
||||||
|
x = paddle.sign(x) * paddle.log1p(mu * paddle.abs(x)) / paddle.log1p(mu)
|
||||||
|
x = ((x + 1) / 2 * mu + 0.5).astype("int64")
|
||||||
|
|
||||||
|
# unquantize
|
||||||
|
x = (x.astype(mu.dtype) / mu) * 2 - 1.0
|
||||||
|
x = paddle.sign(x) * (
|
||||||
|
util.exp_compat(paddle.abs(x) * paddle.log1p(mu)) - 1.0) / mu
|
||||||
|
|
||||||
|
residual = (self.audio_data - x).detach()
|
||||||
|
self.audio_data = self.audio_data - residual
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __matmul__(self, other):
|
||||||
|
return self.convolve(other)
|
||||||
|
|
||||||
|
|
||||||
|
class ImpulseResponseMixin:
|
||||||
|
"""These functions are generally only used with AudioSignals that are derived
|
||||||
|
from impulse responses, not other sources like music or speech. These methods
|
||||||
|
are used to replicate the data augmentation described in [1].
|
||||||
|
|
||||||
|
1. Bryan, Nicholas J. "Impulse response data augmentation and deep
|
||||||
|
neural networks for blind room acoustic parameter estimation."
|
||||||
|
ICASSP 2020-2020 IEEE International Conference on Acoustics,
|
||||||
|
Speech and Signal Processing (ICASSP). IEEE, 2020.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decompose_ir(self):
|
||||||
|
"""Decomposes an impulse response into early and late
|
||||||
|
field responses.
|
||||||
|
"""
|
||||||
|
# Equations 1 and 2
|
||||||
|
# -----------------
|
||||||
|
# Breaking up into early
|
||||||
|
# response + late field response.
|
||||||
|
|
||||||
|
td = paddle.argmax(self.audio_data, axis=-1, keepdim=True)
|
||||||
|
t0 = int(self.sample_rate * 0.0025)
|
||||||
|
|
||||||
|
idx = paddle.arange(self.audio_data.shape[-1])[None, None, :]
|
||||||
|
idx = idx.expand([self.batch_size, -1, -1])
|
||||||
|
early_idx = (idx >= td - t0) * (idx <= td + t0)
|
||||||
|
|
||||||
|
early_response = paddle.zeros_like(self.audio_data)
|
||||||
|
|
||||||
|
# early_response[early_idx] = self.audio_data[early_idx]
|
||||||
|
early_response = paddle.where(early_idx, self.audio_data,
|
||||||
|
early_response)
|
||||||
|
|
||||||
|
late_idx = ~early_idx
|
||||||
|
late_field = paddle.zeros_like(self.audio_data)
|
||||||
|
# late_field[late_idx] = self.audio_data[late_idx]
|
||||||
|
late_field = paddle.where(late_idx, self.audio_data, late_field)
|
||||||
|
|
||||||
|
# Equation 4
|
||||||
|
# ----------
|
||||||
|
# Decompose early response into windowed
|
||||||
|
# direct path and windowed residual.
|
||||||
|
|
||||||
|
window = paddle.zeros_like(self.audio_data)
|
||||||
|
window_idx = paddle.nonzero(early_idx)
|
||||||
|
for idx in range(self.batch_size):
|
||||||
|
# window_idx = early_idx[idx, 0]
|
||||||
|
|
||||||
|
# ----- Just for this -----
|
||||||
|
# window[idx, ..., window_idx] = self.get_window("hann", window_idx.sum().item())
|
||||||
|
# indices = paddle.nonzero(window_idx).reshape(
|
||||||
|
# [-1]) # shape: [num_true], dtype: int64
|
||||||
|
indices = window_idx[window_idx[:, 0] == idx][:, -1]
|
||||||
|
|
||||||
|
temp_window = self.get_window("hann", indices.shape[0])
|
||||||
|
|
||||||
|
window_slice = window[idx, 0]
|
||||||
|
updated_window_slice = paddle.scatter(
|
||||||
|
window_slice, index=indices, updates=temp_window)
|
||||||
|
|
||||||
|
window[idx, 0] = updated_window_slice
|
||||||
|
# ----- Just for that -----
|
||||||
|
|
||||||
|
return early_response, late_field, window
|
||||||
|
|
||||||
|
def measure_drr(self):
|
||||||
|
"""Measures the direct-to-reverberant ratio of the impulse
|
||||||
|
response.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
Direct-to-reverberant ratio
|
||||||
|
"""
|
||||||
|
early_response, late_field, _ = self.decompose_ir()
|
||||||
|
num = (early_response**2).sum(axis=-1)
|
||||||
|
den = (late_field**2).sum(axis=-1)
|
||||||
|
drr = 10 * paddle.log10(num / den)
|
||||||
|
return drr
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def solve_alpha(early_response, late_field, wd, target_drr):
|
||||||
|
"""Used to solve for the alpha value, which is used
|
||||||
|
to alter the drr.
|
||||||
|
"""
|
||||||
|
# Equation 5
|
||||||
|
# ----------
|
||||||
|
# Apply the good ol' quadratic formula.
|
||||||
|
|
||||||
|
wd_sq = wd**2
|
||||||
|
wd_sq_1 = (1 - wd)**2
|
||||||
|
e_sq = early_response**2
|
||||||
|
l_sq = late_field**2
|
||||||
|
a = (wd_sq * e_sq).sum(axis=-1)
|
||||||
|
b = (2 * (1 - wd) * wd * e_sq).sum(axis=-1)
|
||||||
|
c = (wd_sq_1 * e_sq).sum(axis=-1) - paddle.pow(10 * paddle.ones_like(
|
||||||
|
target_drr, dtype="float32"), target_drr.cast("float32") /
|
||||||
|
10) * l_sq.sum(axis=-1)
|
||||||
|
|
||||||
|
expr = ((b**2) - 4 * a * c).sqrt()
|
||||||
|
alpha = paddle.maximum(
|
||||||
|
(-b - expr) / (2 * a),
|
||||||
|
(-b + expr) / (2 * a), )
|
||||||
|
return alpha
|
||||||
|
|
||||||
|
def alter_drr(self, drr: typing.Union[paddle.Tensor, np.ndarray, float]):
|
||||||
|
"""Alters the direct-to-reverberant ratio of the impulse response.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
drr : typing.Union[paddle.Tensor, np.ndarray, float]
|
||||||
|
Direct-to-reverberant ratio that impulse response will be
|
||||||
|
altered to, if specified, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AudioSignal
|
||||||
|
Altered impulse response.
|
||||||
|
"""
|
||||||
|
drr = util.ensure_tensor(
|
||||||
|
drr, 2, self.batch_size
|
||||||
|
) # Assuming util.ensure_tensor is adapted or equivalent exists
|
||||||
|
|
||||||
|
early_response, late_field, window = self.decompose_ir()
|
||||||
|
alpha = self.solve_alpha(early_response, late_field, window, drr)
|
||||||
|
min_alpha = late_field.abs().max(axis=-1)[0] / early_response.abs().max(
|
||||||
|
axis=-1)[0]
|
||||||
|
alpha = paddle.maximum(alpha, min_alpha)[..., None]
|
||||||
|
|
||||||
|
aug_ir_data = alpha * window * early_response + (
|
||||||
|
(1 - window) * early_response) + late_field
|
||||||
|
self.audio_data = aug_ir_data
|
||||||
|
self.ensure_max_of_audio(
|
||||||
|
) # Assuming ensure_max_of_audio is a method defined elsewhere
|
||||||
|
return self
|
@ -0,0 +1,119 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/ffmpeg.py)
|
||||||
|
import json
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import ffmpy
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
|
||||||
|
def r128stats(filepath: str, quiet: bool):
|
||||||
|
"""Takes a path to an audio file, returns a dict with the loudness
|
||||||
|
stats computed by the ffmpeg ebur128 filter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filepath : str
|
||||||
|
Path to compute loudness stats on.
|
||||||
|
quiet : bool
|
||||||
|
Whether to show FFMPEG output during computation.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
Dictionary containing loudness stats.
|
||||||
|
"""
|
||||||
|
ffargs = [
|
||||||
|
"ffmpeg",
|
||||||
|
"-nostats",
|
||||||
|
"-i",
|
||||||
|
filepath,
|
||||||
|
"-filter_complex",
|
||||||
|
"ebur128",
|
||||||
|
"-f",
|
||||||
|
"null",
|
||||||
|
"-",
|
||||||
|
]
|
||||||
|
if quiet:
|
||||||
|
ffargs += ["-hide_banner"]
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
ffargs, stderr=subprocess.PIPE, universal_newlines=True)
|
||||||
|
stats = proc.communicate()[1]
|
||||||
|
summary_index = stats.rfind("Summary:")
|
||||||
|
|
||||||
|
summary_list = stats[summary_index:].split()
|
||||||
|
i_lufs = float(summary_list[summary_list.index("I:") + 1])
|
||||||
|
i_thresh = float(summary_list[summary_list.index("I:") + 4])
|
||||||
|
lra = float(summary_list[summary_list.index("LRA:") + 1])
|
||||||
|
lra_thresh = float(summary_list[summary_list.index("LRA:") + 4])
|
||||||
|
lra_low = float(summary_list[summary_list.index("low:") + 1])
|
||||||
|
lra_high = float(summary_list[summary_list.index("high:") + 1])
|
||||||
|
stats_dict = {
|
||||||
|
"I": i_lufs,
|
||||||
|
"I Threshold": i_thresh,
|
||||||
|
"LRA": lra,
|
||||||
|
"LRA Threshold": lra_thresh,
|
||||||
|
"LRA Low": lra_low,
|
||||||
|
"LRA High": lra_high,
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats_dict
|
||||||
|
|
||||||
|
|
||||||
|
def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]:
|
||||||
|
"""Given a path to a file, returns the start time offset and codec of
|
||||||
|
the first audio stream.
|
||||||
|
"""
|
||||||
|
ff = ffmpy.FFprobe(
|
||||||
|
inputs={path: None},
|
||||||
|
global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet",
|
||||||
|
)
|
||||||
|
streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"]
|
||||||
|
seconds_offset = 0.0
|
||||||
|
codec = None
|
||||||
|
|
||||||
|
# Get the offset and codec of the first audio stream we find
|
||||||
|
# and return its start time, if it has one.
|
||||||
|
for stream in streams:
|
||||||
|
if stream["codec_type"] == "audio":
|
||||||
|
seconds_offset = stream.get("start_time", 0.0)
|
||||||
|
codec = stream.get("codec_name")
|
||||||
|
break
|
||||||
|
return float(seconds_offset), codec
|
||||||
|
|
||||||
|
|
||||||
|
class FFMPEGMixin:
|
||||||
|
_loudness = None
|
||||||
|
|
||||||
|
def ffmpeg_loudness(self, quiet: bool=True):
|
||||||
|
"""Computes loudness of audio file using FFMPEG.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
quiet : bool, optional
|
||||||
|
Whether to show FFMPEG output during computation,
|
||||||
|
by default True
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Loudness of every item in the batch, computed via
|
||||||
|
FFMPEG.
|
||||||
|
"""
|
||||||
|
loudness = []
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
||||||
|
for i in range(self.batch_size):
|
||||||
|
self[i].write(f.name)
|
||||||
|
loudness_stats = r128stats(f.name, quiet=quiet)
|
||||||
|
loudness.append(loudness_stats["I"])
|
||||||
|
|
||||||
|
self._loudness = paddle.to_tensor(np.array(loudness)).astype("float32")
|
||||||
|
return self.loudness()
|
@ -0,0 +1,387 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/loudness.py)
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
import scipy
|
||||||
|
|
||||||
|
from . import _julius
|
||||||
|
|
||||||
|
|
||||||
|
def _unfold1d(x, kernel_size, stride):
|
||||||
|
# https://github.com/PaddlePaddle/Paddle/pull/70102
|
||||||
|
"""1D only unfolding similar to the one from Paddlepaddle.
|
||||||
|
|
||||||
|
Given an _input tensor of size `[*, T]` this will return
|
||||||
|
a tensor `[*, F, K]` with `K` the kernel size, and `F` the number
|
||||||
|
of frames. The i-th frame is a view onto `i * stride: i * stride + kernel_size`.
|
||||||
|
This will automatically pad the _input to cover at least once all entries in `_input`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_input (Tensor): tensor for which to return the frames.
|
||||||
|
kernel_size (int): size of each frame.
|
||||||
|
stride (int): stride between each frame.
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
|
||||||
|
- Inputs: `_input` is `[*, T]`
|
||||||
|
- Output: `[*, F, kernel_size]` with `F = 1 + ceil((T - kernel_size) / stride)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
if 3 != x.dim():
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
N, C, length = x.shape
|
||||||
|
x = x.reshape([N * C, 1, length])
|
||||||
|
|
||||||
|
n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1
|
||||||
|
tgt_length = (n_frames - 1) * stride + kernel_size
|
||||||
|
x = F.pad(x, (0, tgt_length - length), data_format="NCL")
|
||||||
|
|
||||||
|
x = x.unsqueeze(-1)
|
||||||
|
|
||||||
|
unfolded = paddle.nn.functional.unfold(
|
||||||
|
x,
|
||||||
|
kernel_sizes=[kernel_size, 1],
|
||||||
|
strides=[stride, 1], )
|
||||||
|
|
||||||
|
unfolded = unfolded.transpose([0, 2, 1])
|
||||||
|
unfolded = unfolded.reshape([N, C, *unfolded.shape[1:]])
|
||||||
|
return unfolded
|
||||||
|
|
||||||
|
|
||||||
|
class Meter(paddle.nn.Layer):
|
||||||
|
"""Tensorized version of pyloudnorm.Meter. Works with batched audio tensors.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rate : int
|
||||||
|
Sample rate of audio.
|
||||||
|
filter_class : str, optional
|
||||||
|
Class of weighting filter used.
|
||||||
|
K-weighting' (default), 'Fenton/Lee 1'
|
||||||
|
'Fenton/Lee 2', 'Dash et al.'
|
||||||
|
by default "K-weighting"
|
||||||
|
block_size : float, optional
|
||||||
|
Gating block size in seconds, by default 0.400
|
||||||
|
zeros : int, optional
|
||||||
|
Number of zeros to use in FIR approximation of
|
||||||
|
IIR filters, by default 512
|
||||||
|
use_fir : bool, optional
|
||||||
|
Whether to use FIR approximation or exact IIR formulation.
|
||||||
|
If computing on GPU, ``use_fir=True`` will be used, as its
|
||||||
|
much faster, by default False
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rate: int,
|
||||||
|
filter_class: str="K-weighting",
|
||||||
|
block_size: float=0.400,
|
||||||
|
zeros: int=512,
|
||||||
|
use_fir: bool=False, ):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.rate = rate
|
||||||
|
self.filter_class = filter_class
|
||||||
|
self.block_size = block_size
|
||||||
|
self.use_fir = use_fir
|
||||||
|
|
||||||
|
G = paddle.to_tensor(
|
||||||
|
np.array([1.0, 1.0, 1.0, 1.41, 1.41]), stop_gradient=True)
|
||||||
|
self.register_buffer("G", G)
|
||||||
|
|
||||||
|
# Compute impulse responses so that filtering is fast via
|
||||||
|
# a convolution at runtime, on GPU, unlike lfilter.
|
||||||
|
impulse = np.zeros((zeros, ))
|
||||||
|
impulse[..., 0] = 1.0
|
||||||
|
|
||||||
|
firs = np.zeros((len(self._filters), 1, zeros))
|
||||||
|
# passband_gain = torch.zeros(len(self._filters))
|
||||||
|
passband_gain = paddle.zeros([len(self._filters)], dtype="float32")
|
||||||
|
|
||||||
|
for i, (_, filter_stage) in enumerate(self._filters.items()):
|
||||||
|
firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a,
|
||||||
|
impulse)
|
||||||
|
passband_gain[i] = filter_stage.passband_gain
|
||||||
|
|
||||||
|
firs = paddle.to_tensor(
|
||||||
|
firs[..., ::-1].copy(), dtype="float32", stop_gradient=True)
|
||||||
|
|
||||||
|
self.register_buffer("firs", firs)
|
||||||
|
self.register_buffer("passband_gain", passband_gain)
|
||||||
|
|
||||||
|
def apply_filter_gpu(self, data: paddle.Tensor):
|
||||||
|
"""Performs FIR approximation of loudness computation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : paddle.Tensor
|
||||||
|
Audio data of shape (nb, nch, nt).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Filtered audio data.
|
||||||
|
"""
|
||||||
|
# Data is of shape (nb, nch, nt)
|
||||||
|
# Reshape to (nb*nch, 1, nt)
|
||||||
|
nb, nt, nch = data.shape
|
||||||
|
data = data.transpose([0, 2, 1])
|
||||||
|
data = data.reshape([nb * nch, 1, nt])
|
||||||
|
|
||||||
|
# Apply padding
|
||||||
|
pad_length = self.firs.shape[-1]
|
||||||
|
|
||||||
|
# Apply filtering in sequence
|
||||||
|
for i in range(self.firs.shape[0]):
|
||||||
|
data = F.pad(data, (pad_length, pad_length), data_format="NCL")
|
||||||
|
data = _julius.fft_conv1d(data, self.firs[i, None, ...])
|
||||||
|
data = self.passband_gain[i] * data
|
||||||
|
data = data[..., 1:nt + 1]
|
||||||
|
|
||||||
|
data = data.transpose([0, 2, 1])
|
||||||
|
data = data[:, :nt, :]
|
||||||
|
return data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def scipy_lfilter(waveform, a_coeffs, b_coeffs, clamp: bool=True):
|
||||||
|
# 使用 scipy.signal.lfilter 进行滤波(处理三维数据)
|
||||||
|
output = np.zeros_like(waveform)
|
||||||
|
for batch_idx in range(waveform.shape[0]):
|
||||||
|
for channel_idx in range(waveform.shape[2]):
|
||||||
|
output[batch_idx, :, channel_idx] = scipy.signal.lfilter(
|
||||||
|
b_coeffs, a_coeffs, waveform[batch_idx, :, channel_idx])
|
||||||
|
return output
|
||||||
|
|
||||||
|
def apply_filter_cpu(self, data: paddle.Tensor):
|
||||||
|
"""Performs IIR formulation of loudness computation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : paddle.Tensor
|
||||||
|
Audio data of shape (nb, nch, nt).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Filtered audio data.
|
||||||
|
"""
|
||||||
|
_data = data.cpu().numpy().copy()
|
||||||
|
for _, filter_stage in self._filters.items():
|
||||||
|
passband_gain = filter_stage.passband_gain
|
||||||
|
|
||||||
|
a_coeffs = filter_stage.a
|
||||||
|
b_coeffs = filter_stage.b
|
||||||
|
|
||||||
|
filtered = self.scipy_lfilter(_data, a_coeffs, b_coeffs)
|
||||||
|
_data[:] = passband_gain * filtered
|
||||||
|
data = paddle.to_tensor(_data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def apply_filter(self, data: paddle.Tensor):
|
||||||
|
"""Applies filter on either CPU or GPU, depending
|
||||||
|
on if the audio is on GPU or is on CPU, or if
|
||||||
|
``self.use_fir`` is True.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : paddle.Tensor
|
||||||
|
Audio data of shape (nb, nch, nt).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Filtered audio data.
|
||||||
|
"""
|
||||||
|
# if data.place.is_gpu_place() or self.use_fir:
|
||||||
|
# data = self.apply_filter_gpu(data)
|
||||||
|
# else:
|
||||||
|
# data = self.apply_filter_cpu(data)
|
||||||
|
data = self.apply_filter_cpu(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def forward(self, data: paddle.Tensor):
|
||||||
|
"""Computes integrated loudness of data.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : paddle.Tensor
|
||||||
|
Audio data of shape (nb, nch, nt).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Filtered audio data.
|
||||||
|
"""
|
||||||
|
return self.integrated_loudness(data)
|
||||||
|
|
||||||
|
def _unfold(self, input_data):
|
||||||
|
T_g = self.block_size
|
||||||
|
overlap = 0.75 # overlap of 75% of the block duration
|
||||||
|
step = 1.0 - overlap # step size by percentage
|
||||||
|
|
||||||
|
kernel_size = int(T_g * self.rate)
|
||||||
|
stride = int(T_g * self.rate * step)
|
||||||
|
unfolded = _unfold1d(
|
||||||
|
input_data.transpose([0, 2, 1]), kernel_size, stride)
|
||||||
|
unfolded = unfolded.transpose([0, 1, 3, 2])
|
||||||
|
|
||||||
|
return unfolded
|
||||||
|
|
||||||
|
def integrated_loudness(self, data: paddle.Tensor):
|
||||||
|
"""Computes integrated loudness of data.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : paddle.Tensor
|
||||||
|
Audio data of shape (nb, nch, nt).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Filtered audio data.
|
||||||
|
"""
|
||||||
|
if not paddle.is_tensor(data):
|
||||||
|
data = paddle.to_tensor(data, dtype="float32")
|
||||||
|
else:
|
||||||
|
data = data.astype("float32")
|
||||||
|
|
||||||
|
input_data = data.clone()
|
||||||
|
# Data always has a batch and channel dimension.
|
||||||
|
# Is of shape (nb, nt, nch)
|
||||||
|
if input_data.ndim < 2:
|
||||||
|
input_data = input_data.unsqueeze(-1)
|
||||||
|
if input_data.ndim < 3:
|
||||||
|
input_data = input_data.unsqueeze(0)
|
||||||
|
|
||||||
|
nb, nt, nch = input_data.shape
|
||||||
|
|
||||||
|
# Apply frequency weighting filters - account
|
||||||
|
# for the acoustic respose of the head and auditory system
|
||||||
|
input_data = self.apply_filter(input_data)
|
||||||
|
|
||||||
|
G = self.G # channel gains
|
||||||
|
T_g = self.block_size # 400 ms gating block standard
|
||||||
|
Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold
|
||||||
|
|
||||||
|
unfolded = self._unfold(input_data)
|
||||||
|
|
||||||
|
z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
|
||||||
|
l = -0.691 + 10.0 * paddle.log10(
|
||||||
|
(G[None, :nch, None] * z).sum(1, keepdim=True))
|
||||||
|
l = l.expand_as(z)
|
||||||
|
|
||||||
|
# find gating block indices above absolute threshold
|
||||||
|
z_avg_gated = z
|
||||||
|
z_avg_gated[l <= Gamma_a] = 0
|
||||||
|
masked = l > Gamma_a
|
||||||
|
z_avg_gated = z_avg_gated.sum(2) / masked.sum(2).astype("float32")
|
||||||
|
|
||||||
|
# calculate the relative threshold value (see eq. 6)
|
||||||
|
Gamma_r = -0.691 + 10.0 * paddle.log10(
|
||||||
|
(z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
|
||||||
|
Gamma_r = Gamma_r[:, None, None]
|
||||||
|
Gamma_r = Gamma_r.expand([nb, nch, l.shape[-1]])
|
||||||
|
|
||||||
|
# find gating block indices above relative and absolute thresholds (end of eq. 7)
|
||||||
|
z_avg_gated = z
|
||||||
|
z_avg_gated[l <= Gamma_a] = 0
|
||||||
|
z_avg_gated[l <= Gamma_r] = 0
|
||||||
|
masked = (l > Gamma_a) * (l > Gamma_r)
|
||||||
|
z_avg_gated = z_avg_gated.sum(2) / (masked.sum(2) + 10e-6)
|
||||||
|
|
||||||
|
# TODO Currently, paddle has a segmentation fault bug in this section of the code
|
||||||
|
# z_avg_gated = paddle.nan_to_num(z_avg_gated)
|
||||||
|
# z_avg_gated = paddle.where(
|
||||||
|
# paddle.isnan(z_avg_gated),
|
||||||
|
# paddle.zeros_like(z_avg_gated), z_avg_gated)
|
||||||
|
z_avg_gated[z_avg_gated == float("inf")] = float(
|
||||||
|
np.finfo(np.float32).max)
|
||||||
|
z_avg_gated[z_avg_gated == -float("inf")] = float(
|
||||||
|
np.finfo(np.float32).min)
|
||||||
|
|
||||||
|
LUFS = -0.691 + 10.0 * paddle.log10(
|
||||||
|
(G[None, :nch] * z_avg_gated).sum(1))
|
||||||
|
return LUFS.astype("float32")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filter_class(self):
|
||||||
|
return self._filter_class
|
||||||
|
|
||||||
|
@filter_class.setter
|
||||||
|
def filter_class(self, value):
|
||||||
|
from pyloudnorm import Meter
|
||||||
|
|
||||||
|
meter = Meter(self.rate)
|
||||||
|
meter.filter_class = value
|
||||||
|
self._filter_class = value
|
||||||
|
self._filters = meter._filters
|
||||||
|
|
||||||
|
|
||||||
|
class LoudnessMixin:
|
||||||
|
_loudness = None
|
||||||
|
MIN_LOUDNESS = -70
|
||||||
|
"""Minimum loudness possible."""
|
||||||
|
|
||||||
|
def loudness(self,
|
||||||
|
filter_class: str="K-weighting",
|
||||||
|
block_size: float=0.400,
|
||||||
|
**kwargs):
|
||||||
|
"""Calculates loudness using an implementation of ITU-R BS.1770-4.
|
||||||
|
Allows control over gating block size and frequency weighting filters for
|
||||||
|
additional control. Measure the integrated gated loudness of a signal.
|
||||||
|
|
||||||
|
API is derived from PyLoudnorm, but this implementation is ported to PyTorch
|
||||||
|
and is tensorized across batches. When on GPU, an FIR approximation of the IIR
|
||||||
|
filters is used to compute loudness for speed.
|
||||||
|
|
||||||
|
Uses the weighting filters and block size defined by the meter
|
||||||
|
the integrated loudness is measured based upon the gating algorithm
|
||||||
|
defined in the ITU-R BS.1770-4 specification.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filter_class : str, optional
|
||||||
|
Class of weighting filter used.
|
||||||
|
K-weighting' (default), 'Fenton/Lee 1'
|
||||||
|
'Fenton/Lee 2', 'Dash et al.'
|
||||||
|
by default "K-weighting"
|
||||||
|
block_size : float, optional
|
||||||
|
Gating block size in seconds, by default 0.400
|
||||||
|
kwargs : dict, optional
|
||||||
|
Keyword arguments to :py:func:`audiotools.core.loudness.Meter`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Loudness of audio data.
|
||||||
|
"""
|
||||||
|
if self._loudness is not None:
|
||||||
|
return self._loudness # .to(self.device)
|
||||||
|
original_length = self.signal_length
|
||||||
|
if self.signal_duration < 0.5:
|
||||||
|
pad_len = int((0.5 - self.signal_duration) * self.sample_rate)
|
||||||
|
self.zero_pad(0, pad_len)
|
||||||
|
|
||||||
|
# create BS.1770 meter
|
||||||
|
meter = Meter(
|
||||||
|
self.sample_rate,
|
||||||
|
filter_class=filter_class,
|
||||||
|
block_size=block_size,
|
||||||
|
**kwargs)
|
||||||
|
# meter = meter.to(self.device)
|
||||||
|
# measure loudness
|
||||||
|
loudness = meter.integrated_loudness(
|
||||||
|
self.audio_data.transpose([0, 2, 1]))
|
||||||
|
self.truncate_samples(original_length)
|
||||||
|
min_loudness = paddle.ones_like(loudness) * self.MIN_LOUDNESS
|
||||||
|
self._loudness = paddle.maximum(loudness, min_loudness)
|
||||||
|
|
||||||
|
return self._loudness # .to(self.device)
|
@ -0,0 +1,921 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/core/util.py)
|
||||||
|
import collections
|
||||||
|
import csv
|
||||||
|
import glob
|
||||||
|
import math
|
||||||
|
import numbers
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import typing
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Iterable
|
||||||
|
from typing import List
|
||||||
|
from typing import NamedTuple
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Type
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import ffmpeg
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import soundfile
|
||||||
|
from flatten_dict import flatten
|
||||||
|
from flatten_dict import unflatten
|
||||||
|
|
||||||
|
from .audio_signal import AudioSignal
|
||||||
|
from paddlespeech.utils import satisfy_paddle_version
|
||||||
|
from paddlespeech.vector.training.seeding import seed_everything
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"exp_compat",
|
||||||
|
"bool_index_compat",
|
||||||
|
"bool_setitem_compat",
|
||||||
|
"Info",
|
||||||
|
"info",
|
||||||
|
"ensure_tensor",
|
||||||
|
"random_state",
|
||||||
|
"find_audio",
|
||||||
|
"read_sources",
|
||||||
|
"choose_from_list_of_lists",
|
||||||
|
"chdir",
|
||||||
|
"move_to_device",
|
||||||
|
"prepare_batch",
|
||||||
|
"sample_from_dist",
|
||||||
|
"format_figure",
|
||||||
|
"default_collate",
|
||||||
|
"collate",
|
||||||
|
"hz_to_bin",
|
||||||
|
"generate_chord_dataset",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def exp_compat(x):
|
||||||
|
"""
|
||||||
|
Compute the exponential of the input tensor `x`.
|
||||||
|
|
||||||
|
This function is designed to handle compatibility issues with PaddlePaddle versions below 2.6,
|
||||||
|
which do not support the `exp` operation for complex tensors. In such cases, the computation
|
||||||
|
is offloaded to NumPy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): The input tensor for which to compute the exponential.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: The result of the exponential operation, as a PaddlePaddle tensor.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If the PaddlePaddle version is 2.6 or above, the function uses `paddle.exp` directly.
|
||||||
|
- For versions below 2.6, the tensor is first converted to a NumPy array, the exponential
|
||||||
|
is computed using `np.exp`, and the result is then converted back to a PaddlePaddle tensor.
|
||||||
|
"""
|
||||||
|
if satisfy_paddle_version("2.6"):
|
||||||
|
return paddle.exp(x)
|
||||||
|
else:
|
||||||
|
x_np = x.cpu().numpy()
|
||||||
|
return paddle.to_tensor(np.exp(x_np))
|
||||||
|
|
||||||
|
|
||||||
|
def bool_index_compat(x, mask):
|
||||||
|
"""
|
||||||
|
Perform boolean indexing on the input tensor `x` using the provided `mask`.
|
||||||
|
|
||||||
|
This function ensures compatibility with PaddlePaddle versions below 2.6, where boolean indexing
|
||||||
|
may not be fully supported. For older versions, the operation is performed using NumPy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): The input tensor to be indexed.
|
||||||
|
mask (paddle.Tensor or int): The boolean mask or integer index used for indexing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: The result of the boolean indexing operation, as a PaddlePaddle tensor.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If the PaddlePaddle version is 2.6 or above, or if `mask` is an integer, the function uses
|
||||||
|
Paddle's native indexing directly.
|
||||||
|
- For versions below 2.6, the tensor and mask are converted to NumPy arrays, the indexing
|
||||||
|
operation is performed using NumPy, and the result is converted back to a PaddlePaddle tensor.
|
||||||
|
"""
|
||||||
|
if satisfy_paddle_version("2.6") or isinstance(mask, (int, list, slice)):
|
||||||
|
return x[mask]
|
||||||
|
else:
|
||||||
|
x_np = x.cpu().numpy()[mask.cpu().numpy()]
|
||||||
|
return paddle.to_tensor(x_np)
|
||||||
|
|
||||||
|
|
||||||
|
def bool_setitem_compat(x, mask, y):
|
||||||
|
"""
|
||||||
|
Perform boolean assignment on the input tensor `x` using the provided `mask` and values `y`.
|
||||||
|
|
||||||
|
This function ensures compatibility with PaddlePaddle versions below 2.6, where boolean assignment
|
||||||
|
may not be fully supported. For older versions, the operation is performed using NumPy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): The input tensor to be modified.
|
||||||
|
mask (paddle.Tensor): The boolean mask used for assignment.
|
||||||
|
y (paddle.Tensor): The values to assign to the selected elements of `x`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: The modified tensor after the assignment operation.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- If the PaddlePaddle version is 2.6 or above, the function uses Paddle's native assignment directly.
|
||||||
|
- For versions below 2.6, the tensor, mask, and values are converted to NumPy arrays, the assignment
|
||||||
|
operation is performed using NumPy, and the result is converted back to a PaddlePaddle tensor.
|
||||||
|
"""
|
||||||
|
if satisfy_paddle_version("2.6"):
|
||||||
|
|
||||||
|
x[mask] = y
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
x_np = x.cpu().numpy()
|
||||||
|
x_np[mask.cpu().numpy()] = y.cpu().numpy()
|
||||||
|
|
||||||
|
return paddle.to_tensor(x_np)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Info:
|
||||||
|
|
||||||
|
sample_rate: float
|
||||||
|
num_frames: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self) -> float:
|
||||||
|
return self.num_frames / self.sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def info_ffmpeg(audio_path: str):
|
||||||
|
"""
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_path : str
|
||||||
|
Path to audio file.
|
||||||
|
"""
|
||||||
|
probe = ffmpeg.probe(audio_path)
|
||||||
|
audio_streams = [
|
||||||
|
stream for stream in probe['streams'] if stream['codec_type'] == 'audio'
|
||||||
|
]
|
||||||
|
if not audio_streams:
|
||||||
|
raise ValueError("No audio stream found in the file.")
|
||||||
|
audio_stream = audio_streams[0]
|
||||||
|
|
||||||
|
sample_rate = int(audio_stream['sample_rate'])
|
||||||
|
duration = float(audio_stream['duration'])
|
||||||
|
|
||||||
|
num_frames = int(duration * sample_rate)
|
||||||
|
|
||||||
|
info = Info(sample_rate=sample_rate, num_frames=num_frames)
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def info(audio_path: str):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_path : str
|
||||||
|
Path to audio file.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
info = soundfile.info(str(audio_path))
|
||||||
|
info = Info(sample_rate=info.samplerate, num_frames=info.frames)
|
||||||
|
except:
|
||||||
|
info = info_ffmpeg(str(audio_path))
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_tensor(
|
||||||
|
x: typing.Union[np.ndarray, paddle.Tensor, float, int],
|
||||||
|
ndim: int=None,
|
||||||
|
batch_size: int=None, ):
|
||||||
|
"""Ensures that the input ``x`` is a tensor of specified
|
||||||
|
dimensions and batch size.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : typing.Union[np.ndarray, paddle.Tensor, float, int]
|
||||||
|
Data that will become a tensor on its way out.
|
||||||
|
ndim : int, optional
|
||||||
|
How many dimensions should be in the output, by default None
|
||||||
|
batch_size : int, optional
|
||||||
|
The batch size of the output, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Modified version of ``x`` as a tensor.
|
||||||
|
"""
|
||||||
|
if not paddle.is_tensor(x):
|
||||||
|
x = paddle.to_tensor(x)
|
||||||
|
if ndim is not None:
|
||||||
|
assert x.ndim <= ndim
|
||||||
|
while x.ndim < ndim:
|
||||||
|
x = x.unsqueeze(-1)
|
||||||
|
if batch_size is not None:
|
||||||
|
if x.shape[0] != batch_size:
|
||||||
|
shape = list(x.shape)
|
||||||
|
shape[0] = batch_size
|
||||||
|
x = paddle.expand(x, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _get_value(other):
|
||||||
|
#
|
||||||
|
from . import AudioSignal
|
||||||
|
|
||||||
|
if isinstance(other, AudioSignal):
|
||||||
|
return other.audio_data
|
||||||
|
return other
|
||||||
|
|
||||||
|
|
||||||
|
def random_state(seed: typing.Union[int, np.random.RandomState]):
|
||||||
|
"""
|
||||||
|
Turn seed into a np.random.RandomState instance.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
seed : typing.Union[int, np.random.RandomState] or None
|
||||||
|
If seed is None, return the RandomState singleton used by np.random.
|
||||||
|
If seed is an int, return a new RandomState instance seeded with seed.
|
||||||
|
If seed is already a RandomState instance, return it.
|
||||||
|
Otherwise raise ValueError.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.random.RandomState
|
||||||
|
Random state object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If seed is not valid, an error is thrown.
|
||||||
|
"""
|
||||||
|
if seed is None or seed is np.random:
|
||||||
|
return np.random.mtrand._rand
|
||||||
|
elif isinstance(seed, (numbers.Integral, np.integer, int)):
|
||||||
|
return np.random.RandomState(seed)
|
||||||
|
elif isinstance(seed, np.random.RandomState):
|
||||||
|
return seed
|
||||||
|
else:
|
||||||
|
raise ValueError("%r cannot be used to seed a numpy.random.RandomState"
|
||||||
|
" instance" % seed)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _close_temp_files(tmpfiles: list):
|
||||||
|
"""Utility function for creating a context and closing all temporary files
|
||||||
|
once the context is exited. For correct functionality, all temporary file
|
||||||
|
handles created inside the context must be appended to the ```tmpfiles```
|
||||||
|
list.
|
||||||
|
|
||||||
|
This function is taken wholesale from Scaper.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tmpfiles : list
|
||||||
|
List of temporary file handles
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _close():
|
||||||
|
for t in tmpfiles:
|
||||||
|
try:
|
||||||
|
t.close()
|
||||||
|
os.unlink(t.name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except:
|
||||||
|
_close()
|
||||||
|
raise
|
||||||
|
_close()
|
||||||
|
|
||||||
|
|
||||||
|
AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3"]
|
||||||
|
|
||||||
|
|
||||||
|
def find_audio(folder: str, ext: List[str]=AUDIO_EXTENSIONS):
|
||||||
|
"""Finds all audio files in a directory recursively.
|
||||||
|
Returns a list.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
folder : str
|
||||||
|
Folder to look for audio files in, recursively.
|
||||||
|
ext : List[str], optional
|
||||||
|
Extensions to look for without the ., by default
|
||||||
|
``['.wav', '.flac', '.mp3', '.mp4']``.
|
||||||
|
"""
|
||||||
|
folder = Path(folder)
|
||||||
|
# Take care of case where user has passed in an audio file directly
|
||||||
|
# into one of the calling functions.
|
||||||
|
if str(folder).endswith(tuple(ext)):
|
||||||
|
# if, however, there's a glob in the path, we need to
|
||||||
|
# return the glob, not the file.
|
||||||
|
if "*" in str(folder):
|
||||||
|
return glob.glob(str(folder), recursive=("**" in str(folder)))
|
||||||
|
else:
|
||||||
|
return [folder]
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for x in ext:
|
||||||
|
files += folder.glob(f"**/*{x}")
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def read_sources(
|
||||||
|
sources: List[str],
|
||||||
|
remove_empty: bool=True,
|
||||||
|
relative_path: str="",
|
||||||
|
ext: List[str]=AUDIO_EXTENSIONS, ):
|
||||||
|
"""Reads audio sources that can either be folders
|
||||||
|
full of audio files, or CSV files that contain paths
|
||||||
|
to audio files. CSV files that adhere to the expected
|
||||||
|
format can be generated by
|
||||||
|
:py:func:`audiotools.data.preprocess.create_csv`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sources : List[str]
|
||||||
|
List of audio sources to be converted into a
|
||||||
|
list of lists of audio files.
|
||||||
|
remove_empty : bool, optional
|
||||||
|
Whether or not to remove rows with an empty "path"
|
||||||
|
from each CSV file, by default True.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
List of lists of rows of CSV files.
|
||||||
|
"""
|
||||||
|
files = []
|
||||||
|
relative_path = Path(relative_path)
|
||||||
|
for source in sources:
|
||||||
|
source = str(source)
|
||||||
|
_files = []
|
||||||
|
if source.endswith(".csv"):
|
||||||
|
with open(source, "r") as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for x in reader:
|
||||||
|
if remove_empty and x["path"] == "":
|
||||||
|
continue
|
||||||
|
if x["path"] != "":
|
||||||
|
x["path"] = str(relative_path / x["path"])
|
||||||
|
_files.append(x)
|
||||||
|
else:
|
||||||
|
for x in find_audio(source, ext=ext):
|
||||||
|
x = str(relative_path / x)
|
||||||
|
_files.append({"path": x})
|
||||||
|
files.append(sorted(_files, key=lambda x: x["path"]))
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def choose_from_list_of_lists(state: np.random.RandomState,
|
||||||
|
list_of_lists: list,
|
||||||
|
p: float=None):
|
||||||
|
"""Choose a single item from a list of lists.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
state : np.random.RandomState
|
||||||
|
Random state to use when choosing an item.
|
||||||
|
list_of_lists : list
|
||||||
|
A list of lists from which items will be drawn.
|
||||||
|
p : float, optional
|
||||||
|
Probabilities of each list, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
typing.Any
|
||||||
|
An item from the list of lists.
|
||||||
|
"""
|
||||||
|
source_idx = state.choice(list(range(len(list_of_lists))), p=p)
|
||||||
|
item_idx = state.randint(len(list_of_lists[source_idx]))
|
||||||
|
return list_of_lists[source_idx][item_idx], source_idx, item_idx
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def chdir(newdir: typing.Union[Path, str]):
|
||||||
|
"""
|
||||||
|
Context manager for switching directories to run a
|
||||||
|
function. Useful for when you want to use relative
|
||||||
|
paths to different runs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
newdir : typing.Union[Path, str]
|
||||||
|
Directory to switch to.
|
||||||
|
"""
|
||||||
|
curdir = os.getcwd()
|
||||||
|
try:
|
||||||
|
os.chdir(newdir)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
os.chdir(curdir)
|
||||||
|
|
||||||
|
|
||||||
|
def move_to_device(data, device):
|
||||||
|
if device is None or device == "":
|
||||||
|
return data
|
||||||
|
elif device == 'cpu':
|
||||||
|
return paddle.to_tensor(data, place=paddle.CPUPlace())
|
||||||
|
elif device in ('gpu', 'cuda'):
|
||||||
|
return paddle.to_tensor(data, place=paddle.CUDAPlace())
|
||||||
|
else:
|
||||||
|
device = device.replace("cuda", "gpu") if "cuda" in device else device
|
||||||
|
return data.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_batch(batch: typing.Union[dict, list, paddle.Tensor],
|
||||||
|
device: str="cpu"):
|
||||||
|
"""Moves items in a batch (typically generated by a DataLoader as a list
|
||||||
|
or a dict) to the specified device. This works even if dictionaries
|
||||||
|
are nested.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
batch : typing.Union[dict, list, paddle.Tensor]
|
||||||
|
Batch, typically generated by a dataloader, that will be moved to
|
||||||
|
the device.
|
||||||
|
device : str, optional
|
||||||
|
Device to move batch to, by default "cpu"
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
typing.Union[dict, list, paddle.Tensor]
|
||||||
|
Batch with all values moved to the specified device.
|
||||||
|
"""
|
||||||
|
device = device.replace("cuda", "gpu")
|
||||||
|
if isinstance(batch, dict):
|
||||||
|
batch = flatten(batch)
|
||||||
|
for key, val in batch.items():
|
||||||
|
try:
|
||||||
|
# batch[key] = val.to(device)
|
||||||
|
batch[key] = move_to_device(val, device)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
batch = unflatten(batch)
|
||||||
|
elif paddle.is_tensor(batch):
|
||||||
|
# batch = batch.to(device)
|
||||||
|
batch = move_to_device(batch, device)
|
||||||
|
elif isinstance(batch, list):
|
||||||
|
for i in range(len(batch)):
|
||||||
|
try:
|
||||||
|
batch[i] = batch[i].to(device)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState=None):
|
||||||
|
"""Samples from a distribution defined by a tuple. The first
|
||||||
|
item in the tuple is the distribution type, and the rest of the
|
||||||
|
items are arguments to that distribution. The distribution function
|
||||||
|
is gotten from the ``np.random.RandomState`` object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dist_tuple : tuple
|
||||||
|
Distribution tuple
|
||||||
|
state : np.random.RandomState, optional
|
||||||
|
Random state, or seed to use, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
typing.Union[float, int, str]
|
||||||
|
Draw from the distribution.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
Sample from a uniform distribution:
|
||||||
|
|
||||||
|
>>> dist_tuple = ("uniform", 0, 1)
|
||||||
|
>>> sample_from_dist(dist_tuple)
|
||||||
|
|
||||||
|
Sample from a constant distribution:
|
||||||
|
|
||||||
|
>>> dist_tuple = ("const", 0)
|
||||||
|
>>> sample_from_dist(dist_tuple)
|
||||||
|
|
||||||
|
Sample from a normal distribution:
|
||||||
|
|
||||||
|
>>> dist_tuple = ("normal", 0, 0.5)
|
||||||
|
>>> sample_from_dist(dist_tuple)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if dist_tuple[0] == "const":
|
||||||
|
return dist_tuple[1]
|
||||||
|
state = random_state(state)
|
||||||
|
dist_fn = getattr(state, dist_tuple[0])
|
||||||
|
return dist_fn(*dist_tuple[1:])
|
||||||
|
|
||||||
|
|
||||||
|
BASE_SIZE = 864
|
||||||
|
DEFAULT_FIG_SIZE = (9, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def format_figure(
|
||||||
|
fig_size: tuple=None,
|
||||||
|
title: str=None,
|
||||||
|
fig=None,
|
||||||
|
format_axes: bool=True,
|
||||||
|
format: bool=True,
|
||||||
|
font_color: str="white", ):
|
||||||
|
"""Prettifies the spectrogram and waveform plots. A title
|
||||||
|
can be inset into the top right corner, and the axes can be
|
||||||
|
inset into the figure, allowing the data to take up the entire
|
||||||
|
image. Used in
|
||||||
|
|
||||||
|
- :py:func:`audiotools.core.display.DisplayMixin.specshow`
|
||||||
|
- :py:func:`audiotools.core.display.DisplayMixin.waveplot`
|
||||||
|
- :py:func:`audiotools.core.display.DisplayMixin.wavespec`
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fig_size : tuple, optional
|
||||||
|
Size of figure, by default (9, 3)
|
||||||
|
title : str, optional
|
||||||
|
Title to inset in top right, by default None
|
||||||
|
fig : matplotlib.figure.Figure, optional
|
||||||
|
Figure object, if None ``plt.gcf()`` will be used, by default None
|
||||||
|
format_axes : bool, optional
|
||||||
|
Format the axes to be inside the figure, by default True
|
||||||
|
format : bool, optional
|
||||||
|
This formatting can be skipped entirely by passing ``format=False``
|
||||||
|
to any of the plotting functions that use this formater, by default True
|
||||||
|
font_color : str, optional
|
||||||
|
Color of font of axes, by default "white"
|
||||||
|
"""
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
if fig_size is None:
|
||||||
|
fig_size = DEFAULT_FIG_SIZE
|
||||||
|
if not format:
|
||||||
|
return
|
||||||
|
if fig is None:
|
||||||
|
fig = plt.gcf()
|
||||||
|
fig.set_size_inches(*fig_size)
|
||||||
|
axs = fig.axes
|
||||||
|
|
||||||
|
pixels = (fig.get_size_inches() * fig.dpi)[0]
|
||||||
|
font_scale = pixels / BASE_SIZE
|
||||||
|
|
||||||
|
if format_axes:
|
||||||
|
axs = fig.axes
|
||||||
|
|
||||||
|
for ax in axs:
|
||||||
|
ymin, _ = ax.get_ylim()
|
||||||
|
xmin, _ = ax.get_xlim()
|
||||||
|
|
||||||
|
ticks = ax.get_yticks()
|
||||||
|
for t in ticks[2:-1]:
|
||||||
|
t = axs[0].annotate(
|
||||||
|
f"{(t / 1000):2.1f}k",
|
||||||
|
xy=(xmin, t),
|
||||||
|
xycoords="data",
|
||||||
|
xytext=(5, -5),
|
||||||
|
textcoords="offset points",
|
||||||
|
ha="left",
|
||||||
|
va="top",
|
||||||
|
color=font_color,
|
||||||
|
fontsize=12 * font_scale,
|
||||||
|
alpha=0.75, )
|
||||||
|
|
||||||
|
ticks = ax.get_xticks()[2:]
|
||||||
|
for t in ticks[:-1]:
|
||||||
|
t = axs[0].annotate(
|
||||||
|
f"{t:2.1f}s",
|
||||||
|
xy=(t, ymin),
|
||||||
|
xycoords="data",
|
||||||
|
xytext=(5, 5),
|
||||||
|
textcoords="offset points",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
color=font_color,
|
||||||
|
fontsize=12 * font_scale,
|
||||||
|
alpha=0.75, )
|
||||||
|
|
||||||
|
ax.margins(0, 0)
|
||||||
|
ax.set_axis_off()
|
||||||
|
ax.xaxis.set_major_locator(plt.NullLocator())
|
||||||
|
ax.yaxis.set_major_locator(plt.NullLocator())
|
||||||
|
|
||||||
|
plt.subplots_adjust(
|
||||||
|
top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||||
|
|
||||||
|
if title is not None:
|
||||||
|
t = axs[0].annotate(
|
||||||
|
title,
|
||||||
|
xy=(1, 1),
|
||||||
|
xycoords="axes fraction",
|
||||||
|
fontsize=20 * font_scale,
|
||||||
|
xytext=(-5, -5),
|
||||||
|
textcoords="offset points",
|
||||||
|
ha="right",
|
||||||
|
va="top",
|
||||||
|
color="white", )
|
||||||
|
t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black"))
|
||||||
|
|
||||||
|
|
||||||
|
_default_collate_err_msg_format = (
|
||||||
|
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
||||||
|
"dicts or lists; found {}")
|
||||||
|
|
||||||
|
|
||||||
|
def collate_tensor_fn(
|
||||||
|
batch,
|
||||||
|
*,
|
||||||
|
collate_fn_map: Optional[Dict[Union[type, Tuple[type, ...]],
|
||||||
|
Callable]]=None, ):
|
||||||
|
out = paddle.stack(batch, axis=0)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def collate_float_fn(
|
||||||
|
batch,
|
||||||
|
*,
|
||||||
|
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]],
|
||||||
|
Callable]]=None, ):
|
||||||
|
return paddle.to_tensor(batch, dtype=paddle.float64)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_int_fn(
|
||||||
|
batch,
|
||||||
|
*,
|
||||||
|
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]],
|
||||||
|
Callable]]=None, ):
|
||||||
|
return paddle.to_tensor(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_str_fn(
|
||||||
|
batch,
|
||||||
|
*,
|
||||||
|
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]],
|
||||||
|
Callable]]=None, ):
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {
|
||||||
|
paddle.Tensor: collate_tensor_fn
|
||||||
|
}
|
||||||
|
default_collate_fn_map[float] = collate_float_fn
|
||||||
|
default_collate_fn_map[int] = collate_int_fn
|
||||||
|
default_collate_fn_map[str] = collate_str_fn
|
||||||
|
default_collate_fn_map[bytes] = collate_str_fn
|
||||||
|
|
||||||
|
|
||||||
|
def default_collate(batch,
|
||||||
|
*,
|
||||||
|
collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]],
|
||||||
|
Callable]]=None):
|
||||||
|
r"""
|
||||||
|
General collate function that handles collection type of element within each batch.
|
||||||
|
|
||||||
|
The function also opens function registry to deal with specific element types. `default_collate_fn_map`
|
||||||
|
provides default collate functions for tensors, numpy arrays, numbers and strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: a single batch to be collated
|
||||||
|
collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.
|
||||||
|
If the element type isn't present in this dictionary,
|
||||||
|
this function will go through each key of the dictionary in the insertion order to
|
||||||
|
invoke the corresponding collate function if the element type is a subclass of the key.
|
||||||
|
Note:
|
||||||
|
Each collate function requires a positional argument for batch and a keyword argument
|
||||||
|
for the dictionary of collate functions as `collate_fn_map`.
|
||||||
|
"""
|
||||||
|
elem = batch[0]
|
||||||
|
elem_type = type(elem)
|
||||||
|
|
||||||
|
if collate_fn_map is not None:
|
||||||
|
if elem_type in collate_fn_map:
|
||||||
|
return collate_fn_map[elem_type](
|
||||||
|
batch, collate_fn_map=collate_fn_map)
|
||||||
|
|
||||||
|
for collate_type in collate_fn_map:
|
||||||
|
if isinstance(elem, collate_type):
|
||||||
|
return collate_fn_map[collate_type](
|
||||||
|
batch, collate_fn_map=collate_fn_map)
|
||||||
|
|
||||||
|
if isinstance(elem, collections.abc.Mapping):
|
||||||
|
try:
|
||||||
|
return elem_type({
|
||||||
|
key: default_collate(
|
||||||
|
[d[key] for d in batch], collate_fn_map=collate_fn_map)
|
||||||
|
for key in elem
|
||||||
|
})
|
||||||
|
except TypeError:
|
||||||
|
# The mapping type may not support `__init__(iterable)`.
|
||||||
|
return {
|
||||||
|
key: default_collate(
|
||||||
|
[d[key] for d in batch], collate_fn_map=collate_fn_map)
|
||||||
|
for key in elem
|
||||||
|
}
|
||||||
|
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
|
||||||
|
return elem_type(*(default_collate(
|
||||||
|
samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
|
||||||
|
elif isinstance(elem, collections.abc.Sequence):
|
||||||
|
# check to make sure that the elements in batch have consistent size
|
||||||
|
it = iter(batch)
|
||||||
|
elem_size = len(next(it))
|
||||||
|
if not all(len(elem) == elem_size for elem in it):
|
||||||
|
raise RuntimeError(
|
||||||
|
"each element in list of batch should be of equal size")
|
||||||
|
transposed = list(zip(
|
||||||
|
*batch)) # It may be accessed twice, so we use a list.
|
||||||
|
|
||||||
|
if isinstance(elem, tuple):
|
||||||
|
return [
|
||||||
|
default_collate(samples, collate_fn_map=collate_fn_map)
|
||||||
|
for samples in transposed
|
||||||
|
] # Backwards compatibility.
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return elem_type([
|
||||||
|
default_collate(samples, collate_fn_map=collate_fn_map)
|
||||||
|
for samples in transposed
|
||||||
|
])
|
||||||
|
except TypeError:
|
||||||
|
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
|
||||||
|
return [
|
||||||
|
default_collate(samples, collate_fn_map=collate_fn_map)
|
||||||
|
for samples in transposed
|
||||||
|
]
|
||||||
|
|
||||||
|
raise TypeError(_default_collate_err_msg_format.format(elem_type))
|
||||||
|
|
||||||
|
|
||||||
|
def collate(list_of_dicts: list, n_splits: int=None):
|
||||||
|
"""Collates a list of dictionaries (e.g. as returned by a
|
||||||
|
dataloader) into a dictionary with batched values. This routine
|
||||||
|
uses the default torch collate function for everything
|
||||||
|
except AudioSignal objects, which are handled by the
|
||||||
|
:py:func:`audiotools.core.audio_signal.AudioSignal.batch`
|
||||||
|
function.
|
||||||
|
|
||||||
|
This function takes n_splits to enable splitting a batch
|
||||||
|
into multiple sub-batches for the purposes of gradient accumulation,
|
||||||
|
etc.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
list_of_dicts : list
|
||||||
|
List of dictionaries to be collated.
|
||||||
|
n_splits : int
|
||||||
|
Number of splits to make when creating the batches (split into
|
||||||
|
sub-batches). Useful for things like gradient accumulation.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
Dictionary containing batched data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
list_len = len(list_of_dicts)
|
||||||
|
|
||||||
|
return_list = False if n_splits is None else True
|
||||||
|
n_splits = 1 if n_splits is None else n_splits
|
||||||
|
n_items = int(math.ceil(list_len / n_splits))
|
||||||
|
|
||||||
|
for i in range(0, list_len, n_items):
|
||||||
|
# Flatten the dictionaries to avoid recursion.
|
||||||
|
list_of_dicts_ = [flatten(d) for d in list_of_dicts[i:i + n_items]]
|
||||||
|
dict_of_lists = {
|
||||||
|
k: [dic[k] for dic in list_of_dicts_]
|
||||||
|
for k in list_of_dicts_[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
batch = {}
|
||||||
|
for k, v in dict_of_lists.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
if all(isinstance(s, AudioSignal) for s in v):
|
||||||
|
batch[k] = AudioSignal.batch(v, pad_signals=True)
|
||||||
|
else:
|
||||||
|
batch[k] = default_collate(
|
||||||
|
v, collate_fn_map=default_collate_fn_map)
|
||||||
|
batches.append(unflatten(batch))
|
||||||
|
|
||||||
|
batches = batches[0] if not return_list else batches
|
||||||
|
return batches
|
||||||
|
|
||||||
|
|
||||||
|
def hz_to_bin(hz: paddle.Tensor, n_fft: int, sample_rate: int):
|
||||||
|
"""Closest frequency bin given a frequency, number
|
||||||
|
of bins, and a sampling rate.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
hz : paddle.Tensor
|
||||||
|
Tensor of frequencies in Hz.
|
||||||
|
n_fft : int
|
||||||
|
Number of FFT bins.
|
||||||
|
sample_rate : int
|
||||||
|
Sample rate of audio.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.Tensor
|
||||||
|
Closest bins to the data.
|
||||||
|
"""
|
||||||
|
shape = hz.shape
|
||||||
|
hz = hz.reshape([-1])
|
||||||
|
freqs = paddle.linspace(0, sample_rate / 2, 2 + n_fft // 2)
|
||||||
|
hz = paddle.clip(hz, max=sample_rate / 2).astype(freqs.dtype)
|
||||||
|
|
||||||
|
closest = (hz[None, :] - freqs[:, None]).abs()
|
||||||
|
closest_bins = closest.argmin(axis=0)
|
||||||
|
|
||||||
|
return closest_bins.reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_chord_dataset(
|
||||||
|
max_voices: int=8,
|
||||||
|
sample_rate: int=44100,
|
||||||
|
num_items: int=5,
|
||||||
|
duration: float=1.0,
|
||||||
|
min_note: str="C2",
|
||||||
|
max_note: str="C6",
|
||||||
|
output_dir: Path="chords", ):
|
||||||
|
"""
|
||||||
|
Generates a toy multitrack dataset of chords, synthesized from sine waves.
|
||||||
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max_voices : int, optional
|
||||||
|
Maximum number of voices in a chord, by default 8
|
||||||
|
sample_rate : int, optional
|
||||||
|
Sample rate of audio, by default 44100
|
||||||
|
num_items : int, optional
|
||||||
|
Number of items to generate, by default 5
|
||||||
|
duration : float, optional
|
||||||
|
Duration of each item, by default 1.0
|
||||||
|
min_note : str, optional
|
||||||
|
Minimum note in the dataset, by default "C2"
|
||||||
|
max_note : str, optional
|
||||||
|
Maximum note in the dataset, by default "C6"
|
||||||
|
output_dir : Path, optional
|
||||||
|
Directory to save the dataset, by default "chords"
|
||||||
|
|
||||||
|
"""
|
||||||
|
import librosa
|
||||||
|
from . import AudioSignal
|
||||||
|
from ..data.preprocess import create_csv
|
||||||
|
|
||||||
|
min_midi = librosa.note_to_midi(min_note)
|
||||||
|
max_midi = librosa.note_to_midi(max_note)
|
||||||
|
|
||||||
|
tracks = []
|
||||||
|
for idx in range(num_items):
|
||||||
|
track = {}
|
||||||
|
# figure out how many voices to put in this track
|
||||||
|
num_voices = random.randint(1, max_voices)
|
||||||
|
for voice_idx in range(num_voices):
|
||||||
|
# choose some random params
|
||||||
|
midinote = random.randint(min_midi, max_midi)
|
||||||
|
dur = random.uniform(0.85 * duration, duration)
|
||||||
|
|
||||||
|
sig = AudioSignal.wave(
|
||||||
|
frequency=librosa.midi_to_hz(midinote),
|
||||||
|
duration=dur,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
shape="sine", )
|
||||||
|
track[f"voice_{voice_idx}"] = sig
|
||||||
|
tracks.append(track)
|
||||||
|
|
||||||
|
# save the tracks to disk
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
for idx, track in enumerate(tracks):
|
||||||
|
track_dir = output_dir / f"track_{idx}"
|
||||||
|
track_dir.mkdir(exist_ok=True)
|
||||||
|
for voice_name, sig in track.items():
|
||||||
|
sig.write(track_dir / f"{voice_name}.wav")
|
||||||
|
|
||||||
|
all_voices = list(set([k for track in tracks for k in track.keys()]))
|
||||||
|
voice_lists = {voice: [] for voice in all_voices}
|
||||||
|
for track in tracks:
|
||||||
|
for voice_name in all_voices:
|
||||||
|
if voice_name in track:
|
||||||
|
voice_lists[voice_name].append(track[voice_name].path_to_file)
|
||||||
|
else:
|
||||||
|
voice_lists[voice_name].append("")
|
||||||
|
|
||||||
|
for voice_name, paths in voice_lists.items():
|
||||||
|
create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True)
|
||||||
|
|
||||||
|
return output_dir
|
@ -0,0 +1,16 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from . import datasets
|
||||||
|
from . import preprocess
|
||||||
|
from . import transforms
|
@ -0,0 +1,87 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/data/preprocess.py)
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ..core import AudioSignal
|
||||||
|
|
||||||
|
|
||||||
|
def create_csv(audio_files: list,
|
||||||
|
output_csv: Path,
|
||||||
|
loudness: bool=False,
|
||||||
|
data_path: str=None):
|
||||||
|
"""Converts a folder of audio files to a CSV file. If ``loudness = True``,
|
||||||
|
the output of this function will create a CSV file that looks something
|
||||||
|
like:
|
||||||
|
|
||||||
|
.. csv-table::
|
||||||
|
:header: path,loudness
|
||||||
|
|
||||||
|
daps/produced/f1_script1_produced.wav,-16.299999237060547
|
||||||
|
daps/produced/f1_script2_produced.wav,-16.600000381469727
|
||||||
|
daps/produced/f1_script3_produced.wav,-17.299999237060547
|
||||||
|
daps/produced/f1_script4_produced.wav,-16.100000381469727
|
||||||
|
daps/produced/f1_script5_produced.wav,-16.700000762939453
|
||||||
|
daps/produced/f3_script1_produced.wav,-16.5
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The paths above are written relative to the ``data_path`` argument
|
||||||
|
which defaults to the environment variable ``PATH_TO_DATA`` if
|
||||||
|
it isn't passed to this function, and defaults to the empty string
|
||||||
|
if that environment variable is not set.
|
||||||
|
|
||||||
|
You can produce a CSV file from a directory of audio files via:
|
||||||
|
|
||||||
|
>>> from audio import audiotools
|
||||||
|
>>> directory = ...
|
||||||
|
>>> audio_files = audiotools.util.find_audio(directory)
|
||||||
|
>>> output_path = "train.csv"
|
||||||
|
>>> audiotools.data.preprocess.create_csv(
|
||||||
|
>>> audio_files, output_csv, loudness=True
|
||||||
|
>>> )
|
||||||
|
|
||||||
|
Note that you can create empty rows in the CSV file by passing an empty
|
||||||
|
string or None in the ``audio_files`` list. This is useful if you want to
|
||||||
|
sync multiple CSV files in a multitrack setting. The loudness of these
|
||||||
|
empty rows will be set to -inf.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_files : list
|
||||||
|
List of audio files.
|
||||||
|
output_csv : Path
|
||||||
|
Output CSV, with each row containing the relative path of every file
|
||||||
|
to ``data_path``, if specified (defaults to None).
|
||||||
|
loudness : bool
|
||||||
|
Compute loudness of entire file and store alongside path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
info = []
|
||||||
|
pbar = tqdm(audio_files)
|
||||||
|
for af in pbar:
|
||||||
|
af = Path(af)
|
||||||
|
pbar.set_description(f"Processing {af.name}")
|
||||||
|
_info = {}
|
||||||
|
if af.name == "":
|
||||||
|
_info["path"] = ""
|
||||||
|
if loudness:
|
||||||
|
_info["loudness"] = -float("inf")
|
||||||
|
else:
|
||||||
|
_info["path"] = af.relative_to(
|
||||||
|
data_path) if data_path is not None else af
|
||||||
|
if loudness:
|
||||||
|
_info["loudness"] = AudioSignal(af).ffmpeg_loudness().item()
|
||||||
|
|
||||||
|
info.append(_info)
|
||||||
|
|
||||||
|
with open(output_csv, "w") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=list(info[0].keys()))
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for item in info:
|
||||||
|
writer.writerow(item)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,17 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Functions for comparing AudioSignal objects to one another.
|
||||||
|
"""
|
||||||
|
from . import quality
|
@ -0,0 +1,74 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/quality.py)
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from ..core import AudioSignal
|
||||||
|
|
||||||
|
|
||||||
|
def visqol(
|
||||||
|
estimates: AudioSignal,
|
||||||
|
references: AudioSignal,
|
||||||
|
mode: str="audio", ):
|
||||||
|
"""ViSQOL score.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
estimates : AudioSignal
|
||||||
|
Degraded AudioSignal
|
||||||
|
references : AudioSignal
|
||||||
|
Reference AudioSignal
|
||||||
|
mode : str, optional
|
||||||
|
'audio' or 'speech', by default 'audio'
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor[float]
|
||||||
|
ViSQOL score (MOS-LQO)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from pyvisqol import visqol_lib_py
|
||||||
|
from pyvisqol.pb2 import visqol_config_pb2
|
||||||
|
from pyvisqol.pb2 import similarity_result_pb2
|
||||||
|
except ImportError:
|
||||||
|
from visqol import visqol_lib_py
|
||||||
|
from visqol.pb2 import visqol_config_pb2
|
||||||
|
from visqol.pb2 import similarity_result_pb2
|
||||||
|
|
||||||
|
config = visqol_config_pb2.VisqolConfig()
|
||||||
|
if mode == "audio":
|
||||||
|
target_sr = 48000
|
||||||
|
config.options.use_speech_scoring = False
|
||||||
|
svr_model_path = "libsvm_nu_svr_model.txt"
|
||||||
|
elif mode == "speech":
|
||||||
|
target_sr = 16000
|
||||||
|
config.options.use_speech_scoring = True
|
||||||
|
svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized mode: {mode}")
|
||||||
|
config.audio.sample_rate = target_sr
|
||||||
|
config.options.svr_model_path = os.path.join(
|
||||||
|
os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path)
|
||||||
|
|
||||||
|
api = visqol_lib_py.VisqolApi()
|
||||||
|
api.Create(config)
|
||||||
|
|
||||||
|
estimates = estimates.clone().to_mono().resample(target_sr)
|
||||||
|
references = references.clone().to_mono().resample(target_sr)
|
||||||
|
|
||||||
|
visqols = []
|
||||||
|
for i in range(estimates.batch_size):
|
||||||
|
_visqol = api.Measure(
|
||||||
|
references.audio_data[i, 0].detach().cpu().numpy().astype(float),
|
||||||
|
estimates.audio_data[i, 0].detach().cpu().numpy().astype(float), )
|
||||||
|
visqols.append(_visqol.moslqo)
|
||||||
|
return paddle.to_tensor(np.array(visqols))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
signal = AudioSignal(paddle.randn([44100]), 44100)
|
||||||
|
print(visqol(signal, signal))
|
@ -0,0 +1,16 @@
|
|||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from . import decorators
|
||||||
|
from .accelerator import Accelerator
|
||||||
|
from .basemodel import BaseModel
|
@ -0,0 +1,199 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/ml/accelerator.py)
|
||||||
|
import os
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.distributed as dist
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
|
from paddle.io import SequenceSampler
|
||||||
|
|
||||||
|
|
||||||
|
class ResumableDistributedSampler(DistributedBatchSampler):
|
||||||
|
"""Distributed sampler that can be resumed from a given start index."""
|
||||||
|
|
||||||
|
def __init__(self, dataset, start_idx: int=None, **kwargs):
|
||||||
|
super().__init__(dataset, **kwargs)
|
||||||
|
# Start index, allows to resume an experiment at the index it was
|
||||||
|
self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for i, idx in enumerate(super().__iter__()):
|
||||||
|
if i >= self.start_idx:
|
||||||
|
yield idx
|
||||||
|
self.start_idx = 0 # set the index back to 0 so for the next epoch
|
||||||
|
|
||||||
|
|
||||||
|
class ResumableSequentialSampler(SequenceSampler):
|
||||||
|
"""Sequential sampler that can be resumed from a given start index."""
|
||||||
|
|
||||||
|
def __init__(self, dataset, start_idx: int=None, **kwargs):
|
||||||
|
super().__init__(dataset, **kwargs)
|
||||||
|
# Start index, allows to resume an experiment at the index it was
|
||||||
|
self.start_idx = start_idx if start_idx is not None else 0
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for i, idx in enumerate(super().__iter__()):
|
||||||
|
if i >= self.start_idx:
|
||||||
|
yield idx
|
||||||
|
self.start_idx = 0 # set the index back to 0 so for the next epoch
|
||||||
|
|
||||||
|
|
||||||
|
class Accelerator:
|
||||||
|
"""This class is used to prepare models and dataloaders for
|
||||||
|
usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
|
||||||
|
prepare the respective objects. In the case of models, they are moved to
|
||||||
|
the appropriate GPU. In the case of
|
||||||
|
dataloaders, a sampler is created and the dataloader is initialized with
|
||||||
|
that sampler.
|
||||||
|
|
||||||
|
If the world size is 1, prepare_model and prepare_dataloader are
|
||||||
|
no-ops. If the environment variable ``PADDLE_TRAINER_ID`` is not set, then the
|
||||||
|
script was launched without ``paddle.distributed.launch``, and ``DataParallel``
|
||||||
|
will be used instead of ``DistributedDataParallel`` (not recommended), if
|
||||||
|
the world size (number of GPUs) is greater than 1.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
amp : bool, optional
|
||||||
|
Whether or not to enable automatic mixed precision, by default False
|
||||||
|
(Note: This is a placeholder as PaddlePaddle doesn't have native support for AMP as of now)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, amp: bool=False):
|
||||||
|
trainer_id = os.getenv("PADDLE_TRAINER_ID", None)
|
||||||
|
self.world_size = paddle.distributed.get_world_size()
|
||||||
|
|
||||||
|
self.use_ddp = self.world_size > 1 and trainer_id is not None
|
||||||
|
self.use_dp = self.world_size > 1 and trainer_id is None
|
||||||
|
self.device = "cpu" if self.world_size == 0 else "cuda"
|
||||||
|
|
||||||
|
if self.use_ddp:
|
||||||
|
trainer_id = int(trainer_id)
|
||||||
|
dist.init_parallel_env()
|
||||||
|
|
||||||
|
self.local_rank = 0 if trainer_id is None else int(trainer_id)
|
||||||
|
self.amp = amp
|
||||||
|
|
||||||
|
class DummyScaler:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def step(self, optimizer):
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
def scale(self, loss):
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def unscale_(self, optimizer):
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.scaler = paddle.amp.GradScaler() if self.amp else DummyScaler()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def prepare_model(self, model: paddle.nn.Layer, **kwargs):
|
||||||
|
"""Prepares model for DDP or DP. The model is moved to
|
||||||
|
the device of the correct rank.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model : paddle.nn.Layer
|
||||||
|
Model that is converted for DDP or DP.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
paddle.nn.Layer
|
||||||
|
Wrapped model, or original model if DDP and DP are turned off.
|
||||||
|
"""
|
||||||
|
if self.use_ddp:
|
||||||
|
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
|
model = paddle.DataParallel(model, **kwargs)
|
||||||
|
elif self.use_dp:
|
||||||
|
model = paddle.DataParallel(model, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def autocast(self, *args, **kwargs):
|
||||||
|
return paddle.amp.auto_cast(self.amp, *args, **kwargs)
|
||||||
|
|
||||||
|
def backward(self, loss: paddle.Tensor):
|
||||||
|
"""Backwards pass.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
loss : paddle.Tensor
|
||||||
|
Loss value.
|
||||||
|
"""
|
||||||
|
scaled = self.scaler.scale(loss) # scale the loss
|
||||||
|
scaled.backward()
|
||||||
|
|
||||||
|
def step(self, optimizer: paddle.optimizer.Optimizer):
|
||||||
|
"""Steps the optimizer.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
optimizer : paddle.optimizer.Optimizer
|
||||||
|
Optimizer to step forward.
|
||||||
|
"""
|
||||||
|
self.scaler.step(optimizer)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
# https://www.paddlepaddle.org.cn/documentation/docs/zh/2.6/api/paddle/amp/GradScaler_cn.html#step-optimizer
|
||||||
|
self.scaler.update()
|
||||||
|
|
||||||
|
def prepare_dataloader(self,
|
||||||
|
dataset: typing.Iterable,
|
||||||
|
start_idx: int=None,
|
||||||
|
**kwargs):
|
||||||
|
"""Wraps a dataset with a DataLoader, using the correct sampler if DDP is
|
||||||
|
enabled.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dataset : typing.Iterable
|
||||||
|
Dataset to build Dataloader around.
|
||||||
|
start_idx : int, optional
|
||||||
|
Start index of sampler, useful if resuming from some epoch,
|
||||||
|
by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
DataLoader
|
||||||
|
Wrapped DataLoader.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.use_ddp:
|
||||||
|
sampler = ResumableDistributedSampler(
|
||||||
|
dataset,
|
||||||
|
start_idx,
|
||||||
|
batch_size=kwargs.get("batch_size", 1),
|
||||||
|
shuffle=kwargs.get("shuffle", True),
|
||||||
|
drop_last=kwargs.get("drop_last", False),
|
||||||
|
num_replicas=self.world_size,
|
||||||
|
rank=self.local_rank, )
|
||||||
|
if "num_workers" in kwargs:
|
||||||
|
kwargs["num_workers"] = max(kwargs["num_workers"] //
|
||||||
|
self.world_size, 1)
|
||||||
|
else:
|
||||||
|
sampler = ResumableSequentialSampler(dataset, start_idx)
|
||||||
|
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=sampler if self.use_ddp else None,
|
||||||
|
sampler=sampler if not self.use_ddp else None,
|
||||||
|
**kwargs, )
|
||||||
|
return dataloader
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def unwrap(model):
|
||||||
|
return model
|
@ -0,0 +1,272 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/ml/layers/base.py)
|
||||||
|
import inspect
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import typing
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModel(nn.Layer):
|
||||||
|
"""This is a class that adds useful save/load functionality to a
|
||||||
|
``paddle.nn.Layer`` object. ``BaseModel`` objects can be saved
|
||||||
|
as ``package`` easily, making them super easy to port between
|
||||||
|
machines without requiring a ton of dependencies. Files can also be
|
||||||
|
saved as just weights, in the standard way.
|
||||||
|
|
||||||
|
>>> class Model(ml.BaseModel):
|
||||||
|
>>> def __init__(self, arg1: float = 1.0):
|
||||||
|
>>> super().__init__()
|
||||||
|
>>> self.arg1 = arg1
|
||||||
|
>>> self.linear = nn.Linear(1, 1)
|
||||||
|
>>>
|
||||||
|
>>> def forward(self, x):
|
||||||
|
>>> return self.linear(x)
|
||||||
|
>>>
|
||||||
|
>>> model1 = Model()
|
||||||
|
>>>
|
||||||
|
>>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
|
||||||
|
>>> model1.save(
|
||||||
|
>>> f.name,
|
||||||
|
>>> )
|
||||||
|
>>> model2 = Model.load(f.name)
|
||||||
|
>>> out2 = seed_and_run(model2, x)
|
||||||
|
>>> assert paddle.allclose(out1, out2)
|
||||||
|
>>>
|
||||||
|
>>> model1.save(f.name, package=True)
|
||||||
|
>>> model2 = Model.load(f.name)
|
||||||
|
>>> model2.save(f.name, package=False)
|
||||||
|
>>> model3 = Model.load(f.name)
|
||||||
|
>>> out3 = seed_and_run(model3, x)
|
||||||
|
>>>
|
||||||
|
>>> with tempfile.TemporaryDirectory() as d:
|
||||||
|
>>> model1.save_to_folder(d, {"data": 1.0})
|
||||||
|
>>> Model.load_from_folder(d)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
metadata: dict=None,
|
||||||
|
package: bool=False,
|
||||||
|
intern: list=[],
|
||||||
|
extern: list=[],
|
||||||
|
mock: list=[], ):
|
||||||
|
"""Saves the model, either as a package, or just as
|
||||||
|
weights, alongside some specified metadata.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : str
|
||||||
|
Path to save model to.
|
||||||
|
metadata : dict, optional
|
||||||
|
Any metadata to save alongside the model,
|
||||||
|
by default None
|
||||||
|
package : bool, optional
|
||||||
|
Whether to use ``package`` to save the model in
|
||||||
|
a format that is portable, by default True
|
||||||
|
intern : list, optional
|
||||||
|
List of additional libraries that are internal
|
||||||
|
to the model, used with package, by default []
|
||||||
|
extern : list, optional
|
||||||
|
List of additional libraries that are external to
|
||||||
|
the model, used with package, by default []
|
||||||
|
mock : list, optional
|
||||||
|
List of libraries to mock, used with package,
|
||||||
|
by default []
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Path to saved model.
|
||||||
|
"""
|
||||||
|
sig = inspect.signature(self.__class__)
|
||||||
|
args = {}
|
||||||
|
|
||||||
|
for key, val in sig.parameters.items():
|
||||||
|
arg_val = val.default
|
||||||
|
if arg_val is not inspect.Parameter.empty:
|
||||||
|
args[key] = arg_val
|
||||||
|
|
||||||
|
# Look up attibutes in self, and if any of them are in args,
|
||||||
|
# overwrite them in args.
|
||||||
|
for attribute in dir(self):
|
||||||
|
if attribute in args:
|
||||||
|
args[attribute] = getattr(self, attribute)
|
||||||
|
|
||||||
|
metadata = {} if metadata is None else metadata
|
||||||
|
metadata["kwargs"] = args
|
||||||
|
if not hasattr(self, "metadata"):
|
||||||
|
self.metadata = {}
|
||||||
|
self.metadata.update(metadata)
|
||||||
|
|
||||||
|
if not package:
|
||||||
|
state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
|
||||||
|
paddle.save(state_dict, str(path))
|
||||||
|
else:
|
||||||
|
self._save_package(path, intern=intern, extern=extern, mock=mock)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
"""Gets the device the model is on by looking at the device of
|
||||||
|
the first parameter. May not be valid if model is split across
|
||||||
|
multiple devices.
|
||||||
|
"""
|
||||||
|
return list(self.parameters())[0].place
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(
|
||||||
|
cls,
|
||||||
|
location: str,
|
||||||
|
*args,
|
||||||
|
package_name: str=None,
|
||||||
|
strict: bool=False,
|
||||||
|
**kwargs, ):
|
||||||
|
"""Load model from a path. Tries first to load as a package, and if
|
||||||
|
that fails, tries to load as weights. The arguments to the class are
|
||||||
|
specified inside the model weights file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
location : str
|
||||||
|
Path to file.
|
||||||
|
package_name : str, optional
|
||||||
|
Name of package, by default ``cls.__name__``.
|
||||||
|
strict : bool, optional
|
||||||
|
Ignore unmatched keys, by default False
|
||||||
|
kwargs : dict
|
||||||
|
Additional keyword arguments to the model instantiation, if
|
||||||
|
not loading from package.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BaseModel
|
||||||
|
A model that inherits from BaseModel.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model = cls._load_package(location, package_name=package_name)
|
||||||
|
except:
|
||||||
|
model_dict = paddle.load(location)
|
||||||
|
metadata = model_dict["metadata"]
|
||||||
|
metadata["kwargs"].update(kwargs)
|
||||||
|
|
||||||
|
sig = inspect.signature(cls)
|
||||||
|
class_keys = list(sig.parameters.keys())
|
||||||
|
for k in list(metadata["kwargs"].keys()):
|
||||||
|
if k not in class_keys:
|
||||||
|
metadata["kwargs"].pop(k)
|
||||||
|
|
||||||
|
model = cls(*args, **metadata["kwargs"])
|
||||||
|
model.set_state_dict(model_dict["state_dict"])
|
||||||
|
model.metadata = metadata
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
|
||||||
|
raise NotImplementedError("Currently Paddle does not support packaging")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_package(cls, path, package_name=None):
|
||||||
|
raise NotImplementedError("Currently Paddle does not support packaging")
|
||||||
|
|
||||||
|
def save_to_folder(
|
||||||
|
self,
|
||||||
|
folder: typing.Union[str, Path],
|
||||||
|
extra_data: dict=None,
|
||||||
|
package: bool=False, ):
|
||||||
|
"""Dumps a model into a folder, as both a package
|
||||||
|
and as weights, as well as anything specified in
|
||||||
|
``extra_data``. ``extra_data`` is a dictionary of other
|
||||||
|
pickleable files, with the keys being the paths
|
||||||
|
to save them in. The model is saved under a subfolder
|
||||||
|
specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
|
||||||
|
if the model name was ``Generator``).
|
||||||
|
|
||||||
|
>>> with tempfile.TemporaryDirectory() as d:
|
||||||
|
>>> extra_data = {
|
||||||
|
>>> "optimizer.pth": optimizer.state_dict()
|
||||||
|
>>> }
|
||||||
|
>>> model.save_to_folder(d, extra_data)
|
||||||
|
>>> Model.load_from_folder(d)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
folder : typing.Union[str, Path]
|
||||||
|
_description_
|
||||||
|
extra_data : dict, optional
|
||||||
|
_description_, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Path to folder
|
||||||
|
"""
|
||||||
|
extra_data = {} if extra_data is None else extra_data
|
||||||
|
model_name = type(self).__name__.lower()
|
||||||
|
target_base = Path(f"{folder}/{model_name}/")
|
||||||
|
target_base.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
if package:
|
||||||
|
package_path = target_base / f"package.pth"
|
||||||
|
self.save(package_path)
|
||||||
|
|
||||||
|
weights_path = target_base / f"weights.pth"
|
||||||
|
self.save(weights_path, package=False)
|
||||||
|
|
||||||
|
for path, obj in extra_data.items():
|
||||||
|
paddle.save(obj, str(target_base / path))
|
||||||
|
|
||||||
|
return target_base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_folder(
|
||||||
|
cls,
|
||||||
|
folder: typing.Union[str, Path],
|
||||||
|
package: bool=False,
|
||||||
|
strict: bool=False,
|
||||||
|
**kwargs, ):
|
||||||
|
"""Loads the model from a folder generated by
|
||||||
|
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
|
||||||
|
Like that function, this one looks for a subfolder that has
|
||||||
|
the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
|
||||||
|
model name was ``Generator``).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
folder : typing.Union[str, Path]
|
||||||
|
_description_
|
||||||
|
package : bool, optional
|
||||||
|
Whether to use ``package`` to load the model,
|
||||||
|
loading the model from ``package.pth``.
|
||||||
|
strict : bool, optional
|
||||||
|
Ignore unmatched keys, by default False
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple
|
||||||
|
tuple of model and extra data as saved by
|
||||||
|
:py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
|
||||||
|
"""
|
||||||
|
folder = Path(folder) / cls.__name__.lower()
|
||||||
|
model_pth = "package.pth" if package else "weights.pth"
|
||||||
|
model_pth = folder / model_pth
|
||||||
|
|
||||||
|
model = cls.load(str(model_pth))
|
||||||
|
extra_data = {}
|
||||||
|
excluded = ["package.pth", "weights.pth"]
|
||||||
|
files = [
|
||||||
|
x for x in folder.glob("*")
|
||||||
|
if x.is_file() and x.name not in excluded
|
||||||
|
]
|
||||||
|
for f in files:
|
||||||
|
extra_data[f.name] = paddle.load(str(f), **kwargs)
|
||||||
|
|
||||||
|
return model, extra_data
|
@ -0,0 +1,446 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/ml/decorators.py)
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.distributed as dist
|
||||||
|
from rich import box
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.console import Group
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.padding import Padding
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.progress import BarColumn
|
||||||
|
from rich.progress import Progress
|
||||||
|
from rich.progress import SpinnerColumn
|
||||||
|
from rich.progress import TimeElapsedColumn
|
||||||
|
from rich.progress import TimeRemainingColumn
|
||||||
|
from rich.rule import Rule
|
||||||
|
from rich.table import Table
|
||||||
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
|
||||||
|
# This is here so that the history can be pickled.
|
||||||
|
def default_list():
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class Mean:
|
||||||
|
"""Keeps track of the running mean, along with the latest
|
||||||
|
value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def __call__(self):
|
||||||
|
mean = self.total / max(self.count, 1)
|
||||||
|
return mean
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.count = 0
|
||||||
|
self.total = 0
|
||||||
|
|
||||||
|
def update(self, val):
|
||||||
|
if math.isfinite(val):
|
||||||
|
self.count += 1
|
||||||
|
self.total += val
|
||||||
|
|
||||||
|
|
||||||
|
def when(condition):
|
||||||
|
"""Runs a function only when the condition is met. The condition is
|
||||||
|
a function that is run.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
condition : Callable
|
||||||
|
Function to run to check whether or not to run the decorated
|
||||||
|
function.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
Checkpoint only runs every 100 iterations, and only if the
|
||||||
|
local rank is 0.
|
||||||
|
|
||||||
|
>>> i = 0
|
||||||
|
>>> rank = 0
|
||||||
|
>>>
|
||||||
|
>>> @when(lambda: i % 100 == 0 and rank == 0)
|
||||||
|
>>> def checkpoint():
|
||||||
|
>>> print("Saving to /runs/exp1")
|
||||||
|
>>>
|
||||||
|
>>> for i in range(1000):
|
||||||
|
>>> checkpoint()
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
if condition():
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def timer(prefix: str="time"):
|
||||||
|
"""Adds execution time to the output dictionary of the decorated
|
||||||
|
function. The function decorated by this must output a dictionary.
|
||||||
|
The key added will follow the form "[prefix]/[name_of_function]"
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
prefix : str, optional
|
||||||
|
The key added will follow the form "[prefix]/[name_of_function]",
|
||||||
|
by default "time".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
s = time.perf_counter()
|
||||||
|
output = fn(*args, **kwargs)
|
||||||
|
assert isinstance(output, dict)
|
||||||
|
e = time.perf_counter()
|
||||||
|
output[f"{prefix}/{fn.__name__}"] = e - s
|
||||||
|
return output
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class Tracker:
|
||||||
|
"""
|
||||||
|
A tracker class that helps to monitor the progress of training and logging the metrics.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
metrics : dict
|
||||||
|
A dictionary containing the metrics for each label.
|
||||||
|
history : dict
|
||||||
|
A dictionary containing the history of metrics for each label.
|
||||||
|
writer : LogWriter
|
||||||
|
A LogWriter object for logging the metrics.
|
||||||
|
rank : int
|
||||||
|
The rank of the current process.
|
||||||
|
step : int
|
||||||
|
The current step of the training.
|
||||||
|
tasks : dict
|
||||||
|
A dictionary containing the progress bars and tables for each label.
|
||||||
|
pbar : Progress
|
||||||
|
A progress bar object for displaying the progress.
|
||||||
|
consoles : list
|
||||||
|
A list of console objects for logging.
|
||||||
|
live : Live
|
||||||
|
A Live object for updating the display live.
|
||||||
|
|
||||||
|
Methods
|
||||||
|
-------
|
||||||
|
print(msg: str)
|
||||||
|
Prints the given message to all consoles.
|
||||||
|
update(label: str, fn_name: str)
|
||||||
|
Updates the progress bar and table for the given label.
|
||||||
|
done(label: str, title: str)
|
||||||
|
Resets the progress bar and table for the given label and prints the final result.
|
||||||
|
track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
|
||||||
|
A decorator for tracking the progress and metrics of a function.
|
||||||
|
log(label: str, value_type: str = "value", history: bool = True)
|
||||||
|
A decorator for logging the metrics of a function.
|
||||||
|
is_best(label: str, key: str) -> bool
|
||||||
|
Checks if the latest value of the given key in the label is the best so far.
|
||||||
|
state_dict() -> dict
|
||||||
|
Returns a dictionary containing the state of the tracker.
|
||||||
|
load_state_dict(state_dict: dict) -> Tracker
|
||||||
|
Loads the state of the tracker from the given state dictionary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
writer: LogWriter=None,
|
||||||
|
log_file: str=None,
|
||||||
|
rank: int=0,
|
||||||
|
console_width: int=100,
|
||||||
|
step: int=0, ):
|
||||||
|
"""
|
||||||
|
Initializes the Tracker object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
writer : LogWriter, optional
|
||||||
|
A LogWriter object for logging the metrics, by default None.
|
||||||
|
log_file : str, optional
|
||||||
|
The path to the log file, by default None.
|
||||||
|
rank : int, optional
|
||||||
|
The rank of the current process, by default 0.
|
||||||
|
console_width : int, optional
|
||||||
|
The width of the console, by default 100.
|
||||||
|
step : int, optional
|
||||||
|
The current step of the training, by default 0.
|
||||||
|
"""
|
||||||
|
self.metrics = {}
|
||||||
|
self.history = {}
|
||||||
|
self.writer = writer
|
||||||
|
self.rank = rank
|
||||||
|
self.step = step
|
||||||
|
|
||||||
|
# Create progress bars etc.
|
||||||
|
self.tasks = {}
|
||||||
|
self.pbar = Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
"[progress.description]{task.description}",
|
||||||
|
"{task.completed}/{task.total}",
|
||||||
|
BarColumn(),
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
"/",
|
||||||
|
TimeRemainingColumn(), )
|
||||||
|
self.consoles = [Console(width=console_width)]
|
||||||
|
self.live = Live(console=self.consoles[0], refresh_per_second=10)
|
||||||
|
if log_file is not None:
|
||||||
|
self.consoles.append(
|
||||||
|
Console(width=console_width, file=open(log_file, "a")))
|
||||||
|
|
||||||
|
def print(self, msg):
|
||||||
|
"""
|
||||||
|
Prints the given message to all consoles.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
msg : str
|
||||||
|
The message to be printed.
|
||||||
|
"""
|
||||||
|
if self.rank == 0:
|
||||||
|
for c in self.consoles:
|
||||||
|
c.log(msg)
|
||||||
|
|
||||||
|
def update(self, label, fn_name):
|
||||||
|
"""
|
||||||
|
Updates the progress bar and table for the given label.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
The label of the progress bar and table to be updated.
|
||||||
|
fn_name : str
|
||||||
|
The name of the function associated with the label.
|
||||||
|
"""
|
||||||
|
if self.rank == 0:
|
||||||
|
self.pbar.advance(self.tasks[label]["pbar"])
|
||||||
|
|
||||||
|
# Create table
|
||||||
|
table = Table(title=label, expand=True, box=box.MINIMAL)
|
||||||
|
table.add_column("key", style="cyan")
|
||||||
|
table.add_column("value", style="bright_blue")
|
||||||
|
table.add_column("mean", style="bright_green")
|
||||||
|
|
||||||
|
keys = self.metrics[label]["value"].keys()
|
||||||
|
for k in keys:
|
||||||
|
value = self.metrics[label]["value"][k]
|
||||||
|
mean = self.metrics[label]["mean"][k]()
|
||||||
|
table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
|
||||||
|
|
||||||
|
self.tasks[label]["table"] = table
|
||||||
|
tables = [t["table"] for t in self.tasks.values()]
|
||||||
|
group = Group(*tables, self.pbar)
|
||||||
|
self.live.update(
|
||||||
|
Group(
|
||||||
|
Padding("", (0, 0)),
|
||||||
|
Rule(f"[italic]{fn_name}()", style="white"),
|
||||||
|
Padding("", (0, 0)),
|
||||||
|
Panel.fit(
|
||||||
|
group,
|
||||||
|
padding=(0, 5),
|
||||||
|
title="[b]Progress",
|
||||||
|
border_style="blue", ), ))
|
||||||
|
|
||||||
|
def done(self, label: str, title: str):
|
||||||
|
"""
|
||||||
|
Resets the progress bar and table for the given label and prints the final result.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
The label of the progress bar and table to be reset.
|
||||||
|
title : str
|
||||||
|
The title to be displayed when printing the final result.
|
||||||
|
"""
|
||||||
|
for label in self.metrics:
|
||||||
|
for v in self.metrics[label]["mean"].values():
|
||||||
|
v.reset()
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
self.pbar.reset(self.tasks[label]["pbar"])
|
||||||
|
tables = [t["table"] for t in self.tasks.values()]
|
||||||
|
group = Group(Markdown(f"# {title}"), *tables, self.pbar)
|
||||||
|
self.print(group)
|
||||||
|
|
||||||
|
def track(
|
||||||
|
self,
|
||||||
|
label: str,
|
||||||
|
length: int,
|
||||||
|
completed: int=0,
|
||||||
|
op: dist.ReduceOp=dist.ReduceOp.AVG,
|
||||||
|
ddp_active: bool="LOCAL_RANK" in os.environ, ):
|
||||||
|
"""
|
||||||
|
A decorator for tracking the progress and metrics of a function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
The label to be associated with the progress and metrics.
|
||||||
|
length : int
|
||||||
|
The total number of iterations to be completed.
|
||||||
|
completed : int, optional
|
||||||
|
The number of iterations already completed, by default 0.
|
||||||
|
op : dist.ReduceOp, optional
|
||||||
|
The reduce operation to be used, by default dist.ReduceOp.AVG.
|
||||||
|
ddp_active : bool, optional
|
||||||
|
Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
|
||||||
|
"""
|
||||||
|
self.tasks[label] = {
|
||||||
|
"pbar":
|
||||||
|
self.pbar.add_task(
|
||||||
|
f"[white]Iteration ({label})",
|
||||||
|
total=length,
|
||||||
|
completed=completed),
|
||||||
|
"table":
|
||||||
|
Table(),
|
||||||
|
}
|
||||||
|
self.metrics[label] = {
|
||||||
|
"value": defaultdict(),
|
||||||
|
"mean": defaultdict(lambda: Mean()),
|
||||||
|
}
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
output = fn(*args, **kwargs)
|
||||||
|
if not isinstance(output, dict):
|
||||||
|
self.update(label, fn.__name__)
|
||||||
|
return output
|
||||||
|
# Collect across all DDP processes
|
||||||
|
scalar_keys = []
|
||||||
|
for k, v in output.items():
|
||||||
|
if isinstance(v, (int, float)):
|
||||||
|
v = paddle.to_tensor([v])
|
||||||
|
if not paddle.is_tensor(v):
|
||||||
|
continue
|
||||||
|
if ddp_active and v.is_cuda:
|
||||||
|
dist.all_reduce(v, op=op)
|
||||||
|
output[k] = v.detach()
|
||||||
|
if paddle.numel(v) == 1:
|
||||||
|
scalar_keys.append(k)
|
||||||
|
output[k] = v.item()
|
||||||
|
|
||||||
|
# Save the outputs to tracker
|
||||||
|
for k, v in output.items():
|
||||||
|
if k not in scalar_keys:
|
||||||
|
continue
|
||||||
|
self.metrics[label]["value"][k] = v
|
||||||
|
# Update the running mean
|
||||||
|
self.metrics[label]["mean"][k].update(v)
|
||||||
|
|
||||||
|
self.update(label, fn.__name__)
|
||||||
|
return output
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def log(self, label: str, value_type: str="value", history: bool=True):
|
||||||
|
"""
|
||||||
|
A decorator for logging the metrics of a function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
The label to be associated with the logging.
|
||||||
|
value_type : str, optional
|
||||||
|
The type of value to be logged, by default "value".
|
||||||
|
history : bool, optional
|
||||||
|
Whether to save the history of the metrics, by default True.
|
||||||
|
"""
|
||||||
|
assert value_type in ["mean", "value"]
|
||||||
|
if history:
|
||||||
|
if label not in self.history:
|
||||||
|
self.history[label] = defaultdict(default_list)
|
||||||
|
|
||||||
|
def decorator(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
output = fn(*args, **kwargs)
|
||||||
|
if self.rank == 0:
|
||||||
|
nonlocal value_type, label
|
||||||
|
metrics = self.metrics[label][value_type]
|
||||||
|
for k, v in metrics.items():
|
||||||
|
v = v() if isinstance(v, Mean) else v
|
||||||
|
if self.writer is not None:
|
||||||
|
self.writer.add_scalar(
|
||||||
|
tag=f"{k}/{label}", value=v, step=self.step)
|
||||||
|
if label in self.history:
|
||||||
|
self.history[label][k].append(v)
|
||||||
|
|
||||||
|
if label in self.history:
|
||||||
|
self.history[label]["step"].append(self.step)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def is_best(self, label, key):
|
||||||
|
"""
|
||||||
|
Checks if the latest value of the given key in the label is the best so far.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
label : str
|
||||||
|
The label of the metrics to be checked.
|
||||||
|
key : str
|
||||||
|
The key of the metric to be checked.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the latest value is the best so far, otherwise False.
|
||||||
|
"""
|
||||||
|
return self.history[label][key][-1] == min(self.history[label][key])
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
"""
|
||||||
|
Returns a dictionary containing the state of the tracker.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A dictionary containing the history and step of the tracker.
|
||||||
|
"""
|
||||||
|
return {"history": self.history, "step": self.step}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
"""
|
||||||
|
Loads the state of the tracker from the given state dictionary.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
state_dict : dict
|
||||||
|
A dictionary containing the history and step of the tracker.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tracker
|
||||||
|
The tracker object with the loaded state.
|
||||||
|
"""
|
||||||
|
self.history = state_dict["history"]
|
||||||
|
self.step = state_dict["step"]
|
||||||
|
return self
|
@ -0,0 +1,88 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/audiotools/post.py)
|
||||||
|
import typing
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from audio.audiotools.core import AudioSignal
|
||||||
|
|
||||||
|
|
||||||
|
def audio_table(
|
||||||
|
audio_dict: dict,
|
||||||
|
first_column: str=None,
|
||||||
|
format_fn: typing.Callable=None,
|
||||||
|
**kwargs, ):
|
||||||
|
"""Embeds an audio table into HTML, or as the output cell
|
||||||
|
in a notebook.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_dict : dict
|
||||||
|
Dictionary of data to embed.
|
||||||
|
first_column : str, optional
|
||||||
|
The label for the first column of the table, by default None
|
||||||
|
format_fn : typing.Callable, optional
|
||||||
|
How to format the data, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Table as a string
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
|
||||||
|
>>> audio_dict = {}
|
||||||
|
>>> for i in range(signal_batch.batch_size):
|
||||||
|
>>> audio_dict[i] = {
|
||||||
|
>>> "input": signal_batch[i],
|
||||||
|
>>> "output": output_batch[i]
|
||||||
|
>>> }
|
||||||
|
>>> audiotools.post.audio_zip(audio_dict)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
output = []
|
||||||
|
columns = None
|
||||||
|
|
||||||
|
def _default_format_fn(label, x, **kwargs):
|
||||||
|
if paddle.is_tensor(x):
|
||||||
|
x = x.tolist()
|
||||||
|
|
||||||
|
if x is None:
|
||||||
|
return "."
|
||||||
|
elif isinstance(x, AudioSignal):
|
||||||
|
return x.embed(display=False, return_html=True, **kwargs)
|
||||||
|
else:
|
||||||
|
return str(x)
|
||||||
|
|
||||||
|
if format_fn is None:
|
||||||
|
format_fn = _default_format_fn
|
||||||
|
|
||||||
|
if first_column is None:
|
||||||
|
first_column = "."
|
||||||
|
|
||||||
|
for k, v in audio_dict.items():
|
||||||
|
if not isinstance(v, dict):
|
||||||
|
v = {"Audio": v}
|
||||||
|
|
||||||
|
v_keys = list(v.keys())
|
||||||
|
if columns is None:
|
||||||
|
columns = [first_column] + v_keys
|
||||||
|
output.append(" | ".join(columns))
|
||||||
|
|
||||||
|
layout = "|---" + len(v_keys) * "|:-:"
|
||||||
|
output.append(layout)
|
||||||
|
|
||||||
|
formatted_audio = []
|
||||||
|
for col in columns[1:]:
|
||||||
|
formatted_audio.append(format_fn(col, v[col], **kwargs))
|
||||||
|
|
||||||
|
row = f"| {k} | "
|
||||||
|
row += " | ".join(formatted_audio)
|
||||||
|
output.append(row)
|
||||||
|
|
||||||
|
output = "\n" + "\n".join(output)
|
||||||
|
return output
|
@ -0,0 +1,6 @@
|
|||||||
|
ffmpeg-python
|
||||||
|
ffmpy
|
||||||
|
flatten_dict
|
||||||
|
pyloudnorm
|
||||||
|
pytest
|
||||||
|
rich
|
@ -0,0 +1,615 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_audio_signal.py)
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import pytest
|
||||||
|
import rich
|
||||||
|
|
||||||
|
from audio import audiotools
|
||||||
|
from audio.audiotools import AudioSignal
|
||||||
|
from audio.audiotools import util
|
||||||
|
|
||||||
|
|
||||||
|
def test_io():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(pathlib.Path(audio_path))
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
||||||
|
signal.write(f.name)
|
||||||
|
signal_from_file = AudioSignal(f.name)
|
||||||
|
|
||||||
|
mp3_signal = AudioSignal(audio_path.replace("wav", "mp3"))
|
||||||
|
print(mp3_signal)
|
||||||
|
|
||||||
|
assert signal == signal_from_file
|
||||||
|
print(signal)
|
||||||
|
print(signal.markdown())
|
||||||
|
|
||||||
|
mp3_signal = AudioSignal.excerpt(
|
||||||
|
audio_path.replace("wav", "mp3"), offset=5, duration=5)
|
||||||
|
assert mp3_signal.signal_duration == 5.0
|
||||||
|
assert mp3_signal.duration == 5.0
|
||||||
|
assert mp3_signal.length == mp3_signal.signal_length
|
||||||
|
|
||||||
|
rich.print(signal)
|
||||||
|
|
||||||
|
array = np.random.randn(2, 16000)
|
||||||
|
signal = AudioSignal(array, sample_rate=16000)
|
||||||
|
assert np.allclose(signal.numpy(), array)
|
||||||
|
|
||||||
|
signal = AudioSignal(array, 44100)
|
||||||
|
assert signal.sample_rate == 44100
|
||||||
|
signal.shape
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
signal = AudioSignal(5, sample_rate=16000)
|
||||||
|
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=10)
|
||||||
|
assert np.allclose(signal.signal_duration, 10.0)
|
||||||
|
assert np.allclose(signal.duration, 10.0)
|
||||||
|
|
||||||
|
signal = AudioSignal.excerpt(audio_path, offset=5, duration=5)
|
||||||
|
assert signal.signal_duration == 5.0
|
||||||
|
assert signal.duration == 5.0
|
||||||
|
|
||||||
|
assert "offset" in signal.metadata
|
||||||
|
assert "duration" in signal.metadata
|
||||||
|
|
||||||
|
signal = AudioSignal(paddle.randn([1000]), 44100)
|
||||||
|
assert signal.audio_data.ndim == 3
|
||||||
|
assert paddle.all(signal.samples == signal.audio_data)
|
||||||
|
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
assert AudioSignal(audio_path).hash() == AudioSignal(audio_path).hash()
|
||||||
|
assert AudioSignal(audio_path).hash() != AudioSignal(audio_path).normalize(
|
||||||
|
-20).hash()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
AudioSignal(audio_path, offset=100000, duration=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_copy_and_clone():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path)
|
||||||
|
signal.stft()
|
||||||
|
signal.loudness()
|
||||||
|
|
||||||
|
copied = signal.copy()
|
||||||
|
deep_copied = signal.deepcopy()
|
||||||
|
cloned = signal.clone()
|
||||||
|
|
||||||
|
for a in ["audio_data", "stft_data", "_loudness"]:
|
||||||
|
a1 = getattr(signal, a)
|
||||||
|
a2 = getattr(cloned, a)
|
||||||
|
a3 = getattr(copied, a)
|
||||||
|
a4 = getattr(deep_copied, a)
|
||||||
|
|
||||||
|
assert id(a1) != id(a2)
|
||||||
|
assert id(a1) == id(a3)
|
||||||
|
assert id(a1) != id(a4)
|
||||||
|
|
||||||
|
assert np.allclose(a1, a2)
|
||||||
|
assert np.allclose(a1, a3)
|
||||||
|
assert np.allclose(a1, a4)
|
||||||
|
|
||||||
|
for a in ["path_to_file", "metadata"]:
|
||||||
|
a1 = getattr(signal, a)
|
||||||
|
a2 = getattr(cloned, a)
|
||||||
|
a3 = getattr(copied, a)
|
||||||
|
a4 = getattr(deep_copied, a)
|
||||||
|
|
||||||
|
assert id(a1) == id(a2) if isinstance(a1, str) else id(a1) != id(a2)
|
||||||
|
assert id(a1) == id(a3)
|
||||||
|
assert id(a1) == id(a4) if isinstance(a1, str) else id(a1) != id(a2)
|
||||||
|
|
||||||
|
# for clone, id should differ if path is list, and should differ always for metadata
|
||||||
|
# if path is string, id should remain same...
|
||||||
|
|
||||||
|
assert signal.original_signal_length == copied.original_signal_length
|
||||||
|
assert signal.original_signal_length == deep_copied.original_signal_length
|
||||||
|
assert signal.original_signal_length == cloned.original_signal_length
|
||||||
|
|
||||||
|
signal = signal.detach()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("loudness_cutoff", [-np.inf, -160, -80, -40, -20])
|
||||||
|
def test_salient_excerpt(loudness_cutoff):
|
||||||
|
MAP = {-np.inf: 0.0, -160: 0.0, -80: 0.001, -40: 0.01, -20: 0.1}
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
|
||||||
|
sr = 44100
|
||||||
|
signal = AudioSignal(paddle.zeros([sr * 60]), sr)
|
||||||
|
|
||||||
|
signal[..., sr * 20:sr * 21] = MAP[loudness_cutoff] * paddle.randn(
|
||||||
|
[44100])
|
||||||
|
|
||||||
|
signal.write(f.name)
|
||||||
|
signal = AudioSignal.salient_excerpt(
|
||||||
|
f.name, loudness_cutoff=loudness_cutoff, duration=1, num_tries=None)
|
||||||
|
|
||||||
|
assert "offset" in signal.metadata
|
||||||
|
assert "duration" in signal.metadata
|
||||||
|
assert signal.loudness() >= loudness_cutoff
|
||||||
|
|
||||||
|
signal = AudioSignal.salient_excerpt(
|
||||||
|
f.name, loudness_cutoff=np.inf, duration=1, num_tries=10)
|
||||||
|
signal = AudioSignal.salient_excerpt(
|
||||||
|
f.name,
|
||||||
|
loudness_cutoff=None,
|
||||||
|
duration=1, )
|
||||||
|
|
||||||
|
|
||||||
|
def test_arithmetic():
|
||||||
|
def _make_signals():
|
||||||
|
array = np.random.randn(2, 16000)
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
|
||||||
|
array = np.random.randn(2, 16000)
|
||||||
|
sig2 = AudioSignal(array, sample_rate=16000)
|
||||||
|
return sig1, sig2
|
||||||
|
|
||||||
|
# Addition (with a copy)
|
||||||
|
sig1, sig2 = _make_signals()
|
||||||
|
sig3 = sig1 + sig2
|
||||||
|
assert paddle.allclose(sig3.audio_data, sig1.audio_data + sig2.audio_data)
|
||||||
|
|
||||||
|
# Addition (rmul)
|
||||||
|
sig1, _ = _make_signals()
|
||||||
|
sig3 = 5.0 + sig1
|
||||||
|
assert paddle.allclose(sig3.audio_data, sig1.audio_data + 5.0)
|
||||||
|
|
||||||
|
# In place addition
|
||||||
|
sig3, sig2 = _make_signals()
|
||||||
|
sig1 = sig3.deepcopy()
|
||||||
|
sig3 += sig2
|
||||||
|
assert paddle.allclose(sig3.audio_data, sig1.audio_data + sig2.audio_data)
|
||||||
|
|
||||||
|
# Subtraction (with a copy)
|
||||||
|
sig1, sig2 = _make_signals()
|
||||||
|
sig3 = sig1 - sig2
|
||||||
|
assert paddle.allclose(sig3.audio_data, sig1.audio_data - sig2.audio_data)
|
||||||
|
|
||||||
|
# In place subtraction
|
||||||
|
sig3, sig2 = _make_signals()
|
||||||
|
sig1 = sig3.deepcopy()
|
||||||
|
sig3 -= sig2
|
||||||
|
assert paddle.allclose(sig3.audio_data, sig1.audio_data - sig2.audio_data)
|
||||||
|
|
||||||
|
# Multiplication (element-wise)
|
||||||
|
sig1, sig2 = _make_signals()
|
||||||
|
sig3 = sig1 * sig2
|
||||||
|
assert paddle.allclose(sig3.audio_data, sig1.audio_data * sig2.audio_data)
|
||||||
|
|
||||||
|
# Multiplication (gain)
|
||||||
|
sig1, _ = _make_signals()
|
||||||
|
sig3 = sig1 * 5.0
|
||||||
|
assert paddle.allclose(sig3.audio_data, sig1.audio_data * 5.0)
|
||||||
|
|
||||||
|
# Multiplication (rmul)
|
||||||
|
sig1, _ = _make_signals()
|
||||||
|
sig3 = 5.0 * sig1
|
||||||
|
assert paddle.allclose(sig3.audio_data, sig1.audio_data * 5.0)
|
||||||
|
|
||||||
|
# Multiplication (in-place)
|
||||||
|
sig3, sig2 = _make_signals()
|
||||||
|
sig1 = sig3.deepcopy()
|
||||||
|
sig3 *= sig2
|
||||||
|
assert paddle.allclose(sig3.audio_data, sig1.audio_data * sig2.audio_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_equality():
|
||||||
|
array = np.random.randn(2, 16000)
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
sig2 = AudioSignal(array, sample_rate=16000)
|
||||||
|
|
||||||
|
assert sig1 == sig2
|
||||||
|
|
||||||
|
array = np.random.randn(2, 16000)
|
||||||
|
sig3 = AudioSignal(array, sample_rate=16000)
|
||||||
|
|
||||||
|
assert sig1 != sig3
|
||||||
|
|
||||||
|
assert not np.allclose(sig1.numpy(), sig3.numpy())
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexing():
|
||||||
|
array = np.random.randn(4, 2, 16000)
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
|
||||||
|
assert np.allclose(sig1[0].audio_data, array[0])
|
||||||
|
assert np.allclose(sig1[0, :, 8000].audio_data, array[0, :, 8000])
|
||||||
|
|
||||||
|
# Test with the associated STFT data.
|
||||||
|
array = np.random.randn(4, 2, 16000)
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
sig1.loudness()
|
||||||
|
sig1.stft()
|
||||||
|
|
||||||
|
indexed = sig1[0]
|
||||||
|
|
||||||
|
assert np.allclose(indexed.audio_data, array[0])
|
||||||
|
assert np.allclose(indexed.stft_data, sig1.stft_data[0])
|
||||||
|
assert np.allclose(indexed._loudness, sig1._loudness[0])
|
||||||
|
|
||||||
|
indexed = sig1[0:2]
|
||||||
|
|
||||||
|
assert np.allclose(indexed.audio_data, array[0:2])
|
||||||
|
assert np.allclose(indexed.stft_data, sig1.stft_data[0:2])
|
||||||
|
assert np.allclose(indexed._loudness, sig1._loudness[0:2])
|
||||||
|
|
||||||
|
# Test using a boolean tensor to index batch
|
||||||
|
mask = paddle.to_tensor([True, False, True, False])
|
||||||
|
indexed = sig1[mask]
|
||||||
|
|
||||||
|
assert np.allclose(indexed.audio_data, sig1.audio_data[mask])
|
||||||
|
# assert np.allclose(indexed.stft_data, sig1.stft_data[mask])
|
||||||
|
assert np.allclose(indexed.stft_data,
|
||||||
|
util.bool_index_compat(sig1.stft_data, mask))
|
||||||
|
assert np.allclose(indexed._loudness, sig1._loudness[mask])
|
||||||
|
|
||||||
|
# Set parts of signal using tensor
|
||||||
|
other_array = paddle.to_tensor(np.random.randn(4, 2, 16000))
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
sig1[0, :, 6000:8000] = other_array[0, :, 6000:8000]
|
||||||
|
|
||||||
|
assert np.allclose(sig1[0, :, 6000:8000].audio_data,
|
||||||
|
other_array[0, :, 6000:8000])
|
||||||
|
|
||||||
|
# Set parts of signal using AudioSignal
|
||||||
|
sig2 = AudioSignal(other_array, sample_rate=16000)
|
||||||
|
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
sig1[0, :, 6000:8000] = sig2[0, :, 6000:8000]
|
||||||
|
|
||||||
|
assert np.allclose(sig1[0, :, 6000:8000].audio_data,
|
||||||
|
sig2[0, :, 6000:8000].audio_data)
|
||||||
|
|
||||||
|
# Check that loudnesses and stft_data get set as well, if only the batch
|
||||||
|
# dim is indexed.
|
||||||
|
sig2 = AudioSignal(other_array, sample_rate=16000)
|
||||||
|
sig2.stft()
|
||||||
|
sig2.loudness()
|
||||||
|
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
sig1.stft()
|
||||||
|
sig1.loudness()
|
||||||
|
|
||||||
|
# Test using a boolean tensor to index batch
|
||||||
|
mask = paddle.to_tensor([True, False, True, False])
|
||||||
|
sig1[mask] = sig2[mask]
|
||||||
|
|
||||||
|
for k in ["stft_data", "audio_data", "_loudness"]:
|
||||||
|
a1 = getattr(sig1, k)
|
||||||
|
a2 = getattr(sig2, k)
|
||||||
|
|
||||||
|
# assert np.allclose(a1[mask], a2[mask])
|
||||||
|
assert np.allclose(
|
||||||
|
util.bool_index_compat(a1, mask), util.bool_index_compat(a2, mask))
|
||||||
|
|
||||||
|
|
||||||
|
def test_zeros():
|
||||||
|
x = AudioSignal.zeros(0.5, 44100)
|
||||||
|
assert x.signal_duration == 0.5
|
||||||
|
assert x.duration == 0.5
|
||||||
|
assert x.sample_rate == 44100
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("shape",
|
||||||
|
["sine", "square", "sawtooth", "triangle", "beep"])
|
||||||
|
def test_waves(shape: str):
|
||||||
|
# error case
|
||||||
|
if shape == "beep":
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AudioSignal.wave(440, 0.5, 44100, shape=shape)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
x = AudioSignal.wave(440, 0.5, 44100, shape=shape)
|
||||||
|
assert x.duration == 0.5
|
||||||
|
assert x.sample_rate == 44100
|
||||||
|
|
||||||
|
# test the default shape arg
|
||||||
|
x = AudioSignal.wave(440, 0.5, 44100)
|
||||||
|
assert x.duration == 0.5
|
||||||
|
assert x.sample_rate == 44100
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_pad():
|
||||||
|
array = np.random.randn(4, 2, 16000)
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
|
||||||
|
sig1.zero_pad(100, 100)
|
||||||
|
zeros = paddle.zeros([4, 2, 100], dtype="float64")
|
||||||
|
assert paddle.allclose(sig1.audio_data[..., :100], zeros)
|
||||||
|
assert paddle.allclose(sig1.audio_data[..., -100:], zeros)
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_pad_to():
|
||||||
|
array = np.random.randn(4, 2, 16000)
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
|
||||||
|
sig1.zero_pad_to(16100)
|
||||||
|
zeros = paddle.zeros([4, 2, 100], dtype="float64")
|
||||||
|
assert paddle.allclose(sig1.audio_data[..., -100:], zeros)
|
||||||
|
assert sig1.signal_length == 16100
|
||||||
|
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
sig1.zero_pad_to(15000)
|
||||||
|
assert sig1.signal_length == 16000
|
||||||
|
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
sig1.zero_pad_to(16100, mode="before")
|
||||||
|
zeros = paddle.zeros([4, 2, 100], dtype="float64")
|
||||||
|
assert paddle.allclose(sig1.audio_data[..., :100], zeros)
|
||||||
|
assert sig1.signal_length == 16100
|
||||||
|
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
sig1.zero_pad_to(15000, mode="before")
|
||||||
|
assert sig1.signal_length == 16000
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate():
|
||||||
|
array = np.random.randn(4, 2, 16000)
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
|
||||||
|
sig1.truncate_samples(100)
|
||||||
|
assert sig1.signal_length == 100
|
||||||
|
assert np.allclose(sig1.audio_data, array[..., :100])
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim():
|
||||||
|
array = np.random.randn(4, 2, 16000)
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
|
||||||
|
sig1.trim(100, 100)
|
||||||
|
assert sig1.signal_length == 16000 - 200
|
||||||
|
assert np.allclose(sig1.audio_data, array[..., 100:-100])
|
||||||
|
|
||||||
|
array = np.random.randn(4, 2, 16000)
|
||||||
|
sig1 = AudioSignal(array, sample_rate=16000)
|
||||||
|
sig1.trim(0, 0)
|
||||||
|
assert np.allclose(sig1.audio_data, array)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_from_ops():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path)
|
||||||
|
signal.stft()
|
||||||
|
signal.loudness()
|
||||||
|
signal = signal.to("cpu")
|
||||||
|
|
||||||
|
assert str(signal.audio_data.place) == "Place(cpu)"
|
||||||
|
assert isinstance(signal.numpy(), np.ndarray)
|
||||||
|
|
||||||
|
signal.cpu()
|
||||||
|
# signal.cuda()
|
||||||
|
signal.float()
|
||||||
|
|
||||||
|
|
||||||
|
def test_device():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path)
|
||||||
|
signal.to("cpu")
|
||||||
|
|
||||||
|
assert str(signal.device) == "Place(cpu)"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("window_length", [2048, 512])
|
||||||
|
@pytest.mark.parametrize("hop_length", [512, 128])
|
||||||
|
@pytest.mark.parametrize("window_type", ["sqrt_hann", "hann", None])
|
||||||
|
def test_stft(window_length, hop_length, window_type):
|
||||||
|
if hop_length >= window_length:
|
||||||
|
hop_length = window_length // 2
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
stft_params = audiotools.STFTParams(
|
||||||
|
window_length=window_length,
|
||||||
|
hop_length=hop_length,
|
||||||
|
window_type=window_type)
|
||||||
|
for _stft_params in [None, stft_params]:
|
||||||
|
signal = AudioSignal(audio_path, duration=10, stft_params=_stft_params)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
signal.istft()
|
||||||
|
|
||||||
|
stft_data = signal.stft()
|
||||||
|
|
||||||
|
# assert paddle.allclose(signal.stft_data, stft_data)
|
||||||
|
assert np.allclose(signal.stft_data.cpu().numpy(),
|
||||||
|
stft_data.cpu().numpy())
|
||||||
|
copied_signal = signal.deepcopy()
|
||||||
|
copied_signal.stft()
|
||||||
|
copied_signal = copied_signal.istft()
|
||||||
|
|
||||||
|
assert copied_signal == signal
|
||||||
|
|
||||||
|
mag = signal.magnitude
|
||||||
|
phase = signal.phase
|
||||||
|
|
||||||
|
recon_stft = mag * util.exp_compat(1j * phase)
|
||||||
|
# assert paddle.allclose(recon_stft, signal.stft_data)
|
||||||
|
assert np.allclose(recon_stft.cpu().numpy(),
|
||||||
|
signal.stft_data.cpu().numpy())
|
||||||
|
|
||||||
|
signal.stft_data = None
|
||||||
|
mag = signal.magnitude
|
||||||
|
signal.stft_data = None
|
||||||
|
phase = signal.phase
|
||||||
|
|
||||||
|
recon_stft = mag * util.exp_compat(1j * phase)
|
||||||
|
# assert paddle.allclose(recon_stft, signal.stft_data)
|
||||||
|
assert np.allclose(recon_stft.cpu().numpy(),
|
||||||
|
signal.stft_data.cpu().numpy())
|
||||||
|
|
||||||
|
# Test with match_stride=True, ignoring the beginning and end.
|
||||||
|
s = signal.stft_params
|
||||||
|
if s.hop_length == s.window_length // 4:
|
||||||
|
og_signal = signal.clone()
|
||||||
|
stft_data = signal.stft(match_stride=True)
|
||||||
|
recon_data = signal.istft(match_stride=True)
|
||||||
|
discard = window_length * 2
|
||||||
|
|
||||||
|
right_pad, _ = signal.compute_stft_padding(
|
||||||
|
s.window_length, s.hop_length, match_stride=True)
|
||||||
|
length = signal.signal_length + right_pad
|
||||||
|
assert stft_data.shape[-1] == length // s.hop_length
|
||||||
|
|
||||||
|
assert paddle.allclose(
|
||||||
|
recon_data.audio_data[..., discard:-discard],
|
||||||
|
og_signal.audio_data[..., discard:-discard],
|
||||||
|
atol=1e-6, )
|
||||||
|
|
||||||
|
|
||||||
|
def test_log_magnitude():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
for _ in range(10):
|
||||||
|
signal = AudioSignal.excerpt(audio_path, duration=5.0)
|
||||||
|
magnitude = signal.magnitude.numpy()[0, 0]
|
||||||
|
librosa_log_mag = librosa.amplitude_to_db(magnitude)
|
||||||
|
log_mag = signal.log_magnitude().numpy()[0, 0]
|
||||||
|
|
||||||
|
# print(abs((log_mag - librosa_log_mag)).max())
|
||||||
|
assert np.allclose(log_mag, librosa_log_mag, atol=10e-7)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_mels", [40, 80, 128])
|
||||||
|
@pytest.mark.parametrize("window_length", [2048, 512])
|
||||||
|
@pytest.mark.parametrize("hop_length", [512, 128])
|
||||||
|
@pytest.mark.parametrize("window_type", ["sqrt_hann", "hann", None])
|
||||||
|
def test_mel_spectrogram(n_mels, window_length, hop_length, window_type):
|
||||||
|
if hop_length >= window_length:
|
||||||
|
hop_length = window_length // 2
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
stft_params = audiotools.STFTParams(
|
||||||
|
window_length=window_length,
|
||||||
|
hop_length=hop_length,
|
||||||
|
window_type=window_type)
|
||||||
|
for _stft_params in [None, stft_params]:
|
||||||
|
signal = AudioSignal(audio_path, duration=10, stft_params=_stft_params)
|
||||||
|
mel_spec = signal.mel_spectrogram(n_mels=n_mels)
|
||||||
|
assert mel_spec.shape[2] == n_mels
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_mfcc", [20, 40])
|
||||||
|
@pytest.mark.parametrize("n_mels", [40, 80, 128])
|
||||||
|
@pytest.mark.parametrize("window_length", [2048, 512])
|
||||||
|
@pytest.mark.parametrize("hop_length", [512, 128])
|
||||||
|
def test_mfcc(n_mfcc, n_mels, window_length, hop_length):
|
||||||
|
if hop_length >= window_length:
|
||||||
|
hop_length = window_length // 2
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
stft_params = audiotools.STFTParams(
|
||||||
|
window_length=window_length, hop_length=hop_length)
|
||||||
|
for _stft_params in [None, stft_params]:
|
||||||
|
signal = AudioSignal(audio_path, duration=10, stft_params=_stft_params)
|
||||||
|
mfcc = signal.mfcc(n_mfcc=n_mfcc, n_mels=n_mels)
|
||||||
|
assert mfcc.shape[2] == n_mfcc
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_mono():
|
||||||
|
array = np.random.randn(4, 2, 16000)
|
||||||
|
sr = 16000
|
||||||
|
|
||||||
|
signal = AudioSignal(array, sample_rate=sr)
|
||||||
|
assert signal.num_channels == 2
|
||||||
|
|
||||||
|
signal = signal.to_mono()
|
||||||
|
assert signal.num_channels == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_float():
|
||||||
|
array = np.random.randn(4, 1, 16000).astype("float64")
|
||||||
|
sr = 1600
|
||||||
|
signal = AudioSignal(array, sample_rate=sr)
|
||||||
|
|
||||||
|
signal = signal.float()
|
||||||
|
assert signal.audio_data.dtype == paddle.float32
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100, 48000])
|
||||||
|
def test_resample(sample_rate):
|
||||||
|
array = np.random.randn(4, 2, 16000)
|
||||||
|
sr = 16000
|
||||||
|
|
||||||
|
signal = AudioSignal(array, sample_rate=sr)
|
||||||
|
|
||||||
|
signal = signal.resample(sample_rate)
|
||||||
|
assert signal.sample_rate == sample_rate
|
||||||
|
assert signal.signal_length == sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def test_batching():
|
||||||
|
signals = []
|
||||||
|
batch_size = 16
|
||||||
|
|
||||||
|
# All same length, same sample rate.
|
||||||
|
for _ in range(batch_size):
|
||||||
|
array = np.random.randn(2, 16000)
|
||||||
|
signal = AudioSignal(array, sample_rate=16000)
|
||||||
|
signals.append(signal)
|
||||||
|
|
||||||
|
batched_signal = AudioSignal.batch(signals)
|
||||||
|
assert batched_signal.batch_size == batch_size
|
||||||
|
|
||||||
|
signals = []
|
||||||
|
# All different lengths, same sample rate, pad signals
|
||||||
|
for _ in range(batch_size):
|
||||||
|
L = np.random.randint(8000, 32000)
|
||||||
|
array = np.random.randn(2, L)
|
||||||
|
signal = AudioSignal(array, sample_rate=16000)
|
||||||
|
signals.append(signal)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
batched_signal = AudioSignal.batch(signals)
|
||||||
|
|
||||||
|
signal_lengths = [x.signal_length for x in signals]
|
||||||
|
max_length = max(signal_lengths)
|
||||||
|
batched_signal = AudioSignal.batch(signals, pad_signals=True)
|
||||||
|
|
||||||
|
assert batched_signal.signal_length == max_length
|
||||||
|
assert batched_signal.batch_size == batch_size
|
||||||
|
|
||||||
|
signals = []
|
||||||
|
# All different lengths, same sample rate, truncate signals
|
||||||
|
for _ in range(batch_size):
|
||||||
|
L = np.random.randint(8000, 32000)
|
||||||
|
array = np.random.randn(2, L)
|
||||||
|
signal = AudioSignal(array, sample_rate=16000)
|
||||||
|
signals.append(signal)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
batched_signal = AudioSignal.batch(signals)
|
||||||
|
|
||||||
|
signal_lengths = [x.signal_length for x in signals]
|
||||||
|
min_length = min(signal_lengths)
|
||||||
|
batched_signal = AudioSignal.batch(signals, truncate_signals=True)
|
||||||
|
|
||||||
|
assert batched_signal.signal_length == min_length
|
||||||
|
assert batched_signal.batch_size == batch_size
|
||||||
|
|
||||||
|
signals = []
|
||||||
|
# All different lengths, different sample rate, pad signals
|
||||||
|
for _ in range(batch_size):
|
||||||
|
L = np.random.randint(8000, 32000)
|
||||||
|
sr = np.random.choice([8000, 16000, 32000])
|
||||||
|
array = np.random.randn(2, L)
|
||||||
|
signal = AudioSignal(array, sample_rate=int(sr))
|
||||||
|
signals.append(signal)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
batched_signal = AudioSignal.batch(signals)
|
||||||
|
|
||||||
|
signal_lengths = [x.signal_length for x in signals]
|
||||||
|
max_length = max(signal_lengths)
|
||||||
|
for i, x in enumerate(signals):
|
||||||
|
x.path_to_file = i
|
||||||
|
batched_signal = AudioSignal.batch(signals, resample=True, pad_signals=True)
|
||||||
|
|
||||||
|
assert batched_signal.signal_length == max_length
|
||||||
|
assert batched_signal.batch_size == batch_size
|
||||||
|
assert batched_signal.path_to_file == list(range(len(signals)))
|
||||||
|
assert batched_signal.path_to_input_file == batched_signal.path_to_file
|
@ -0,0 +1,54 @@
|
|||||||
|
# MIT License, Copyright (c) 2020 Alexandre Défossez.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_bands.py)
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from audio.audiotools.core import pure_tone
|
||||||
|
from audio.audiotools.core import split_bands
|
||||||
|
from audio.audiotools.core import SplitBands
|
||||||
|
|
||||||
|
|
||||||
|
def delta(a, b, ref, fraction=0.9):
|
||||||
|
length = a.shape[-1]
|
||||||
|
compare_length = int(length * fraction)
|
||||||
|
offset = (length - compare_length) // 2
|
||||||
|
a = a[..., offset:offset + length]
|
||||||
|
b = b[..., offset:offset + length]
|
||||||
|
return 100 * paddle.abs(a - b).mean() / ref.std()
|
||||||
|
|
||||||
|
|
||||||
|
TOLERANCE = 0.5 # Tolerance to errors as percentage of the std of the input signal
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseTest(unittest.TestCase):
|
||||||
|
def assertSimilar(self, a, b, ref, msg=None, tol=TOLERANCE):
|
||||||
|
self.assertLessEqual(delta(a, b, ref), tol, msg)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLowPassFilters(_BaseTest):
|
||||||
|
def setUp(self):
|
||||||
|
paddle.seed(1234)
|
||||||
|
random.seed(1234)
|
||||||
|
|
||||||
|
def test_keep_or_kill(self):
|
||||||
|
sr = 256
|
||||||
|
low = pure_tone(10, sr)
|
||||||
|
mid = pure_tone(40, sr)
|
||||||
|
high = pure_tone(100, sr)
|
||||||
|
|
||||||
|
x = low + mid + high
|
||||||
|
|
||||||
|
decomp = split_bands(x, sr, cutoffs=[20, 70])
|
||||||
|
self.assertEqual(len(decomp), 3)
|
||||||
|
for est, gt, name in zip(decomp, [low, mid, high],
|
||||||
|
["low", "mid", "high"]):
|
||||||
|
self.assertSimilar(est, gt, gt, name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,51 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_display.py)
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
from audio.audiotools import AudioSignal
|
||||||
|
|
||||||
|
|
||||||
|
def test_specshow():
|
||||||
|
array = np.zeros((1, 16000))
|
||||||
|
AudioSignal(array, sample_rate=16000).specshow()
|
||||||
|
AudioSignal(array, sample_rate=16000).specshow(preemphasis=True)
|
||||||
|
AudioSignal(
|
||||||
|
array, sample_rate=16000).specshow(
|
||||||
|
title="test", preemphasis=True)
|
||||||
|
AudioSignal(
|
||||||
|
array, sample_rate=16000).specshow(
|
||||||
|
format=False, preemphasis=True)
|
||||||
|
AudioSignal(
|
||||||
|
array, sample_rate=16000).specshow(
|
||||||
|
format=False, preemphasis=False, y_axis="mel")
|
||||||
|
|
||||||
|
|
||||||
|
def test_waveplot():
|
||||||
|
array = np.zeros((1, 16000))
|
||||||
|
AudioSignal(array, sample_rate=16000).waveplot()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wavespec():
|
||||||
|
array = np.zeros((1, 16000))
|
||||||
|
AudioSignal(array, sample_rate=16000).wavespec()
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_audio_to_tb():
|
||||||
|
signal = AudioSignal("./audio/spk/f10_script4_produced.mp3", duration=5)
|
||||||
|
|
||||||
|
Path("./scratch").mkdir(parents=True, exist_ok=True)
|
||||||
|
writer = LogWriter("./scratch/")
|
||||||
|
signal.write_audio_to_tb("tag", writer)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_image():
|
||||||
|
signal = AudioSignal(
|
||||||
|
"./audio/spk/f10_script4_produced.wav", duration=10, offset=10)
|
||||||
|
Path("./scratch").mkdir(parents=True, exist_ok=True)
|
||||||
|
signal.save_image("./scratch/image.png")
|
@ -0,0 +1,181 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_dsp.py)
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from audio.audiotools import AudioSignal
|
||||||
|
from audio.audiotools.core.util import sample_from_dist
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("window_duration", [0.1, 0.25, 0.5, 1.0])
|
||||||
|
@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100])
|
||||||
|
@pytest.mark.parametrize("duration", [0.5, 1.0, 2.0, 10.0])
|
||||||
|
def test_overlap_add(duration, sample_rate, window_duration):
|
||||||
|
np.random.seed(0)
|
||||||
|
if duration > window_duration:
|
||||||
|
spk_signal = AudioSignal.batch([
|
||||||
|
AudioSignal.excerpt(
|
||||||
|
"./audio/spk/f10_script4_produced.wav", duration=duration)
|
||||||
|
for _ in range(16)
|
||||||
|
])
|
||||||
|
spk_signal.resample(sample_rate)
|
||||||
|
|
||||||
|
noise = paddle.randn([16, 1, int(duration * sample_rate)])
|
||||||
|
nz_signal = AudioSignal(noise, sample_rate=sample_rate)
|
||||||
|
|
||||||
|
def _test(signal):
|
||||||
|
hop_duration = window_duration / 2
|
||||||
|
windowed_signal = signal.clone().collect_windows(window_duration,
|
||||||
|
hop_duration)
|
||||||
|
recombined = windowed_signal.overlap_and_add(hop_duration)
|
||||||
|
|
||||||
|
assert recombined == signal
|
||||||
|
assert np.allclose(recombined.audio_data, signal.audio_data, 1e-3)
|
||||||
|
|
||||||
|
_test(nz_signal)
|
||||||
|
_test(spk_signal)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("window_duration", [0.1, 0.25, 0.5, 1.0])
|
||||||
|
@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 44100])
|
||||||
|
@pytest.mark.parametrize("duration", [0.5, 1.0, 2.0, 10.0])
|
||||||
|
def test_inplace_overlap_add(duration, sample_rate, window_duration):
|
||||||
|
np.random.seed(0)
|
||||||
|
if duration > window_duration:
|
||||||
|
spk_signal = AudioSignal.batch([
|
||||||
|
AudioSignal.excerpt(
|
||||||
|
"./audio/spk/f10_script4_produced.wav", duration=duration)
|
||||||
|
for _ in range(16)
|
||||||
|
])
|
||||||
|
spk_signal.resample(sample_rate)
|
||||||
|
|
||||||
|
noise = paddle.randn([16, 1, int(duration * sample_rate)])
|
||||||
|
nz_signal = AudioSignal(noise, sample_rate=sample_rate)
|
||||||
|
|
||||||
|
def _test(signal):
|
||||||
|
hop_duration = window_duration / 2
|
||||||
|
windowed_signal = signal.clone().collect_windows(window_duration,
|
||||||
|
hop_duration)
|
||||||
|
# Compare in-place with unfold results
|
||||||
|
for i, window in enumerate(
|
||||||
|
signal.clone().windows(window_duration, hop_duration)):
|
||||||
|
assert np.allclose(window.audio_data,
|
||||||
|
windowed_signal.audio_data[i])
|
||||||
|
|
||||||
|
_test(nz_signal)
|
||||||
|
_test(spk_signal)
|
||||||
|
|
||||||
|
|
||||||
|
def test_low_pass():
|
||||||
|
sample_rate = 44100
|
||||||
|
f = 440
|
||||||
|
t = paddle.arange(0, 1, 1 / sample_rate)
|
||||||
|
sine_wave = paddle.sin(2 * np.pi * f * t)
|
||||||
|
window = AudioSignal.get_window("hann", sine_wave.shape[-1])
|
||||||
|
sine_wave = sine_wave * window
|
||||||
|
signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate)
|
||||||
|
out = signal.clone().low_pass(220)
|
||||||
|
assert out.audio_data.abs().max() < 1e-4
|
||||||
|
|
||||||
|
out = signal.clone().low_pass(880)
|
||||||
|
assert (out - signal).audio_data.abs().max() < 1e-3
|
||||||
|
|
||||||
|
batch = AudioSignal.batch([signal.clone(), signal.clone(), signal.clone()])
|
||||||
|
|
||||||
|
cutoffs = [220, 880, 220]
|
||||||
|
out = batch.clone().low_pass(cutoffs)
|
||||||
|
|
||||||
|
assert out.audio_data[0].abs().max() < 1e-4
|
||||||
|
assert out.audio_data[2].abs().max() < 1e-4
|
||||||
|
assert (out - batch).audio_data[1].abs().max() < 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
def test_high_pass():
|
||||||
|
sample_rate = 44100
|
||||||
|
f = 440
|
||||||
|
t = paddle.arange(0, 1, 1 / sample_rate)
|
||||||
|
sine_wave = paddle.sin(2 * np.pi * f * t)
|
||||||
|
window = AudioSignal.get_window("hann", sine_wave.shape[-1])
|
||||||
|
sine_wave = sine_wave * window
|
||||||
|
signal = AudioSignal(sine_wave.unsqueeze(0), sample_rate=sample_rate)
|
||||||
|
out = signal.clone().high_pass(220)
|
||||||
|
assert (signal - out).audio_data.abs().max() < 1e-4
|
||||||
|
|
||||||
|
|
||||||
|
def test_mask_frequencies():
|
||||||
|
sample_rate = 44100
|
||||||
|
fs = paddle.to_tensor([500.0, 2000.0, 8000.0, 32000.0])[None]
|
||||||
|
t = paddle.arange(0, 1, 1 / sample_rate)[:, None]
|
||||||
|
sine_wave = paddle.sin(2 * np.pi * t @ fs).sum(axis=-1)
|
||||||
|
sine_wave = AudioSignal(sine_wave, sample_rate)
|
||||||
|
masked_sine_wave = sine_wave.mask_frequencies(fmin_hz=1500, fmax_hz=10000)
|
||||||
|
|
||||||
|
fs2 = paddle.to_tensor([500.0, 32000.0])[None]
|
||||||
|
sine_wave2 = paddle.sin(2 * np.pi * t @ fs).sum(axis=-1)
|
||||||
|
sine_wave2 = AudioSignal(sine_wave2, sample_rate)
|
||||||
|
|
||||||
|
assert paddle.allclose(masked_sine_wave.audio_data, sine_wave2.audio_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mask_timesteps():
|
||||||
|
sample_rate = 44100
|
||||||
|
f = 440
|
||||||
|
t = paddle.linspace(0, 1, sample_rate)
|
||||||
|
sine_wave = paddle.sin(2 * np.pi * f * t)
|
||||||
|
sine_wave = AudioSignal(sine_wave, sample_rate)
|
||||||
|
|
||||||
|
masked_sine_wave = sine_wave.mask_timesteps(tmin_s=0.25, tmax_s=0.75)
|
||||||
|
masked_sine_wave.istft()
|
||||||
|
|
||||||
|
mask = ((0.3 < t) & (t < 0.7))[None, None]
|
||||||
|
assert paddle.allclose(
|
||||||
|
masked_sine_wave.audio_data[mask],
|
||||||
|
paddle.zeros_like(masked_sine_wave.audio_data[mask]), )
|
||||||
|
|
||||||
|
|
||||||
|
def test_shift_phase():
|
||||||
|
sample_rate = 44100
|
||||||
|
f = 440
|
||||||
|
t = paddle.linspace(0, 1, sample_rate)
|
||||||
|
sine_wave = paddle.sin(2 * np.pi * f * t)
|
||||||
|
sine_wave = AudioSignal(sine_wave, sample_rate)
|
||||||
|
sine_wave2 = sine_wave.clone()
|
||||||
|
|
||||||
|
shifted_sine_wave = sine_wave.shift_phase(np.pi)
|
||||||
|
shifted_sine_wave.istft()
|
||||||
|
|
||||||
|
sine_wave2.phase = sine_wave2.phase + np.pi
|
||||||
|
sine_wave2.istft()
|
||||||
|
|
||||||
|
assert paddle.allclose(shifted_sine_wave.audio_data, sine_wave2.audio_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_corrupt_phase():
|
||||||
|
sample_rate = 44100
|
||||||
|
f = 440
|
||||||
|
t = paddle.linspace(0, 1, sample_rate)
|
||||||
|
sine_wave = paddle.sin(2 * np.pi * f * t)
|
||||||
|
sine_wave = AudioSignal(sine_wave, sample_rate)
|
||||||
|
sine_wave2 = sine_wave.clone()
|
||||||
|
|
||||||
|
shifted_sine_wave = sine_wave.corrupt_phase(scale=np.pi)
|
||||||
|
shifted_sine_wave.istft()
|
||||||
|
|
||||||
|
assert (sine_wave2.phase - shifted_sine_wave.phase).abs().mean() > 0.0
|
||||||
|
assert ((sine_wave2.phase - shifted_sine_wave.phase).std() / np.pi) < 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_preemphasis():
|
||||||
|
x = AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=5)
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
x.specshow(preemphasis=False)
|
||||||
|
|
||||||
|
x.specshow(preemphasis=True)
|
||||||
|
|
||||||
|
x.preemphasis()
|
@ -0,0 +1,321 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_effects.py)
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from audio.audiotools import AudioSignal
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=10)
|
||||||
|
signal = signal.normalize()
|
||||||
|
assert np.allclose(signal.loudness(), -24, atol=1e-1)
|
||||||
|
|
||||||
|
array = np.random.randn(1, 2, 32000)
|
||||||
|
array = array / np.abs(array).max()
|
||||||
|
|
||||||
|
signal = AudioSignal(array, sample_rate=16000)
|
||||||
|
for db_incr in np.arange(10, 75, 5):
|
||||||
|
db = -80 + db_incr
|
||||||
|
signal = signal.normalize(db)
|
||||||
|
loudness = signal.loudness()
|
||||||
|
assert np.allclose(loudness, db, atol=1) # TODO, atol=1e-1
|
||||||
|
|
||||||
|
batch_size = 16
|
||||||
|
db = -60 + paddle.linspace(10, 30, batch_size)
|
||||||
|
|
||||||
|
array = np.random.randn(batch_size, 2, 32000)
|
||||||
|
array = array / np.abs(array).max()
|
||||||
|
signal = AudioSignal(array, sample_rate=16000)
|
||||||
|
|
||||||
|
signal = signal.normalize(db)
|
||||||
|
assert np.allclose(signal.loudness(), db, 1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_volume_change():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=10)
|
||||||
|
|
||||||
|
boost = 3
|
||||||
|
before_db = signal.loudness().clone()
|
||||||
|
signal = signal.volume_change(boost)
|
||||||
|
after_db = signal.loudness()
|
||||||
|
assert np.allclose(before_db + boost, after_db)
|
||||||
|
|
||||||
|
signal._loudness = None
|
||||||
|
after_db = signal.loudness()
|
||||||
|
assert np.allclose(before_db + boost, after_db, 1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mix():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=10)
|
||||||
|
|
||||||
|
audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav"
|
||||||
|
nz = AudioSignal(audio_path, offset=10, duration=10)
|
||||||
|
|
||||||
|
spk.deepcopy().mix(nz, snr=-10)
|
||||||
|
snr = spk.loudness() - nz.loudness()
|
||||||
|
assert np.allclose(snr, -10, atol=1)
|
||||||
|
|
||||||
|
# Test in batch
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=10)
|
||||||
|
|
||||||
|
audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav"
|
||||||
|
nz = AudioSignal(audio_path, offset=10, duration=10)
|
||||||
|
|
||||||
|
batch_size = 4
|
||||||
|
tgt_snr = paddle.linspace(-10, 10, batch_size)
|
||||||
|
|
||||||
|
spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
|
||||||
|
nz_batch = AudioSignal.batch([nz.deepcopy() for _ in range(batch_size)])
|
||||||
|
|
||||||
|
spk_batch.deepcopy().mix(nz_batch, snr=tgt_snr)
|
||||||
|
snr = spk_batch.loudness() - nz_batch.loudness()
|
||||||
|
assert np.allclose(snr, tgt_snr, atol=1)
|
||||||
|
|
||||||
|
# Test with "EQing" the other signal
|
||||||
|
db = 0 + 0 * paddle.rand([10])
|
||||||
|
spk_batch.deepcopy().mix(nz_batch, snr=tgt_snr, other_eq=db)
|
||||||
|
snr = spk_batch.loudness() - nz_batch.loudness()
|
||||||
|
assert np.allclose(snr, tgt_snr, atol=1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_convolve():
|
||||||
|
np.random.seed(6) # Found a failing seed
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=10)
|
||||||
|
|
||||||
|
impulse = np.zeros((1, 16000), dtype="float32")
|
||||||
|
impulse[..., 0] = 1
|
||||||
|
ir = AudioSignal(impulse, 16000)
|
||||||
|
batch_size = 4
|
||||||
|
|
||||||
|
spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
|
||||||
|
ir_batch = AudioSignal.batch(
|
||||||
|
[
|
||||||
|
ir.deepcopy().zero_pad(np.random.randint(1000), 0)
|
||||||
|
for _ in range(batch_size)
|
||||||
|
],
|
||||||
|
pad_signals=True, )
|
||||||
|
|
||||||
|
convolved = spk_batch.deepcopy().convolve(ir_batch)
|
||||||
|
assert convolved == spk_batch
|
||||||
|
|
||||||
|
# Short duration
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=0.1)
|
||||||
|
|
||||||
|
impulse = np.zeros((1, 16000), dtype="float32")
|
||||||
|
impulse[..., 0] = 1
|
||||||
|
ir = AudioSignal(impulse, 16000)
|
||||||
|
batch_size = 4
|
||||||
|
|
||||||
|
spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)])
|
||||||
|
ir_batch = AudioSignal.batch(
|
||||||
|
[
|
||||||
|
ir.deepcopy().zero_pad(np.random.randint(1000), 0)
|
||||||
|
for _ in range(batch_size)
|
||||||
|
],
|
||||||
|
pad_signals=True, )
|
||||||
|
|
||||||
|
convolved = spk_batch.deepcopy().convolve(ir_batch)
|
||||||
|
assert convolved == spk_batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline():
|
||||||
|
# An actual IR, no batching
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=5)
|
||||||
|
|
||||||
|
audio_path = "./audio/ir/h179_Bar_1txts.wav"
|
||||||
|
ir = AudioSignal(audio_path)
|
||||||
|
spk.deepcopy().convolve(ir)
|
||||||
|
|
||||||
|
audio_path = "./audio/nz/f5_script2_ipad_balcony1_room_tone.wav"
|
||||||
|
nz = AudioSignal(audio_path, offset=10, duration=5)
|
||||||
|
|
||||||
|
batch_size = 16
|
||||||
|
tgt_snr = paddle.linspace(20, 30, batch_size)
|
||||||
|
|
||||||
|
(spk @ ir).mix(nz, snr=tgt_snr)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16])
|
||||||
|
def test_mel_filterbank(n_bands):
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=1)
|
||||||
|
fbank = spk.deepcopy().mel_filterbank(n_bands)
|
||||||
|
|
||||||
|
assert paddle.allclose(fbank.sum(-1), spk.audio_data, atol=1e-6)
|
||||||
|
|
||||||
|
# Check if it works in batches.
|
||||||
|
spk_batch = AudioSignal.batch([
|
||||||
|
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2)
|
||||||
|
for _ in range(16)
|
||||||
|
])
|
||||||
|
fbank = spk_batch.deepcopy().mel_filterbank(n_bands)
|
||||||
|
summed = fbank.sum(-1)
|
||||||
|
assert paddle.allclose(summed, spk_batch.audio_data, atol=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16])
|
||||||
|
def test_equalizer(n_bands):
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=10)
|
||||||
|
|
||||||
|
db = -3 + 1 * paddle.rand([n_bands])
|
||||||
|
spk.deepcopy().equalizer(db)
|
||||||
|
|
||||||
|
db = -3 + 1 * np.random.rand(n_bands)
|
||||||
|
spk.deepcopy().equalizer(db)
|
||||||
|
|
||||||
|
audio_path = "./audio/ir/h179_Bar_1txts.wav"
|
||||||
|
ir = AudioSignal(audio_path)
|
||||||
|
db = -3 + 1 * paddle.rand([n_bands])
|
||||||
|
|
||||||
|
spk.deepcopy().convolve(ir.equalizer(db))
|
||||||
|
|
||||||
|
spk_batch = AudioSignal.batch([
|
||||||
|
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2)
|
||||||
|
for _ in range(16)
|
||||||
|
])
|
||||||
|
|
||||||
|
db = paddle.zeros([spk_batch.batch_size, n_bands])
|
||||||
|
output = spk_batch.deepcopy().equalizer(db)
|
||||||
|
|
||||||
|
assert output == spk_batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_clip_distortion():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
clipped = spk.deepcopy().clip_distortion(0.05)
|
||||||
|
|
||||||
|
spk_batch = AudioSignal.batch([
|
||||||
|
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2)
|
||||||
|
for _ in range(16)
|
||||||
|
])
|
||||||
|
percs = paddle.to_tensor(np.random.uniform(size=(16, ))).astype("float32")
|
||||||
|
clipped_batch = spk_batch.deepcopy().clip_distortion(percs)
|
||||||
|
|
||||||
|
assert clipped.audio_data.abs().max() < 1.0
|
||||||
|
assert clipped_batch.audio_data.abs().max() < 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128])
|
||||||
|
def test_quantization(quant_ch):
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
|
||||||
|
quantized = spk.deepcopy().quantization(quant_ch)
|
||||||
|
|
||||||
|
# Need to round audio_data off because torch ops with straight
|
||||||
|
# through estimator are sometimes a bit off past 3 decimal places.
|
||||||
|
found_quant_ch = len(np.unique(np.around(quantized.audio_data, decimals=3)))
|
||||||
|
assert found_quant_ch <= quant_ch
|
||||||
|
|
||||||
|
spk_batch = AudioSignal.batch([
|
||||||
|
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2)
|
||||||
|
for _ in range(16)
|
||||||
|
])
|
||||||
|
|
||||||
|
quant_ch = np.random.choice(
|
||||||
|
[2, 4, 8, 16, 32, 64, 128], size=(16, ), replace=True)
|
||||||
|
quantized = spk_batch.deepcopy().quantization(quant_ch)
|
||||||
|
|
||||||
|
for i, q_ch in enumerate(quant_ch):
|
||||||
|
found_quant_ch = len(
|
||||||
|
np.unique(np.around(quantized.audio_data[i], decimals=3)))
|
||||||
|
assert found_quant_ch <= q_ch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128])
|
||||||
|
def test_mulaw_quantization(quant_ch):
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
|
||||||
|
quantized = spk.deepcopy().mulaw_quantization(quant_ch)
|
||||||
|
|
||||||
|
# Need to round audio_data off because torch ops with straight
|
||||||
|
# through estimator are sometimes a bit off past 3 decimal places.
|
||||||
|
found_quant_ch = len(np.unique(np.around(quantized.audio_data, decimals=3)))
|
||||||
|
assert found_quant_ch <= quant_ch
|
||||||
|
|
||||||
|
spk_batch = AudioSignal.batch([
|
||||||
|
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=2)
|
||||||
|
for _ in range(16)
|
||||||
|
])
|
||||||
|
|
||||||
|
quant_ch = np.random.choice(
|
||||||
|
[2, 4, 8, 16, 32, 64, 128], size=(16, ), replace=True)
|
||||||
|
quantized = spk_batch.deepcopy().mulaw_quantization(quant_ch)
|
||||||
|
|
||||||
|
for i, q_ch in enumerate(quant_ch):
|
||||||
|
found_quant_ch = len(
|
||||||
|
np.unique(np.around(quantized.audio_data[i], decimals=3)))
|
||||||
|
assert found_quant_ch <= q_ch
|
||||||
|
|
||||||
|
|
||||||
|
def test_impulse_response_augmentation():
|
||||||
|
audio_path = "./audio/ir/h179_Bar_1txts.wav"
|
||||||
|
batch_size = 16
|
||||||
|
ir = AudioSignal(audio_path)
|
||||||
|
ir_batch = AudioSignal.batch([ir for _ in range(batch_size)])
|
||||||
|
early_response, late_field, window = ir_batch.decompose_ir()
|
||||||
|
|
||||||
|
assert early_response.shape == late_field.shape
|
||||||
|
assert late_field.shape == window.shape
|
||||||
|
|
||||||
|
drr = ir_batch.measure_drr()
|
||||||
|
|
||||||
|
alpha = AudioSignal.solve_alpha(early_response, late_field, window, drr)
|
||||||
|
assert np.allclose(alpha, np.ones_like(alpha), 1e-5)
|
||||||
|
|
||||||
|
target_drr = 5
|
||||||
|
out = ir_batch.deepcopy().alter_drr(target_drr)
|
||||||
|
drr = out.measure_drr()
|
||||||
|
assert np.allclose(drr, np.ones_like(drr) * target_drr)
|
||||||
|
|
||||||
|
target_drr = np.random.rand(batch_size).astype("float32") * 50
|
||||||
|
altered_ir = ir_batch.deepcopy().alter_drr(target_drr)
|
||||||
|
drr = altered_ir.measure_drr()
|
||||||
|
assert np.allclose(drr.flatten(), target_drr.flatten())
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_ir():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
ir_path = "./audio/ir/h179_Bar_1txts.wav"
|
||||||
|
|
||||||
|
spk = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
ir = AudioSignal(ir_path)
|
||||||
|
db = 0 + 0 * paddle.rand([10])
|
||||||
|
output = spk.deepcopy().apply_ir(ir, drr=10, ir_eq=db)
|
||||||
|
|
||||||
|
assert np.allclose(ir.measure_drr().flatten(), 10)
|
||||||
|
|
||||||
|
output = spk.deepcopy().apply_ir(
|
||||||
|
ir, drr=10, ir_eq=db, use_original_phase=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ensure_max_of_audio():
|
||||||
|
spk = AudioSignal(paddle.randn([1, 1, 44100]), 44100)
|
||||||
|
|
||||||
|
max_vals = [1.0] + [np.random.rand() for _ in range(10)]
|
||||||
|
for val in max_vals:
|
||||||
|
after = spk.deepcopy().ensure_max_of_audio(val)
|
||||||
|
assert after.audio_data.abs().max() <= val + 1e-3
|
||||||
|
|
||||||
|
# Make sure it does nothing to a tiny signal
|
||||||
|
spk = AudioSignal(paddle.rand([1, 1, 44100]), 44100)
|
||||||
|
spk.audio_data = spk.audio_data * 0.5
|
||||||
|
after = spk.deepcopy().ensure_max_of_audio()
|
||||||
|
|
||||||
|
assert paddle.allclose(after.audio_data, spk.audio_data)
|
@ -0,0 +1,85 @@
|
|||||||
|
# MIT License, Copyright (c) 2020 Alexandre Défossez.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from julius(https://github.com/adefossez/julius/blob/main/tests/test_fftconv.py)
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
from audio.audiotools.core import fft_conv1d
|
||||||
|
from audio.audiotools.core import FFTConv1D
|
||||||
|
|
||||||
|
TOLERANCE = 1e-4 # as relative delta in percentage
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseTest(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
paddle.seed(1234)
|
||||||
|
random.seed(1234)
|
||||||
|
|
||||||
|
def assertSimilar(self, a, b, msg=None, tol=TOLERANCE):
|
||||||
|
delta = 100 * paddle.norm(a - b, p=2) / paddle.norm(b, p=2)
|
||||||
|
self.assertLessEqual(delta.numpy(), tol, msg)
|
||||||
|
|
||||||
|
def compare_paddle(self, *args, msg=None, tol=TOLERANCE, **kwargs):
|
||||||
|
y_ref = F.conv1d(*args, **kwargs)
|
||||||
|
y = fft_conv1d(*args, **kwargs)
|
||||||
|
self.assertEqual(list(y.shape), list(y_ref.shape), msg)
|
||||||
|
self.assertSimilar(y, y_ref, msg, tol)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFFTConv1d(_BaseTest):
|
||||||
|
def test_same_as_paddle(self):
|
||||||
|
for _ in range(5):
|
||||||
|
kernel_size = random.randrange(4, 128)
|
||||||
|
batch_size = random.randrange(1, 6)
|
||||||
|
length = random.randrange(kernel_size, 1024)
|
||||||
|
chin = random.randrange(1, 12)
|
||||||
|
chout = random.randrange(1, 12)
|
||||||
|
bias = random.random() < 0.5
|
||||||
|
if random.random() < 0.5:
|
||||||
|
padding = 0
|
||||||
|
else:
|
||||||
|
padding = random.randrange(kernel_size // 2, 2 * kernel_size)
|
||||||
|
x = paddle.randn([batch_size, chin, length])
|
||||||
|
w = paddle.randn([chout, chin, kernel_size])
|
||||||
|
keys = ["length", "kernel_size", "chin", "chout", "bias"]
|
||||||
|
loc = locals()
|
||||||
|
state = {key: loc[key] for key in keys}
|
||||||
|
if bias:
|
||||||
|
bias = paddle.randn([chout])
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
for stride in [1, 2, 5]:
|
||||||
|
state["stride"] = stride
|
||||||
|
self.compare_paddle(
|
||||||
|
x, w, bias, stride, padding, msg=repr(state))
|
||||||
|
|
||||||
|
def test_small_input(self):
|
||||||
|
x = paddle.randn([1, 5, 19])
|
||||||
|
w = paddle.randn([10, 5, 32])
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
fft_conv1d(x, w)
|
||||||
|
|
||||||
|
x = paddle.randn([1, 5, 19])
|
||||||
|
w = paddle.randn([10, 5, 19])
|
||||||
|
self.assertEqual(list(fft_conv1d(x, w).shape), [1, 10, 1])
|
||||||
|
|
||||||
|
def test_module(self):
|
||||||
|
x = paddle.randn([16, 4, 1024])
|
||||||
|
mod = FFTConv1D(4, 5, 8, bias_attr=True)
|
||||||
|
mod(x)
|
||||||
|
mod = FFTConv1D(4, 5, 8, bias_attr=False)
|
||||||
|
mod(x)
|
||||||
|
|
||||||
|
def test_dynamic_graph(self):
|
||||||
|
x = paddle.randn([16, 4, 1024])
|
||||||
|
mod = FFTConv1D(4, 5, 8, bias_attr=True)
|
||||||
|
self.assertEqual(list(mod(x).shape), [16, 5, 1024 - 8 + 1])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -0,0 +1,172 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_grad.py)
|
||||||
|
import sys
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from audio.audiotools import AudioSignal
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_grad():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
ir_path = "./audio/ir/h179_Bar_1txts.wav"
|
||||||
|
|
||||||
|
def _test_audio_grad(attr: str, target=True, kwargs: dict={}):
|
||||||
|
signal = AudioSignal(audio_path)
|
||||||
|
signal.audio_data.stop_gradient = False
|
||||||
|
|
||||||
|
assert signal.audio_data.grad is None
|
||||||
|
|
||||||
|
# Avoid overwriting leaf tensor by cloning signal
|
||||||
|
attr = getattr(signal.clone(), attr)
|
||||||
|
result = attr(**kwargs) if isinstance(attr, Callable) else attr
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(result, AudioSignal):
|
||||||
|
# If necessary, propagate spectrogram changes to waveform
|
||||||
|
if result.stft_data is not None:
|
||||||
|
result.istft()
|
||||||
|
# if result.audio_data.dtype.is_complex:
|
||||||
|
if paddle.is_complex(result.audio_data):
|
||||||
|
result.audio_data.real.sum().backward()
|
||||||
|
else:
|
||||||
|
result.audio_data.sum().backward()
|
||||||
|
else:
|
||||||
|
# if result.dtype.is_complex:
|
||||||
|
if paddle.is_complex(result):
|
||||||
|
result.real().sum().backward()
|
||||||
|
else:
|
||||||
|
result.sum().backward()
|
||||||
|
|
||||||
|
assert signal.audio_data.grad is not None or not target
|
||||||
|
except RuntimeError:
|
||||||
|
assert not target
|
||||||
|
|
||||||
|
for a in [
|
||||||
|
["mix", True, {
|
||||||
|
"other": AudioSignal(audio_path),
|
||||||
|
"snr": 0
|
||||||
|
}],
|
||||||
|
["convolve", True, {
|
||||||
|
"other": AudioSignal(ir_path)
|
||||||
|
}],
|
||||||
|
[
|
||||||
|
"apply_ir",
|
||||||
|
True,
|
||||||
|
{
|
||||||
|
"ir": AudioSignal(ir_path),
|
||||||
|
"drr": 0.1,
|
||||||
|
"ir_eq": paddle.randn([6])
|
||||||
|
},
|
||||||
|
],
|
||||||
|
["ensure_max_of_audio", True],
|
||||||
|
["normalize", True],
|
||||||
|
["volume_change", True, {
|
||||||
|
"db": 1
|
||||||
|
}],
|
||||||
|
# ["pitch_shift", False, {"n_semitones": 1}],
|
||||||
|
# ["time_stretch", False, {"factor": 2}],
|
||||||
|
# ["apply_codec", False],
|
||||||
|
["equalizer", True, {
|
||||||
|
"db": paddle.randn([6])
|
||||||
|
}],
|
||||||
|
["clip_distortion", True, {
|
||||||
|
"clip_percentile": 0.5
|
||||||
|
}],
|
||||||
|
["quantization", True, {
|
||||||
|
"quantization_channels": 8
|
||||||
|
}],
|
||||||
|
["mulaw_quantization", True, {
|
||||||
|
"quantization_channels": 8
|
||||||
|
}],
|
||||||
|
["resample", True, {
|
||||||
|
"sample_rate": 16000
|
||||||
|
}],
|
||||||
|
["low_pass", True, {
|
||||||
|
"cutoffs": 1000
|
||||||
|
}],
|
||||||
|
["high_pass", True, {
|
||||||
|
"cutoffs": 1000
|
||||||
|
}],
|
||||||
|
["to_mono", True],
|
||||||
|
["zero_pad", True, {
|
||||||
|
"before": 10,
|
||||||
|
"after": 10
|
||||||
|
}],
|
||||||
|
["magnitude", True],
|
||||||
|
["phase", True],
|
||||||
|
["log_magnitude", True],
|
||||||
|
["loudness", False],
|
||||||
|
["stft", True],
|
||||||
|
["clone", True],
|
||||||
|
["mel_spectrogram", True],
|
||||||
|
["zero_pad_to", True, {
|
||||||
|
"length": 100000
|
||||||
|
}],
|
||||||
|
["truncate_samples", True, {
|
||||||
|
"length_in_samples": 1000
|
||||||
|
}],
|
||||||
|
["corrupt_phase", True, {
|
||||||
|
"scale": 0.5
|
||||||
|
}],
|
||||||
|
["shift_phase", True, {
|
||||||
|
"shift": 1
|
||||||
|
}],
|
||||||
|
["mask_low_magnitudes", True, {
|
||||||
|
"db_cutoff": 0
|
||||||
|
}],
|
||||||
|
["mask_frequencies", True, {
|
||||||
|
"fmin_hz": 100,
|
||||||
|
"fmax_hz": 1000
|
||||||
|
}],
|
||||||
|
["mask_timesteps", True, {
|
||||||
|
"tmin_s": 0.1,
|
||||||
|
"tmax_s": 0.5
|
||||||
|
}],
|
||||||
|
["__add__", True, {
|
||||||
|
"other": AudioSignal(audio_path)
|
||||||
|
}],
|
||||||
|
["__iadd__", True, {
|
||||||
|
"other": AudioSignal(audio_path)
|
||||||
|
}],
|
||||||
|
["__radd__", True, {
|
||||||
|
"other": AudioSignal(audio_path)
|
||||||
|
}],
|
||||||
|
["__sub__", True, {
|
||||||
|
"other": AudioSignal(audio_path)
|
||||||
|
}],
|
||||||
|
["__isub__", True, {
|
||||||
|
"other": AudioSignal(audio_path)
|
||||||
|
}],
|
||||||
|
["__mul__", True, {
|
||||||
|
"other": AudioSignal(audio_path)
|
||||||
|
}],
|
||||||
|
["__imul__", True, {
|
||||||
|
"other": AudioSignal(audio_path)
|
||||||
|
}],
|
||||||
|
["__rmul__", True, {
|
||||||
|
"other": AudioSignal(audio_path)
|
||||||
|
}],
|
||||||
|
]:
|
||||||
|
_test_audio_grad(*a)
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_grad():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
|
||||||
|
signal = AudioSignal(audio_path)
|
||||||
|
signal.audio_data.stop_gradient = False
|
||||||
|
|
||||||
|
assert signal.audio_data.grad is None
|
||||||
|
|
||||||
|
batch_size = 16
|
||||||
|
batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)])
|
||||||
|
|
||||||
|
batch.audio_data.sum().backward()
|
||||||
|
|
||||||
|
assert signal.audio_data.grad is not None
|
@ -0,0 +1,274 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_loudness.py)
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pyloudnorm
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
from audio.audiotools import AudioSignal
|
||||||
|
from audio.audiotools import datasets
|
||||||
|
from audio.audiotools import Meter
|
||||||
|
from audio.audiotools import transforms
|
||||||
|
|
||||||
|
ATOL = 1e-1
|
||||||
|
|
||||||
|
|
||||||
|
def test_loudness_against_pyln():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=5, duration=10)
|
||||||
|
signal_loudness = signal.loudness()
|
||||||
|
|
||||||
|
meter = pyloudnorm.Meter(
|
||||||
|
signal.sample_rate, filter_class="K-weighting", block_size=0.4)
|
||||||
|
py_loudness = meter.integrated_loudness(signal.numpy()[0].T)
|
||||||
|
assert np.allclose(signal_loudness, py_loudness)
|
||||||
|
|
||||||
|
|
||||||
|
def test_loudness_short():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=0.25)
|
||||||
|
signal_loudness = signal.loudness()
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_loudness():
|
||||||
|
np.random.seed(0)
|
||||||
|
array = np.random.randn(16, 2, 16000)
|
||||||
|
array /= np.abs(array).max()
|
||||||
|
|
||||||
|
gains = np.random.rand(array.shape[0])[:, None, None]
|
||||||
|
array = array * gains
|
||||||
|
|
||||||
|
meter = pyloudnorm.Meter(16000)
|
||||||
|
py_loudness = [
|
||||||
|
meter.integrated_loudness(array[i].T) for i in range(array.shape[0])
|
||||||
|
]
|
||||||
|
|
||||||
|
meter = Meter(16000)
|
||||||
|
meter.filter_class
|
||||||
|
at_loudness_iso = [
|
||||||
|
meter.integrated_loudness(array[i].T).item()
|
||||||
|
for i in range(array.shape[0])
|
||||||
|
]
|
||||||
|
|
||||||
|
assert np.allclose(py_loudness, at_loudness_iso, atol=1e-1)
|
||||||
|
|
||||||
|
signal = AudioSignal(array, sample_rate=16000)
|
||||||
|
at_loudness_batch = signal.loudness()
|
||||||
|
assert np.allclose(py_loudness, at_loudness_batch, atol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
# Tests below are copied from pyloudnorm
|
||||||
|
def test_integrated_loudness():
|
||||||
|
data, rate = sf.read("./audio/loudness/sine_1000.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter(data)
|
||||||
|
|
||||||
|
targetLoudness = -3.0523438444331137
|
||||||
|
assert np.allclose(loudness, targetLoudness)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rel_gate_test():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_RelGateTest.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -10.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_abs_gate_test():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_AbsGateTest.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -69.5
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_24LKFS_25Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_25Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -24.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_24LKFS_100Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_100Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -24.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_24LKFS_500Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_500Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -24.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_24LKFS_1000Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_1000Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -24.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_24LKFS_2000Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_2000Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -24.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_24LKFS_10000Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_24LKFS_10000Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -24.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_23LKFS_25Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_25Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -23.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_23LKFS_100Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_100Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -23.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_23LKFS_500Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_500Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -23.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_23LKFS_1000Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_1000Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -23.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_23LKFS_2000Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_2000Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -23.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_23LKFS_10000Hz_2ch():
|
||||||
|
data, rate = sf.read("./audio/loudness/1770-2_Comp_23LKFS_10000Hz_2ch.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -23.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_18LKFS_frequency_sweep():
|
||||||
|
data, rate = sf.read(
|
||||||
|
"./audio/loudness/1770-2_Comp_18LKFS_FrequencySweep.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -18.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conf_stereo_vinL_R_23LKFS():
|
||||||
|
data, rate = sf.read(
|
||||||
|
"./audio/loudness/1770-2_Conf_Stereo_VinL+R-23LKFS.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -23.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conf_monovoice_music_24LKFS():
|
||||||
|
data, rate = sf.read(
|
||||||
|
"./audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -24.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def conf_monovoice_music_24LKFS():
|
||||||
|
data, rate = sf.read(
|
||||||
|
"./audio/loudness/1770-2_Conf_Mono_Voice+Music-24LKFS.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -24.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conf_monovoice_music_23LKFS():
|
||||||
|
data, rate = sf.read(
|
||||||
|
"./audio/loudness/1770-2_Conf_Mono_Voice+Music-23LKFS.wav")
|
||||||
|
meter = Meter(rate)
|
||||||
|
loudness = meter.integrated_loudness(data)
|
||||||
|
|
||||||
|
targetLoudness = -23.0
|
||||||
|
assert np.allclose(loudness, targetLoudness, atol=ATOL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fir_accuracy():
|
||||||
|
transform = transforms.Compose(
|
||||||
|
transforms.ClippingDistortion(prob=0.5),
|
||||||
|
transforms.LowPass(prob=0.5),
|
||||||
|
transforms.HighPass(prob=0.5),
|
||||||
|
transforms.Equalizer(prob=0.5),
|
||||||
|
prob=0.5, )
|
||||||
|
loader = datasets.AudioLoader(sources=["./audio/spk.csv"])
|
||||||
|
dataset = datasets.AudioDataset(
|
||||||
|
loader,
|
||||||
|
44100,
|
||||||
|
10,
|
||||||
|
5.0,
|
||||||
|
transform=transform, )
|
||||||
|
|
||||||
|
for i in range(20):
|
||||||
|
item = dataset[i]
|
||||||
|
kwargs = item["transform_args"]
|
||||||
|
signal = item["signal"]
|
||||||
|
signal = transform(signal, **kwargs)
|
||||||
|
|
||||||
|
signal._loudness = None
|
||||||
|
iir_db = signal.clone().loudness()
|
||||||
|
fir_db = signal.clone().loudness(use_fir=True)
|
||||||
|
|
||||||
|
assert np.allclose(iir_db, fir_db, atol=1e-2)
|
@ -0,0 +1,157 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/core/test_util.py)
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from audio.audiotools import util
|
||||||
|
from audio.audiotools.core.audio_signal import AudioSignal
|
||||||
|
from paddlespeech.vector.training.seeding import seed_everything
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_random_state():
|
||||||
|
# seed is None
|
||||||
|
rng_type = type(np.random.RandomState(10))
|
||||||
|
rng = util.random_state(None)
|
||||||
|
assert type(rng) == rng_type
|
||||||
|
|
||||||
|
# seed is int
|
||||||
|
rng = util.random_state(10)
|
||||||
|
assert type(rng) == rng_type
|
||||||
|
|
||||||
|
# seed is RandomState
|
||||||
|
rng_test = np.random.RandomState(10)
|
||||||
|
rng = util.random_state(rng_test)
|
||||||
|
assert type(rng) == rng_type
|
||||||
|
|
||||||
|
# seed is none of the above : error
|
||||||
|
pytest.raises(ValueError, util.random_state, "random")
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed():
|
||||||
|
seed_everything(0)
|
||||||
|
paddle_result_a = paddle.randn([1])
|
||||||
|
np_result_a = np.random.randn(1)
|
||||||
|
py_result_a = random.random()
|
||||||
|
|
||||||
|
seed_everything(0)
|
||||||
|
paddle_result_b = paddle.randn([1])
|
||||||
|
np_result_b = np.random.randn(1)
|
||||||
|
py_result_b = random.random()
|
||||||
|
|
||||||
|
assert paddle_result_a == paddle_result_b
|
||||||
|
assert np_result_a == np_result_b
|
||||||
|
assert py_result_a == py_result_b
|
||||||
|
|
||||||
|
|
||||||
|
def test_hz_to_bin():
|
||||||
|
hz = paddle.to_tensor(np.array([100, 200, 300]), dtype="float32")
|
||||||
|
sr = 1000
|
||||||
|
n_fft = 2048
|
||||||
|
|
||||||
|
bins = util.hz_to_bin(hz, n_fft, sr)
|
||||||
|
|
||||||
|
assert (((bins / n_fft) * sr) - hz).abs().max() < 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_audio():
|
||||||
|
wav_files = util.find_audio("tests/", ["wav"])
|
||||||
|
for a in wav_files:
|
||||||
|
assert "wav" in str(a)
|
||||||
|
|
||||||
|
audio_files = util.find_audio("tests/", ["flac"])
|
||||||
|
assert not audio_files
|
||||||
|
|
||||||
|
# Make sure it works with single audio files
|
||||||
|
audio_files = util.find_audio("./audio/spk//f10_script4_produced.wav")
|
||||||
|
|
||||||
|
# Make sure it works with globs
|
||||||
|
audio_files = util.find_audio("tests/**/*.wav")
|
||||||
|
assert len(audio_files) == len(wav_files)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chdir():
|
||||||
|
with tempfile.TemporaryDirectory(suffix="tmp") as d:
|
||||||
|
with util.chdir(d):
|
||||||
|
assert os.path.samefile(d, os.path.realpath("."))
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_batch():
|
||||||
|
batch = {"tensor": paddle.randn([1]), "non_tensor": np.random.randn(1)}
|
||||||
|
util.prepare_batch(batch)
|
||||||
|
|
||||||
|
batch = paddle.randn([1])
|
||||||
|
util.prepare_batch(batch)
|
||||||
|
|
||||||
|
batch = [paddle.randn([1]), np.random.randn(1)]
|
||||||
|
util.prepare_batch(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sample_dist():
|
||||||
|
state = util.random_state(0)
|
||||||
|
v1 = state.uniform(0.0, 1.0)
|
||||||
|
v2 = util.sample_from_dist(("uniform", 0.0, 1.0), 0)
|
||||||
|
assert v1 == v2
|
||||||
|
|
||||||
|
assert util.sample_from_dist(("const", 1.0)) == 1.0
|
||||||
|
|
||||||
|
dist_tuple = ("choice", [8, 16, 32])
|
||||||
|
assert util.sample_from_dist(dist_tuple) in [8, 16, 32]
|
||||||
|
|
||||||
|
|
||||||
|
def test_collate():
|
||||||
|
batch_size = 16
|
||||||
|
|
||||||
|
def _one_item():
|
||||||
|
return {
|
||||||
|
"signal": AudioSignal(paddle.randn([1, 1, 44100]), 44100),
|
||||||
|
"tensor": paddle.randn([1]),
|
||||||
|
"string": "Testing",
|
||||||
|
"dict": {
|
||||||
|
"nested_signal":
|
||||||
|
AudioSignal(paddle.randn([1, 1, 44100]), 44100),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
items = [_one_item() for _ in range(batch_size)]
|
||||||
|
collated = util.collate(items)
|
||||||
|
|
||||||
|
assert collated["signal"].batch_size == batch_size
|
||||||
|
assert collated["tensor"].shape[0] == batch_size
|
||||||
|
assert len(collated["string"]) == batch_size
|
||||||
|
assert collated["dict"]["nested_signal"].batch_size == batch_size
|
||||||
|
|
||||||
|
# test collate with splitting (evenly)
|
||||||
|
batch_size = 16
|
||||||
|
n_splits = 4
|
||||||
|
|
||||||
|
items = [_one_item() for _ in range(batch_size)]
|
||||||
|
collated = util.collate(items, n_splits=n_splits)
|
||||||
|
|
||||||
|
for x in collated:
|
||||||
|
assert x["signal"].batch_size == batch_size // n_splits
|
||||||
|
assert x["tensor"].shape[0] == batch_size // n_splits
|
||||||
|
assert len(x["string"]) == batch_size // n_splits
|
||||||
|
assert x["dict"]["nested_signal"].batch_size == batch_size // n_splits
|
||||||
|
|
||||||
|
# test collate with splitting (unevenly)
|
||||||
|
batch_size = 15
|
||||||
|
n_splits = 4
|
||||||
|
|
||||||
|
items = [_one_item() for _ in range(batch_size)]
|
||||||
|
collated = util.collate(items, n_splits=n_splits)
|
||||||
|
|
||||||
|
tlen = [4, 4, 4, 3]
|
||||||
|
|
||||||
|
for x, t in zip(collated, tlen):
|
||||||
|
assert x["signal"].batch_size == t
|
||||||
|
assert x["tensor"].shape[0] == t
|
||||||
|
assert len(x["string"]) == t
|
||||||
|
assert x["dict"]["nested_signal"].batch_size == t
|
@ -0,0 +1,208 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/data/test_datasets.py)
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from audio import audiotools
|
||||||
|
from audio.audiotools.data import transforms as tfm
|
||||||
|
|
||||||
|
|
||||||
|
def test_align_lists():
|
||||||
|
input_lists = [
|
||||||
|
["a/1.wav", "b/1.wav", "c/1.wav", "d/1.wav"],
|
||||||
|
["a/2.wav", "c/2.wav"],
|
||||||
|
["c/3.wav"],
|
||||||
|
]
|
||||||
|
target_lists = [
|
||||||
|
["a/1.wav", "b/1.wav", "c/1.wav", "d/1.wav"],
|
||||||
|
["a/2.wav", "none", "c/2.wav", "none"],
|
||||||
|
["none", "none", "c/3.wav", "none"],
|
||||||
|
]
|
||||||
|
|
||||||
|
def _preprocess(lists):
|
||||||
|
output = []
|
||||||
|
for x in lists:
|
||||||
|
output.append([])
|
||||||
|
for y in x:
|
||||||
|
output[-1].append({"path": y})
|
||||||
|
return output
|
||||||
|
|
||||||
|
input_lists = _preprocess(input_lists)
|
||||||
|
target_lists = _preprocess(target_lists)
|
||||||
|
|
||||||
|
aligned_lists = audiotools.datasets.align_lists(input_lists)
|
||||||
|
assert target_lists == aligned_lists
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_dataset():
|
||||||
|
transform = tfm.Compose(
|
||||||
|
[
|
||||||
|
tfm.VolumeNorm(),
|
||||||
|
tfm.Silence(prob=0.5),
|
||||||
|
], )
|
||||||
|
loader = audiotools.data.datasets.AudioLoader(
|
||||||
|
sources=["./audio/spk.csv"],
|
||||||
|
transform=transform, )
|
||||||
|
dataset = audiotools.data.datasets.AudioDataset(
|
||||||
|
loader,
|
||||||
|
44100,
|
||||||
|
n_examples=100,
|
||||||
|
transform=transform, )
|
||||||
|
dataloader = paddle.io.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=16,
|
||||||
|
num_workers=0,
|
||||||
|
collate_fn=dataset.collate, )
|
||||||
|
for batch in dataloader:
|
||||||
|
kwargs = batch["transform_args"]
|
||||||
|
signal = batch["signal"]
|
||||||
|
original = signal.clone()
|
||||||
|
|
||||||
|
signal = dataset.transform(signal, **kwargs)
|
||||||
|
original = dataset.transform(original, **kwargs)
|
||||||
|
|
||||||
|
mask = kwargs["Compose"]["1.Silence"]["mask"]
|
||||||
|
|
||||||
|
zeros_ = paddle.zeros_like(signal[mask].audio_data)
|
||||||
|
original_ = original[~mask].audio_data
|
||||||
|
|
||||||
|
assert paddle.allclose(signal[mask].audio_data, zeros_)
|
||||||
|
assert paddle.allclose(signal[~mask].audio_data, original_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_aligned_audio_dataset():
|
||||||
|
with tempfile.TemporaryDirectory() as d:
|
||||||
|
dataset_dir = Path(d)
|
||||||
|
audiotools.util.generate_chord_dataset(
|
||||||
|
max_voices=8, num_items=3, output_dir=dataset_dir)
|
||||||
|
loaders = [
|
||||||
|
audiotools.data.datasets.AudioLoader([dataset_dir / f"track_{i}"])
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
dataset = audiotools.data.datasets.AudioDataset(
|
||||||
|
loaders, 44100, n_examples=1000, aligned=True, shuffle_loaders=True)
|
||||||
|
dataloader = paddle.io.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=16,
|
||||||
|
num_workers=0,
|
||||||
|
collate_fn=dataset.collate, )
|
||||||
|
|
||||||
|
# Make sure the voice tracks are aligned.
|
||||||
|
for batch in dataloader:
|
||||||
|
paths = []
|
||||||
|
for i in range(len(loaders)):
|
||||||
|
_paths = [p.split("/")[-1] for p in batch[i]["path"]]
|
||||||
|
paths.append(_paths)
|
||||||
|
paths = np.array(paths)
|
||||||
|
for i in range(paths.shape[1]):
|
||||||
|
col = paths[:, i]
|
||||||
|
col = col[col != "none"]
|
||||||
|
assert np.all(col == col[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_loader_without_replacement():
|
||||||
|
with tempfile.TemporaryDirectory() as d:
|
||||||
|
dataset_dir = Path(d)
|
||||||
|
num_items = 100
|
||||||
|
audiotools.util.generate_chord_dataset(
|
||||||
|
max_voices=1,
|
||||||
|
num_items=num_items,
|
||||||
|
output_dir=dataset_dir,
|
||||||
|
duration=0.01, )
|
||||||
|
loader = audiotools.data.datasets.AudioLoader(
|
||||||
|
[dataset_dir], shuffle=False)
|
||||||
|
dataset = audiotools.data.datasets.AudioDataset(loader, 44100)
|
||||||
|
|
||||||
|
for idx in range(num_items):
|
||||||
|
item = dataset[idx]
|
||||||
|
assert item["item_idx"] == idx
|
||||||
|
|
||||||
|
|
||||||
|
def test_loader_with_replacement():
|
||||||
|
with tempfile.TemporaryDirectory() as d:
|
||||||
|
dataset_dir = Path(d)
|
||||||
|
num_items = 100
|
||||||
|
audiotools.util.generate_chord_dataset(
|
||||||
|
max_voices=1,
|
||||||
|
num_items=num_items,
|
||||||
|
output_dir=dataset_dir,
|
||||||
|
duration=0.01, )
|
||||||
|
loader = audiotools.data.datasets.AudioLoader([dataset_dir])
|
||||||
|
dataset = audiotools.data.datasets.AudioDataset(
|
||||||
|
loader, 44100, without_replacement=False)
|
||||||
|
|
||||||
|
for idx in range(num_items):
|
||||||
|
item = dataset[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def test_loader_out_of_range():
|
||||||
|
with tempfile.TemporaryDirectory() as d:
|
||||||
|
dataset_dir = Path(d)
|
||||||
|
num_items = 100
|
||||||
|
audiotools.util.generate_chord_dataset(
|
||||||
|
max_voices=1,
|
||||||
|
num_items=num_items,
|
||||||
|
output_dir=dataset_dir,
|
||||||
|
duration=0.01, )
|
||||||
|
loader = audiotools.data.datasets.AudioLoader([dataset_dir])
|
||||||
|
|
||||||
|
item = loader(
|
||||||
|
sample_rate=44100,
|
||||||
|
duration=0.01,
|
||||||
|
state=audiotools.util.random_state(0),
|
||||||
|
source_idx=0,
|
||||||
|
item_idx=101, )
|
||||||
|
assert item["path"] == "none"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_pipeline():
|
||||||
|
transform = tfm.Compose([
|
||||||
|
tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]),
|
||||||
|
tfm.BackgroundNoise(sources=["./audio/noises.csv"]),
|
||||||
|
])
|
||||||
|
loader = audiotools.data.datasets.AudioLoader(sources=["./audio/spk.csv"])
|
||||||
|
dataset = audiotools.data.datasets.AudioDataset(
|
||||||
|
loader,
|
||||||
|
44100,
|
||||||
|
n_examples=10,
|
||||||
|
transform=transform, )
|
||||||
|
dataloader = paddle.io.DataLoader(
|
||||||
|
dataset, num_workers=0, batch_size=1, collate_fn=dataset.collate)
|
||||||
|
for batch in dataloader:
|
||||||
|
batch = audiotools.core.util.prepare_batch(batch, device="cpu")
|
||||||
|
kwargs = batch["transform_args"]
|
||||||
|
signal = batch["signal"]
|
||||||
|
batch = dataset.transform(signal, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class NumberDataset:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 10
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return {"idx": idx}
|
||||||
|
|
||||||
|
|
||||||
|
def test_concat_dataset():
|
||||||
|
d1 = NumberDataset()
|
||||||
|
d2 = NumberDataset()
|
||||||
|
d3 = NumberDataset()
|
||||||
|
|
||||||
|
d = audiotools.datasets.ConcatDataset([d1, d2, d3])
|
||||||
|
x = d.collate([d[i] for i in range(len(d))])["idx"].tolist()
|
||||||
|
|
||||||
|
t = []
|
||||||
|
for i in range(10):
|
||||||
|
t += [i, i, i]
|
||||||
|
|
||||||
|
assert x == t
|
@ -0,0 +1,33 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/data/test_preprocess.py)
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from audio.audiotools.core.util import find_audio
|
||||||
|
from audio.audiotools.core.util import read_sources
|
||||||
|
from audio.audiotools.data import preprocess
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_csv():
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".csv") as f:
|
||||||
|
preprocess.create_csv(
|
||||||
|
find_audio("././audio/spk", ext=["wav"]), f.name, loudness=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_csv_with_empty_rows():
|
||||||
|
audio_files = find_audio("././audio/spk", ext=["wav"])
|
||||||
|
audio_files.insert(0, "")
|
||||||
|
audio_files.insert(2, "")
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".csv") as f:
|
||||||
|
preprocess.create_csv(audio_files, f.name, loudness=True)
|
||||||
|
|
||||||
|
audio_files = read_sources([f.name], remove_empty=True)
|
||||||
|
assert len(audio_files[0]) == 1
|
||||||
|
audio_files = read_sources([f.name], remove_empty=False)
|
||||||
|
assert len(audio_files[0]) == 3
|
@ -0,0 +1,453 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/data/test_transforms.py)
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from audio import audiotools
|
||||||
|
from audio.audiotools import AudioSignal
|
||||||
|
from audio.audiotools import util
|
||||||
|
from audio.audiotools.data import transforms as tfm
|
||||||
|
from audio.audiotools.data.datasets import AudioDataset
|
||||||
|
from paddlespeech.vector.training.seeding import seed_everything
|
||||||
|
|
||||||
|
non_deterministic_transforms = ["TimeNoise", "FrequencyNoise"]
|
||||||
|
transforms_to_test = []
|
||||||
|
for x in dir(tfm):
|
||||||
|
if hasattr(getattr(tfm, x), "transform"):
|
||||||
|
if x not in [
|
||||||
|
"Compose",
|
||||||
|
"Choose",
|
||||||
|
"Repeat",
|
||||||
|
"RepeatUpTo",
|
||||||
|
# The above 4 transforms are currently excluded from testing at 1e-4 precision due to potential accuracy issues
|
||||||
|
"BackgroundNoise",
|
||||||
|
"Equalizer",
|
||||||
|
"FrequencyNoise",
|
||||||
|
"RoomImpulseResponse"
|
||||||
|
]:
|
||||||
|
transforms_to_test.append(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_transform(transform_name, signal):
|
||||||
|
regression_data = Path(f"regression/transforms/{transform_name}.wav")
|
||||||
|
regression_data.parent.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
if regression_data.exists():
|
||||||
|
regression_signal = AudioSignal(regression_data)
|
||||||
|
|
||||||
|
assert paddle.allclose(
|
||||||
|
signal.audio_data, regression_signal.audio_data, atol=1e-4)
|
||||||
|
else:
|
||||||
|
signal.write(regression_data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("transform_name", transforms_to_test)
|
||||||
|
def test_transform(transform_name):
|
||||||
|
seed = 0
|
||||||
|
seed_everything(seed)
|
||||||
|
transform_cls = getattr(tfm, transform_name)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if transform_name == "BackgroundNoise":
|
||||||
|
kwargs["sources"] = ["./audio/noises.csv"]
|
||||||
|
if transform_name == "RoomImpulseResponse":
|
||||||
|
kwargs["sources"] = ["./audio/irs.csv"]
|
||||||
|
if transform_name == "CrossTalk":
|
||||||
|
kwargs["sources"] = ["./audio/spk.csv"]
|
||||||
|
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
signal.metadata["loudness"] = AudioSignal(
|
||||||
|
audio_path).ffmpeg_loudness().item()
|
||||||
|
transform = transform_cls(prob=1.0, **kwargs)
|
||||||
|
|
||||||
|
kwargs = transform.instantiate(seed, signal)
|
||||||
|
for k in kwargs[transform_name]:
|
||||||
|
assert k in transform.keys
|
||||||
|
|
||||||
|
output = transform(signal, **kwargs)
|
||||||
|
assert isinstance(output, AudioSignal)
|
||||||
|
|
||||||
|
_compare_transform(transform_name, output)
|
||||||
|
|
||||||
|
if transform_name in non_deterministic_transforms:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Test that if you make a batch of signals and call it,
|
||||||
|
# the first item in the batch is still the same as above.
|
||||||
|
batch_size = 4
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
signal_batch = AudioSignal.batch(
|
||||||
|
[signal.clone() for _ in range(batch_size)])
|
||||||
|
signal_batch.metadata["loudness"] = AudioSignal(
|
||||||
|
audio_path).ffmpeg_loudness().item()
|
||||||
|
|
||||||
|
states = [seed + idx for idx in list(range(batch_size))]
|
||||||
|
kwargs = transform.batch_instantiate(states, signal_batch)
|
||||||
|
batch_output = transform(signal_batch, **kwargs)
|
||||||
|
|
||||||
|
assert batch_output[0] == output
|
||||||
|
|
||||||
|
## Test that you can apply transform with the same args twice.
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
signal.metadata["loudness"] = AudioSignal(
|
||||||
|
audio_path).ffmpeg_loudness().item()
|
||||||
|
kwargs = transform.instantiate(seed, signal)
|
||||||
|
output_a = transform(signal.clone(), **kwargs)
|
||||||
|
output_b = transform(signal.clone(), **kwargs)
|
||||||
|
|
||||||
|
assert output_a == output_b
|
||||||
|
|
||||||
|
|
||||||
|
def test_compose_basic():
|
||||||
|
seed = 0
|
||||||
|
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
transform = tfm.Compose(
|
||||||
|
[
|
||||||
|
tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]),
|
||||||
|
tfm.BackgroundNoise(sources=["./audio/noises.csv"]),
|
||||||
|
], )
|
||||||
|
|
||||||
|
kwargs = transform.instantiate(seed, signal)
|
||||||
|
output = transform(signal, **kwargs)
|
||||||
|
|
||||||
|
# Due to precision issues with RoomImpulseResponse and BackgroundNoise used in the Compose test,
|
||||||
|
# we only perform logical testing for Compose and skip precision testing of the final output
|
||||||
|
# _compare_transform("Compose", output)
|
||||||
|
|
||||||
|
assert isinstance(transform[0], tfm.RoomImpulseResponse)
|
||||||
|
assert isinstance(transform[1], tfm.BackgroundNoise)
|
||||||
|
assert len(transform) == 2
|
||||||
|
|
||||||
|
# Make sure __iter__ works
|
||||||
|
for _tfm in transform:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MulTransform(tfm.BaseTransform):
|
||||||
|
def __init__(self, num, name=None):
|
||||||
|
self.num = num
|
||||||
|
super().__init__(name=name, keys=["num"])
|
||||||
|
|
||||||
|
def _transform(self, signal, num):
|
||||||
|
|
||||||
|
if not num.dim():
|
||||||
|
num = num.unsqueeze(axis=0)
|
||||||
|
|
||||||
|
signal.audio_data = signal.audio_data * num[:, None, None]
|
||||||
|
return signal
|
||||||
|
|
||||||
|
def _instantiate(self, state):
|
||||||
|
return {"num": self.num}
|
||||||
|
|
||||||
|
|
||||||
|
def test_compose_with_duplicate_transforms():
|
||||||
|
muls = [0.5, 0.25, 0.125]
|
||||||
|
transform = tfm.Compose([MulTransform(x) for x in muls])
|
||||||
|
full_mul = np.prod(muls)
|
||||||
|
|
||||||
|
kwargs = transform.instantiate(0)
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
expected_output = signal.audio_data * full_mul
|
||||||
|
|
||||||
|
assert paddle.allclose(output.audio_data, expected_output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_compose():
|
||||||
|
muls = [0.5, 0.25, 0.125]
|
||||||
|
transform = tfm.Compose([
|
||||||
|
MulTransform(muls[0]),
|
||||||
|
tfm.Compose(
|
||||||
|
[MulTransform(muls[1]), tfm.Compose([MulTransform(muls[2])])]),
|
||||||
|
])
|
||||||
|
full_mul = np.prod(muls)
|
||||||
|
|
||||||
|
kwargs = transform.instantiate(0)
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
expected_output = signal.audio_data * full_mul
|
||||||
|
|
||||||
|
assert paddle.allclose(output.audio_data, expected_output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_compose_filtering():
|
||||||
|
muls = [0.5, 0.25, 0.125]
|
||||||
|
transform = tfm.Compose([MulTransform(x, name=str(x)) for x in muls])
|
||||||
|
|
||||||
|
kwargs = transform.instantiate(0)
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
|
||||||
|
for s in range(len(muls)):
|
||||||
|
for _ in range(10):
|
||||||
|
_muls = np.random.choice(muls, size=s, replace=False).tolist()
|
||||||
|
full_mul = np.prod(_muls)
|
||||||
|
with transform.filter(*[str(x) for x in _muls]):
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
|
||||||
|
expected_output = signal.audio_data * full_mul
|
||||||
|
assert paddle.allclose(output.audio_data, expected_output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_compose():
|
||||||
|
muls = [0.5, 0.25, 0.125]
|
||||||
|
transform = tfm.Compose([
|
||||||
|
tfm.Compose([MulTransform(muls[0])]),
|
||||||
|
tfm.Compose([MulTransform(muls[1]), MulTransform(muls[2])]),
|
||||||
|
])
|
||||||
|
full_mul = np.prod(muls)
|
||||||
|
|
||||||
|
kwargs = transform.instantiate(0)
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
expected_output = signal.audio_data * full_mul
|
||||||
|
|
||||||
|
assert paddle.allclose(output.audio_data, expected_output)
|
||||||
|
|
||||||
|
|
||||||
|
def test_choose_basic():
|
||||||
|
seed = 0
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
transform = tfm.Choose([
|
||||||
|
tfm.RoomImpulseResponse(sources=["./audio/irs.csv"]),
|
||||||
|
tfm.BackgroundNoise(sources=["./audio/noises.csv"]),
|
||||||
|
])
|
||||||
|
|
||||||
|
kwargs = transform.instantiate(seed, signal)
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
|
||||||
|
# Due to precision issues with RoomImpulseResponse and BackgroundNoise used in the Choose test,
|
||||||
|
# we only perform logical testing for Choose and skip precision testing of the final output
|
||||||
|
# _compare_transform("Choose", output)
|
||||||
|
|
||||||
|
transform = tfm.Choose([
|
||||||
|
MulTransform(0.0),
|
||||||
|
MulTransform(2.0),
|
||||||
|
])
|
||||||
|
targets = [signal.clone() * 0.0, signal.clone() * 2.0]
|
||||||
|
|
||||||
|
for seed in range(10):
|
||||||
|
kwargs = transform.instantiate(seed, signal)
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
|
||||||
|
assert any([output == target for target in targets])
|
||||||
|
|
||||||
|
# Test that if you make a batch of signals and call it,
|
||||||
|
# the first item in the batch is still the same as above.
|
||||||
|
batch_size = 4
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
signal_batch = AudioSignal.batch(
|
||||||
|
[signal.clone() for _ in range(batch_size)])
|
||||||
|
|
||||||
|
states = [seed + idx for idx in list(range(batch_size))]
|
||||||
|
kwargs = transform.batch_instantiate(states, signal_batch)
|
||||||
|
batch_output = transform(signal_batch, **kwargs)
|
||||||
|
|
||||||
|
for nb in range(batch_size):
|
||||||
|
assert batch_output[nb] in targets
|
||||||
|
|
||||||
|
|
||||||
|
def test_choose_weighted():
|
||||||
|
seed = 0
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
transform = tfm.Choose(
|
||||||
|
[
|
||||||
|
MulTransform(0.0),
|
||||||
|
MulTransform(2.0),
|
||||||
|
],
|
||||||
|
weights=[0.0, 1.0], )
|
||||||
|
|
||||||
|
# Test that if you make a batch of signals and call it,
|
||||||
|
# the first item in the batch is still the same as above.
|
||||||
|
batch_size = 4
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
signal_batch = AudioSignal.batch(
|
||||||
|
[signal.clone() for _ in range(batch_size)])
|
||||||
|
|
||||||
|
targets = [signal.clone() * 0.0, signal.clone() * 2.0]
|
||||||
|
|
||||||
|
states = [seed + idx for idx in list(range(batch_size))]
|
||||||
|
kwargs = transform.batch_instantiate(states, signal_batch)
|
||||||
|
batch_output = transform(signal_batch, **kwargs)
|
||||||
|
|
||||||
|
for nb in range(batch_size):
|
||||||
|
assert batch_output[nb] == targets[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_choose_with_compose():
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
|
||||||
|
transform = tfm.Choose([
|
||||||
|
tfm.Compose([MulTransform(0.0)]),
|
||||||
|
tfm.Compose([MulTransform(2.0)]),
|
||||||
|
])
|
||||||
|
|
||||||
|
targets = [signal.clone() * 0.0, signal.clone() * 2.0]
|
||||||
|
|
||||||
|
for seed in range(10):
|
||||||
|
kwargs = transform.instantiate(seed, signal)
|
||||||
|
output = transform(signal, **kwargs)
|
||||||
|
|
||||||
|
assert output in targets
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeat():
|
||||||
|
seed = 0
|
||||||
|
audio_path = "./audio/spk/f10_script4_produced.wav"
|
||||||
|
signal = AudioSignal(audio_path, offset=10, duration=2)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
kwargs["transform"] = tfm.Compose(
|
||||||
|
tfm.FrequencyMask(),
|
||||||
|
tfm.TimeMask(), )
|
||||||
|
kwargs["n_repeat"] = 5
|
||||||
|
|
||||||
|
transform = tfm.Repeat(**kwargs)
|
||||||
|
kwargs = transform.instantiate(seed, signal)
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
|
||||||
|
_compare_transform("Repeat", output)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
kwargs["transform"] = tfm.Compose(
|
||||||
|
tfm.FrequencyMask(),
|
||||||
|
tfm.TimeMask(), )
|
||||||
|
kwargs["max_repeat"] = 10
|
||||||
|
|
||||||
|
transform = tfm.RepeatUpTo(**kwargs)
|
||||||
|
kwargs = transform.instantiate(seed, signal)
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
|
||||||
|
_compare_transform("RepeatUpTo", output)
|
||||||
|
|
||||||
|
# Make sure repeat does what it says
|
||||||
|
transform = tfm.Repeat(MulTransform(0.5), n_repeat=3)
|
||||||
|
kwargs = transform.instantiate(seed, signal)
|
||||||
|
signal = AudioSignal(paddle.randn([1, 1, 100]).clip(1e-5), 44100)
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
|
||||||
|
scale = (output.audio_data / signal.audio_data).mean()
|
||||||
|
assert scale == (0.5**3)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyData(paddle.io.Dataset):
|
||||||
|
def __init__(self, audio_path):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.audio_path = audio_path
|
||||||
|
self.length = 100
|
||||||
|
self.transform = tfm.Silence(prob=0.5)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
state = util.random_state(idx)
|
||||||
|
signal = AudioSignal.salient_excerpt(
|
||||||
|
self.audio_path, state=state, duration=1.0).resample(44100)
|
||||||
|
|
||||||
|
item = self.transform.instantiate(state, signal=signal)
|
||||||
|
item["signal"] = signal
|
||||||
|
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
|
||||||
|
def test_masking():
|
||||||
|
dataset = DummyData("./audio/spk/f10_script4_produced.wav")
|
||||||
|
dataloader = paddle.io.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=16,
|
||||||
|
num_workers=0,
|
||||||
|
collate_fn=util.collate, )
|
||||||
|
for batch in dataloader:
|
||||||
|
signal = batch.pop("signal")
|
||||||
|
original = signal.clone()
|
||||||
|
|
||||||
|
signal = dataset.transform(signal, **batch)
|
||||||
|
original = dataset.transform(original, **batch)
|
||||||
|
mask = batch["Silence"]["mask"]
|
||||||
|
|
||||||
|
zeros_ = paddle.zeros_like(signal[mask].audio_data)
|
||||||
|
original_ = original[~mask].audio_data
|
||||||
|
|
||||||
|
assert paddle.allclose(signal[mask].audio_data, zeros_)
|
||||||
|
assert paddle.allclose(original[~mask].audio_data, original_)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_masking():
|
||||||
|
transform = tfm.Compose(
|
||||||
|
[
|
||||||
|
tfm.VolumeNorm(prob=0.5),
|
||||||
|
tfm.Silence(prob=0.9),
|
||||||
|
],
|
||||||
|
prob=0.9, )
|
||||||
|
|
||||||
|
loader = audiotools.data.datasets.AudioLoader(sources=["./audio/spk.csv"])
|
||||||
|
dataset = audiotools.data.datasets.AudioDataset(
|
||||||
|
loader,
|
||||||
|
44100,
|
||||||
|
n_examples=100,
|
||||||
|
transform=transform, )
|
||||||
|
dataloader = paddle.io.DataLoader(
|
||||||
|
dataset, num_workers=0, batch_size=10, collate_fn=dataset.collate)
|
||||||
|
|
||||||
|
for batch in dataloader:
|
||||||
|
batch = util.prepare_batch(batch, device="cpu")
|
||||||
|
signal = batch["signal"]
|
||||||
|
kwargs = batch["transform_args"]
|
||||||
|
with paddle.no_grad():
|
||||||
|
output = dataset.transform(signal, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_smoothing_edge_case():
|
||||||
|
transform = tfm.Smoothing()
|
||||||
|
zeros = paddle.zeros([1, 1, 44100])
|
||||||
|
signal = AudioSignal(zeros, 44100)
|
||||||
|
kwargs = transform.instantiate(0, signal)
|
||||||
|
output = transform(signal, **kwargs)
|
||||||
|
|
||||||
|
assert paddle.allclose(output.audio_data, zeros)
|
||||||
|
|
||||||
|
|
||||||
|
def test_global_volume_norm():
|
||||||
|
signal = AudioSignal.wave(440, 1, 44100, 1)
|
||||||
|
|
||||||
|
# signal with -inf loudness should be unchanged
|
||||||
|
signal.metadata["loudness"] = float("-inf")
|
||||||
|
|
||||||
|
transform = tfm.GlobalVolumeNorm(db=("const", -100))
|
||||||
|
kwargs = transform.instantiate(0, signal)
|
||||||
|
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
assert paddle.allclose(output.samples, signal.samples)
|
||||||
|
|
||||||
|
# signal without a loudness key should be unchanged
|
||||||
|
signal.metadata.pop("loudness")
|
||||||
|
kwargs = transform.instantiate(0, signal)
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
assert paddle.allclose(output.samples, signal.samples)
|
||||||
|
|
||||||
|
# signal with the actual loudness should be normalized
|
||||||
|
signal.metadata["loudness"] = signal.ffmpeg_loudness()
|
||||||
|
kwargs = transform.instantiate(0, signal)
|
||||||
|
output = transform(signal.clone(), **kwargs)
|
||||||
|
assert not paddle.allclose(output.samples, signal.samples)
|
@ -0,0 +1,110 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/ml/test_decorators.py)
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
from audio.audiotools import util
|
||||||
|
from audio.audiotools.ml.decorators import timer
|
||||||
|
from audio.audiotools.ml.decorators import Tracker
|
||||||
|
from audio.audiotools.ml.decorators import when
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_decorators():
|
||||||
|
rank = 0
|
||||||
|
max_iters = 100
|
||||||
|
|
||||||
|
writer = LogWriter("/tmp/logs")
|
||||||
|
tracker = Tracker(writer, log_file="/tmp/log.txt")
|
||||||
|
|
||||||
|
train_data = range(100)
|
||||||
|
val_data = range(100)
|
||||||
|
|
||||||
|
@tracker.log("train", "value", history=False)
|
||||||
|
@tracker.track("train", max_iters, tracker.step)
|
||||||
|
@timer()
|
||||||
|
def train_loop():
|
||||||
|
i = tracker.step
|
||||||
|
time.sleep(0.01)
|
||||||
|
return {
|
||||||
|
"loss":
|
||||||
|
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
|
||||||
|
"mel":
|
||||||
|
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
|
||||||
|
"stft":
|
||||||
|
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
|
||||||
|
"waveform":
|
||||||
|
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
|
||||||
|
"not_scalar":
|
||||||
|
paddle.arange(start=0, end=10, step=1, dtype="int64"),
|
||||||
|
}
|
||||||
|
|
||||||
|
@tracker.track("val", len(val_data))
|
||||||
|
@timer()
|
||||||
|
def val_loop():
|
||||||
|
i = tracker.step
|
||||||
|
time.sleep(0.01)
|
||||||
|
return {
|
||||||
|
"loss":
|
||||||
|
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
|
||||||
|
"mel":
|
||||||
|
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
|
||||||
|
"stft":
|
||||||
|
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
|
||||||
|
"waveform":
|
||||||
|
util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")),
|
||||||
|
"not_scalar":
|
||||||
|
paddle.arange(10, dtype="int64"),
|
||||||
|
"string":
|
||||||
|
"string",
|
||||||
|
}
|
||||||
|
|
||||||
|
@when(lambda: tracker.step % 1000 == 0 and rank == 0)
|
||||||
|
@paddle.no_grad()
|
||||||
|
def save_samples():
|
||||||
|
tracker.print("Saving samples to TensorBoard.")
|
||||||
|
|
||||||
|
@when(lambda: tracker.step % 100 == 0 and rank == 0)
|
||||||
|
def checkpoint():
|
||||||
|
save_samples()
|
||||||
|
if tracker.is_best("val", "mel"):
|
||||||
|
tracker.print("Best model so far.")
|
||||||
|
tracker.print("Saving to /runs/exp1")
|
||||||
|
tracker.done("val", f"Iteration {tracker.step}")
|
||||||
|
|
||||||
|
@when(lambda: tracker.step % 100 == 0)
|
||||||
|
@tracker.log("val", "mean")
|
||||||
|
@paddle.no_grad()
|
||||||
|
def validate():
|
||||||
|
for _ in range(len(val_data)):
|
||||||
|
output = val_loop()
|
||||||
|
return output
|
||||||
|
|
||||||
|
with tracker.live:
|
||||||
|
for tracker.step in range(max_iters):
|
||||||
|
validate()
|
||||||
|
checkpoint()
|
||||||
|
train_loop()
|
||||||
|
|
||||||
|
state_dict = tracker.state_dict()
|
||||||
|
tracker.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# If train loop returned not a dict
|
||||||
|
@tracker.track("train", max_iters, tracker.step)
|
||||||
|
def train_loop_2():
|
||||||
|
i = tracker.step
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
with tracker.live:
|
||||||
|
for tracker.step in range(max_iters):
|
||||||
|
validate()
|
||||||
|
checkpoint()
|
||||||
|
train_loop_2()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_all_decorators()
|
@ -0,0 +1,89 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/ml/test_model.py)
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
from audio.audiotools import ml
|
||||||
|
from audio.audiotools import util
|
||||||
|
from paddlespeech.vector.training.seeding import seed_everything
|
||||||
|
SEED = 0
|
||||||
|
|
||||||
|
|
||||||
|
def seed_and_run(model, *args, **kwargs):
|
||||||
|
seed_everything(SEED)
|
||||||
|
return model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(ml.BaseModel):
|
||||||
|
def __init__(self, arg1: float=1.0):
|
||||||
|
super().__init__()
|
||||||
|
self.arg1 = arg1
|
||||||
|
self.linear = nn.Linear(1, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
|
||||||
|
class OtherModel(ml.BaseModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(1, 1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x)
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_model():
|
||||||
|
# Save and load
|
||||||
|
# ml.BaseModel.EXTERN += ["test_model"]
|
||||||
|
|
||||||
|
x = paddle.randn([10, 1])
|
||||||
|
model1 = Model()
|
||||||
|
|
||||||
|
# assert str(model1.device) == 'Place(cpu)'
|
||||||
|
|
||||||
|
out1 = seed_and_run(model1, x)
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pdparams") as f:
|
||||||
|
model1.save(
|
||||||
|
f.name, )
|
||||||
|
model2 = Model.load(f.name)
|
||||||
|
out2 = seed_and_run(model2, x)
|
||||||
|
assert paddle.allclose(out1, out2)
|
||||||
|
|
||||||
|
# test re-export
|
||||||
|
model2.save(f.name)
|
||||||
|
model3 = Model.load(f.name)
|
||||||
|
out3 = seed_and_run(model3, x)
|
||||||
|
assert paddle.allclose(out1, out3)
|
||||||
|
|
||||||
|
# make sure legacy/save load works
|
||||||
|
model1.save(f.name, package=False)
|
||||||
|
model2 = Model.load(f.name)
|
||||||
|
out2 = seed_and_run(model2, x)
|
||||||
|
assert paddle.allclose(out1, out2)
|
||||||
|
|
||||||
|
# make sure new way -> legacy save -> legacy load works
|
||||||
|
model1.save(f.name, package=False)
|
||||||
|
model2 = Model.load(f.name)
|
||||||
|
model2.save(f.name, package=False)
|
||||||
|
model3 = Model.load(f.name)
|
||||||
|
out3 = seed_and_run(model3, x)
|
||||||
|
|
||||||
|
# save/load without package, but with model2 being a model
|
||||||
|
# without an argument of arg1 to its instantiation.
|
||||||
|
model1.save(f.name, package=False)
|
||||||
|
model2 = OtherModel.load(f.name)
|
||||||
|
out2 = seed_and_run(model2, x)
|
||||||
|
assert paddle.allclose(out1, out2)
|
||||||
|
|
||||||
|
assert paddle.allclose(out1, out3)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as d:
|
||||||
|
model1.save_to_folder(d, {"data": 1.0})
|
||||||
|
Model.load_from_folder(d)
|
@ -0,0 +1,7 @@
|
|||||||
|
python -m pip install -r ../../audiotools/requirements.txt
|
||||||
|
export PYTHONPATH=$PYTHONPATH:$(realpath ../../..) # this is root path of `PaddleSpeech`
|
||||||
|
wget https://paddlespeech.bj.bcebos.com/PaddleAudio/audio_tools/audio.tar.gz
|
||||||
|
wget https://paddlespeech.bj.bcebos.com/PaddleAudio/audio_tools/regression.tar.gz
|
||||||
|
tar -zxvf audio.tar.gz
|
||||||
|
tar -zxvf regression.tar.gz
|
||||||
|
python -m pytest
|
@ -0,0 +1,30 @@
|
|||||||
|
# MIT License, Copyright (c) 2023-Present, Descript.
|
||||||
|
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Modified from audiotools(https://github.com/descriptinc/audiotools/blob/master/tests/test_post.py)
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from audio.audiotools import AudioSignal
|
||||||
|
from audio.audiotools import post
|
||||||
|
from audio.audiotools import transforms
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_table():
|
||||||
|
tfm = transforms.LowPass()
|
||||||
|
|
||||||
|
audio_dict = {}
|
||||||
|
|
||||||
|
audio_dict["inputs"] = [
|
||||||
|
AudioSignal.excerpt("./audio/spk/f10_script4_produced.wav", duration=5)
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
|
audio_dict["outputs"] = []
|
||||||
|
for i in range(3):
|
||||||
|
x = audio_dict["inputs"][i]
|
||||||
|
|
||||||
|
kwargs = tfm.instantiate()
|
||||||
|
output = tfm(x.clone(), **kwargs)
|
||||||
|
audio_dict["outputs"].append(output)
|
||||||
|
|
||||||
|
post.audio_table(audio_dict)
|
Loading…
Reference in new issue