diff --git a/audio/audiotools/data/transforms.py b/audio/audiotools/data/transforms.py index df4382862..868fb724b 100644 --- a/audio/audiotools/data/transforms.py +++ b/audio/audiotools/data/transforms.py @@ -13,6 +13,7 @@ from .. import ml from ..core import AudioSignal from ..core import util from .datasets import AudioLoader +from paddlespeech.utils import satisfy_paddle_version class BaseTransform: @@ -130,11 +131,17 @@ class BaseTransform: # `v` may be `Tensor` or `AudioSignal` if 0 == len(v.shape) and 0 == mask.dim(): if mask: # 0d 的 True - masked_batch[k] = v[None] + masked_batch[k] = v.unsqueeze(0) else: masked_batch[k] = paddle.to_tensor([], dtype=v.dtype) else: - masked_batch[k] = v[mask] + if not satisfy_paddle_version('2.6'): + if 0 == mask.dim() and bool(mask) and paddle.is_tensor(v): + masked_batch[k] = v.unsqueeze(0) + else: + masked_batch[k] = v[mask] + else: + masked_batch[k] = v[mask] return unflatten(masked_batch) def transform(self, signal: AudioSignal, **kwargs):