vits, test=tts

pull/2809/head
WongLaw 3 years ago
parent ca8432f69c
commit 7890164818

@ -29,7 +29,7 @@ from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator
from paddlespeech.t2s.models.vits.generator import VITSGenerator
from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out
from paddlespeech.utils.initialize import kaiming_normal_
from paddlespeech.utils.initialize import kaiming_uniform_
from paddlespeech.utils.initialize import normal_
from paddlespeech.utils.initialize import ones_
from paddlespeech.utils.initialize import uniform_
@ -221,6 +221,8 @@ class VITS(nn.Layer):
self.reuse_cache_dis = True
self.reset_parameters()
self.generator.decoder.reset_parameters()
# print("VITS===============")
def forward(
self,
@ -510,14 +512,14 @@ class VITS(nn.Layer):
def _reset_parameters(module):
if isinstance(module, nn.Conv1D) or isinstance(module,
nn.Conv1DTranspose):
kaiming_normal_(module.weight, mode="fan_out")
kaiming_uniform_(module.weight, a=math.sqrt(5), mode="fan_out")
if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
uniform_(module.bias, -bound, bound)
if isinstance(module, nn.Conv2D) or isinstance(module,
nn.Conv2DTranspose):
kaiming_normal_(module.weight, mode="fan_out")
kaiming_uniform_(module.weight, a=math.sqrt(5), mode="fan_out")
if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
@ -527,7 +529,7 @@ class VITS(nn.Layer):
ones_(module.weight)
zeros_(module.bias)
if isinstance(module, nn.Linear):
kaiming_normal_(module.weight, a=math.sqrt(5))
kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0

Loading…
Cancel
Save