fix vits with CSMSC (#3920)

pull/3923/head
张春乔 3 weeks ago committed by GitHub
parent 890c87ea93
commit c33d9bfb50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -577,8 +577,9 @@ class VITSGenerator(nn.Layer):
# decoder
z_p = m_p + paddle.randn(
paddle.shape(m_p)) * paddle.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, inverse=True)
wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
z = self.flow(z_p, y_mask.astype(z_p.dtype), g=g, inverse=True)
wav = self.decoder(
(z * y_mask.astype(z.dtype))[:, :, :max_len], g=g)
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
@ -695,4 +696,5 @@ class VITSGenerator(nn.Layer):
path = paddle.cast(path, dtype='float32')
pad_tmp = self.pad1d(path)[:, :-1]
path = path - pad_tmp
return path.unsqueeze(1).transpose([0, 1, 3, 2]) * mask
return path.unsqueeze(1).transpose(
[0, 1, 3, 2]) * mask.astype(path.dtype)

@ -129,6 +129,7 @@ class PosteriorEncoder(nn.Layer):
"""
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
x_mask = x_mask.astype(x.dtype)
x = self.input_conv(x) * x_mask
x = self.encoder(x, x_mask, g=g)
stats = self.proj(x) * x_mask

@ -155,6 +155,7 @@ class TextEncoder(nn.Layer):
"""
x = self.emb(x) * math.sqrt(self.attention_dim)
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
x_mask = x_mask.astype(x.dtype)
# encoder assume the channel last (B, T_text, attention_dim)
# but mask shape shoud be (B, 1, T_text)
x, _ = self.encoder(x, x_mask)

@ -181,6 +181,10 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
# check if ilens is 0-dim tensor, if so, add a dimension
if lengths.ndim == 0:
lengths = lengths.unsqueeze(0)
bs = paddle.shape(lengths)
if xs is None:
maxlen = paddle.cast(lengths.max(), dtype=bs.dtype)
@ -348,7 +352,9 @@ def get_random_segments(
"""
b, c, t = paddle.shape(x)
max_start_idx = x_lengths - segment_size
start_idxs = paddle.cast(paddle.rand([b]) * max_start_idx, 'int64')
rand_number = paddle.rand([b])
start_idxs = paddle.cast(rand_number *
max_start_idx.astype(rand_number.dtype), 'int64')
segments = get_segments(x, start_idxs, segment_size)
return segments, start_idxs

Loading…
Cancel
Save