fix paddle2.5 index bug

pull/3900/head
drryanhuang 9 months ago
parent 33eba3370b
commit df188b673e

@ -13,6 +13,7 @@ from .. import ml
from ..core import AudioSignal from ..core import AudioSignal
from ..core import util from ..core import util
from .datasets import AudioLoader from .datasets import AudioLoader
from paddlespeech.utils import satisfy_paddle_version
class BaseTransform: class BaseTransform:
@ -130,9 +131,15 @@ class BaseTransform:
# `v` may be `Tensor` or `AudioSignal` # `v` may be `Tensor` or `AudioSignal`
if 0 == len(v.shape) and 0 == mask.dim(): if 0 == len(v.shape) and 0 == mask.dim():
if mask: # 0d 的 True if mask: # 0d 的 True
masked_batch[k] = v[None] masked_batch[k] = v.unsqueeze(0)
else: else:
masked_batch[k] = paddle.to_tensor([], dtype=v.dtype) masked_batch[k] = paddle.to_tensor([], dtype=v.dtype)
else:
if not satisfy_paddle_version('2.6'):
if 0 == mask.dim() and bool(mask) and paddle.is_tensor(v):
masked_batch[k] = v.unsqueeze(0)
else:
masked_batch[k] = v[mask]
else: else:
masked_batch[k] = v[mask] masked_batch[k] = v[mask]
return unflatten(masked_batch) return unflatten(masked_batch)

Loading…
Cancel
Save