adapt paddle2.5, test=tts

pull/3203/head
liangym 2 years ago
parent 7cab869d63
commit 45da88f0b8

@ -308,7 +308,7 @@ class FastSpeech2MIDI(FastSpeech2):
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
(paddle.shape(zs)[0:1], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
@ -334,7 +334,7 @@ class FastSpeech2MIDI(FastSpeech2):
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
# setup batch axis
ilens = paddle.shape(xs)[1]
ilens = paddle.shape(xs)[1:2]
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
@ -449,7 +449,7 @@ class FastSpeech2MIDI(FastSpeech2):
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
d, p, e = durations, pitch, energy
# setup batch axis
ilens = paddle.shape(xs)[1]
ilens = paddle.shape(xs)[1:2]
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)

@ -485,11 +485,11 @@ class MLMEncAsDecoder(MLM):
zs, _ = self.decoder(encoder_out, h_masks)
else:
zs = encoder_out
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
speech_hidden_states = zs[:, :paddle.shape(speech)[1:2], :]
if self.sfc is not None:
before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
(paddle.shape(speech_hidden_states)[0:1], -1, self.odim))
else:
before_outs = speech_hidden_states
if self.postnet is not None:
@ -524,16 +524,16 @@ class MLMDualMaksing(MLM):
zs, _ = self.decoder(encoder_out, h_masks)
else:
zs = encoder_out
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
speech_hidden_states = zs[:, :paddle.shape(speech)[1:2], :]
if self.text_sfc:
text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :]
text_hiddent_states = zs[:, paddle.shape(speech)[1:2]:, :]
text_outs = paddle.reshape(
self.text_sfc(text_hiddent_states),
(paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
(paddle.shape(text_hiddent_states)[0:1], -1, self.vocab_size))
if self.sfc is not None:
before_outs = paddle.reshape(
self.sfc(speech_hidden_states),
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
(paddle.shape(speech_hidden_states)[0:1], -1, self.odim))
else:
before_outs = speech_hidden_states
if self.postnet is not None:

@ -696,7 +696,7 @@ class FastSpeech2(nn.Layer):
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
(paddle.shape(zs)[0:1], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
@ -718,7 +718,7 @@ class FastSpeech2(nn.Layer):
# input of embedding must be int64
x = paddle.cast(text, 'int64')
# setup batch axis
ilens = paddle.shape(x)[0]
ilens = paddle.shape(x)[0:1]
xs = x.unsqueeze(0)
@ -783,7 +783,7 @@ class FastSpeech2(nn.Layer):
x = paddle.cast(text, 'int64')
d, p, e = durations, pitch, energy
# setup batch axis
ilens = paddle.shape(x)[0]
ilens = paddle.shape(x)[0:1]
xs = x.unsqueeze(0)
@ -843,7 +843,7 @@ class FastSpeech2(nn.Layer):
elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
shape=[-1, paddle.shape(hs)[1], -1])
shape=[-1, paddle.shape(hs)[1:2], -1])
hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1))
else:
raise NotImplementedError("support only add or concat.")

@ -512,13 +512,13 @@ class JETS(nn.Layer):
"""
# setup
text = text[None]
text_lengths = paddle.to_tensor(paddle.shape(text)[1])
text_lengths = paddle.to_tensor(paddle.shape(text)[1:2])
# inference
if use_alignment_module:
assert feats is not None
feats = feats[None]
feats_lengths = paddle.to_tensor(paddle.shape(feats)[1])
feats_lengths = paddle.to_tensor(paddle.shape(feats)[1:2])
pitch = pitch[None]
energy = energy[None]
wav, dur = self.generator.inference(

@ -156,7 +156,7 @@ class StyleMelGANGenerator(nn.Layer):
"""
# batch_max_steps(24000) == noise_upsample_factor(80) * upsample_factor(300)
if z is None:
z = paddle.randn([paddle.shape(c)[0], self.in_channels, 1])
z = paddle.randn([paddle.shape(c)[0:1], self.in_channels, 1])
# (B, in_channels, noise_upsample_factor).
x = self.noise_upsample(z)
for block in self.blocks:

@ -223,7 +223,7 @@ class PWGGenerator(nn.Layer):
"""
# when to static, can not input x, see https://github.com/PaddlePaddle/Parakeet/pull/132/files
x = paddle.randn(
[1, self.in_channels, paddle.shape(c)[0] * self.upsample_factor])
[1, self.in_channels, paddle.shape(c)[0:1] * self.upsample_factor])
c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch
c = nn.Pad1D(self.aux_context_window, mode='replicate')(c)
out = self(x, c).squeeze(0).transpose([1, 0])

@ -286,7 +286,7 @@ class Tacotron2(nn.Layer):
text = text[:, :text_lengths.max()]
speech = speech[:, :speech_lengths.max()]
batch_size = paddle.shape(text)[0]
batch_size = paddle.shape(text)[0:1]
# Add eos at the last of sequence
xs = F.pad(text, [0, 0, 0, 1], "constant", self.padding_idx)
@ -413,8 +413,8 @@ class Tacotron2(nn.Layer):
xs, ys = x.unsqueeze(0), y.unsqueeze(0)
spk_emb = None if spk_emb is None else spk_emb.unsqueeze(0)
ilens = paddle.shape(xs)[1]
olens = paddle.shape(ys)[1]
ilens = paddle.shape(xs)[1:2]
olens = paddle.shape(ys)[1:2]
outs, _, _, att_ws = self._forward(
xs=xs,
ilens=ilens,
@ -470,7 +470,7 @@ class Tacotron2(nn.Layer):
elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
shape=[-1, paddle.shape(hs)[1], -1])
shape=[-1, paddle.shape(hs)[1:2], -1])
hs = paddle.concat([hs, spk_emb], axis=-1)
else:
raise NotImplementedError("support only add or concat.")

@ -145,7 +145,7 @@ class StochasticDurationPredictor(nn.Layer):
h_w = self.post_pre(w)
h_w = self.post_dds(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (paddle.randn([paddle.shape(w)[0], 2, paddle.shape(w)[2]]) *
e_q = (paddle.randn([paddle.shape(w)[0:1], 2, paddle.shape(w)[2]]) *
x_mask)
z_q = e_q
logdet_tot_q = 0.0
@ -174,7 +174,7 @@ class StochasticDurationPredictor(nn.Layer):
flows = list(reversed(self.flows))
# remove a useless vflow
flows = flows[:-2] + [flows[-1]]
z = (paddle.randn([paddle.shape(x)[0], 2, paddle.shape(x)[2]]) *
z = (paddle.randn([paddle.shape(x)[0:1], 2, paddle.shape(x)[2]]) *
noise_scale)
for flow in flows:
z = flow(z, x_mask, g=x, inverse=inverse)

@ -46,7 +46,7 @@ class FlipFlow(nn.Layer):
"""
x = paddle.flip(x, [1])
if not inverse:
logdet = paddle.zeros(paddle.shape(x)[0], dtype=x.dtype)
logdet = paddle.zeros(paddle.shape(x)[0:1], dtype=x.dtype)
return x, logdet
else:
return x

@ -420,7 +420,7 @@ class VITS(nn.Layer):
"""
# setup
text = text[None]
text_lengths = paddle.to_tensor(paddle.shape(text)[1])
text_lengths = paddle.to_tensor(paddle.shape(text)[1:2])
if durations is not None:
durations = paddle.reshape(durations, [1, 1, -1])
@ -429,7 +429,7 @@ class VITS(nn.Layer):
if use_teacher_forcing:
assert feats is not None
feats = feats[None].transpose([0, 2, 1])
feats_lengths = paddle.to_tensor(paddle.shape(feats)[2])
feats_lengths = paddle.to_tensor(paddle.shape(feats)[2:3])
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
@ -484,7 +484,7 @@ class VITS(nn.Layer):
"""
assert feats is not None
feats = feats[None].transpose([0, 2, 1])
feats_lengths = paddle.to_tensor(paddle.shape(feats)[2])
feats_lengths = paddle.to_tensor(paddle.shape(feats)[2:3])
sids_none = sids_src is None and sids_tgt is None
spembs_none = spembs_src is None and spembs_tgt is None

@ -255,7 +255,7 @@ class WaveRNN(nn.Layer):
# weights are contiguous in GPU memory. Hence, we must call it again
self._flatten_parameters()
bsize = paddle.shape(x)[0]
bsize = paddle.shape(x)[0:1]
h1 = paddle.zeros([1, bsize, self.rnn_dims])
h2 = paddle.zeros([1, bsize, self.rnn_dims])
# c: [B, T, C_aux]
@ -339,8 +339,8 @@ class WaveRNN(nn.Layer):
# will not get TensorArray
# see https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/04_dygraph_to_static/case_analysis_cn.html#list-lodtensorarray
# b_size, seq_len, _ = paddle.shape(c)
b_size = paddle.shape(c)[0]
seq_len = paddle.shape(c)[1]
b_size = paddle.shape(c)[0:1]
seq_len = paddle.shape(c)[1:2]
h1 = paddle.zeros([b_size, self.rnn_dims])
h2 = paddle.zeros([b_size, self.rnn_dims])

@ -433,8 +433,8 @@ class Tacotron2Loss(nn.Layer):
masks = make_non_pad_mask(olens).unsqueeze(-1)
weights = masks.float() / masks.sum(axis=1, keepdim=True).float()
out_weights = weights.divide(
paddle.shape(ys)[0] * paddle.shape(ys)[2])
logit_weights = weights.divide(paddle.shape(ys)[0])
paddle.shape(ys)[0:1] * paddle.shape(ys)[2:3])
logit_weights = weights.divide(paddle.shape(ys)[0:1])
# apply weight
l1_loss = l1_loss.multiply(out_weights)
@ -907,7 +907,7 @@ class MelSpectrogram(nn.Layer):
"""
if len(x.shape) == 3:
# (B, C, T) -> (B*C, T)
x = x.reshape([-1, paddle.shape(x)[2]])
x = x.reshape([-1, paddle.shape(x)[2:3]])
if self.window is not None:
# calculate window

@ -181,11 +181,11 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
bs = paddle.shape(lengths)[0]
bs = paddle.shape(lengths)[0:1]
if xs is None:
maxlen = paddle.cast(lengths.max(), dtype=bs.dtype)
else:
maxlen = paddle.shape(xs)[length_dim]
maxlen = paddle.shape(xs)[length_dim:length_dim+1]
seq_range = paddle.arange(0, maxlen, dtype=paddle.int64)
# VITS 最后一个 expand 的位置
@ -194,7 +194,7 @@ def make_pad_mask(lengths, xs=None, length_dim=-1):
mask = seq_range_expand >= seq_length_expand.cast(seq_range_expand.dtype)
if xs is not None:
assert paddle.shape(xs)[0] == bs, (paddle.shape(xs)[0], bs)
assert paddle.shape(xs)[0:1] == bs, (paddle.shape(xs)[0:1], bs)
if length_dim < 0:
length_dim = len(paddle.shape(xs)) + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)

@ -51,8 +51,8 @@ class LengthRegulator(nn.Layer):
durations: (B, T)
"""
#batch_size, t_enc = durations.shape # linux
batch_size = paddle.shape(durations)[0] # windows and mac
t_enc = paddle.shape(durations)[1] # windows and mac
batch_size = paddle.shape(durations)[0:1] # windows and mac
t_enc = paddle.shape(durations)[1:2] # windows and mac
durations = durations.numpy()
slens = np.sum(durations, -1)
t_dec = np.max(slens)

@ -56,7 +56,7 @@ def _apply_attention_constraint(e,
forward_idx = paddle.cast(last_attended_idx + forward_window, dtype='int64')
if backward_idx > 0:
e[:, :backward_idx] = -float("inf")
if forward_idx < paddle.shape(e)[1]:
if forward_idx < paddle.shape(e)[1:2]:
e[:, forward_idx:] = -float("inf")
return e
@ -153,12 +153,12 @@ class AttLoc(nn.Layer):
Tensor:
previous attention weights (B, T_max)
"""
batch = paddle.shape(enc_hs_pad)[0]
batch = paddle.shape(enc_hs_pad)[0:1]
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None or self.han_mode:
# (utt, frame, hdim)
self.enc_h = enc_hs_pad
self.h_length = paddle.shape(self.enc_h)[1]
self.h_length = paddle.shape(self.enc_h)[1:2]
# (utt, frame, att_dim)
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
@ -294,7 +294,7 @@ class AttForward(nn.Layer):
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = paddle.shape(self.enc_h)[1]
self.h_length = paddle.shape(self.enc_h)[1:2]
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
@ -445,7 +445,7 @@ class AttForwardTA(nn.Layer):
# pre-compute all h outside the decoder loop
if self.pre_compute_enc_h is None:
self.enc_h = enc_hs_pad # utt x frame x hdim
self.h_length = paddle.shape(self.enc_h)[1]
self.h_length = paddle.shape(self.enc_h)[1:2]
# utt x frame x att_dim
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)

@ -375,7 +375,7 @@ class Decoder(nn.Layer):
self.prob_out = nn.Linear(iunits, reduction_factor)
def _zero_state(self, hs):
init_hs = paddle.zeros([paddle.shape(hs)[0], self.lstm[0].hidden_size])
init_hs = paddle.zeros([paddle.shape(hs)[0:1], self.lstm[0].hidden_size])
return init_hs
def forward(self, hs, hlens, ys):
@ -415,7 +415,7 @@ class Decoder(nn.Layer):
for _ in range(1, len(self.lstm)):
c_list.append(self._zero_state(hs))
z_list.append(self._zero_state(hs))
prev_out = paddle.zeros([paddle.shape(hs)[0], self.odim])
prev_out = paddle.zeros([paddle.shape(hs)[0:1], self.odim])
# initialize attention
prev_att_ws = []
@ -445,7 +445,7 @@ class Decoder(nn.Layer):
zcs = (paddle.concat([z_list[-1], att_c], axis=1)
if self.use_concate else z_list[-1])
outs.append(
self.feat_out(zcs).reshape([paddle.shape(hs)[0], self.odim, -1
self.feat_out(zcs).reshape([paddle.shape(hs)[0:1], self.odim, -1
]))
logits.append(self.prob_out(zcs))
att_ws.append(att_w)
@ -466,7 +466,7 @@ class Decoder(nn.Layer):
if self.reduction_factor > 1:
# (B, odim, Lmax)
before_outs = before_outs.reshape(
[paddle.shape(before_outs)[0], self.odim, -1])
[paddle.shape(before_outs)[0:1], self.odim, -1])
if self.postnet is not None:
# (B, odim, Lmax)
@ -530,10 +530,10 @@ class Decoder(nn.Layer):
assert len(paddle.shape(h)) == 2
hs = h.unsqueeze(0)
ilens = paddle.shape(h)[0]
ilens = paddle.shape(h)[0:1]
# 本来 maxlen 和 minlen 外面有 int(),防止动转静的问题此处删除
maxlen = paddle.shape(h)[0] * maxlenratio
minlen = paddle.shape(h)[0] * minlenratio
maxlen = paddle.shape(h)[0:1] * maxlenratio
minlen = paddle.shape(h)[0:1] * minlenratio
# 本来是直接使用 threshold 的,此处为了防止动转静的问题把 threshold 转成 tensor
threshold = paddle.ones([1]) * threshold
@ -690,7 +690,7 @@ class Decoder(nn.Layer):
for _ in range(1, len(self.lstm)):
c_list.append(self._zero_state(hs))
z_list.append(self._zero_state(hs))
prev_out = paddle.zeros([paddle.shape(hs)[0], self.odim])
prev_out = paddle.zeros([paddle.shape(hs)[0:1], self.odim])
# initialize attention
prev_att_w = None

@ -184,6 +184,6 @@ class Encoder(nn.Layer):
"""
xs = x.unsqueeze(0)
ilens = paddle.shape(x)[0]
ilens = paddle.shape(x)[0:1]
return self.forward(xs, ilens)[0][0]

@ -66,7 +66,7 @@ class MultiHeadedAttention(nn.Layer):
Tensor:
Transformed value tensor (#batch, n_head, time2, d_k).
"""
n_batch = paddle.shape(query)[0]
n_batch = paddle.shape(query)[0:1]
q = paddle.reshape(
self.linear_q(query), [n_batch, -1, self.h, self.d_k])
@ -96,7 +96,7 @@ class MultiHeadedAttention(nn.Layer):
Returns:
Tensor: Transformed value (#batch, time1, d_model) weighted by the attention score (#batch, time1, time2).
"""
n_batch = paddle.shape(value)[0]
n_batch = paddle.shape(value)[0:1]
softmax = paddle.nn.Softmax(axis=-1)
if mask is not None:
mask = mask.unsqueeze(1)
@ -220,7 +220,7 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
q, k, v = self.forward_qkv(query, key, value)
# (batch, time1, head, d_k)
q = q.transpose([0, 2, 1, 3])
n_batch_pos = paddle.shape(pos_emb)[0]
n_batch_pos = paddle.shape(pos_emb)[0:1]
p = self.linear_pos(pos_emb).reshape(
[n_batch_pos, -1, self.h, self.d_k])
# (batch, head, 2*time1-1, d_k)
@ -318,7 +318,7 @@ class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
# (batch, time1, head, d_k)
q = paddle.transpose(q, [0, 2, 1, 3])
n_batch_pos = paddle.shape(pos_emb)[0]
n_batch_pos = paddle.shape(pos_emb)[0:1]
p = paddle.reshape(
self.linear_pos(pos_emb), [n_batch_pos, -1, self.h, self.d_k])
# (batch, head, time1, d_k)

@ -80,7 +80,7 @@ class PositionalEncoding(nn.Layer):
Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
T = paddle.shape(x)[1]
T = paddle.shape(x)[1:2]
x = x * self.xscale + self.pe[:, :T]
return self.dropout(x)
@ -127,7 +127,7 @@ class ScaledPositionalEncoding(PositionalEncoding):
Tensor: Encoded tensor (batch, time, `*`).
"""
self.extend_pe(x)
T = paddle.shape(x)[1]
T = paddle.shape(x)[1:2]
x = x + self.alpha * self.pe[:, :T]
return self.dropout(x)
@ -161,7 +161,7 @@ class RelPositionalEncoding(nn.Layer):
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if paddle.shape(self.pe)[1] >= paddle.shape(x)[1] * 2 - 1:
if paddle.shape(self.pe)[1:2] >= paddle.shape(x)[1:2] * 2 - 1:
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
@ -196,7 +196,7 @@ class RelPositionalEncoding(nn.Layer):
"""
self.extend_pe(x)
x = x * self.xscale
T = paddle.shape(x)[1]
T = paddle.shape(x)[1:2]
pe_size = paddle.shape(self.pe)
tmp = paddle.cast(paddle.floor(pe_size[1] / 2), dtype='int32')
pos_emb = self.pe[:, tmp - T + 1:tmp + T, ]
@ -235,16 +235,16 @@ class LegacyRelPositionalEncoding(PositionalEncoding):
def extend_pe(self, x):
"""Reset the positional encodings."""
if self.pe is not None:
if paddle.shape(self.pe)[1] >= paddle.shape(x)[1]:
if paddle.shape(self.pe)[1:2] >= paddle.shape(x)[1:2]:
return
pe = paddle.zeros((paddle.shape(x)[1], self.d_model))
pe = paddle.zeros((paddle.shape(x)[1:2], self.d_model))
if self.reverse:
position = paddle.arange(
paddle.shape(x)[1] - 1, -1, -1.0,
paddle.shape(x)[1:2] - 1, -1, -1.0,
dtype=paddle.float32).unsqueeze(1)
else:
position = paddle.arange(
0, paddle.shape(x)[1], dtype=paddle.float32).unsqueeze(1)
0, paddle.shape(x)[1:2], dtype=paddle.float32).unsqueeze(1)
div_term = paddle.exp(
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
-(math.log(10000.0) / self.d_model))
@ -266,5 +266,5 @@ class LegacyRelPositionalEncoding(PositionalEncoding):
"""
self.extend_pe(x)
x = x * self.xscale
pos_emb = self.pe[:, :paddle.shape(x)[1]]
pos_emb = self.pe[:, :paddle.shape(x)[1:2]]
return self.dropout(x), self.dropout(pos_emb)

Loading…
Cancel
Save