From 8955e5efd399a2719f57adc23f43726c2c4cfb67 Mon Sep 17 00:00:00 2001 From: WongLaw Date: Mon, 9 Jan 2023 04:03:17 +0000 Subject: [PATCH] vits_init, test=tts --- paddlespeech/t2s/models/vits/vits.py | 53 ++-------------------------- 1 file changed, 3 insertions(+), 50 deletions(-) diff --git a/paddlespeech/t2s/models/vits/vits.py b/paddlespeech/t2s/models/vits/vits.py index 05ec93e80..afe1e8dd5 100644 --- a/paddlespeech/t2s/models/vits/vits.py +++ b/paddlespeech/t2s/models/vits/vits.py @@ -520,74 +520,27 @@ class VITS(nn.Layer): if fan_in != 0: bound = 1 / math.sqrt(fan_in) uniform_(module.bias, -bound, bound) - """ - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) - if fan_in != 0: - bound = 1 / math.sqrt(fan_in) - init.uniform_(self.bias, -bound, bound) - """ if isinstance(module, (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm)): ones_(module.weight) zeros_(module.bias) - """ - def reset_running_stats(self) -> None: - if self.track_running_stats: - # running_mean/running_var/num_batches... are registered at runtime depending - # if self.track_running_stats is on - self.running_mean.zero_() # type: ignore[union-attr] - self.running_var.fill_(1) # type: ignore[union-attr] - self.num_batches_tracked.zero_() # type: ignore[union-attr,operator] - def reset_parameters(self) -> None: - self.reset_running_stats() - if self.affine: - init.ones_(self.weight) - init.zeros_(self.bias) - GroupNorm: - def reset_parameters(self) -> None: - if self.affine: - init.ones_(self.weight) - init.zeros_(self.bias) - """ + if isinstance(module, nn.Linear): kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 uniform_(module.bias, -bound, bound) - """ - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - init.uniform_(self.bias, -bound, bound) - """ + if isinstance(module, nn.Embedding): normal_(module.weight) if module._padding_idx is not None: with paddle.no_grad(): module.weight[module._padding_idx] = 0 - """ - def reset_parameters(self) -> None: - init.normal_(self.weight) - self._fill_padding_idx_with_zero() - - def _fill_padding_idx_with_zero(self) -> None: - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - """ + if isinstance(module, nn.LayerNorm): ones_(module.weight) zeros_(module.bias) - """ - def reset_parameters(self) -> None: - if self.elementwise_affine: - init.ones_(self.weight) - init.zeros_(self.bias) - """ self.apply(_reset_parameters)