diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index a95a9b288..3dd90b588 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -842,6 +842,8 @@ class FastSpeech2(nn.Layer): hs = hs + spk_emb.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection + if spk_emb.dim() < 2: + spk_emb = spk_emb.unsqueeze(0) spk_emb = F.normalize(spk_emb).unsqueeze(1).expand( shape=[-1, paddle.shape(hs)[1], -1]) hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1))