pull/3912/head
enkilee 10 months ago
parent 30af963c45
commit ed02bd551c

@ -182,7 +182,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
raise ValueError("length_dim cannot be 0: {}".format(length_dim)) raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if lengths.ndim == 0: if lengths.ndim == 0:
bs = paddle.shape(lengths) bs = paddle.shape(lengths.unsqueeze(0))
else: else:
bs = paddle.shape(lengths) bs = paddle.shape(lengths)

Loading…
Cancel
Save