From e16999b3c3b7110d591f72c177740a79469892d4 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Wed, 8 Jan 2025 07:48:20 +0000 Subject: [PATCH] Adapt to paddle3.0 && update readme --- audio/audiotools/README.md | 63 ++++++++++++++++++++++-- audio/audiotools/core/dsp.py | 5 ++ audio/audiotools/core/effects.py | 5 +- audio/audiotools/core/loudness.py | 2 +- audio/audiotools/core/util.py | 4 +- audio/tests/audiotools/core/test_util.py | 2 +- 6 files changed, 71 insertions(+), 10 deletions(-) diff --git a/audio/audiotools/README.md b/audio/audiotools/README.md index b28776ebf..a0eac3675 100644 --- a/audio/audiotools/README.md +++ b/audio/audiotools/README.md @@ -2,12 +2,67 @@ Audiotools is a comprehensive toolkit designed for audio processing and analysis ### Directory Structure -- **core directory**: Contains the core class AudioSignal, which is responsible for the fundamental representation and manipulation of audio signals. +``` +. +├── audiotools +│ ├── README.md +│ ├── __init__.py +│ ├── core +│ │ ├── __init__.py +│ │ ├── _julius.py +│ │ ├── audio_signal.py +│ │ ├── display.py +│ │ ├── dsp.py +│ │ ├── effects.py +│ │ ├── ffmpeg.py +│ │ ├── loudness.py +│ │ └── util.py +│ ├── data +│ │ ├── __init__.py +│ │ ├── datasets.py +│ │ ├── preprocess.py +│ │ └── transforms.py +│ ├── metrics +│ │ ├── __init__.py +│ │ └── quality.py +│ ├── ml +│ │ ├── __init__.py +│ │ ├── accelerator.py +│ │ ├── basemodel.py +│ │ └── decorators.py +│ ├── requirements.txt +│ └── post.py +├── tests +│ └── audiotools +│ ├── core +│ │ ├── test_audio_signal.py +│ │ ├── test_bands.py +│ │ ├── test_display.py +│ │ ├── test_dsp.py +│ │ ├── test_effects.py +│ │ ├── test_fftconv.py +│ │ ├── test_grad.py +│ │ ├── test_highpass.py +│ │ ├── test_loudness.py +│ │ ├── test_lowpass.py +│ │ └── test_util.py +│ ├── data +│ │ ├── test_datasets.py +│ │ ├── test_preprocess.py +│ │ └── test_transforms.py +│ ├── ml +│ │ ├── test_decorators.py +│ │ └── test_model.py +│ └── test_post.py -- **data directory**: Primarily dedicated to storing and processing datasets, including classes and functions for data preprocessing, ensuring efficient loading and transformation of audio data. +``` -- **metrics directory**: Implements functions for various audio evaluation metrics, enabling precise assessment of the performance of audio models and processing algorithms. +- **core**: Contains the core class AudioSignal, which is responsible for the fundamental representation and manipulation of audio signals. -- **ml directory**: Comprises classes and methods related to model training, supporting the construction, training, and optimization of machine learning models in the context of audio. +- **data**: Primarily dedicated to storing and processing datasets, including classes and functions for data preprocessing, ensuring efficient loading and transformation of audio data. + +- **metrics**: Implements functions for various audio evaluation metrics, enabling precise assessment of the performance of audio models and processing algorithms. + +- **ml**: Comprises classes and methods related to model training, supporting the construction, training, and optimization of machine learning models in the context of audio. This project aims to provide developers and researchers with an efficient and flexible framework to foster innovation and exploration across various domains of audio technology. diff --git a/audio/audiotools/core/dsp.py b/audio/audiotools/core/dsp.py index c1f1e9efb..c434e30da 100644 --- a/audio/audiotools/core/dsp.py +++ b/audio/audiotools/core/dsp.py @@ -349,6 +349,9 @@ class DSPMixin: nbins, ) bins_hz = bins_hz[None, None, :, None].tile( [self.batch_size, 1, 1, mag.shape[-1]]) + + fmin_hz, fmax_hz = fmin_hz.astype(bins_hz.dtype), fmax_hz.astype( + bins_hz.dtype) mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz) mag = paddle.where(mask, paddle.full_like(mag, val), mag) @@ -429,6 +432,7 @@ class DSPMixin: log_mag = self.log_magnitude() db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim) + db_cutoff = db_cutoff.astype(log_mag.dtype) mask = log_mag < db_cutoff # mag = mag.masked_fill(mask, val) mag = paddle.where(mask, mag, val * paddle.ones_like(mag)) @@ -452,6 +456,7 @@ class DSPMixin: masked audio data. """ shift = util.ensure_tensor(shift, ndim=self.phase.ndim) + shift = shift.astype(self.phase.dtype) self.phase = self.phase + shift return self diff --git a/audio/audiotools/core/effects.py b/audio/audiotools/core/effects.py index 9efbe681a..ed26df6c9 100644 --- a/audio/audiotools/core/effects.py +++ b/audio/audiotools/core/effects.py @@ -266,7 +266,7 @@ class EffectMixin: """ db = util.ensure_tensor(db) ref_db = self.loudness() - gain = db - ref_db + gain = db.astype(ref_db.dtype) - ref_db gain = util.exp_compat(gain * self.GAIN_FACTOR) self.audio_data = self.audio_data * gain[:, None, None] @@ -388,6 +388,7 @@ class EffectMixin: quantization_channels, ndim=3) x = self.audio_data + quantization_channels = quantization_channels.astype(x.dtype) x = (x + 1) / 2 x = x * quantization_channels x = x.floor() @@ -424,7 +425,7 @@ class EffectMixin: x = ((x + 1) / 2 * mu + 0.5).astype("int64") # unquantize - x = (x / mu) * 2 - 1.0 + x = (x.astype(mu.dtype) / mu) * 2 - 1.0 x = paddle.sign(x) * ( util.exp_compat(paddle.abs(x) * paddle.log1p(mu)) - 1.0) / mu diff --git a/audio/audiotools/core/loudness.py b/audio/audiotools/core/loudness.py index 27d2ca98f..ea24b09d8 100644 --- a/audio/audiotools/core/loudness.py +++ b/audio/audiotools/core/loudness.py @@ -317,7 +317,7 @@ class Meter(paddle.nn.Layer): z_avg_gated = z z_avg_gated[l <= Gamma_a] = 0 masked = l > Gamma_a - z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) + z_avg_gated = z_avg_gated.sum(2) / masked.sum(2).astype("float32") # calculate the relative threshold value (see eq. 6) Gamma_r = -0.691 + 10.0 * paddle.log10( diff --git a/audio/audiotools/core/util.py b/audio/audiotools/core/util.py index 78718e2d7..3e723d879 100644 --- a/audio/audiotools/core/util.py +++ b/audio/audiotools/core/util.py @@ -338,7 +338,7 @@ def _close_temp_files(tmpfiles: list): _close() -AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] +AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3"] def find_audio(folder: str, ext: List[str]=AUDIO_EXTENSIONS): @@ -869,7 +869,7 @@ def hz_to_bin(hz: paddle.Tensor, n_fft: int, sample_rate: int): shape = hz.shape hz = hz.reshape([-1]) freqs = paddle.linspace(0, sample_rate / 2, 2 + n_fft // 2) - hz = paddle.clip(hz, max=sample_rate / 2) + hz = paddle.clip(hz, max=sample_rate / 2).astype(freqs.dtype) closest = (hz[None, :] - freqs[:, None]).abs() closest_bins = closest.argmin(axis=0) diff --git a/audio/tests/audiotools/core/test_util.py b/audio/tests/audiotools/core/test_util.py index fd9fdc241..b350732d6 100644 --- a/audio/tests/audiotools/core/test_util.py +++ b/audio/tests/audiotools/core/test_util.py @@ -88,7 +88,7 @@ def test_seed(): def test_hz_to_bin(): - hz = paddle.to_tensor(np.array([100, 200, 300])) + hz = paddle.to_tensor(np.array([100, 200, 300]), dtype="float32") sr = 1000 n_fft = 2048