update VITS init method, test=tts

pull/2809/head
WongLaw 3 years ago
parent 2a8b62d7d2
commit 7f8fbbcd35

@ -106,10 +106,6 @@ class TextEncoder(nn.Layer):
# define modules
self.emb = nn.Embedding(vocabs, attention_dim)
# dist = paddle.distribution.Normal(loc=0.0, scale=attention_dim**-0.5)
# w = dist.sample(self.emb.weight.shape)
# self.emb.weight.set_value(w)
self.encoder = Encoder(
idim=-1,
input_layer=None,

@ -158,8 +158,7 @@ class VITS(nn.Layer):
"use_spectral_norm": False,
},
},
cache_generator_outputs: bool=True,
init_type: str="xavier_uniform", ):
cache_generator_outputs: bool=True, ):
"""Initialize VITS module.
Args:
idim (int):
@ -185,9 +184,6 @@ class VITS(nn.Layer):
assert check_argument_types()
super().__init__()
# initialize parameters
# initialize(self, init_type)
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":
@ -202,8 +198,6 @@ class VITS(nn.Layer):
self.discriminator = discriminator_class(
**discriminator_params, )
# nn.initializer.set_global_initializer(None)
# cache
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
@ -223,7 +217,6 @@ class VITS(nn.Layer):
self.reset_parameters()
self.generator.decoder.reset_parameters()
self.generator.text_encoder.reset_parameters()
# print("VITS===============")
def forward(
self,

Loading…
Cancel
Save