diff --git a/paddlespeech/t2s/models/vits/text_encoder.py b/paddlespeech/t2s/models/vits/text_encoder.py index 799e0c759..033472eb0 100644 --- a/paddlespeech/t2s/models/vits/text_encoder.py +++ b/paddlespeech/t2s/models/vits/text_encoder.py @@ -24,6 +24,7 @@ from paddle import nn from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder as Encoder +from paddlespeech.utils.initialize import normal_ class TextEncoder(nn.Layer): @@ -166,3 +167,9 @@ class TextEncoder(nn.Layer): m, logs = paddle.split(stats, 2, axis=1) return x, m, logs, x_mask + + def reset_parameters(self): + normal_(self.emb.weight) + if self.emb._padding_idx is not None: + with paddle.no_grad(): + self.emb.weight[self.emb._padding_idx] = 0 \ No newline at end of file diff --git a/paddlespeech/t2s/models/vits/vits.py b/paddlespeech/t2s/models/vits/vits.py index afe1e8dd5..ceac7c120 100644 --- a/paddlespeech/t2s/models/vits/vits.py +++ b/paddlespeech/t2s/models/vits/vits.py @@ -222,6 +222,7 @@ class VITS(nn.Layer): self.reset_parameters() self.generator.decoder.reset_parameters() + self.generator.text_encoder.reset_parameters() # print("VITS===============") def forward( @@ -510,10 +511,8 @@ class VITS(nn.Layer): def reset_parameters(self): def _reset_parameters(module): - if isinstance(module, nn.Conv1D) - or isinstance(module, nn.Conv1DTranspose) - or isinstance(module, nn.Conv2D) - or isinstance(module, nn.Conv2DTranspose): + if isinstance(module, + (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)): kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) @@ -522,7 +521,7 @@ class VITS(nn.Layer): uniform_(module.bias, -bound, bound) if isinstance(module, - (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm)): + (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)): ones_(module.weight) zeros_(module.bias) @@ -539,8 +538,4 @@ class VITS(nn.Layer): with paddle.no_grad(): module.weight[module._padding_idx] = 0 - if isinstance(module, nn.LayerNorm): - ones_(module.weight) - zeros_(module.bias) - self.apply(_reset_parameters)