pull/3912/head
enkilee 10 months ago
parent c5c3a8a9b5
commit 30af963c45

@ -181,11 +181,11 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
if length_dim == 0: if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) raise ValueError("length_dim cannot be 0: {}".format(length_dim))
# check if ilens is 0-dim tensor, if so, add a dimension
if lengths.ndim == 0: if lengths.ndim == 0:
lengths = lengths.unsqueeze(0) bs = paddle.shape(lengths)
else:
bs = paddle.shape(lengths)
bs = paddle.shape(lengths)
if xs is None: if xs is None:
maxlen = paddle.cast(lengths.max(), dtype=bs.dtype) maxlen = paddle.cast(lengths.max(), dtype=bs.dtype)
else: else:

Loading…
Cancel
Save