update VITS init method, test=tts

pull/2809/head
WongLaw 3 years ago
parent 2a8b62d7d2
commit 7f8fbbcd35

@ -106,10 +106,6 @@ class TextEncoder(nn.Layer):
# define modules
self.emb = nn.Embedding(vocabs, attention_dim)
# dist = paddle.distribution.Normal(loc=0.0, scale=attention_dim**-0.5)
# w = dist.sample(self.emb.weight.shape)
# self.emb.weight.set_value(w)
self.encoder = Encoder(
idim=-1,
input_layer=None,
@ -132,7 +128,7 @@ class TextEncoder(nn.Layer):
self.proj = nn.Conv1D(attention_dim, attention_dim * 2, 1)
self.reset_parameters()
def forward(
self,
x: paddle.Tensor,
@ -174,4 +170,4 @@ class TextEncoder(nn.Layer):
normal_(self.emb.weight, mean=0.0, std=self.attention_dim**-0.5)
if self.emb._padding_idx is not None:
with paddle.no_grad():
self.emb.weight[self.emb._padding_idx] = 0
self.emb.weight[self.emb._padding_idx] = 0

@ -158,8 +158,7 @@ class VITS(nn.Layer):
"use_spectral_norm": False,
},
},
cache_generator_outputs: bool=True,
init_type: str="xavier_uniform", ):
cache_generator_outputs: bool=True, ):
"""Initialize VITS module.
Args:
idim (int):
@ -185,9 +184,6 @@ class VITS(nn.Layer):
assert check_argument_types()
super().__init__()
# initialize parameters
# initialize(self, init_type)
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":
@ -202,8 +198,6 @@ class VITS(nn.Layer):
self.discriminator = discriminator_class(
**discriminator_params, )
# nn.initializer.set_global_initializer(None)
# cache
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
@ -223,7 +217,6 @@ class VITS(nn.Layer):
self.reset_parameters()
self.generator.decoder.reset_parameters()
self.generator.text_encoder.reset_parameters()
# print("VITS===============")
def forward(
self,
@ -254,7 +247,7 @@ class VITS(nn.Layer):
forward_generator (bool):
Whether to forward generator.
Returns:
"""
if forward_generator:
return self._forward_generator(
@ -301,7 +294,7 @@ class VITS(nn.Layer):
lids (Optional[Tensor]):
Language index tensor (B,) or (B, 1).
Returns:
"""
# setup
feats = feats.transpose([0, 2, 1])
@ -511,7 +504,7 @@ class VITS(nn.Layer):
def reset_parameters(self):
def _reset_parameters(module):
if isinstance(module,
if isinstance(module,
(nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)):
kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
@ -519,7 +512,7 @@ class VITS(nn.Layer):
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
uniform_(module.bias, -bound, bound)
if isinstance(module,
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)):
ones_(module.weight)

Loading…
Cancel
Save