diff --git a/paddlespeech/t2s/models/vits/vits.py b/paddlespeech/t2s/models/vits/vits.py index 6c85a44ee..6e9339899 100644 --- a/paddlespeech/t2s/models/vits/vits.py +++ b/paddlespeech/t2s/models/vits/vits.py @@ -29,7 +29,7 @@ from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator from paddlespeech.t2s.models.vits.generator import VITSGenerator from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out -from paddlespeech.utils.initialize import kaiming_normal_ +from paddlespeech.utils.initialize import kaiming_uniform_ from paddlespeech.utils.initialize import normal_ from paddlespeech.utils.initialize import ones_ from paddlespeech.utils.initialize import uniform_ @@ -221,6 +221,8 @@ class VITS(nn.Layer): self.reuse_cache_dis = True self.reset_parameters() + self.generator.decoder.reset_parameters() + # print("VITS===============") def forward( self, @@ -510,14 +512,14 @@ class VITS(nn.Layer): def _reset_parameters(module): if isinstance(module, nn.Conv1D) or isinstance(module, nn.Conv1DTranspose): - kaiming_normal_(module.weight, mode="fan_out") + kaiming_uniform_(module.weight, a=math.sqrt(5), mode="fan_out") if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 uniform_(module.bias, -bound, bound) if isinstance(module, nn.Conv2D) or isinstance(module, nn.Conv2DTranspose): - kaiming_normal_(module.weight, mode="fan_out") + kaiming_uniform_(module.weight, a=math.sqrt(5), mode="fan_out") if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 @@ -527,7 +529,7 @@ class VITS(nn.Layer): ones_(module.weight) zeros_(module.bias) if isinstance(module, nn.Linear): - kaiming_normal_(module.weight, a=math.sqrt(5)) + kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0