|
|
|
@ -222,6 +222,7 @@ class VITS(nn.Layer):
|
|
|
|
|
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
self.generator.decoder.reset_parameters()
|
|
|
|
|
self.generator.text_encoder.reset_parameters()
|
|
|
|
|
# print("VITS===============")
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
@ -510,10 +511,8 @@ class VITS(nn.Layer):
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
|
def _reset_parameters(module):
|
|
|
|
|
if isinstance(module, nn.Conv1D)
|
|
|
|
|
or isinstance(module, nn.Conv1DTranspose)
|
|
|
|
|
or isinstance(module, nn.Conv2D)
|
|
|
|
|
or isinstance(module, nn.Conv2DTranspose):
|
|
|
|
|
if isinstance(module,
|
|
|
|
|
(nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)):
|
|
|
|
|
kaiming_uniform_(module.weight, a=math.sqrt(5))
|
|
|
|
|
if module.bias is not None:
|
|
|
|
|
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
|
|
|
|
@ -522,7 +521,7 @@ class VITS(nn.Layer):
|
|
|
|
|
uniform_(module.bias, -bound, bound)
|
|
|
|
|
|
|
|
|
|
if isinstance(module,
|
|
|
|
|
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm)):
|
|
|
|
|
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)):
|
|
|
|
|
ones_(module.weight)
|
|
|
|
|
zeros_(module.bias)
|
|
|
|
|
|
|
|
|
@ -539,8 +538,4 @@ class VITS(nn.Layer):
|
|
|
|
|
with paddle.no_grad():
|
|
|
|
|
module.weight[module._padding_idx] = 0
|
|
|
|
|
|
|
|
|
|
if isinstance(module, nn.LayerNorm):
|
|
|
|
|
ones_(module.weight)
|
|
|
|
|
zeros_(module.bias)
|
|
|
|
|
|
|
|
|
|
self.apply(_reset_parameters)
|
|
|
|
|