fix paddle2.5 index bug

pull/3900/head
drryanhuang 9 months ago
parent 33eba3370b
commit df188b673e

@ -13,6 +13,7 @@ from .. import ml
from ..core import AudioSignal
from ..core import util
from .datasets import AudioLoader
from paddlespeech.utils import satisfy_paddle_version
class BaseTransform:
@ -130,11 +131,17 @@ class BaseTransform:
# `v` may be `Tensor` or `AudioSignal`
if 0 == len(v.shape) and 0 == mask.dim():
if mask: # 0d 的 True
masked_batch[k] = v[None]
masked_batch[k] = v.unsqueeze(0)
else:
masked_batch[k] = paddle.to_tensor([], dtype=v.dtype)
else:
masked_batch[k] = v[mask]
if not satisfy_paddle_version('2.6'):
if 0 == mask.dim() and bool(mask) and paddle.is_tensor(v):
masked_batch[k] = v.unsqueeze(0)
else:
masked_batch[k] = v[mask]
else:
masked_batch[k] = v[mask]
return unflatten(masked_batch)
def transform(self, signal: AudioSignal, **kwargs):

Loading…
Cancel
Save