From 1d2c078529583a4c936e5becd288642cf6f49768 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Tue, 31 Dec 2024 07:02:35 +0000 Subject: [PATCH] fix dim error --- audio/audiotools/core/audio_signal.py | 10 +++- audio/audiotools/core/dsp.py | 3 +- audio/audiotools/core/util.py | 58 +++++++++++++++++++ .../audiotools/data/test_transforms✅.py | 4 ++ 4 files changed, 71 insertions(+), 4 deletions(-) diff --git a/audio/audiotools/core/audio_signal.py b/audio/audiotools/core/audio_signal.py index 71504c34a..68de04731 100644 --- a/audio/audiotools/core/audio_signal.py +++ b/audio/audiotools/core/audio_signal.py @@ -1695,8 +1695,10 @@ class AudioSignal( audio_data = self.audio_data[key] _loudness = self._loudness[ key] if self._loudness is not None else None - stft_data = self.stft_data[ - key] if self.stft_data is not None else None + # stft_data = self.stft_data[ + # key] if self.stft_data is not None else None + stft_data = util.bool_index_compat( + self.stft_data, key) if self.stft_data is not None else None sources = None @@ -1732,7 +1734,9 @@ class AudioSignal( else: self._loudness[key] = value._loudness if self.stft_data is not None and value.stft_data is not None: - self.stft_data[key] = value.stft_data + # self.stft_data[key] = value.stft_data + self.stft_data = util.bool_setitem_compat(self.stft_data, key, + value.stft_data) return def __ne__(self, other): diff --git a/audio/audiotools/core/dsp.py b/audio/audiotools/core/dsp.py index e89196094..b190af0cc 100644 --- a/audio/audiotools/core/dsp.py +++ b/audio/audiotools/core/dsp.py @@ -391,7 +391,8 @@ class DSPMixin: db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim) mask = log_mag < db_cutoff - mag = mag.masked_fill(mask, val) + # mag = mag.masked_fill(mask, val) + mag = paddle.where(mask, mag, val * paddle.ones_like(mag)) self.magnitude = mag return self diff --git a/audio/audiotools/core/util.py b/audio/audiotools/core/util.py index ad2bbc721..cf9f99636 100644 --- a/audio/audiotools/core/util.py +++ b/audio/audiotools/core/util.py @@ -59,6 +59,64 @@ def exp_compat(x): return paddle.to_tensor(np.exp(x_np)) +def bool_index_compat(x, mask): + """ + Perform boolean indexing on the input tensor `x` using the provided `mask`. + + This function ensures compatibility with PaddlePaddle versions below 2.6, where boolean indexing + may not be fully supported. For older versions, the operation is performed using NumPy. + + Args: + x (paddle.Tensor): The input tensor to be indexed. + mask (paddle.Tensor or int): The boolean mask or integer index used for indexing. + + Returns: + paddle.Tensor: The result of the boolean indexing operation, as a PaddlePaddle tensor. + + Notes: + - If the PaddlePaddle version is 2.6 or above, or if `mask` is an integer, the function uses + Paddle's native indexing directly. + - For versions below 2.6, the tensor and mask are converted to NumPy arrays, the indexing + operation is performed using NumPy, and the result is converted back to a PaddlePaddle tensor. + """ + if satisfy_paddle_version("2.6") or isinstance(mask, (int, list)): + return x[mask] + else: + x_np = x.cpu().numpy()[mask.cpu().numpy()] + return paddle.to_tensor(x_np) + + +def bool_setitem_compat(x, mask, y): + """ + Perform boolean assignment on the input tensor `x` using the provided `mask` and values `y`. + + This function ensures compatibility with PaddlePaddle versions below 2.6, where boolean assignment + may not be fully supported. For older versions, the operation is performed using NumPy. + + Args: + x (paddle.Tensor): The input tensor to be modified. + mask (paddle.Tensor): The boolean mask used for assignment. + y (paddle.Tensor): The values to assign to the selected elements of `x`. + + Returns: + paddle.Tensor: The modified tensor after the assignment operation. + + Notes: + - If the PaddlePaddle version is 2.6 or above, the function uses Paddle's native assignment directly. + - For versions below 2.6, the tensor, mask, and values are converted to NumPy arrays, the assignment + operation is performed using NumPy, and the result is converted back to a PaddlePaddle tensor. + """ + if satisfy_paddle_version("2.6"): + + x[mask] = y + return x + else: + x_np = x.cpu().numpy() + x_np[mask.cpu().numpy()] = y.cpu().numpy() + + return paddle.to_tensor(x_np) + + @dataclass class Info: diff --git a/audio/tests/audiotools/data/test_transforms✅.py b/audio/tests/audiotools/data/test_transforms✅.py index aaa656ca5..27ce128f4 100644 --- a/audio/tests/audiotools/data/test_transforms✅.py +++ b/audio/tests/audiotools/data/test_transforms✅.py @@ -130,6 +130,10 @@ class MulTransform(tfm.BaseTransform): super().__init__(name=name, keys=["num"]) def _transform(self, signal, num): + + if not num.dim(): + num = num.unsqueeze(axis=0) + signal.audio_data = signal.audio_data * num[:, None, None] return signal