[TTS]StarGANv2 VC fix some trainer bugs, add add reset_parameters (#3182)

pull/3187/head
TianYuan 1 year ago committed by GitHub
parent 9cf8c1985a
commit 3ad55a31e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,7 +41,7 @@ discriminator_params:
dim_in: 64 # same as dim_in in generator_params dim_in: 64 # same as dim_in in generator_params
num_domains: 20 # same as num_domains in mapping_network_params num_domains: 20 # same as num_domains in mapping_network_params
max_conv_dim: 512 # same as max_conv_dim in generator_params max_conv_dim: 512 # same as max_conv_dim in generator_params
n_repeat: 4 repeat_num: 4
asr_params: asr_params:
input_dim: 80 input_dim: 80
hidden_dim: 256 hidden_dim: 256
@ -77,6 +77,7 @@ loss_params:
########################################################### ###########################################################
batch_size: 5 # Batch size. batch_size: 5 # Batch size.
num_workers: 2 # Number of workers in DataLoader. num_workers: 2 # Number of workers in DataLoader.
max_mel_length: 192
########################################################### ###########################################################
# OPTIMIZER & SCHEDULER SETTING # # OPTIMIZER & SCHEDULER SETTING #
@ -84,47 +85,47 @@ num_workers: 2 # Number of workers in DataLoader.
generator_optimizer_params: generator_optimizer_params:
beta1: 0.0 beta1: 0.0
beta2: 0.99 beta2: 0.99
weight_decay: 1e-4 weight_decay: 1.0e-4
epsilon: 1e-9 epsilon: 1.0e-9
generator_scheduler_params: generator_scheduler_params:
max_learning_rate: 2e-4 max_learning_rate: 2.0e-4
phase_pct: 0.0 phase_pct: 0.0
divide_factor: 1 divide_factor: 1
total_steps: 200000 # train_max_steps total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4 end_learning_rate: 2.0e-4
style_encoder_optimizer_params: style_encoder_optimizer_params:
beta1: 0.0 beta1: 0.0
beta2: 0.99 beta2: 0.99
weight_decay: 1e-4 weight_decay: 1.0e-4
epsilon: 1e-9 epsilon: 1.0e-9
style_encoder_scheduler_params: style_encoder_scheduler_params:
max_learning_rate: 2e-4 max_learning_rate: 2.0e-4
phase_pct: 0.0 phase_pct: 0.0
divide_factor: 1 divide_factor: 1
total_steps: 200000 # train_max_steps total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4 end_learning_rate: 2.0e-4
mapping_network_optimizer_params: mapping_network_optimizer_params:
beta1: 0.0 beta1: 0.0
beta2: 0.99 beta2: 0.99
weight_decay: 1e-4 weight_decay: 1.0e-4
epsilon: 1e-9 epsilon: 1.0e-9
mapping_network_scheduler_params: mapping_network_scheduler_params:
max_learning_rate: 2e-6 max_learning_rate: 2.0e-6
phase_pct: 0.0 phase_pct: 0.0
divide_factor: 1 divide_factor: 1
total_steps: 200000 # train_max_steps total_steps: 200000 # train_max_steps
end_learning_rate: 2e-6 end_learning_rate: 2.0e-6
discriminator_optimizer_params: discriminator_optimizer_params:
beta1: 0.0 beta1: 0.0
beta2: 0.99 beta2: 0.99
weight_decay: 1e-4 weight_decay: 1.0e-4
epsilon: 1e-9 epsilon: 1.0e-9
discriminator_scheduler_params: discriminator_scheduler_params:
max_learning_rate: 2e-4 max_learning_rate: 2.0e-4
phase_pct: 0.0 phase_pct: 0.0
divide_factor: 1 divide_factor: 1
total_steps: 200000 # train_max_steps total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4 end_learning_rate: 2.0e-4
########################################################### ###########################################################
# TRAINING SETTING # # TRAINING SETTING #

@ -8,6 +8,4 @@ python3 ${BIN_DIR}/train.py \
--dev-metadata=dump/dev/norm/metadata.jsonl \ --dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \ --config=${config_path} \
--output-dir=${train_output_path} \ --output-dir=${train_output_path} \
--ngpu=1 \ --ngpu=1
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt

@ -852,20 +852,21 @@ class StarGANv2VCCollateFn:
# (B,) # (B,)
label = paddle.to_tensor(label) label = paddle.to_tensor(label)
ref_label = paddle.to_tensor(ref_label) ref_label = paddle.to_tensor(ref_label)
# [B, 80, T] -> [B, 1, 80, T] # [B, T, 80] -> [B, 1, 80, T]
mel = paddle.to_tensor(mel) mel = paddle.to_tensor(mel).transpose([0, 2, 1]).unsqueeze(1)
ref_mel = paddle.to_tensor(ref_mel) ref_mel = paddle.to_tensor(ref_mel).transpose([0, 2, 1]).unsqueeze(1)
ref_mel_2 = paddle.to_tensor(ref_mel_2) ref_mel_2 = paddle.to_tensor(ref_mel_2).transpose(
[0, 2, 1]).unsqueeze(1)
z_trg = paddle.randn(batch_size, self.latent_dim) z_trg = paddle.randn([batch_size, self.latent_dim])
z_trg2 = paddle.randn(batch_size, self.latent_dim) z_trg2 = paddle.randn([batch_size, self.latent_dim])
batch = { batch = {
"x_real": mels, "x_real": mel,
"y_org": labels, "y_org": label,
"x_ref": ref_mels, "x_ref": ref_mel,
"x_ref2": ref_mels_2, "x_ref2": ref_mel_2,
"y_trg": ref_labels, "y_trg": ref_label,
"z_trg": z_trg, "z_trg": z_trg,
"z_trg2": z_trg2 "z_trg2": z_trg2
} }

@ -29,9 +29,12 @@ from paddle.optimizer import AdamW
from paddle.optimizer.lr import OneCycleLR from paddle.optimizer.lr import OneCycleLR
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import starganv2_vc_batch_fn from paddlespeech.cli.utils import download_and_decompress
from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.resource.pretrained_models import StarGANv2VC_source
from paddlespeech.t2s.datasets.am_batch_fn import build_starganv2_vc_collate_fn
from paddlespeech.t2s.datasets.data_table import StarGANv2VCDataTable
from paddlespeech.t2s.models.starganv2_vc import ASRCNN from paddlespeech.t2s.models.starganv2_vc import ASRCNN
from paddlespeech.t2s.models.starganv2_vc import Discriminator
from paddlespeech.t2s.models.starganv2_vc import Generator from paddlespeech.t2s.models.starganv2_vc import Generator
from paddlespeech.t2s.models.starganv2_vc import JDCNet from paddlespeech.t2s.models.starganv2_vc import JDCNet
from paddlespeech.t2s.models.starganv2_vc import MappingNetwork from paddlespeech.t2s.models.starganv2_vc import MappingNetwork
@ -66,7 +69,9 @@ def train_sp(args, config):
fields = ["speech", "speech_lengths"] fields = ["speech", "speech_lengths"]
converters = {"speech": np.load} converters = {"speech": np.load}
collate_fn = starganv2_vc_batch_fn collate_fn = build_starganv2_vc_collate_fn(
latent_dim=config['mapping_network_params']['latent_dim'],
max_mel_length=config['max_mel_length'])
# dataloader has been too verbose # dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True logging.getLogger("DataLoader").disabled = True
@ -74,16 +79,10 @@ def train_sp(args, config):
# construct dataset for training and validation # construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader: with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader) train_metadata = list(reader)
train_dataset = DataTable( train_dataset = StarGANv2VCDataTable(data=train_metadata)
data=train_metadata,
fields=fields,
converters=converters, )
with jsonlines.open(args.dev_metadata, 'r') as reader: with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader) dev_metadata = list(reader)
dev_dataset = DataTable( dev_dataset = StarGANv2VCDataTable(data=dev_metadata)
data=dev_metadata,
fields=fields,
converters=converters, )
# collate function and dataloader # collate function and dataloader
train_sampler = DistributedBatchSampler( train_sampler = DistributedBatchSampler(
@ -118,6 +117,7 @@ def train_sp(args, config):
generator = Generator(**config['generator_params']) generator = Generator(**config['generator_params'])
mapping_network = MappingNetwork(**config['mapping_network_params']) mapping_network = MappingNetwork(**config['mapping_network_params'])
style_encoder = StyleEncoder(**config['style_encoder_params']) style_encoder = StyleEncoder(**config['style_encoder_params'])
discriminator = Discriminator(**config['discriminator_params'])
# load pretrained model # load pretrained model
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz') jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')

@ -22,6 +22,7 @@ from .layers import ConvBlock
from .layers import ConvNorm from .layers import ConvNorm
from .layers import LinearNorm from .layers import LinearNorm
from .layers import MFCC from .layers import MFCC
from paddlespeech.t2s.modules.nets_utils import _reset_parameters
from paddlespeech.utils.initialize import uniform_ from paddlespeech.utils.initialize import uniform_
@ -59,6 +60,9 @@ class ASRCNN(nn.Layer):
hidden_dim=hidden_dim // 2, hidden_dim=hidden_dim // 2,
n_token=n_token) n_token=n_token)
self.reset_parameters()
self.asr_s2s.reset_parameters()
def forward(self, def forward(self,
x: paddle.Tensor, x: paddle.Tensor,
src_key_padding_mask: paddle.Tensor=None, src_key_padding_mask: paddle.Tensor=None,
@ -108,6 +112,9 @@ class ASRCNN(nn.Layer):
index_tensor.T + unmask_future_steps) index_tensor.T + unmask_future_steps)
return mask return mask
def reset_parameters(self):
self.apply(_reset_parameters)
class ASRS2S(nn.Layer): class ASRS2S(nn.Layer):
def __init__(self, def __init__(self,
@ -118,8 +125,7 @@ class ASRS2S(nn.Layer):
n_token: int=40): n_token: int=40):
super().__init__() super().__init__()
self.embedding = nn.Embedding(n_token, embedding_dim) self.embedding = nn.Embedding(n_token, embedding_dim)
val_range = math.sqrt(6 / hidden_dim) self.val_range = math.sqrt(6 / hidden_dim)
uniform_(self.embedding.weight, -val_range, val_range)
self.decoder_rnn_dim = hidden_dim self.decoder_rnn_dim = hidden_dim
self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token) self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
@ -236,3 +242,6 @@ class ASRS2S(nn.Layer):
hidden = paddle.stack(hidden).transpose([1, 0, 2]) hidden = paddle.stack(hidden).transpose([1, 0, 2])
return hidden, logit, alignments return hidden, logit, alignments
def reset_parameters(self):
uniform_(self.embedding.weight, -self.val_range, self.val_range)

@ -27,9 +27,9 @@ def compute_d_loss(nets: Dict[str, Any],
y_trg: paddle.Tensor, y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None, z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None, x_ref: paddle.Tensor=None,
use_r1_reg=True, use_r1_reg: bool=True,
use_adv_cls=False, use_adv_cls: bool=False,
use_con_reg=False, use_con_reg: bool=False,
lambda_reg: float=1., lambda_reg: float=1.,
lambda_adv_cls: float=0.1, lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.): lambda_con_reg: float=10.):
@ -37,7 +37,6 @@ def compute_d_loss(nets: Dict[str, Any],
assert (z_trg is None) != (x_ref is None) assert (z_trg is None) != (x_ref is None)
# with real audios # with real audios
x_real.stop_gradient = False x_real.stop_gradient = False
out = nets['discriminator'](x_real, y_org) out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1) loss_real = adv_loss(out, 1)

@ -25,6 +25,8 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import nn from paddle import nn
from paddlespeech.t2s.modules.nets_utils import _reset_parameters
class DownSample(nn.Layer): class DownSample(nn.Layer):
def __init__(self, layer_type: str): def __init__(self, layer_type: str):
@ -355,6 +357,8 @@ class Generator(nn.Layer):
if w_hpf > 0: if w_hpf > 0:
self.hpf = HighPass(w_hpf) self.hpf = HighPass(w_hpf)
self.reset_parameters()
def forward(self, def forward(self,
x: paddle.Tensor, x: paddle.Tensor,
s: paddle.Tensor, s: paddle.Tensor,
@ -399,6 +403,9 @@ class Generator(nn.Layer):
out = self.to_out(x) out = self.to_out(x)
return out return out
def reset_parameters(self):
self.apply(_reset_parameters)
class MappingNetwork(nn.Layer): class MappingNetwork(nn.Layer):
def __init__(self, def __init__(self,
@ -427,6 +434,8 @@ class MappingNetwork(nn.Layer):
nn.ReLU(), nn.Linear(hidden_dim, style_dim)) nn.ReLU(), nn.Linear(hidden_dim, style_dim))
]) ])
self.reset_parameters()
def forward(self, z: paddle.Tensor, y: paddle.Tensor): def forward(self, z: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
@ -449,6 +458,9 @@ class MappingNetwork(nn.Layer):
s = out[idx, y] s = out[idx, y]
return s return s
def reset_parameters(self):
self.apply(_reset_parameters)
class StyleEncoder(nn.Layer): class StyleEncoder(nn.Layer):
def __init__(self, def __init__(self,
@ -490,6 +502,8 @@ class StyleEncoder(nn.Layer):
for _ in range(num_domains): for _ in range(num_domains):
self.unshared.append(nn.Linear(dim_out, style_dim)) self.unshared.append(nn.Linear(dim_out, style_dim))
self.reset_parameters()
def forward(self, x: paddle.Tensor, y: paddle.Tensor): def forward(self, x: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation. """Calculate forward propagation.
Args: Args:
@ -513,6 +527,9 @@ class StyleEncoder(nn.Layer):
s = out[idx, y] s = out[idx, y]
return s return s
def reset_parameters(self):
self.apply(_reset_parameters)
class Discriminator(nn.Layer): class Discriminator(nn.Layer):
def __init__(self, def __init__(self,
@ -535,7 +552,19 @@ class Discriminator(nn.Layer):
repeat_num=repeat_num) repeat_num=repeat_num)
self.num_domains = num_domains self.num_domains = num_domains
self.reset_parameters()
def forward(self, x: paddle.Tensor, y: paddle.Tensor): def forward(self, x: paddle.Tensor, y: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor(float32)):
Shape (B, 1, 80, T).
y(Tensor(float32)):
Shape (B, ).
Returns:
Tensor:
Shape (B, )
"""
out = self.dis(x, y) out = self.dis(x, y)
return out return out
@ -543,6 +572,9 @@ class Discriminator(nn.Layer):
out = self.cls.get_feature(x) out = self.cls.get_feature(x)
return out return out
def reset_parameters(self):
self.apply(_reset_parameters)
class Discriminator2D(nn.Layer): class Discriminator2D(nn.Layer):
def __init__(self, def __init__(self,

@ -21,10 +21,13 @@ from paddle.nn import Layer
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.models.starganv2_vc.losses import compute_d_loss
from paddlespeech.t2s.models.starganv2_vc.losses import compute_g_loss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
logging.basicConfig( logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]') datefmt='[%Y-%m-%d %H:%M:%S]')
@ -62,10 +65,10 @@ class StarGANv2VCUpdater(StandardUpdater):
self.models = models self.models = models
self.optimizers = optimizers self.optimizers = optimizers
self.optimizer_g = optimizers['optimizer_g'] self.optimizer_g = optimizers['generator']
self.optimizer_s = optimizers['optimizer_s'] self.optimizer_s = optimizers['style_encoder']
self.optimizer_m = optimizers['optimizer_m'] self.optimizer_m = optimizers['mapping_network']
self.optimizer_d = optimizers['optimizer_d'] self.optimizer_d = optimizers['discriminator']
self.schedulers = schedulers self.schedulers = schedulers
self.scheduler_g = schedulers['generator'] self.scheduler_g = schedulers['generator']

@ -20,6 +20,44 @@ import paddle
from paddle import nn from paddle import nn
from typeguard import check_argument_types from typeguard import check_argument_types
from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out
from paddlespeech.utils.initialize import kaiming_uniform_
from paddlespeech.utils.initialize import normal_
from paddlespeech.utils.initialize import ones_
from paddlespeech.utils.initialize import uniform_
from paddlespeech.utils.initialize import zeros_
# default init method of torch
# copy from https://github.com/PaddlePaddle/PaddleSpeech/blob/9cf8c1985a98bb380c183116123672976bdfe5c9/paddlespeech/t2s/models/vits/vits.py#L506
def _reset_parameters(module):
if isinstance(module, (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D,
nn.Conv2DTranspose)):
kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
uniform_(module.bias, -bound, bound)
if isinstance(module,
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)):
ones_(module.weight)
zeros_(module.bias)
if isinstance(module, nn.Linear):
kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
uniform_(module.bias, -bound, bound)
if isinstance(module, nn.Embedding):
normal_(module.weight)
if module._padding_idx is not None:
with paddle.no_grad():
module.weight[module._padding_idx] = 0
def pad_list(xs, pad_value): def pad_list(xs, pad_value):
"""Perform padding for the list of tensors. """Perform padding for the list of tensors.

Loading…
Cancel
Save