vits_init, test=tts

pull/2809/head
WongLaw 3 years ago
parent a4c80be48b
commit 8955e5efd3

@ -520,74 +520,27 @@ class VITS(nn.Layer):
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
uniform_(module.bias, -bound, bound)
"""
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
init.uniform_(self.bias, -bound, bound)
"""
if isinstance(module,
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm)):
ones_(module.weight)
zeros_(module.bias)
"""
def reset_running_stats(self) -> None:
if self.track_running_stats:
# running_mean/running_var/num_batches... are registered at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[union-attr]
self.running_var.fill_(1) # type: ignore[union-attr]
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
def reset_parameters(self) -> None:
self.reset_running_stats()
if self.affine:
init.ones_(self.weight)
init.zeros_(self.bias)
GroupNorm:
def reset_parameters(self) -> None:
if self.affine:
init.ones_(self.weight)
init.zeros_(self.bias)
"""
if isinstance(module, nn.Linear):
kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
uniform_(module.bias, -bound, bound)
"""
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
"""
if isinstance(module, nn.Embedding):
normal_(module.weight)
if module._padding_idx is not None:
with paddle.no_grad():
module.weight[module._padding_idx] = 0
"""
def reset_parameters(self) -> None:
init.normal_(self.weight)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
"""
if isinstance(module, nn.LayerNorm):
ones_(module.weight)
zeros_(module.bias)
"""
def reset_parameters(self) -> None:
if self.elementwise_affine:
init.ones_(self.weight)
init.zeros_(self.bias)
"""
self.apply(_reset_parameters)

Loading…
Cancel
Save