update VITS init method, test=tts

pull/2809/head
WongLaw 3 years ago
parent 2a8b62d7d2
commit 7f8fbbcd35

@ -106,10 +106,6 @@ class TextEncoder(nn.Layer):
# define modules # define modules
self.emb = nn.Embedding(vocabs, attention_dim) self.emb = nn.Embedding(vocabs, attention_dim)
# dist = paddle.distribution.Normal(loc=0.0, scale=attention_dim**-0.5)
# w = dist.sample(self.emb.weight.shape)
# self.emb.weight.set_value(w)
self.encoder = Encoder( self.encoder = Encoder(
idim=-1, idim=-1,
input_layer=None, input_layer=None,
@ -132,7 +128,7 @@ class TextEncoder(nn.Layer):
self.proj = nn.Conv1D(attention_dim, attention_dim * 2, 1) self.proj = nn.Conv1D(attention_dim, attention_dim * 2, 1)
self.reset_parameters() self.reset_parameters()
def forward( def forward(
self, self,
x: paddle.Tensor, x: paddle.Tensor,
@ -174,4 +170,4 @@ class TextEncoder(nn.Layer):
normal_(self.emb.weight, mean=0.0, std=self.attention_dim**-0.5) normal_(self.emb.weight, mean=0.0, std=self.attention_dim**-0.5)
if self.emb._padding_idx is not None: if self.emb._padding_idx is not None:
with paddle.no_grad(): with paddle.no_grad():
self.emb.weight[self.emb._padding_idx] = 0 self.emb.weight[self.emb._padding_idx] = 0

@ -158,8 +158,7 @@ class VITS(nn.Layer):
"use_spectral_norm": False, "use_spectral_norm": False,
}, },
}, },
cache_generator_outputs: bool=True, cache_generator_outputs: bool=True, ):
init_type: str="xavier_uniform", ):
"""Initialize VITS module. """Initialize VITS module.
Args: Args:
idim (int): idim (int):
@ -185,9 +184,6 @@ class VITS(nn.Layer):
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
# initialize parameters
# initialize(self, init_type)
# define modules # define modules
generator_class = AVAILABLE_GENERATERS[generator_type] generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator": if generator_type == "vits_generator":
@ -202,8 +198,6 @@ class VITS(nn.Layer):
self.discriminator = discriminator_class( self.discriminator = discriminator_class(
**discriminator_params, ) **discriminator_params, )
# nn.initializer.set_global_initializer(None)
# cache # cache
self.cache_generator_outputs = cache_generator_outputs self.cache_generator_outputs = cache_generator_outputs
self._cache = None self._cache = None
@ -223,7 +217,6 @@ class VITS(nn.Layer):
self.reset_parameters() self.reset_parameters()
self.generator.decoder.reset_parameters() self.generator.decoder.reset_parameters()
self.generator.text_encoder.reset_parameters() self.generator.text_encoder.reset_parameters()
# print("VITS===============")
def forward( def forward(
self, self,
@ -254,7 +247,7 @@ class VITS(nn.Layer):
forward_generator (bool): forward_generator (bool):
Whether to forward generator. Whether to forward generator.
Returns: Returns:
""" """
if forward_generator: if forward_generator:
return self._forward_generator( return self._forward_generator(
@ -301,7 +294,7 @@ class VITS(nn.Layer):
lids (Optional[Tensor]): lids (Optional[Tensor]):
Language index tensor (B,) or (B, 1). Language index tensor (B,) or (B, 1).
Returns: Returns:
""" """
# setup # setup
feats = feats.transpose([0, 2, 1]) feats = feats.transpose([0, 2, 1])
@ -511,7 +504,7 @@ class VITS(nn.Layer):
def reset_parameters(self): def reset_parameters(self):
def _reset_parameters(module): def _reset_parameters(module):
if isinstance(module, if isinstance(module,
(nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)): (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)):
kaiming_uniform_(module.weight, a=math.sqrt(5)) kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None: if module.bias is not None:
@ -519,7 +512,7 @@ class VITS(nn.Layer):
if fan_in != 0: if fan_in != 0:
bound = 1 / math.sqrt(fan_in) bound = 1 / math.sqrt(fan_in)
uniform_(module.bias, -bound, bound) uniform_(module.bias, -bound, bound)
if isinstance(module, if isinstance(module,
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)): (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)):
ones_(module.weight) ones_(module.weight)

Loading…
Cancel
Save