|
|
|
@ -181,11 +181,11 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
|
|
|
|
|
if length_dim == 0:
|
|
|
|
|
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
|
|
|
|
|
|
|
|
|
# check if ilens is 0-dim tensor, if so, add a dimension
|
|
|
|
|
if lengths.ndim == 0:
|
|
|
|
|
lengths = lengths.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
bs = paddle.shape(lengths)
|
|
|
|
|
else:
|
|
|
|
|
bs = paddle.shape(lengths)
|
|
|
|
|
|
|
|
|
|
if xs is None:
|
|
|
|
|
maxlen = paddle.cast(lengths.max(), dtype=bs.dtype)
|
|
|
|
|
else:
|
|
|
|
|