fix dtype diff of last expand_v2 op of VITS (#3041)

pull/3054/head
TianYuan 1 year ago committed by GitHub
parent 348064de0d
commit 706a68bde9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -334,11 +334,12 @@ class ConvFlow(nn.Layer):
unnorm_widths = h[..., :self.bins] / denom
unnorm_heights = h[..., self.bins:2 * self.bins] / denom
unnorm_derivatives = h[..., 2 * self.bins:]
xb, logdet_abs = piecewise_rational_quadratic_transform(
xb,
unnorm_widths,
unnorm_heights,
unnorm_derivatives,
inputs=xb,
unnormalized_widths=unnorm_widths,
unnormalized_heights=unnorm_heights,
unnormalized_derivatives=unnorm_derivatives,
inverse=inverse,
tails="linear",
tail_bound=self.tail_bound, )

@ -245,6 +245,6 @@ def rational_quadratic_spline(
def _searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
mask = inputs[..., None] >= bin_locations
mask_int = paddle.cast(mask, 'int64')
mask_int = paddle.cast(mask, dtype='int64')
out = paddle.sum(mask_int, axis=-1) - 1
return out

@ -145,18 +145,18 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
bs = paddle.shape(lengths)[0]
if xs is None:
maxlen = lengths.max()
maxlen = paddle.cast(lengths.max(), dtype=bs.dtype)
else:
maxlen = paddle.shape(xs)[length_dim]
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
# VITS 最后一个 expand 的位置
seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen])
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand.cast(seq_range_expand.dtype)
if xs is not None:
assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs)
if length_dim < 0:
length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)

Loading…
Cancel
Save