[Update] dim == 1

pull/3916/head
megemini 10 months ago
parent 3aa4e7bfa1
commit 075fcc529e

@ -841,9 +841,10 @@ class FastSpeech2(nn.Layer):
spk_emb = self.spk_projection(F.normalize(spk_emb))
hs = hs + spk_emb.unsqueeze(1)
elif self.spk_embed_integration_type == "concat":
# concat hidden states with spk embeds and then apply projection
if spk_emb.dim() < 2:
# one wave `spk_emb` under synthesize, the dim is `1`
if spk_emb.dim() == 1:
spk_emb = spk_emb.unsqueeze(0)
# concat hidden states with spk embeds and then apply projection
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
shape=[-1, paddle.shape(hs)[1], -1])
hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1))

Loading…
Cancel
Save