|
|
|
@ -145,18 +145,18 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
|
|
|
|
|
|
|
|
|
|
bs = paddle.shape(lengths)[0]
|
|
|
|
|
if xs is None:
|
|
|
|
|
maxlen = lengths.max()
|
|
|
|
|
maxlen = paddle.cast(lengths.max(), dtype=bs.dtype)
|
|
|
|
|
else:
|
|
|
|
|
maxlen = paddle.shape(xs)[length_dim]
|
|
|
|
|
|
|
|
|
|
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
|
|
|
|
|
# VITS 最后一个 expand 的位置
|
|
|
|
|
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
|
|
|
|
|
seq_length_expand = lengths.unsqueeze(-1)
|
|
|
|
|
mask = seq_range_expand >= seq_length_expand.cast(seq_range_expand.dtype)
|
|
|
|
|
|
|
|
|
|
if xs is not None:
|
|
|
|
|
assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs)
|
|
|
|
|
|
|
|
|
|
if length_dim < 0:
|
|
|
|
|
length_dim = len(paddle.shape(xs)) + length_dim
|
|
|
|
|
# ind = (:, None, ..., None, :, , None, ..., None)
|
|
|
|
|