diff --git a/paddlespeech/t2s/models/vits/vits.py b/paddlespeech/t2s/models/vits/vits.py index 6e9339899..bc63bb995 100644 --- a/paddlespeech/t2s/models/vits/vits.py +++ b/paddlespeech/t2s/models/vits/vits.py @@ -510,37 +510,84 @@ class VITS(nn.Layer): def reset_parameters(self): def _reset_parameters(module): - if isinstance(module, nn.Conv1D) or isinstance(module, - nn.Conv1DTranspose): + if isinstance(module, nn.Conv1D) + or isinstance(module, nn.Conv1DTranspose) + or isinstance(module, nn.Conv2D) + or isinstance(module, nn.Conv2DTranspose): kaiming_uniform_(module.weight, a=math.sqrt(5), mode="fan_out") if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - uniform_(module.bias, -bound, bound) - if isinstance(module, nn.Conv2D) or isinstance(module, - nn.Conv2DTranspose): - kaiming_uniform_(module.weight, a=math.sqrt(5), mode="fan_out") - if module.bias is not None: - fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) - bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 - uniform_(module.bias, -bound, bound) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + uniform_(module.bias, -bound, bound) + """ + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + """ + if isinstance(module, (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm)): ones_(module.weight) zeros_(module.bias) + """ + def reset_running_stats(self) -> None: + if self.track_running_stats: + # running_mean/running_var/num_batches... are registered at runtime depending + # if self.track_running_stats is on + self.running_mean.zero_() # type: ignore[union-attr] + self.running_var.fill_(1) # type: ignore[union-attr] + self.num_batches_tracked.zero_() # type: ignore[union-attr,operator] + def reset_parameters(self) -> None: + self.reset_running_stats() + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + GroupNorm: + def reset_parameters(self) -> None: + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + """ if isinstance(module, nn.Linear): kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 uniform_(module.bias, -bound, bound) + """ + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(self.bias, -bound, bound) + """ if isinstance(module, nn.Embedding): normal_(module.weight) if module._padding_idx is not None: with paddle.no_grad(): module.weight[module._padding_idx] = 0 + """ + def reset_parameters(self) -> None: + init.normal_(self.weight) + self._fill_padding_idx_with_zero() + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + """ if isinstance(module, nn.LayerNorm): ones_(module.weight) zeros_(module.bias) + """ + def reset_parameters(self) -> None: + if self.elementwise_affine: + init.ones_(self.weight) + init.zeros_(self.bias) + """ self.apply(_reset_parameters)