|
|
@ -520,74 +520,27 @@ class VITS(nn.Layer):
|
|
|
|
if fan_in != 0:
|
|
|
|
if fan_in != 0:
|
|
|
|
bound = 1 / math.sqrt(fan_in)
|
|
|
|
bound = 1 / math.sqrt(fan_in)
|
|
|
|
uniform_(module.bias, -bound, bound)
|
|
|
|
uniform_(module.bias, -bound, bound)
|
|
|
|
"""
|
|
|
|
|
|
|
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
|
|
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
|
|
|
|
|
|
|
if fan_in != 0:
|
|
|
|
|
|
|
|
bound = 1 / math.sqrt(fan_in)
|
|
|
|
|
|
|
|
init.uniform_(self.bias, -bound, bound)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(module,
|
|
|
|
if isinstance(module,
|
|
|
|
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm)):
|
|
|
|
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm)):
|
|
|
|
ones_(module.weight)
|
|
|
|
ones_(module.weight)
|
|
|
|
zeros_(module.bias)
|
|
|
|
zeros_(module.bias)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def reset_running_stats(self) -> None:
|
|
|
|
|
|
|
|
if self.track_running_stats:
|
|
|
|
|
|
|
|
# running_mean/running_var/num_batches... are registered at runtime depending
|
|
|
|
|
|
|
|
# if self.track_running_stats is on
|
|
|
|
|
|
|
|
self.running_mean.zero_() # type: ignore[union-attr]
|
|
|
|
|
|
|
|
self.running_var.fill_(1) # type: ignore[union-attr]
|
|
|
|
|
|
|
|
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
|
|
|
|
|
|
|
|
def reset_parameters(self) -> None:
|
|
|
|
|
|
|
|
self.reset_running_stats()
|
|
|
|
|
|
|
|
if self.affine:
|
|
|
|
|
|
|
|
init.ones_(self.weight)
|
|
|
|
|
|
|
|
init.zeros_(self.bias)
|
|
|
|
|
|
|
|
GroupNorm:
|
|
|
|
|
|
|
|
def reset_parameters(self) -> None:
|
|
|
|
|
|
|
|
if self.affine:
|
|
|
|
|
|
|
|
init.ones_(self.weight)
|
|
|
|
|
|
|
|
init.zeros_(self.bias)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if isinstance(module, nn.Linear):
|
|
|
|
if isinstance(module, nn.Linear):
|
|
|
|
kaiming_uniform_(module.weight, a=math.sqrt(5))
|
|
|
|
kaiming_uniform_(module.weight, a=math.sqrt(5))
|
|
|
|
if module.bias is not None:
|
|
|
|
if module.bias is not None:
|
|
|
|
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
|
|
|
|
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
|
|
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
|
|
uniform_(module.bias, -bound, bound)
|
|
|
|
uniform_(module.bias, -bound, bound)
|
|
|
|
"""
|
|
|
|
|
|
|
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
|
|
|
|
|
|
if self.bias is not None:
|
|
|
|
|
|
|
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
|
|
|
|
|
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
|
|
|
|
|
|
init.uniform_(self.bias, -bound, bound)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if isinstance(module, nn.Embedding):
|
|
|
|
if isinstance(module, nn.Embedding):
|
|
|
|
normal_(module.weight)
|
|
|
|
normal_(module.weight)
|
|
|
|
if module._padding_idx is not None:
|
|
|
|
if module._padding_idx is not None:
|
|
|
|
with paddle.no_grad():
|
|
|
|
with paddle.no_grad():
|
|
|
|
module.weight[module._padding_idx] = 0
|
|
|
|
module.weight[module._padding_idx] = 0
|
|
|
|
"""
|
|
|
|
|
|
|
|
def reset_parameters(self) -> None:
|
|
|
|
|
|
|
|
init.normal_(self.weight)
|
|
|
|
|
|
|
|
self._fill_padding_idx_with_zero()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fill_padding_idx_with_zero(self) -> None:
|
|
|
|
|
|
|
|
if self.padding_idx is not None:
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
self.weight[self.padding_idx].fill_(0)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if isinstance(module, nn.LayerNorm):
|
|
|
|
if isinstance(module, nn.LayerNorm):
|
|
|
|
ones_(module.weight)
|
|
|
|
ones_(module.weight)
|
|
|
|
zeros_(module.bias)
|
|
|
|
zeros_(module.bias)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def reset_parameters(self) -> None:
|
|
|
|
|
|
|
|
if self.elementwise_affine:
|
|
|
|
|
|
|
|
init.ones_(self.weight)
|
|
|
|
|
|
|
|
init.zeros_(self.bias)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.apply(_reset_parameters)
|
|
|
|
self.apply(_reset_parameters)
|
|
|
|