From 706a68bde9bbbec4688506c917caf84b575291b1 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Tue, 14 Mar 2023 20:41:23 +0800 Subject: [PATCH] fix dtype diff of last expand_v2 op of VITS (#3041) --- paddlespeech/t2s/models/vits/flow.py | 9 +++++---- paddlespeech/t2s/models/vits/transform.py | 2 +- paddlespeech/t2s/modules/nets_utils.py | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/paddlespeech/t2s/models/vits/flow.py b/paddlespeech/t2s/models/vits/flow.py index 7593eb72..94df968a 100644 --- a/paddlespeech/t2s/models/vits/flow.py +++ b/paddlespeech/t2s/models/vits/flow.py @@ -334,11 +334,12 @@ class ConvFlow(nn.Layer): unnorm_widths = h[..., :self.bins] / denom unnorm_heights = h[..., self.bins:2 * self.bins] / denom unnorm_derivatives = h[..., 2 * self.bins:] + xb, logdet_abs = piecewise_rational_quadratic_transform( - xb, - unnorm_widths, - unnorm_heights, - unnorm_derivatives, + inputs=xb, + unnormalized_widths=unnorm_widths, + unnormalized_heights=unnorm_heights, + unnormalized_derivatives=unnorm_derivatives, inverse=inverse, tails="linear", tail_bound=self.tail_bound, ) diff --git a/paddlespeech/t2s/models/vits/transform.py b/paddlespeech/t2s/models/vits/transform.py index 0edc1d09..917f2843 100644 --- a/paddlespeech/t2s/models/vits/transform.py +++ b/paddlespeech/t2s/models/vits/transform.py @@ -245,6 +245,6 @@ def rational_quadratic_spline( def _searchsorted(bin_locations, inputs, eps=1e-6): bin_locations[..., -1] += eps mask = inputs[..., None] >= bin_locations - mask_int = paddle.cast(mask, 'int64') + mask_int = paddle.cast(mask, dtype='int64') out = paddle.sum(mask_int, axis=-1) - 1 return out diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 798e4dee..99130acc 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -145,18 +145,18 @@ def make_pad_mask(lengths, xs=None, length_dim=-1): bs = paddle.shape(lengths)[0] if xs is None: - maxlen = lengths.max() + maxlen = paddle.cast(lengths.max(), dtype=bs.dtype) else: maxlen = paddle.shape(xs)[length_dim] seq_range = paddle.arange(0, maxlen, dtype=paddle.int64) + # VITS 最后一个 expand 的位置 seq_range_expand = seq_range.unsqueeze(0).expand([bs, maxlen]) seq_length_expand = lengths.unsqueeze(-1) mask = seq_range_expand >= seq_length_expand.cast(seq_range_expand.dtype) if xs is not None: assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs) - if length_dim < 0: length_dim = len(paddle.shape(xs)) + length_dim # ind = (:, None, ..., None, :, , None, ..., None)