From 643f1c6071e72e9001ba51872e5843f59fb819fd Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Tue, 10 Dec 2024 09:57:53 +0000 Subject: [PATCH] add readme && __all__ --- audio/audiotools/README.md | 24 +- audio/audiotools/core/__init__.py | 2 +- audio/audiotools/core/_julius.py | 199 +++++++++- audio/audiotools/core/audio_signal.py | 4 +- audio/audiotools/core/effects.py | 128 +++--- audio/audiotools/core/resample.py | 231 ----------- .../tests/audiotools/core/test_effects✅.py | 363 ++++++++++++++++++ audio/tests/audiotools/core/test_grad✅.py | 168 ++++++++ 8 files changed, 807 insertions(+), 312 deletions(-) delete mode 100644 audio/audiotools/core/resample.py create mode 100644 audio/tests/audiotools/core/test_effects✅.py create mode 100644 audio/tests/audiotools/core/test_grad✅.py diff --git a/audio/audiotools/README.md b/audio/audiotools/README.md index a8c47efe8..b28776ebf 100644 --- a/audio/audiotools/README.md +++ b/audio/audiotools/README.md @@ -1,23 +1,13 @@ -# PaddleAudio +Audiotools is a comprehensive toolkit designed for audio processing and analysis, providing robust solutions for audio signal processing, data management, model training, and evaluation. -安装方式: pip install paddleaudio +### Directory Structure -目前支持的平台:Linux, Mac, Windows +- **core directory**: Contains the core class AudioSignal, which is responsible for the fundamental representation and manipulation of audio signals. -## Environment +- **data directory**: Primarily dedicated to storing and processing datasets, including classes and functions for data preprocessing, ensuring efficient loading and transformation of audio data. -## Build wheel -cmd: python setup.py bdist_wheel +- **metrics directory**: Implements functions for various audio evaluation metrics, enabling precise assessment of the performance of audio models and processing algorithms. -Linux test build whl environment: -* os - Ubuntu 16.04.7 LTS -* gcc/g++ - 8.2.0 -* cmake - 3.18.0 (need install) +- **ml directory**: Comprises classes and methods related to model training, supporting the construction, training, and optimization of machine learning models in the context of audio. -MAC:test build whl environment: -* os -* gcc/g++ 12.2.0 -* cpu Intel Xeon E5 x86_64 - -Windows: -not support paddleaudio C++ extension lib (sox io, kaldi native fbank) +This project aims to provide developers and researchers with an efficient and flexible framework to foster innovation and exploration across various domains of audio technology. diff --git a/audio/audiotools/core/__init__.py b/audio/audiotools/core/__init__.py index a4038c4ed..0e4d916d0 100644 --- a/audio/audiotools/core/__init__.py +++ b/audio/audiotools/core/__init__.py @@ -7,9 +7,9 @@ from ._julius import lowpass_filter from ._julius import LowPassFilter from ._julius import LowPassFilters from ._julius import pure_tone +from ._julius import resample_frac from ._julius import split_bands from ._julius import SplitBands from .audio_signal import AudioSignal from .audio_signal import STFTParams from .loudness import Meter -from .resample import resample_frac diff --git a/audio/audiotools/core/_julius.py b/audio/audiotools/core/_julius.py index 36ac88529..fc137c569 100644 --- a/audio/audiotools/core/_julius.py +++ b/audio/audiotools/core/_julius.py @@ -7,6 +7,7 @@ This module implements efficient FFT based convolutions for such cases. A typica application is for evaluating FIR filters with a long receptive field, typically evaluated with a stride of 1. """ +import inspect import math import typing from typing import Optional @@ -16,7 +17,203 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from .resample import sinc +__all__ = [ + 'fft_conv1d', 'FFTConv1d', 'highpass_filter', 'highpass_filters', + 'lowpass_filter', 'LowPassFilter', 'LowPassFilters', 'pure_tone', + 'resample_frac', 'split_bands', 'SplitBands' +] + + +def simple_repr(obj, attrs: Optional[Sequence[str]]=None, overrides: dict={}): + """ + Return a simple representation string for `obj`. + If `attrs` is not None, it should be a list of attributes to include. + """ + params = inspect.signature(obj.__class__).parameters + attrs_repr = [] + if attrs is None: + attrs = list(params.keys()) + for attr in attrs: + display = False + if attr in overrides: + value = overrides[attr] + elif hasattr(obj, attr): + value = getattr(obj, attr) + else: + continue + if attr in params: + param = params[attr] + if param.default is inspect._empty or value != param.default: # type: ignore + display = True + else: + display = True + + if display: + attrs_repr.append(f"{attr}={value}") + return f"{obj.__class__.__name__}({','.join(attrs_repr)})" + + +def sinc(x: paddle.Tensor): + """ + Implementation of sinc, i.e. sin(x) / x + + __Warning__: the input is not multiplied by `pi`! + """ + return paddle.where( + x == 0, + paddle.to_tensor(1.0, dtype=x.dtype, place=x.place), + paddle.sin(x) / x, ) + + +class ResampleFrac(paddle.nn.Layer): + """ + Resampling from the sample rate `old_sr` to `new_sr`. + """ + + def __init__(self, + old_sr: int, + new_sr: int, + zeros: int=24, + rolloff: float=0.945): + """ + Args: + old_sr (int): sample rate of the input signal x. + new_sr (int): sample rate of the output. + zeros (int): number of zero crossing to keep in the sinc filter. + rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`, + to ensure sufficient margin due to the imperfection of the FIR filter used. + Lowering this value will reduce anti-aliasing, but will reduce some of the + highest frequencies. + + Shape: + + - Input: `[*, T]` + - Output: `[*, T']` with `T' = int(new_sr * T / old_sr)` + + + .. caution:: + After dividing `old_sr` and `new_sr` by their GCD, both should be small + for this implementation to be fast. + + >>> import paddle + >>> resample = ResampleFrac(4, 5) + >>> x = paddle.randn([1000]) + >>> print(len(resample(x))) + 1250 + """ + super(ResampleFrac, self).__init__() + if not isinstance(old_sr, int) or not isinstance(new_sr, int): + raise ValueError("old_sr and new_sr should be integers") + gcd = math.gcd(old_sr, new_sr) + self.old_sr = old_sr // gcd + self.new_sr = new_sr // gcd + self.zeros = zeros + self.rolloff = rolloff + + self._init_kernels() + + def _init_kernels(self): + if self.old_sr == self.new_sr: + return + + kernels = [] + sr = min(self.new_sr, self.old_sr) + sr *= self.rolloff + + self._width = math.ceil(self.zeros * self.old_sr / sr) + idx = paddle.arange( + -self._width, self._width + self.old_sr, dtype="float32") + for i in range(self.new_sr): + t = (-i / self.new_sr + idx / self.old_sr) * sr + t = paddle.clip(t, -self.zeros, self.zeros) + t *= math.pi + window = paddle.cos(t / self.zeros / 2)**2 + kernel = sinc(t) * window + # Renormalize kernel to ensure a constant signal is preserved. + kernel = kernel / kernel.sum() + kernels.append(kernel) + + _kernel = paddle.stack(kernels).reshape([self.new_sr, 1, -1]) + self.kernel = self.create_parameter( + shape=_kernel.shape, + dtype=_kernel.dtype, ) + self.kernel.set_value(_kernel) + + def forward( + self, + x: paddle.Tensor, + output_length: Optional[int]=None, + full: bool=False, ): + """ + Resample x. + Args: + x (Tensor): signal to resample, time should be the last dimension + output_length (None or int): This can be set to the desired output length + (last dimension). Allowed values are between 0 and + ceil(length * new_sr / old_sr). When None (default) is specified, the + floored output length will be used. In order to select the largest possible + size, use the `full` argument. + full (bool): return the longest possible output from the input. This can be useful + if you chain resampling operations, and want to give the `output_length` only + for the last one, while passing `full=True` to all the other ones. + """ + if self.old_sr == self.new_sr: + return x + shape = x.shape + _dtype = x.dtype + length = x.shape[-1] + x = x.reshape([-1, length]) + x = F.pad( + x.unsqueeze(1), + [self._width, self._width + self.old_sr], + mode="replicate", + data_format="NCL", ).astype(self.kernel.dtype) + ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL") + y = ys.transpose( + [0, 2, 1]).reshape(list(shape[:-1]) + [-1]).astype(_dtype) + + float_output_length = paddle.to_tensor( + self.new_sr * length / self.old_sr, dtype="float32") + max_output_length = paddle.ceil(float_output_length).astype("int64") + default_output_length = paddle.floor(float_output_length).astype( + "int64") + + if output_length is None: + applied_output_length = (max_output_length + if full else default_output_length) + elif output_length < 0 or output_length > max_output_length: + raise ValueError( + f"output_length must be between 0 and {max_output_length.numpy()}" + ) + else: + applied_output_length = paddle.to_tensor( + output_length, dtype="int64") + if full: + raise ValueError( + "You cannot pass both full=True and output_length") + return y[..., :applied_output_length] + + def __repr__(self): + return simple_repr(self) + + +def resample_frac( + x: paddle.Tensor, + old_sr: int, + new_sr: int, + zeros: int=24, + rolloff: float=0.945, + output_length: Optional[int]=None, + full: bool=False, ): + """ + Functional version of `ResampleFrac`, refer to its documentation for more information. + + ..warning:: + If you call repeatidly this functions with the same sample rates, then the + resampling kernel will be recomputed everytime. For best performance, you should use + and cache an instance of `ResampleFrac`. + """ + return ResampleFrac(old_sr, new_sr, zeros, rolloff)(x, output_length, full) def pad_to(tensor: paddle.Tensor, diff --git a/audio/audiotools/core/audio_signal.py b/audio/audiotools/core/audio_signal.py index 717959050..85fbec1c0 100644 --- a/audio/audiotools/core/audio_signal.py +++ b/audio/audiotools/core/audio_signal.py @@ -16,18 +16,20 @@ import paddle import soundfile from . import util +from ._julius import resample_frac from .dsp import DSPMixin from .effects import EffectMixin from .effects import ImpulseResponseMixin from .ffmpeg import FFMPEGMixin from .loudness import LoudnessMixin -from .resample import resample_frac # from .display import DisplayMixin # from .playback import PlayMixin # from .whisper import WhisperMixin +__all__ = ['STFTParams', 'AudioSignal'] + def create_dct(n_mfcc: int, n_mels: int, norm: Optional[str]) -> paddle.Tensor: r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``), diff --git a/audio/audiotools/core/effects.py b/audio/audiotools/core/effects.py index c658f75a7..ff08ab31b 100644 --- a/audio/audiotools/core/effects.py +++ b/audio/audiotools/core/effects.py @@ -207,7 +207,8 @@ class EffectMixin: """ peak = self.audio_data.abs().max(axis=-1, keepdim=True) peak_gain = paddle.ones_like(peak) - peak_gain[peak > _max] = _max / peak[peak > _max] + # peak_gain[peak > _max] = _max / peak[peak > _max] + peak_gain = paddle.where(peak > _max, _max / peak, peak_gain) self.audio_data = self.audio_data * peak_gain return self @@ -476,70 +477,72 @@ class EffectMixin: return self - # def quantization( - # self, quantization_channels: typing.Union[paddle.Tensor, np.ndarray, int] - # ): - # """Applies quantization to the input waveform. + def quantization(self, + quantization_channels: typing.Union[paddle.Tensor, + np.ndarray, int]): + """Applies quantization to the input waveform. - # Parameters - # ---------- - # quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int] - # Number of evenly spaced quantization channels to quantize - # to. - - # Returns - # ------- - # AudioSignal - # Quantized AudioSignal. - # """ - # quantization_channels = util.ensure_tensor(quantization_channels, ndim=3) - - # x = self.audio_data - # x = (x + 1) / 2 - # x = x * quantization_channels - # x = x.floor() - # x = x / quantization_channels - # x = 2 * x - 1 + Parameters + ---------- + quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int] + Number of evenly spaced quantization channels to quantize + to. - # residual = (self.audio_data - x).detach() - # self.audio_data = self.audio_data - residual - # return self + Returns + ------- + AudioSignal + Quantized AudioSignal. + """ + quantization_channels = util.ensure_tensor( + quantization_channels, ndim=3) + + x = self.audio_data + x = (x + 1) / 2 + x = x * quantization_channels + x = x.floor() + x = x / quantization_channels + x = 2 * x - 1 + + residual = (self.audio_data - x).detach() + self.audio_data = self.audio_data - residual + return self - # def mulaw_quantization( - # self, quantization_channels: typing.Union[paddle.Tensor, np.ndarray, int] - # ): - # """Applies mu-law quantization to the input waveform. + def mulaw_quantization(self, + quantization_channels: typing.Union[ + paddle.Tensor, np.ndarray, int]): + """Applies mu-law quantization to the input waveform. - # Parameters - # ---------- - # quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int] - # Number of mu-law spaced quantization channels to quantize - # to. + Parameters + ---------- + quantization_channels : typing.Union[paddle.Tensor, np.ndarray, int] + Number of mu-law spaced quantization channels to quantize + to. - # Returns - # ------- - # AudioSignal - # Quantized AudioSignal. - # """ - # mu = quantization_channels - 1.0 - # mu = util.ensure_tensor(mu, ndim=3) + Returns + ------- + AudioSignal + Quantized AudioSignal. + """ + mu = quantization_channels - 1.0 + mu = util.ensure_tensor(mu, ndim=3) - # x = self.audio_data + x = self.audio_data - # # quantize - # x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) - # x = ((x + 1) / 2 * mu + 0.5).to(torch.int64) + # quantize + x = paddle.sign(x) * paddle.log1p(mu * paddle.abs(x)) / paddle.log1p(mu) + x = ((x + 1) / 2 * mu + 0.5).astype("int64") - # # unquantize - # x = (x / mu) * 2 - 1.0 - # x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu + # unquantize + x = (x / mu) * 2 - 1.0 + x = paddle.sign(x) * ( + paddle.exp(paddle.abs(x) * paddle.log1p(mu)) - 1.0) / mu - # residual = (self.audio_data - x).detach() - # self.audio_data = self.audio_data - residual - # return self + residual = (self.audio_data - x).detach() + self.audio_data = self.audio_data - residual + return self - # def __matmul__(self, other): - # return self.convolve(other) + def __matmul__(self, other): + return self.convolve(other) import paddle @@ -591,13 +594,16 @@ class ImpulseResponseMixin: # direct path and windowed residual. window = paddle.zeros_like(self.audio_data) + window_idx = paddle.nonzero(early_idx) for idx in range(self.batch_size): - window_idx = early_idx[idx, 0] + # window_idx = early_idx[idx, 0] # ----- Just for this ----- # window[idx, ..., window_idx] = self.get_window("hann", window_idx.sum().item()) - indices = paddle.nonzero(window_idx).reshape( - [-1]) # shape: [num_true], dtype: int64 + # indices = paddle.nonzero(window_idx).reshape( + # [-1]) # shape: [num_true], dtype: int64 + indices = window_idx[window_idx[:, 0] == idx][:, -1] + temp_window = self.get_window("hann", indices.shape[0]) window_slice = window[idx, 0] @@ -639,9 +645,9 @@ class ImpulseResponseMixin: l_sq = late_field**2 a = (wd_sq * e_sq).sum(axis=-1) b = (2 * (1 - wd) * wd * e_sq).sum(axis=-1) - c = (wd_sq_1 * e_sq).sum(axis=-1) - paddle.pow( - 10 * paddle.ones_like(target_drr), target_drr / 10) * l_sq.sum( - axis=-1) + c = (wd_sq_1 * e_sq).sum(axis=-1) - paddle.pow(10 * paddle.ones_like( + target_drr, dtype="float32"), target_drr.cast("float32") / + 10) * l_sq.sum(axis=-1) expr = ((b**2) - 4 * a * c).sqrt() alpha = paddle.maximum( diff --git a/audio/audiotools/core/resample.py b/audio/audiotools/core/resample.py deleted file mode 100644 index 46ee5b222..000000000 --- a/audio/audiotools/core/resample.py +++ /dev/null @@ -1,231 +0,0 @@ -import inspect -import math -from typing import Optional -from typing import Sequence - -import paddle -import paddle.nn.functional as F - - -def simple_repr(obj, attrs: Optional[Sequence[str]]=None, overrides: dict={}): - """ - Return a simple representation string for `obj`. - If `attrs` is not None, it should be a list of attributes to include. - """ - params = inspect.signature(obj.__class__).parameters - attrs_repr = [] - if attrs is None: - attrs = list(params.keys()) - for attr in attrs: - display = False - if attr in overrides: - value = overrides[attr] - elif hasattr(obj, attr): - value = getattr(obj, attr) - else: - continue - if attr in params: - param = params[attr] - if param.default is inspect._empty or value != param.default: # type: ignore - display = True - else: - display = True - - if display: - attrs_repr.append(f"{attr}={value}") - return f"{obj.__class__.__name__}({','.join(attrs_repr)})" - - -def sinc(x: paddle.Tensor): - """ - Implementation of sinc, i.e. sin(x) / x - - __Warning__: the input is not multiplied by `pi`! - """ - return paddle.where( - x == 0, - paddle.to_tensor(1.0, dtype=x.dtype, place=x.place), - paddle.sin(x) / x, ) - - -class ResampleFrac(paddle.nn.Layer): - """ - Resampling from the sample rate `old_sr` to `new_sr`. - """ - - def __init__(self, - old_sr: int, - new_sr: int, - zeros: int=24, - rolloff: float=0.945): - """ - Args: - old_sr (int): sample rate of the input signal x. - new_sr (int): sample rate of the output. - zeros (int): number of zero crossing to keep in the sinc filter. - rolloff (float): use a lowpass filter that is `rolloff * new_sr / 2`, - to ensure sufficient margin due to the imperfection of the FIR filter used. - Lowering this value will reduce anti-aliasing, but will reduce some of the - highest frequencies. - - Shape: - - - Input: `[*, T]` - - Output: `[*, T']` with `T' = int(new_sr * T / old_sr)` - - - .. caution:: - After dividing `old_sr` and `new_sr` by their GCD, both should be small - for this implementation to be fast. - - >>> import paddle - >>> resample = ResampleFrac(4, 5) - >>> x = paddle.randn([1000]) - >>> print(len(resample(x))) - 1250 - """ - super(ResampleFrac, self).__init__() - if not isinstance(old_sr, int) or not isinstance(new_sr, int): - raise ValueError("old_sr and new_sr should be integers") - gcd = math.gcd(old_sr, new_sr) - self.old_sr = old_sr // gcd - self.new_sr = new_sr // gcd - self.zeros = zeros - self.rolloff = rolloff - - self._init_kernels() - - def _init_kernels(self): - if self.old_sr == self.new_sr: - return - - kernels = [] - sr = min(self.new_sr, self.old_sr) - # rolloff will perform antialiasing filtering by removing the highest frequencies. - # At first I thought I only needed this when downsampling, but when upsampling - # you will get edge artifacts without this, the edge is equivalent to zero padding, - # which will add high freq artifacts. - sr *= self.rolloff - - # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor) - # using the sinc interpolation formula: - # x(t) = sum_i x[i] sinc(pi * old_sr * (i / old_sr - t)) - # We can then sample the function x(t) with a different sample rate: - # y[j] = x(j / new_sr) - # or, - # y[j] = sum_i x[i] sinc(pi * old_sr * (i / old_sr - j / new_sr)) - - # We see here that y[j] is the convolution of x[i] with a specific filter, for which - # we take an FIR approximation, stopping when we see at least `zeros` zeros crossing. - # But y[j+1] is going to have a different set of weights and so on, until y[j + new_sr]. - # Indeed: - # y[j + new_sr] = sum_i x[i] sinc(pi * old_sr * ((i / old_sr - (j + new_sr) / new_sr)) - # = sum_i x[i] sinc(pi * old_sr * ((i - old_sr) / old_sr - j / new_sr)) - # = sum_i x[i + old_sr] sinc(pi * old_sr * (i / old_sr - j / new_sr)) - # so y[j+new_sr] uses the same filter as y[j], but on a shifted version of x by `old_sr`. - # This will explain the F.conv1d after, with a stride of old_sr. - self._width = math.ceil(self.zeros * self.old_sr / sr) - # If old_sr is still big after GCD reduction, most filters will be very unbalanced, i.e., - # they will have a lot of almost zero values to the left or to the right... - # There is probably a way to evaluate those filters more efficiently, but this is kept for - # future work. - idx = paddle.arange( - -self._width, self._width + self.old_sr, dtype="float32") - for i in range(self.new_sr): - t = (-i / self.new_sr + idx / self.old_sr) * sr - t = paddle.clip(t, -self.zeros, self.zeros) - t *= math.pi - window = paddle.cos(t / self.zeros / 2)**2 - kernel = sinc(t) * window - # Renormalize kernel to ensure a constant signal is preserved. - kernel = kernel / kernel.sum() - kernels.append(kernel) - - _kernel = paddle.stack(kernels).reshape([self.new_sr, 1, -1]) - self.kernel = self.create_parameter( - shape=_kernel.shape, - dtype=_kernel.dtype, ) - self.kernel.set_value(_kernel) - - def forward( - self, - x: paddle.Tensor, - output_length: Optional[int]=None, - full: bool=False, ): - """ - Resample x. - Args: - x (Tensor): signal to resample, time should be the last dimension - output_length (None or int): This can be set to the desired output length - (last dimension). Allowed values are between 0 and - ceil(length * new_sr / old_sr). When None (default) is specified, the - floored output length will be used. In order to select the largest possible - size, use the `full` argument. - full (bool): return the longest possible output from the input. This can be useful - if you chain resampling operations, and want to give the `output_length` only - for the last one, while passing `full=True` to all the other ones. - """ - if self.old_sr == self.new_sr: - return x - shape = x.shape - _dtype = x.dtype - length = x.shape[-1] - x = x.reshape([-1, length]) - x = F.pad( - x.unsqueeze(1), - [self._width, self._width + self.old_sr], - mode="replicate", - data_format="NCL", ).astype(self.kernel.dtype) - ys = F.conv1d(x, self.kernel, stride=self.old_sr, data_format="NCL") - y = ys.transpose( - [0, 2, 1]).reshape(list(shape[:-1]) + [-1]).astype(_dtype) - - float_output_length = paddle.to_tensor( - self.new_sr * length / self.old_sr, dtype="float32") - max_output_length = paddle.ceil(float_output_length).astype("int64") - default_output_length = paddle.floor(float_output_length).astype( - "int64") - - if output_length is None: - applied_output_length = (max_output_length - if full else default_output_length) - elif output_length < 0 or output_length > max_output_length: - raise ValueError( - f"output_length must be between 0 and {max_output_length.numpy()}" - ) - else: - applied_output_length = paddle.to_tensor( - output_length, dtype="int64") - if full: - raise ValueError( - "You cannot pass both full=True and output_length") - return y[..., :applied_output_length] - - def __repr__(self): - return simple_repr(self) - - -def resample_frac( - x: paddle.Tensor, - old_sr: int, - new_sr: int, - zeros: int=24, - rolloff: float=0.945, - output_length: Optional[int]=None, - full: bool=False, ): - """ - Functional version of `ResampleFrac`, refer to its documentation for more information. - - ..warning:: - If you call repeatidly this functions with the same sample rates, then the - resampling kernel will be recomputed everytime. For best performance, you should use - and cache an instance of `ResampleFrac`. - """ - return ResampleFrac(old_sr, new_sr, zeros, rolloff)(x, output_length, full) - - -if __name__ == "__main__": - - resample = ResampleFrac(4, 5) - x = paddle.randn([1000]) - print(len(resample(x))) diff --git a/audio/tests/audiotools/core/test_effects✅.py b/audio/tests/audiotools/core/test_effects✅.py new file mode 100644 index 000000000..e798f06a6 --- /dev/null +++ b/audio/tests/audiotools/core/test_effects✅.py @@ -0,0 +1,363 @@ +import sys + +import numpy as np +import paddle +import pytest +sys.path.append("/home/aistudio/PaddleSpeech/audio") +from audiotools import AudioSignal + + +def test_normalize(): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=10) + signal = signal.normalize() + assert np.allclose(signal.loudness(), -24, atol=1e-1) + + array = np.random.randn(1, 2, 32000) + array = array / np.abs(array).max() + + signal = AudioSignal(array, sample_rate=16000) + for db_incr in np.arange(10, 75, 5): + db = -80 + db_incr + signal = signal.normalize(db) + loudness = signal.loudness() + assert np.allclose(loudness, db, atol=1) # TODO, atol=1e-1 + + batch_size = 16 + db = -60 + paddle.linspace(10, 30, batch_size) + + array = np.random.randn(batch_size, 2, 32000) + array = array / np.abs(array).max() + signal = AudioSignal(array, sample_rate=16000) + + signal = signal.normalize(db) + assert np.allclose(signal.loudness(), db, 1e-1) + + +def test_volume_change(): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + signal = AudioSignal(audio_path, offset=10, duration=10) + + boost = 3 + before_db = signal.loudness().clone() + signal = signal.volume_change(boost) + after_db = signal.loudness() + assert np.allclose(before_db + boost, after_db) + + signal._loudness = None + after_db = signal.loudness() + assert np.allclose(before_db + boost, after_db, 1e-1) + + +def test_mix(): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=10) + + audio_path = "tests/audiotools/audio/nz/f5_script2_ipad_balcony1_room_tone.wav" + nz = AudioSignal(audio_path, offset=10, duration=10) + + spk.deepcopy().mix(nz, snr=-10) + snr = spk.loudness() - nz.loudness() + assert np.allclose(snr, -10, atol=1) + + # Test in batch + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=10) + + audio_path = "tests/audiotools/audio/nz/f5_script2_ipad_balcony1_room_tone.wav" + nz = AudioSignal(audio_path, offset=10, duration=10) + + batch_size = 4 + tgt_snr = paddle.linspace(-10, 10, batch_size) + + spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)]) + nz_batch = AudioSignal.batch([nz.deepcopy() for _ in range(batch_size)]) + + spk_batch.deepcopy().mix(nz_batch, snr=tgt_snr) + snr = spk_batch.loudness() - nz_batch.loudness() + assert np.allclose(snr, tgt_snr, atol=1) + + # Test with "EQing" the other signal + db = 0 + 0 * paddle.rand([10]) + spk_batch.deepcopy().mix(nz_batch, snr=tgt_snr, other_eq=db) + snr = spk_batch.loudness() - nz_batch.loudness() + assert np.allclose(snr, tgt_snr, atol=1) + + +def test_convolve(): + np.random.seed(6) # Found a failing seed + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=10) + + impulse = np.zeros((1, 16000), dtype="float32") + impulse[..., 0] = 1 + ir = AudioSignal(impulse, 16000) + batch_size = 4 + + spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)]) + ir_batch = AudioSignal.batch( + [ + ir.deepcopy().zero_pad(np.random.randint(1000), 0) + for _ in range(batch_size) + ], + pad_signals=True, ) + + convolved = spk_batch.deepcopy().convolve(ir_batch) + assert convolved == spk_batch + + # Short duration + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=0.1) + + impulse = np.zeros((1, 16000), dtype="float32") + impulse[..., 0] = 1 + ir = AudioSignal(impulse, 16000) + batch_size = 4 + + spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)]) + ir_batch = AudioSignal.batch( + [ + ir.deepcopy().zero_pad(np.random.randint(1000), 0) + for _ in range(batch_size) + ], + pad_signals=True, ) + + convolved = spk_batch.deepcopy().convolve(ir_batch) + assert convolved == spk_batch + + +def test_pipeline(): + # An actual IR, no batching + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=5) + + audio_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav" + ir = AudioSignal(audio_path) + spk.deepcopy().convolve(ir) + + audio_path = "tests/audiotools/audio/nz/f5_script2_ipad_balcony1_room_tone.wav" + nz = AudioSignal(audio_path, offset=10, duration=5) + + batch_size = 16 + tgt_snr = paddle.linspace(20, 30, batch_size) + + (spk @ ir).mix(nz, snr=tgt_snr) + + +# def test_codec(): + +# audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" +# spk = AudioSignal(audio_path, offset=10, duration=10) + +# with pytest.raises(ValueError): +# spk.apply_codec("unknown preset") + +# out = spk.deepcopy().apply_codec("Ogg") +# out = spk.deepcopy().apply_codec("8-bit") + +# def test_pitch_shift(): +# audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" +# spk = AudioSignal(audio_path, offset=10, duration=1) + +# single = spk.deepcopy().pitch_shift(5) + +# batch_size = 4 +# spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)]) + +# batched = spk_batch.deepcopy().pitch_shift(5) + +# assert np.allclose(batched[0].audio_data, single[0].audio_data) + +# def test_time_stretch(): +# audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" +# spk = AudioSignal(audio_path, offset=10, duration=1) + +# single = spk.deepcopy().time_stretch(0.8) + +# batch_size = 4 +# spk_batch = AudioSignal.batch([spk.deepcopy() for _ in range(batch_size)]) + +# batched = spk_batch.deepcopy().time_stretch(0.8) + +# assert np.allclose(batched[0].audio_data, single[0].audio_data) + + +@pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16]) +def test_mel_filterbank(n_bands): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=1) + fbank = spk.deepcopy().mel_filterbank(n_bands) + + assert paddle.allclose(fbank.sum(-1), spk.audio_data, atol=1e-6) + + # Check if it works in batches. + spk_batch = AudioSignal.batch([ + AudioSignal.excerpt( + "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2) + for _ in range(16) + ]) + fbank = spk_batch.deepcopy().mel_filterbank(n_bands) + summed = fbank.sum(-1) + assert paddle.allclose(summed, spk_batch.audio_data, atol=1e-6) + + +@pytest.mark.parametrize("n_bands", [1, 2, 4, 8, 12, 16]) +def test_equalizer(n_bands): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=10) + + db = -3 + 1 * paddle.rand([n_bands]) + spk.deepcopy().equalizer(db) + + db = -3 + 1 * np.random.rand(n_bands) + spk.deepcopy().equalizer(db) + + audio_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav" + ir = AudioSignal(audio_path) + db = -3 + 1 * paddle.rand([n_bands]) + + spk.deepcopy().convolve(ir.equalizer(db)) + + spk_batch = AudioSignal.batch([ + AudioSignal.excerpt( + "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2) + for _ in range(16) + ]) + + db = paddle.zeros([spk_batch.batch_size, n_bands]) + output = spk_batch.deepcopy().equalizer(db) + + assert output == spk_batch + + +def test_clip_distortion(): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=2) + clipped = spk.deepcopy().clip_distortion(0.05) + + spk_batch = AudioSignal.batch([ + AudioSignal.excerpt( + "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2) + for _ in range(16) + ]) + percs = paddle.to_tensor(np.random.uniform(size=(16, ))).astype("float32") + clipped_batch = spk_batch.deepcopy().clip_distortion(percs) + + assert clipped.audio_data.abs().max() < 1.0 + assert clipped_batch.audio_data.abs().max() < 1.0 + + +@pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128]) +def test_quantization(quant_ch): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=2) + + quantized = spk.deepcopy().quantization(quant_ch) + + # Need to round audio_data off because torch ops with straight + # through estimator are sometimes a bit off past 3 decimal places. + found_quant_ch = len(np.unique(np.around(quantized.audio_data, decimals=3))) + assert found_quant_ch <= quant_ch + + spk_batch = AudioSignal.batch([ + AudioSignal.excerpt( + "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2) + for _ in range(16) + ]) + + quant_ch = np.random.choice( + [2, 4, 8, 16, 32, 64, 128], size=(16, ), replace=True) + quantized = spk_batch.deepcopy().quantization(quant_ch) + + for i, q_ch in enumerate(quant_ch): + found_quant_ch = len( + np.unique(np.around(quantized.audio_data[i], decimals=3))) + assert found_quant_ch <= q_ch + + +@pytest.mark.parametrize("quant_ch", [2, 4, 8, 16, 32, 64, 128]) +def test_mulaw_quantization(quant_ch): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + spk = AudioSignal(audio_path, offset=10, duration=2) + + quantized = spk.deepcopy().mulaw_quantization(quant_ch) + + # Need to round audio_data off because torch ops with straight + # through estimator are sometimes a bit off past 3 decimal places. + found_quant_ch = len(np.unique(np.around(quantized.audio_data, decimals=3))) + assert found_quant_ch <= quant_ch + + spk_batch = AudioSignal.batch([ + AudioSignal.excerpt( + "tests/audiotools/audio/spk/f10_script4_produced.wav", duration=2) + for _ in range(16) + ]) + + quant_ch = np.random.choice( + [2, 4, 8, 16, 32, 64, 128], size=(16, ), replace=True) + quantized = spk_batch.deepcopy().mulaw_quantization(quant_ch) + + for i, q_ch in enumerate(quant_ch): + found_quant_ch = len( + np.unique(np.around(quantized.audio_data[i], decimals=3))) + assert found_quant_ch <= q_ch + + +def test_impulse_response_augmentation(): + audio_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav" + batch_size = 16 + ir = AudioSignal(audio_path) + ir_batch = AudioSignal.batch([ir for _ in range(batch_size)]) + early_response, late_field, window = ir_batch.decompose_ir() + + assert early_response.shape == late_field.shape + assert late_field.shape == window.shape + + drr = ir_batch.measure_drr() + + alpha = AudioSignal.solve_alpha(early_response, late_field, window, drr) + assert np.allclose(alpha, np.ones_like(alpha), 1e-5) + + target_drr = 5 + out = ir_batch.deepcopy().alter_drr(target_drr) + drr = out.measure_drr() + assert np.allclose(drr, np.ones_like(drr) * target_drr) + + target_drr = np.random.rand(batch_size).astype("float32") * 50 + altered_ir = ir_batch.deepcopy().alter_drr(target_drr) + drr = altered_ir.measure_drr() + assert np.allclose(drr.flatten(), target_drr.flatten()) + + +def test_apply_ir(): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + ir_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav" + + spk = AudioSignal(audio_path, offset=10, duration=2) + ir = AudioSignal(ir_path) + db = 0 + 0 * paddle.rand([10]) + output = spk.deepcopy().apply_ir(ir, drr=10, ir_eq=db) + + assert np.allclose(ir.measure_drr().flatten(), 10) + + output = spk.deepcopy().apply_ir( + ir, drr=10, ir_eq=db, use_original_phase=True) + + +def test_ensure_max_of_audio(): + spk = AudioSignal(paddle.randn([1, 1, 44100]), 44100) + + max_vals = [1.0] + [np.random.rand() for _ in range(10)] + for val in max_vals: + after = spk.deepcopy().ensure_max_of_audio(val) + assert after.audio_data.abs().max() <= val + 1e-3 + + # Make sure it does nothing to a tiny signal + spk = AudioSignal(paddle.rand([1, 1, 44100]), 44100) + spk.audio_data = spk.audio_data * 0.5 + after = spk.deepcopy().ensure_max_of_audio() + + assert paddle.allclose(after.audio_data, spk.audio_data) + + +test_normalize() diff --git a/audio/tests/audiotools/core/test_grad✅.py b/audio/tests/audiotools/core/test_grad✅.py new file mode 100644 index 000000000..d5ef3f307 --- /dev/null +++ b/audio/tests/audiotools/core/test_grad✅.py @@ -0,0 +1,168 @@ +import sys +from typing import Callable + +import numpy as np +import paddle +import pytest +sys.path.append("/home/aistudio/PaddleSpeech/audio") +from audiotools import AudioSignal + + +def test_audio_grad(): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + ir_path = "tests/audiotools/audio/ir/h179_Bar_1txts.wav" + + def _test_audio_grad(attr: str, target=True, kwargs: dict={}): + signal = AudioSignal(audio_path) + signal.audio_data.stop_gradient = False + + assert signal.audio_data.grad is None + + # Avoid overwriting leaf tensor by cloning signal + attr = getattr(signal.clone(), attr) + result = attr(**kwargs) if isinstance(attr, Callable) else attr + + try: + if isinstance(result, AudioSignal): + # If necessary, propagate spectrogram changes to waveform + if result.stft_data is not None: + result.istft() + # if result.audio_data.dtype.is_complex: + if paddle.is_complex(result.audio_data): + result.audio_data.real.sum().backward() + else: + result.audio_data.sum().backward() + else: + # if result.dtype.is_complex: + if paddle.is_complex(result): + result.real().sum().backward() + else: + result.sum().backward() + + assert signal.audio_data.grad is not None or not target + except RuntimeError: + assert not target + + for a in [ + ["mix", True, { + "other": AudioSignal(audio_path), + "snr": 0 + }], + ["convolve", True, { + "other": AudioSignal(ir_path) + }], + [ + "apply_ir", + True, + { + "ir": AudioSignal(ir_path), + "drr": 0.1, + "ir_eq": paddle.randn([6]) + }, + ], + ["ensure_max_of_audio", True], + ["normalize", True], + ["volume_change", True, { + "db": 1 + }], + # ["pitch_shift", False, {"n_semitones": 1}], + # ["time_stretch", False, {"factor": 2}], + # ["apply_codec", False], + ["equalizer", True, { + "db": paddle.randn([6]) + }], + ["clip_distortion", True, { + "clip_percentile": 0.5 + }], + ["quantization", True, { + "quantization_channels": 8 + }], + ["mulaw_quantization", True, { + "quantization_channels": 8 + }], + ["resample", True, { + "sample_rate": 16000 + }], + ["low_pass", True, { + "cutoffs": 1000 + }], + ["high_pass", True, { + "cutoffs": 1000 + }], + ["to_mono", True], + ["zero_pad", True, { + "before": 10, + "after": 10 + }], + ["magnitude", True], + ["phase", True], + ["log_magnitude", True], + ["loudness", False], + ["stft", True], + ["clone", True], + ["mel_spectrogram", True], + ["zero_pad_to", True, { + "length": 100000 + }], + ["truncate_samples", True, { + "length_in_samples": 1000 + }], + ["corrupt_phase", True, { + "scale": 0.5 + }], + ["shift_phase", True, { + "shift": 1 + }], + ["mask_low_magnitudes", True, { + "db_cutoff": 0 + }], + ["mask_frequencies", True, { + "fmin_hz": 100, + "fmax_hz": 1000 + }], + ["mask_timesteps", True, { + "tmin_s": 0.1, + "tmax_s": 0.5 + }], + ["__add__", True, { + "other": AudioSignal(audio_path) + }], + ["__iadd__", True, { + "other": AudioSignal(audio_path) + }], + ["__radd__", True, { + "other": AudioSignal(audio_path) + }], + ["__sub__", True, { + "other": AudioSignal(audio_path) + }], + ["__isub__", True, { + "other": AudioSignal(audio_path) + }], + ["__mul__", True, { + "other": AudioSignal(audio_path) + }], + ["__imul__", True, { + "other": AudioSignal(audio_path) + }], + ["__rmul__", True, { + "other": AudioSignal(audio_path) + }], + ]: + _test_audio_grad(*a) + + +def test_batch_grad(): + audio_path = "tests/audiotools/audio/spk/f10_script4_produced.wav" + + signal = AudioSignal(audio_path) + signal.audio_data.stop_gradient = False + + assert signal.audio_data.grad is None + + batch_size = 16 + batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)]) + + batch.audio_data.sum().backward() + + assert signal.audio_data.grad is not None