fix dim error

pull/3900/head
drryanhuang 9 months ago
parent 838596a2be
commit 1d2c078529

@ -1695,8 +1695,10 @@ class AudioSignal(
audio_data = self.audio_data[key] audio_data = self.audio_data[key]
_loudness = self._loudness[ _loudness = self._loudness[
key] if self._loudness is not None else None key] if self._loudness is not None else None
stft_data = self.stft_data[ # stft_data = self.stft_data[
key] if self.stft_data is not None else None # key] if self.stft_data is not None else None
stft_data = util.bool_index_compat(
self.stft_data, key) if self.stft_data is not None else None
sources = None sources = None
@ -1732,7 +1734,9 @@ class AudioSignal(
else: else:
self._loudness[key] = value._loudness self._loudness[key] = value._loudness
if self.stft_data is not None and value.stft_data is not None: if self.stft_data is not None and value.stft_data is not None:
self.stft_data[key] = value.stft_data # self.stft_data[key] = value.stft_data
self.stft_data = util.bool_setitem_compat(self.stft_data, key,
value.stft_data)
return return
def __ne__(self, other): def __ne__(self, other):

@ -391,7 +391,8 @@ class DSPMixin:
db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim) db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
mask = log_mag < db_cutoff mask = log_mag < db_cutoff
mag = mag.masked_fill(mask, val) # mag = mag.masked_fill(mask, val)
mag = paddle.where(mask, mag, val * paddle.ones_like(mag))
self.magnitude = mag self.magnitude = mag
return self return self

@ -59,6 +59,64 @@ def exp_compat(x):
return paddle.to_tensor(np.exp(x_np)) return paddle.to_tensor(np.exp(x_np))
def bool_index_compat(x, mask):
"""
Perform boolean indexing on the input tensor `x` using the provided `mask`.
This function ensures compatibility with PaddlePaddle versions below 2.6, where boolean indexing
may not be fully supported. For older versions, the operation is performed using NumPy.
Args:
x (paddle.Tensor): The input tensor to be indexed.
mask (paddle.Tensor or int): The boolean mask or integer index used for indexing.
Returns:
paddle.Tensor: The result of the boolean indexing operation, as a PaddlePaddle tensor.
Notes:
- If the PaddlePaddle version is 2.6 or above, or if `mask` is an integer, the function uses
Paddle's native indexing directly.
- For versions below 2.6, the tensor and mask are converted to NumPy arrays, the indexing
operation is performed using NumPy, and the result is converted back to a PaddlePaddle tensor.
"""
if satisfy_paddle_version("2.6") or isinstance(mask, (int, list)):
return x[mask]
else:
x_np = x.cpu().numpy()[mask.cpu().numpy()]
return paddle.to_tensor(x_np)
def bool_setitem_compat(x, mask, y):
"""
Perform boolean assignment on the input tensor `x` using the provided `mask` and values `y`.
This function ensures compatibility with PaddlePaddle versions below 2.6, where boolean assignment
may not be fully supported. For older versions, the operation is performed using NumPy.
Args:
x (paddle.Tensor): The input tensor to be modified.
mask (paddle.Tensor): The boolean mask used for assignment.
y (paddle.Tensor): The values to assign to the selected elements of `x`.
Returns:
paddle.Tensor: The modified tensor after the assignment operation.
Notes:
- If the PaddlePaddle version is 2.6 or above, the function uses Paddle's native assignment directly.
- For versions below 2.6, the tensor, mask, and values are converted to NumPy arrays, the assignment
operation is performed using NumPy, and the result is converted back to a PaddlePaddle tensor.
"""
if satisfy_paddle_version("2.6"):
x[mask] = y
return x
else:
x_np = x.cpu().numpy()
x_np[mask.cpu().numpy()] = y.cpu().numpy()
return paddle.to_tensor(x_np)
@dataclass @dataclass
class Info: class Info:

@ -130,6 +130,10 @@ class MulTransform(tfm.BaseTransform):
super().__init__(name=name, keys=["num"]) super().__init__(name=name, keys=["num"])
def _transform(self, signal, num): def _transform(self, signal, num):
if not num.dim():
num = num.unsqueeze(axis=0)
signal.audio_data = signal.audio_data * num[:, None, None] signal.audio_data = signal.audio_data * num[:, None, None]
return signal return signal

Loading…
Cancel
Save