From 3ad55a31e715561e0b2343f9aa5b4812608fc5b3 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Thu, 20 Apr 2023 18:43:22 +0800 Subject: [PATCH] [TTS]StarGANv2 VC fix some trainer bugs, add add reset_parameters (#3182) --- examples/vctk/vc3/conf/default.yaml | 35 ++++++++--------- examples/vctk/vc3/local/train.sh | 4 +- paddlespeech/t2s/datasets/am_batch_fn.py | 23 +++++------ paddlespeech/t2s/exps/starganv2_vc/train.py | 22 +++++------ .../models/starganv2_vc/AuxiliaryASR/model.py | 13 ++++++- .../t2s/models/starganv2_vc/losses.py | 7 ++-- .../t2s/models/starganv2_vc/starganv2_vc.py | 32 ++++++++++++++++ .../starganv2_vc/starganv2_vc_updater.py | 11 ++++-- paddlespeech/t2s/modules/nets_utils.py | 38 +++++++++++++++++++ 9 files changed, 133 insertions(+), 52 deletions(-) diff --git a/examples/vctk/vc3/conf/default.yaml b/examples/vctk/vc3/conf/default.yaml index b1168a40..eb98515a 100644 --- a/examples/vctk/vc3/conf/default.yaml +++ b/examples/vctk/vc3/conf/default.yaml @@ -41,7 +41,7 @@ discriminator_params: dim_in: 64 # same as dim_in in generator_params num_domains: 20 # same as num_domains in mapping_network_params max_conv_dim: 512 # same as max_conv_dim in generator_params - n_repeat: 4 + repeat_num: 4 asr_params: input_dim: 80 hidden_dim: 256 @@ -77,6 +77,7 @@ loss_params: ########################################################### batch_size: 5 # Batch size. num_workers: 2 # Number of workers in DataLoader. +max_mel_length: 192 ########################################################### # OPTIMIZER & SCHEDULER SETTING # @@ -84,47 +85,47 @@ num_workers: 2 # Number of workers in DataLoader. generator_optimizer_params: beta1: 0.0 beta2: 0.99 - weight_decay: 1e-4 - epsilon: 1e-9 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 generator_scheduler_params: - max_learning_rate: 2e-4 + max_learning_rate: 2.0e-4 phase_pct: 0.0 divide_factor: 1 total_steps: 200000 # train_max_steps - end_learning_rate: 2e-4 + end_learning_rate: 2.0e-4 style_encoder_optimizer_params: beta1: 0.0 beta2: 0.99 - weight_decay: 1e-4 - epsilon: 1e-9 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 style_encoder_scheduler_params: - max_learning_rate: 2e-4 + max_learning_rate: 2.0e-4 phase_pct: 0.0 divide_factor: 1 total_steps: 200000 # train_max_steps - end_learning_rate: 2e-4 + end_learning_rate: 2.0e-4 mapping_network_optimizer_params: beta1: 0.0 beta2: 0.99 - weight_decay: 1e-4 - epsilon: 1e-9 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 mapping_network_scheduler_params: - max_learning_rate: 2e-6 + max_learning_rate: 2.0e-6 phase_pct: 0.0 divide_factor: 1 total_steps: 200000 # train_max_steps - end_learning_rate: 2e-6 + end_learning_rate: 2.0e-6 discriminator_optimizer_params: beta1: 0.0 beta2: 0.99 - weight_decay: 1e-4 - epsilon: 1e-9 + weight_decay: 1.0e-4 + epsilon: 1.0e-9 discriminator_scheduler_params: - max_learning_rate: 2e-4 + max_learning_rate: 2.0e-4 phase_pct: 0.0 divide_factor: 1 total_steps: 200000 # train_max_steps - end_learning_rate: 2e-4 + end_learning_rate: 2.0e-4 ########################################################### # TRAINING SETTING # diff --git a/examples/vctk/vc3/local/train.sh b/examples/vctk/vc3/local/train.sh index 3a507650..bdd8deae 100755 --- a/examples/vctk/vc3/local/train.sh +++ b/examples/vctk/vc3/local/train.sh @@ -8,6 +8,4 @@ python3 ${BIN_DIR}/train.py \ --dev-metadata=dump/dev/norm/metadata.jsonl \ --config=${config_path} \ --output-dir=${train_output_path} \ - --ngpu=1 \ - --phones-dict=dump/phone_id_map.txt \ - --speaker-dict=dump/speaker_id_map.txt + --ngpu=1 diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 4cd5bccc..ae46f1e1 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -852,20 +852,21 @@ class StarGANv2VCCollateFn: # (B,) label = paddle.to_tensor(label) ref_label = paddle.to_tensor(ref_label) - # [B, 80, T] -> [B, 1, 80, T] - mel = paddle.to_tensor(mel) - ref_mel = paddle.to_tensor(ref_mel) - ref_mel_2 = paddle.to_tensor(ref_mel_2) + # [B, T, 80] -> [B, 1, 80, T] + mel = paddle.to_tensor(mel).transpose([0, 2, 1]).unsqueeze(1) + ref_mel = paddle.to_tensor(ref_mel).transpose([0, 2, 1]).unsqueeze(1) + ref_mel_2 = paddle.to_tensor(ref_mel_2).transpose( + [0, 2, 1]).unsqueeze(1) - z_trg = paddle.randn(batch_size, self.latent_dim) - z_trg2 = paddle.randn(batch_size, self.latent_dim) + z_trg = paddle.randn([batch_size, self.latent_dim]) + z_trg2 = paddle.randn([batch_size, self.latent_dim]) batch = { - "x_real": mels, - "y_org": labels, - "x_ref": ref_mels, - "x_ref2": ref_mels_2, - "y_trg": ref_labels, + "x_real": mel, + "y_org": label, + "x_ref": ref_mel, + "x_ref2": ref_mel_2, + "y_trg": ref_label, "z_trg": z_trg, "z_trg2": z_trg2 } diff --git a/paddlespeech/t2s/exps/starganv2_vc/train.py b/paddlespeech/t2s/exps/starganv2_vc/train.py index 529f1f3d..616591e7 100644 --- a/paddlespeech/t2s/exps/starganv2_vc/train.py +++ b/paddlespeech/t2s/exps/starganv2_vc/train.py @@ -29,9 +29,12 @@ from paddle.optimizer import AdamW from paddle.optimizer.lr import OneCycleLR from yacs.config import CfgNode -from paddlespeech.t2s.datasets.am_batch_fn import starganv2_vc_batch_fn -from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.cli.utils import download_and_decompress +from paddlespeech.resource.pretrained_models import StarGANv2VC_source +from paddlespeech.t2s.datasets.am_batch_fn import build_starganv2_vc_collate_fn +from paddlespeech.t2s.datasets.data_table import StarGANv2VCDataTable from paddlespeech.t2s.models.starganv2_vc import ASRCNN +from paddlespeech.t2s.models.starganv2_vc import Discriminator from paddlespeech.t2s.models.starganv2_vc import Generator from paddlespeech.t2s.models.starganv2_vc import JDCNet from paddlespeech.t2s.models.starganv2_vc import MappingNetwork @@ -66,7 +69,9 @@ def train_sp(args, config): fields = ["speech", "speech_lengths"] converters = {"speech": np.load} - collate_fn = starganv2_vc_batch_fn + collate_fn = build_starganv2_vc_collate_fn( + latent_dim=config['mapping_network_params']['latent_dim'], + max_mel_length=config['max_mel_length']) # dataloader has been too verbose logging.getLogger("DataLoader").disabled = True @@ -74,16 +79,10 @@ def train_sp(args, config): # construct dataset for training and validation with jsonlines.open(args.train_metadata, 'r') as reader: train_metadata = list(reader) - train_dataset = DataTable( - data=train_metadata, - fields=fields, - converters=converters, ) + train_dataset = StarGANv2VCDataTable(data=train_metadata) with jsonlines.open(args.dev_metadata, 'r') as reader: dev_metadata = list(reader) - dev_dataset = DataTable( - data=dev_metadata, - fields=fields, - converters=converters, ) + dev_dataset = StarGANv2VCDataTable(data=dev_metadata) # collate function and dataloader train_sampler = DistributedBatchSampler( @@ -118,6 +117,7 @@ def train_sp(args, config): generator = Generator(**config['generator_params']) mapping_network = MappingNetwork(**config['mapping_network_params']) style_encoder = StyleEncoder(**config['style_encoder_params']) + discriminator = Discriminator(**config['discriminator_params']) # load pretrained model jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz') diff --git a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py index 25197457..85b3453d 100644 --- a/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py +++ b/paddlespeech/t2s/models/starganv2_vc/AuxiliaryASR/model.py @@ -22,6 +22,7 @@ from .layers import ConvBlock from .layers import ConvNorm from .layers import LinearNorm from .layers import MFCC +from paddlespeech.t2s.modules.nets_utils import _reset_parameters from paddlespeech.utils.initialize import uniform_ @@ -59,6 +60,9 @@ class ASRCNN(nn.Layer): hidden_dim=hidden_dim // 2, n_token=n_token) + self.reset_parameters() + self.asr_s2s.reset_parameters() + def forward(self, x: paddle.Tensor, src_key_padding_mask: paddle.Tensor=None, @@ -108,6 +112,9 @@ class ASRCNN(nn.Layer): index_tensor.T + unmask_future_steps) return mask + def reset_parameters(self): + self.apply(_reset_parameters) + class ASRS2S(nn.Layer): def __init__(self, @@ -118,8 +125,7 @@ class ASRS2S(nn.Layer): n_token: int=40): super().__init__() self.embedding = nn.Embedding(n_token, embedding_dim) - val_range = math.sqrt(6 / hidden_dim) - uniform_(self.embedding.weight, -val_range, val_range) + self.val_range = math.sqrt(6 / hidden_dim) self.decoder_rnn_dim = hidden_dim self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token) @@ -236,3 +242,6 @@ class ASRS2S(nn.Layer): hidden = paddle.stack(hidden).transpose([1, 0, 2]) return hidden, logit, alignments + + def reset_parameters(self): + uniform_(self.embedding.weight, -self.val_range, self.val_range) diff --git a/paddlespeech/t2s/models/starganv2_vc/losses.py b/paddlespeech/t2s/models/starganv2_vc/losses.py index f9ff3927..aef7559f 100644 --- a/paddlespeech/t2s/models/starganv2_vc/losses.py +++ b/paddlespeech/t2s/models/starganv2_vc/losses.py @@ -27,9 +27,9 @@ def compute_d_loss(nets: Dict[str, Any], y_trg: paddle.Tensor, z_trg: paddle.Tensor=None, x_ref: paddle.Tensor=None, - use_r1_reg=True, - use_adv_cls=False, - use_con_reg=False, + use_r1_reg: bool=True, + use_adv_cls: bool=False, + use_con_reg: bool=False, lambda_reg: float=1., lambda_adv_cls: float=0.1, lambda_con_reg: float=10.): @@ -37,7 +37,6 @@ def compute_d_loss(nets: Dict[str, Any], assert (z_trg is None) != (x_ref is None) # with real audios x_real.stop_gradient = False - out = nets['discriminator'](x_real, y_org) loss_real = adv_loss(out, 1) diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py index 2b6775c4..99aeb73b 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc.py @@ -25,6 +25,8 @@ import paddle import paddle.nn.functional as F from paddle import nn +from paddlespeech.t2s.modules.nets_utils import _reset_parameters + class DownSample(nn.Layer): def __init__(self, layer_type: str): @@ -355,6 +357,8 @@ class Generator(nn.Layer): if w_hpf > 0: self.hpf = HighPass(w_hpf) + self.reset_parameters() + def forward(self, x: paddle.Tensor, s: paddle.Tensor, @@ -399,6 +403,9 @@ class Generator(nn.Layer): out = self.to_out(x) return out + def reset_parameters(self): + self.apply(_reset_parameters) + class MappingNetwork(nn.Layer): def __init__(self, @@ -427,6 +434,8 @@ class MappingNetwork(nn.Layer): nn.ReLU(), nn.Linear(hidden_dim, style_dim)) ]) + self.reset_parameters() + def forward(self, z: paddle.Tensor, y: paddle.Tensor): """Calculate forward propagation. Args: @@ -449,6 +458,9 @@ class MappingNetwork(nn.Layer): s = out[idx, y] return s + def reset_parameters(self): + self.apply(_reset_parameters) + class StyleEncoder(nn.Layer): def __init__(self, @@ -490,6 +502,8 @@ class StyleEncoder(nn.Layer): for _ in range(num_domains): self.unshared.append(nn.Linear(dim_out, style_dim)) + self.reset_parameters() + def forward(self, x: paddle.Tensor, y: paddle.Tensor): """Calculate forward propagation. Args: @@ -513,6 +527,9 @@ class StyleEncoder(nn.Layer): s = out[idx, y] return s + def reset_parameters(self): + self.apply(_reset_parameters) + class Discriminator(nn.Layer): def __init__(self, @@ -535,7 +552,19 @@ class Discriminator(nn.Layer): repeat_num=repeat_num) self.num_domains = num_domains + self.reset_parameters() + def forward(self, x: paddle.Tensor, y: paddle.Tensor): + """Calculate forward propagation. + Args: + x(Tensor(float32)): + Shape (B, 1, 80, T). + y(Tensor(float32)): + Shape (B, ). + Returns: + Tensor: + Shape (B, ) + """ out = self.dis(x, y) return out @@ -543,6 +572,9 @@ class Discriminator(nn.Layer): out = self.cls.get_feature(x) return out + def reset_parameters(self): + self.apply(_reset_parameters) + class Discriminator2D(nn.Layer): def __init__(self, diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py index 09d4780e..6a77fbb2 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py @@ -21,10 +21,13 @@ from paddle.nn import Layer from paddle.optimizer import Optimizer from paddle.optimizer.lr import LRScheduler +from paddlespeech.t2s.models.starganv2_vc.losses import compute_d_loss +from paddlespeech.t2s.models.starganv2_vc.losses import compute_g_loss from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState + logging.basicConfig( format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', datefmt='[%Y-%m-%d %H:%M:%S]') @@ -62,10 +65,10 @@ class StarGANv2VCUpdater(StandardUpdater): self.models = models self.optimizers = optimizers - self.optimizer_g = optimizers['optimizer_g'] - self.optimizer_s = optimizers['optimizer_s'] - self.optimizer_m = optimizers['optimizer_m'] - self.optimizer_d = optimizers['optimizer_d'] + self.optimizer_g = optimizers['generator'] + self.optimizer_s = optimizers['style_encoder'] + self.optimizer_m = optimizers['mapping_network'] + self.optimizer_d = optimizers['discriminator'] self.schedulers = schedulers self.scheduler_g = schedulers['generator'] diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index 99130acc..3d1b48de 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -20,6 +20,44 @@ import paddle from paddle import nn from typeguard import check_argument_types +from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out +from paddlespeech.utils.initialize import kaiming_uniform_ +from paddlespeech.utils.initialize import normal_ +from paddlespeech.utils.initialize import ones_ +from paddlespeech.utils.initialize import uniform_ +from paddlespeech.utils.initialize import zeros_ + + +# default init method of torch +# copy from https://github.com/PaddlePaddle/PaddleSpeech/blob/9cf8c1985a98bb380c183116123672976bdfe5c9/paddlespeech/t2s/models/vits/vits.py#L506 +def _reset_parameters(module): + if isinstance(module, (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, + nn.Conv2DTranspose)): + kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + uniform_(module.bias, -bound, bound) + + if isinstance(module, + (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)): + ones_(module.weight) + zeros_(module.bias) + + if isinstance(module, nn.Linear): + kaiming_uniform_(module.weight, a=math.sqrt(5)) + if module.bias is not None: + fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + uniform_(module.bias, -bound, bound) + + if isinstance(module, nn.Embedding): + normal_(module.weight) + if module._padding_idx is not None: + with paddle.no_grad(): + module.weight[module._padding_idx] = 0 + def pad_list(xs, pad_value): """Perform padding for the list of tensors.