VITS initialize method 2, test=tts

pull/2802/head
WongLaw 3 years ago
parent ca8432f69c
commit b684a320c2

@ -29,7 +29,7 @@ from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator
from paddlespeech.t2s.models.vits.generator import VITSGenerator from paddlespeech.t2s.models.vits.generator import VITSGenerator
from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out
from paddlespeech.utils.initialize import kaiming_normal_ from paddlespeech.utils.initialize import kaiming_uniform_
from paddlespeech.utils.initialize import normal_ from paddlespeech.utils.initialize import normal_
from paddlespeech.utils.initialize import ones_ from paddlespeech.utils.initialize import ones_
from paddlespeech.utils.initialize import uniform_ from paddlespeech.utils.initialize import uniform_
@ -510,14 +510,14 @@ class VITS(nn.Layer):
def _reset_parameters(module): def _reset_parameters(module):
if isinstance(module, nn.Conv1D) or isinstance(module, if isinstance(module, nn.Conv1D) or isinstance(module,
nn.Conv1DTranspose): nn.Conv1DTranspose):
kaiming_normal_(module.weight, mode="fan_out") kaiming_uniform_(module.weight, a=math.sqrt(5), mode="fan_out")
if module.bias is not None: if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
uniform_(module.bias, -bound, bound) uniform_(module.bias, -bound, bound)
if isinstance(module, nn.Conv2D) or isinstance(module, if isinstance(module, nn.Conv2D) or isinstance(module,
nn.Conv2DTranspose): nn.Conv2DTranspose):
kaiming_normal_(module.weight, mode="fan_out") kaiming_uniform_(module.weight, a=math.sqrt(5), mode="fan_out")
if module.bias is not None: if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
@ -527,7 +527,7 @@ class VITS(nn.Layer):
ones_(module.weight) ones_(module.weight)
zeros_(module.bias) zeros_(module.bias)
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
kaiming_normal_(module.weight, a=math.sqrt(5)) kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None: if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0

Loading…
Cancel
Save