From c93cdea39fa34f9cbd3ff4295dccb34852201510 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Tue, 31 Dec 2024 05:48:44 +0000 Subject: [PATCH] fix exp --- audio/audiotools/core/audio_signal.py | 4 +-- audio/audiotools/core/dsp.py | 4 +-- audio/audiotools/core/effects.py | 8 ++--- audio/audiotools/core/util.py | 27 ++++++++++++++++ .../audiotools/core/test_audio_signal✅.py | 5 +-- .../tests/audiotools/ml/test_decorators✅.py | 32 ++++++++++++------- 6 files changed, 59 insertions(+), 21 deletions(-) diff --git a/audio/audiotools/core/audio_signal.py b/audio/audiotools/core/audio_signal.py index 0d64ab7b8..ef9f35fe3 100644 --- a/audio/audiotools/core/audio_signal.py +++ b/audio/audiotools/core/audio_signal.py @@ -1479,7 +1479,7 @@ class AudioSignal( @magnitude.setter def magnitude(self, value): - self.stft_data = value * paddle.exp(1j * self.phase) + self.stft_data = value * util.exp_compat(1j * self.phase) return def log_magnitude(self, @@ -1544,7 +1544,7 @@ class AudioSignal( @phase.setter def phase(self, value): # - self.stft_data = self.magnitude * paddle.exp(1j * value) + self.stft_data = self.magnitude * util.exp_compat(1j * value) return # Operator overloading diff --git a/audio/audiotools/core/dsp.py b/audio/audiotools/core/dsp.py index 20d7708a9..d6d2a4f94 100644 --- a/audio/audiotools/core/dsp.py +++ b/audio/audiotools/core/dsp.py @@ -313,7 +313,7 @@ class DSPMixin: mag = paddle.where(mask, paddle.full_like(mag, val), mag) phase = paddle.where(mask, paddle.full_like(phase, val), phase) - self.stft_data = mag * paddle.exp(1j * phase) + self.stft_data = mag * util.exp_compat(1j * phase) return self def mask_timesteps( @@ -362,7 +362,7 @@ class DSPMixin: mag = paddle.where(mask, paddle.full_like(mag, val), mag) phase = paddle.where(mask, paddle.full_like(phase, val), phase) - self.stft_data = mag * paddle.exp(1j * phase) + self.stft_data = mag * util.exp_compat(1j * phase) return self def mask_low_magnitudes( diff --git a/audio/audiotools/core/effects.py b/audio/audiotools/core/effects.py index ff08ab31b..edaf35969 100644 --- a/audio/audiotools/core/effects.py +++ b/audio/audiotools/core/effects.py @@ -182,7 +182,7 @@ class EffectMixin: # Use the input phase if use_original_phase: self.stft() - self.stft_data = self.magnitude * paddle.exp(1j * phase) + self.stft_data = self.magnitude * util.exp_compat(1j * phase) self.istft() # Rescale to the input's amplitude @@ -230,7 +230,7 @@ class EffectMixin: db = util.ensure_tensor(db) ref_db = self.loudness() gain = db - ref_db - gain = paddle.exp(gain * self.GAIN_FACTOR) + gain = util.exp_compat(gain * self.GAIN_FACTOR) self.audio_data = self.audio_data * gain[:, None, None] return self @@ -249,7 +249,7 @@ class EffectMixin: Signal at new volume. """ db = util.ensure_tensor(db, ndim=1) - gain = paddle.exp(db * self.GAIN_FACTOR) + gain = util.exp_compat(db * self.GAIN_FACTOR) self.audio_data = self.audio_data * gain[:, None, None] return self @@ -535,7 +535,7 @@ class EffectMixin: # unquantize x = (x / mu) * 2 - 1.0 x = paddle.sign(x) * ( - paddle.exp(paddle.abs(x) * paddle.log1p(mu)) - 1.0) / mu + util.exp_compat(paddle.abs(x) * paddle.log1p(mu)) - 1.0) / mu residual = (self.audio_data - x).detach() self.audio_data = self.audio_data - residual diff --git a/audio/audiotools/core/util.py b/audio/audiotools/core/util.py index d08ade1e4..ad2bbc721 100644 --- a/audio/audiotools/core/util.py +++ b/audio/audiotools/core/util.py @@ -28,10 +28,37 @@ from flatten_dict import flatten from flatten_dict import unflatten from .audio_signal import AudioSignal +from paddlespeech.utils import satisfy_paddle_version # from ..data.preprocess import create_csv +def exp_compat(x): + """ + Compute the exponential of the input tensor `x`. + + This function is designed to handle compatibility issues with PaddlePaddle versions below 2.6, + which do not support the `exp` operation for complex tensors. In such cases, the computation + is offloaded to NumPy. + + Args: + x (paddle.Tensor): The input tensor for which to compute the exponential. + + Returns: + paddle.Tensor: The result of the exponential operation, as a PaddlePaddle tensor. + + Notes: + - If the PaddlePaddle version is 2.6 or above, the function uses `paddle.exp` directly. + - For versions below 2.6, the tensor is first converted to a NumPy array, the exponential + is computed using `np.exp`, and the result is then converted back to a PaddlePaddle tensor. + """ + if satisfy_paddle_version("2.6"): + return paddle.exp(x) + else: + x_np = x.cpu().numpy() + return paddle.to_tensor(np.exp(x_np)) + + @dataclass class Info: diff --git a/audio/tests/audiotools/core/test_audio_signal✅.py b/audio/tests/audiotools/core/test_audio_signal✅.py index fd470d7e8..3bcb4a166 100644 --- a/audio/tests/audiotools/core/test_audio_signal✅.py +++ b/audio/tests/audiotools/core/test_audio_signal✅.py @@ -10,6 +10,7 @@ import rich sys.path.append("../..") import audiotools from audiotools import AudioSignal +from audiotools import util def test_io(): @@ -421,7 +422,7 @@ def test_stft(window_length, hop_length, window_type): mag = signal.magnitude phase = signal.phase - recon_stft = mag * paddle.exp(1j * phase) + recon_stft = mag * util.exp_compat(1j * phase) # assert paddle.allclose(recon_stft, signal.stft_data) assert np.allclose(recon_stft.cpu().numpy(), signal.stft_data.cpu().numpy()) @@ -431,7 +432,7 @@ def test_stft(window_length, hop_length, window_type): signal.stft_data = None phase = signal.phase - recon_stft = mag * paddle.exp(1j * phase) + recon_stft = mag * util.exp_compat(1j * phase) # assert paddle.allclose(recon_stft, signal.stft_data) assert np.allclose(recon_stft.cpu().numpy(), signal.stft_data.cpu().numpy()) diff --git a/audio/tests/audiotools/ml/test_decorators✅.py b/audio/tests/audiotools/ml/test_decorators✅.py index 806949af7..dbd2ff6e2 100644 --- a/audio/tests/audiotools/ml/test_decorators✅.py +++ b/audio/tests/audiotools/ml/test_decorators✅.py @@ -7,6 +7,7 @@ from visualdl import LogWriter from audiotools.ml.decorators import timer from audiotools.ml.decorators import Tracker from audiotools.ml.decorators import when +from audiotools import util def test_all_decorators(): @@ -26,12 +27,16 @@ def test_all_decorators(): i = tracker.step time.sleep(0.01) return { - "loss": paddle.exp(paddle.to_tensor([-i / 100], dtype="float32")), - "mel": paddle.exp(paddle.to_tensor([-i / 100], dtype="float32")), - "stft": paddle.exp(paddle.to_tensor([-i / 100], dtype="float32")), + "loss": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "mel": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "stft": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), "waveform": - paddle.exp(paddle.to_tensor([-i / 100], dtype="float32")), - "not_scalar": paddle.arange(start=0, end=10, step=1, dtype="int64"), + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "not_scalar": + paddle.arange(start=0, end=10, step=1, dtype="int64"), } @tracker.track("val", len(val_data)) @@ -40,13 +45,18 @@ def test_all_decorators(): i = tracker.step time.sleep(0.01) return { - "loss": paddle.exp(paddle.to_tensor([-i / 100], dtype="float32")), - "mel": paddle.exp(paddle.to_tensor([-i / 100], dtype="float32")), - "stft": paddle.exp(paddle.to_tensor([-i / 100], dtype="float32")), + "loss": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "mel": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "stft": + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), "waveform": - paddle.exp(paddle.to_tensor([-i / 100], dtype="float32")), - "not_scalar": paddle.arange(10, dtype="int64"), - "string": "string", + util.exp_compat(paddle.to_tensor([-i / 100], dtype="float32")), + "not_scalar": + paddle.arange(10, dtype="int64"), + "string": + "string", } @when(lambda: tracker.step % 1000 == 0 and rank == 0)