[Hackathon 7th] 修复 vctk 中 `spk_emb` 维度问题 (#3916)

* [Fix] vctk spk_emb dim

* [Update] dim == 1
pull/3894/merge
megemini 4 weeks ago committed by GitHub
parent 77dfdc439f
commit 3e53497a28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -841,6 +841,9 @@ class FastSpeech2(nn.Layer):
spk_emb = self.spk_projection(F.normalize(spk_emb)) spk_emb = self.spk_projection(F.normalize(spk_emb))
hs = hs + spk_emb.unsqueeze(1) hs = hs + spk_emb.unsqueeze(1)
elif self.spk_embed_integration_type == "concat": elif self.spk_embed_integration_type == "concat":
# one wave `spk_emb` under synthesize, the dim is `1`
if spk_emb.dim() == 1:
spk_emb = spk_emb.unsqueeze(0)
# concat hidden states with spk embeds and then apply projection # concat hidden states with spk embeds and then apply projection
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand( spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
shape=[-1, paddle.shape(hs)[1], -1]) shape=[-1, paddle.shape(hs)[1], -1])

Loading…
Cancel
Save