|
|
|
@ -158,8 +158,7 @@ class VITS(nn.Layer):
|
|
|
|
|
"use_spectral_norm": False,
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
cache_generator_outputs: bool=True,
|
|
|
|
|
init_type: str="xavier_uniform", ):
|
|
|
|
|
cache_generator_outputs: bool=True, ):
|
|
|
|
|
"""Initialize VITS module.
|
|
|
|
|
Args:
|
|
|
|
|
idim (int):
|
|
|
|
@ -185,9 +184,6 @@ class VITS(nn.Layer):
|
|
|
|
|
assert check_argument_types()
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
# initialize parameters
|
|
|
|
|
# initialize(self, init_type)
|
|
|
|
|
|
|
|
|
|
# define modules
|
|
|
|
|
generator_class = AVAILABLE_GENERATERS[generator_type]
|
|
|
|
|
if generator_type == "vits_generator":
|
|
|
|
@ -202,8 +198,6 @@ class VITS(nn.Layer):
|
|
|
|
|
self.discriminator = discriminator_class(
|
|
|
|
|
**discriminator_params, )
|
|
|
|
|
|
|
|
|
|
# nn.initializer.set_global_initializer(None)
|
|
|
|
|
|
|
|
|
|
# cache
|
|
|
|
|
self.cache_generator_outputs = cache_generator_outputs
|
|
|
|
|
self._cache = None
|
|
|
|
@ -223,7 +217,6 @@ class VITS(nn.Layer):
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
self.generator.decoder.reset_parameters()
|
|
|
|
|
self.generator.text_encoder.reset_parameters()
|
|
|
|
|
# print("VITS===============")
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
@ -254,7 +247,7 @@ class VITS(nn.Layer):
|
|
|
|
|
forward_generator (bool):
|
|
|
|
|
Whether to forward generator.
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if forward_generator:
|
|
|
|
|
return self._forward_generator(
|
|
|
|
@ -301,7 +294,7 @@ class VITS(nn.Layer):
|
|
|
|
|
lids (Optional[Tensor]):
|
|
|
|
|
Language index tensor (B,) or (B, 1).
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# setup
|
|
|
|
|
feats = feats.transpose([0, 2, 1])
|
|
|
|
@ -511,7 +504,7 @@ class VITS(nn.Layer):
|
|
|
|
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
|
def _reset_parameters(module):
|
|
|
|
|
if isinstance(module,
|
|
|
|
|
if isinstance(module,
|
|
|
|
|
(nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)):
|
|
|
|
|
kaiming_uniform_(module.weight, a=math.sqrt(5))
|
|
|
|
|
if module.bias is not None:
|
|
|
|
@ -519,7 +512,7 @@ class VITS(nn.Layer):
|
|
|
|
|
if fan_in != 0:
|
|
|
|
|
bound = 1 / math.sqrt(fan_in)
|
|
|
|
|
uniform_(module.bias, -bound, bound)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(module,
|
|
|
|
|
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)):
|
|
|
|
|
ones_(module.weight)
|
|
|
|
|