add reset_parameters

pull/3182/head
TianYuan 2 years ago
parent 90acc1b435
commit dee6630945

@ -22,6 +22,7 @@ from .layers import ConvBlock
from .layers import ConvNorm from .layers import ConvNorm
from .layers import LinearNorm from .layers import LinearNorm
from .layers import MFCC from .layers import MFCC
from paddlespeech.t2s.modules.nets_utils import _reset_parameters
from paddlespeech.utils.initialize import uniform_ from paddlespeech.utils.initialize import uniform_
@ -59,6 +60,9 @@ class ASRCNN(nn.Layer):
hidden_dim=hidden_dim // 2, hidden_dim=hidden_dim // 2,
n_token=n_token) n_token=n_token)
self.reset_parameters()
self.asr_s2s.reset_parameters()
def forward(self, def forward(self,
x: paddle.Tensor, x: paddle.Tensor,
src_key_padding_mask: paddle.Tensor=None, src_key_padding_mask: paddle.Tensor=None,
@ -108,6 +112,9 @@ class ASRCNN(nn.Layer):
index_tensor.T + unmask_future_steps) index_tensor.T + unmask_future_steps)
return mask return mask
def reset_parameters(self):
self.apply(_reset_parameters)
class ASRS2S(nn.Layer): class ASRS2S(nn.Layer):
def __init__(self, def __init__(self,
@ -118,8 +125,7 @@ class ASRS2S(nn.Layer):
n_token: int=40): n_token: int=40):
super().__init__() super().__init__()
self.embedding = nn.Embedding(n_token, embedding_dim) self.embedding = nn.Embedding(n_token, embedding_dim)
val_range = math.sqrt(6 / hidden_dim) self.val_range = math.sqrt(6 / hidden_dim)
uniform_(self.embedding.weight, -val_range, val_range)
self.decoder_rnn_dim = hidden_dim self.decoder_rnn_dim = hidden_dim
self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token) self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
@ -236,3 +242,6 @@ class ASRS2S(nn.Layer):
hidden = paddle.stack(hidden).transpose([1, 0, 2]) hidden = paddle.stack(hidden).transpose([1, 0, 2])
return hidden, logit, alignments return hidden, logit, alignments
def reset_parameters(self):
uniform_(self.embedding.weight, -self.val_range, self.val_range)

@ -25,6 +25,8 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn from paddle import nn
from paddlespeech.t2s.modules.nets_utils import _reset_parameters
class DownSample(nn.Layer): class DownSample(nn.Layer):
def __init__(self, layer_type: str): def __init__(self, layer_type: str):
@ -355,6 +357,8 @@ class Generator(nn.Layer):
if w_hpf > 0: if w_hpf > 0:
self.hpf = HighPass(w_hpf) self.hpf = HighPass(w_hpf)
self.reset_parameters()
def forward(self, def forward(self,
x: paddle.Tensor, x: paddle.Tensor,
s: paddle.Tensor, s: paddle.Tensor,
@ -399,6 +403,9 @@ class Generator(nn.Layer):
out = self.to_out(x) out = self.to_out(x)
return out return out
def reset_parameters(self):
self.apply(_reset_parameters)
class MappingNetwork(nn.Layer): class MappingNetwork(nn.Layer):
def __init__(self, def __init__(self,
@ -427,6 +434,8 @@ class MappingNetwork(nn.Layer):
nn.ReLU(), nn.Linear(hidden_dim, style_dim)) nn.ReLU(), nn.Linear(hidden_dim, style_dim))
]) ])
self.reset_parameters()
def forward(self, z: paddle.Tensor, y: paddle.Tensor): def forward(self, z: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
@ -449,6 +458,9 @@ class MappingNetwork(nn.Layer):
s = out[idx, y] s = out[idx, y]
return s return s
def reset_parameters(self):
self.apply(_reset_parameters)
class StyleEncoder(nn.Layer): class StyleEncoder(nn.Layer):
def __init__(self, def __init__(self,
@ -490,6 +502,8 @@ class StyleEncoder(nn.Layer):
for _ in range(num_domains): for _ in range(num_domains):
self.unshared.append(nn.Linear(dim_out, style_dim)) self.unshared.append(nn.Linear(dim_out, style_dim))
self.reset_parameters()
def forward(self, x: paddle.Tensor, y: paddle.Tensor): def forward(self, x: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
@ -513,6 +527,9 @@ class StyleEncoder(nn.Layer):
s = out[idx, y] s = out[idx, y]
return s return s
def reset_parameters(self):
self.apply(_reset_parameters)
class Discriminator(nn.Layer): class Discriminator(nn.Layer):
def __init__(self, def __init__(self,
@ -535,6 +552,8 @@ class Discriminator(nn.Layer):
repeat_num=repeat_num) repeat_num=repeat_num)
self.num_domains = num_domains self.num_domains = num_domains
self.reset_parameters()
def forward(self, x: paddle.Tensor, y: paddle.Tensor): def forward(self, x: paddle.Tensor, y: paddle.Tensor):
out = self.dis(x, y) out = self.dis(x, y)
return out return out
@ -543,6 +562,9 @@ class Discriminator(nn.Layer):
out = self.cls.get_feature(x) out = self.cls.get_feature(x)
return out return out
def reset_parameters(self):
self.apply(_reset_parameters)
class Discriminator2D(nn.Layer): class Discriminator2D(nn.Layer):
def __init__(self, def __init__(self,

@ -20,6 +20,44 @@ import paddle
from paddle import nn from paddle import nn
from typeguard import check_argument_types from typeguard import check_argument_types
from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out
from paddlespeech.utils.initialize import kaiming_uniform_
from paddlespeech.utils.initialize import normal_
from paddlespeech.utils.initialize import ones_
from paddlespeech.utils.initialize import uniform_
from paddlespeech.utils.initialize import zeros_
# default init method of torch
# copy from https://github.com/PaddlePaddle/PaddleSpeech/blob/9cf8c1985a98bb380c183116123672976bdfe5c9/paddlespeech/t2s/models/vits/vits.py#L506
def _reset_parameters(module):
if isinstance(module, (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D,
nn.Conv2DTranspose)):
kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
uniform_(module.bias, -bound, bound)
if isinstance(module,
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)):
ones_(module.weight)
zeros_(module.bias)
if isinstance(module, nn.Linear):
kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
uniform_(module.bias, -bound, bound)
if isinstance(module, nn.Embedding):
normal_(module.weight)
if module._padding_idx is not None:
with paddle.no_grad():
module.weight[module._padding_idx] = 0
def pad_list(xs, pad_value): def pad_list(xs, pad_value):
"""Perform padding for the list of tensors. """Perform padding for the list of tensors.

Loading…
Cancel
Save