From b684a320c27c4b281cbf45ad3701242b957f1ea5 Mon Sep 17 00:00:00 2001 From: WongLaw Date: Mon, 9 Jan 2023 02:08:59 +0000 Subject: [PATCH] VITS initialize method 2, test=tts --- paddlespeech/t2s/models/vits/vits.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddlespeech/t2s/models/vits/vits.py b/paddlespeech/t2s/models/vits/vits.py index 6c85a44ee..3898a02b5 100644 --- a/paddlespeech/t2s/models/vits/vits.py +++ b/paddlespeech/t2s/models/vits/vits.py @@ -29,7 +29,7 @@ from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator from paddlespeech.t2s.models.vits.generator import VITSGenerator from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out -from paddlespeech.utils.initialize import kaiming_normal_ +from paddlespeech.utils.initialize import kaiming_uniform_ from paddlespeech.utils.initialize import normal_ from paddlespeech.utils.initialize import ones_ from paddlespeech.utils.initialize import uniform_ @@ -510,14 +510,14 @@ class VITS(nn.Layer): def _reset_parameters(module): if isinstance(module, nn.Conv1D) or isinstance(module, nn.Conv1DTranspose): - kaiming_normal_(module.weight, mode="fan_out") + kaiming_uniform_(module.weight, a=math.sqrt(5), mode="fan_out") if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 uniform_(module.bias, -bound, bound) if isinstance(module, nn.Conv2D) or isinstance(module, nn.Conv2DTranspose): - kaiming_normal_(module.weight, mode="fan_out") + kaiming_uniform_(module.weight, a=math.sqrt(5), mode="fan_out") if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 @@ -527,7 +527,7 @@ class VITS(nn.Layer): ones_(module.weight) zeros_(module.bias) if isinstance(module, nn.Linear): - kaiming_normal_(module.weight, a=math.sqrt(5)) + kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0