vits_init, test=tts

pull/2809/head
WongLaw 3 years ago
parent 8955e5efd3
commit da33520691

@ -24,6 +24,7 @@ from paddle import nn
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder as Encoder from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder as Encoder
from paddlespeech.utils.initialize import normal_
class TextEncoder(nn.Layer): class TextEncoder(nn.Layer):
@ -166,3 +167,9 @@ class TextEncoder(nn.Layer):
m, logs = paddle.split(stats, 2, axis=1) m, logs = paddle.split(stats, 2, axis=1)
return x, m, logs, x_mask return x, m, logs, x_mask
def reset_parameters(self):
normal_(self.emb.weight)
if self.emb._padding_idx is not None:
with paddle.no_grad():
self.emb.weight[self.emb._padding_idx] = 0

@ -222,6 +222,7 @@ class VITS(nn.Layer):
self.reset_parameters() self.reset_parameters()
self.generator.decoder.reset_parameters() self.generator.decoder.reset_parameters()
self.generator.text_encoder.reset_parameters()
# print("VITS===============") # print("VITS===============")
def forward( def forward(
@ -510,10 +511,8 @@ class VITS(nn.Layer):
def reset_parameters(self): def reset_parameters(self):
def _reset_parameters(module): def _reset_parameters(module):
if isinstance(module, nn.Conv1D) if isinstance(module,
or isinstance(module, nn.Conv1DTranspose) (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)):
or isinstance(module, nn.Conv2D)
or isinstance(module, nn.Conv2DTranspose):
kaiming_uniform_(module.weight, a=math.sqrt(5)) kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None: if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
@ -522,7 +521,7 @@ class VITS(nn.Layer):
uniform_(module.bias, -bound, bound) uniform_(module.bias, -bound, bound)
if isinstance(module, if isinstance(module,
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm)): (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)):
ones_(module.weight) ones_(module.weight)
zeros_(module.bias) zeros_(module.bias)
@ -539,8 +538,4 @@ class VITS(nn.Layer):
with paddle.no_grad(): with paddle.no_grad():
module.weight[module._padding_idx] = 0 module.weight[module._padding_idx] = 0
if isinstance(module, nn.LayerNorm):
ones_(module.weight)
zeros_(module.bias)
self.apply(_reset_parameters) self.apply(_reset_parameters)

Loading…
Cancel
Save