|
|
|
@ -158,8 +158,7 @@ class VITS(nn.Layer):
|
|
|
|
|
"use_spectral_norm": False,
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
cache_generator_outputs: bool=True,
|
|
|
|
|
init_type: str="xavier_uniform", ):
|
|
|
|
|
cache_generator_outputs: bool=True, ):
|
|
|
|
|
"""Initialize VITS module.
|
|
|
|
|
Args:
|
|
|
|
|
idim (int):
|
|
|
|
@ -185,9 +184,6 @@ class VITS(nn.Layer):
|
|
|
|
|
assert check_argument_types()
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
# initialize parameters
|
|
|
|
|
# initialize(self, init_type)
|
|
|
|
|
|
|
|
|
|
# define modules
|
|
|
|
|
generator_class = AVAILABLE_GENERATERS[generator_type]
|
|
|
|
|
if generator_type == "vits_generator":
|
|
|
|
@ -202,8 +198,6 @@ class VITS(nn.Layer):
|
|
|
|
|
self.discriminator = discriminator_class(
|
|
|
|
|
**discriminator_params, )
|
|
|
|
|
|
|
|
|
|
# nn.initializer.set_global_initializer(None)
|
|
|
|
|
|
|
|
|
|
# cache
|
|
|
|
|
self.cache_generator_outputs = cache_generator_outputs
|
|
|
|
|
self._cache = None
|
|
|
|
@ -223,7 +217,6 @@ class VITS(nn.Layer):
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
self.generator.decoder.reset_parameters()
|
|
|
|
|
self.generator.text_encoder.reset_parameters()
|
|
|
|
|
# print("VITS===============")
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|