diff --git a/examples/vctk/vc3/conf/default.yaml b/examples/vctk/vc3/conf/default.yaml index 0acc2a56..46198902 100644 --- a/examples/vctk/vc3/conf/default.yaml +++ b/examples/vctk/vc3/conf/default.yaml @@ -1,22 +1,123 @@ - generator_params: +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +# 其实没用上,其实用的是 16000 +sr: 24000 +n_fft: 2048 +win_length: 1200 +hop_length: 300 +n_mels: 80 +########################################################### +# MODEL SETTING # +########################################################### +generator_params: dim_in: 64 style_dim: 64 max_conv_dim: 512 w_hpf: 0 F0_channel: 256 - mapping_network_params: +mapping_network_params: num_domains: 20 # num of speakers in StarGANv2 latent_dim: 16 style_dim: 64 # same as style_dim in generator_params hidden_dim: 512 # same as max_conv_dim in generator_params - style_encoder_params: +style_encoder_params: dim_in: 64 # same as dim_in in generator_params style_dim: 64 # same as style_dim in generator_params num_domains: 20 # same as num_domains in generator_params max_conv_dim: 512 # same as max_conv_dim in generator_params - discriminator_params: +discriminator_params: dim_in: 64 # same as dim_in in generator_params num_domains: 20 # same as num_domains in mapping_network_params max_conv_dim: 512 # same as max_conv_dim in generator_params n_repeat: 4 - \ No newline at end of file +asr_params: + input_dim: 80 + hidden_dim: 256 + n_token: 80 + token_embedding_dim: 256 + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +loss_params: + g_loss: + lambda_sty: 1. + lambda_cyc: 5. + lambda_ds: 1. + lambda_norm: 1. + lambda_asr: 10. + lambda_f0: 5. + lambda_f0_sty: 0.1 + lambda_adv: 2. + lambda_adv_cls: 0.5 + norm_bias: 0.5 + d_loss: + lambda_reg: 1. + lambda_adv_cls: 0.1 + lambda_con_reg: 10. + + adv_cls_epoch: 50 + con_reg_epoch: 30 + + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 5 # Batch size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.0 + beta2: 0.99 + weight_decay: 1e-4 + epsilon: 1e-9 +generator_scheduler_params: + max_learning_rate: 2e-4 + phase_pct: 0.0 + divide_factor: 1 + total_steps: 200000 # train_max_steps + end_learning_rate: 2e-4 +style_encoder_optimizer_params: + beta1: 0.0 + beta2: 0.99 + weight_decay: 1e-4 + epsilon: 1e-9 +style_encoder_scheduler_params: + max_learning_rate: 2e-4 + phase_pct: 0.0 + divide_factor: 1 + total_steps: 200000 # train_max_steps + end_learning_rate: 2e-4 +mapping_network_optimizer_params: + beta1: 0.0 + beta2: 0.99 + weight_decay: 1e-4 + epsilon: 1e-9 +mapping_network_scheduler_params: + max_learning_rate: 2e-6 + phase_pct: 0.0 + divide_factor: 1 + total_steps: 200000 # train_max_steps + end_learning_rate: 2e-6 +discriminator_optimizer_params: + beta1: 0.0 + beta2: 0.99 + weight_decay: 1e-4 + epsilon: 1e-9 +discriminator_scheduler_params: + max_learning_rate: 2e-4 + phase_pct: 0.0 + divide_factor: 1 + total_steps: 200000 # train_max_steps + end_learning_rate: 2e-4 + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 150 +num_snapshots: 5 +seed: 1 \ No newline at end of file diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 9ae791b4..8ec15f5d 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -114,7 +114,7 @@ def erniesat_batch_fn(examples, ] span_bdy = paddle.to_tensor(span_bdy) - # dual_mask 的是混合中英时候同时 mask 语音和文本 + # dual_mask 的是混合中英时候同时 mask 语音和文本 # ernie sat 在实现跨语言的时候都 mask 了 if text_masking: masked_pos, text_masked_pos = phones_text_masking( @@ -153,7 +153,7 @@ def erniesat_batch_fn(examples, batch = { "text": text, "speech": speech, - # need to generate + # need to generate "masked_pos": masked_pos, "speech_mask": speech_mask, "text_mask": text_mask, @@ -415,10 +415,13 @@ def fastspeech2_multi_spk_batch_fn(examples): def diffsinger_single_spk_batch_fn(examples): - # fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"] + # fields = ["text", "note", "note_dur", "is_slur", "text_lengths", \ + # "speech", "speech_lengths", "durations", "pitch", "energy"] text = [np.array(item["text"], dtype=np.int64) for item in examples] note = [np.array(item["note"], dtype=np.int64) for item in examples] - note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples] + note_dur = [ + np.array(item["note_dur"], dtype=np.float32) for item in examples + ] is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples] speech = [np.array(item["speech"], dtype=np.float32) for item in examples] pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples] @@ -471,10 +474,13 @@ def diffsinger_single_spk_batch_fn(examples): def diffsinger_multi_spk_batch_fn(examples): - # fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"] + # fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", \ + # "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"] text = [np.array(item["text"], dtype=np.int64) for item in examples] note = [np.array(item["note"], dtype=np.int64) for item in examples] - note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples] + note_dur = [ + np.array(item["note_dur"], dtype=np.float32) for item in examples + ] is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples] speech = [np.array(item["speech"], dtype=np.float32) for item in examples] pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples] @@ -663,6 +669,20 @@ def vits_multi_spk_batch_fn(examples): return batch +# 未完成 +def starganv2_vc_batch_fn(examples): + batch = { + "x_real": None, + "y_org": None, + "x_ref": None, + "x_ref2": None, + "y_trg": None, + "z_trg": None, + "z_trg2": None, + } + return batch + + # for PaddleSlim def fastspeech2_single_spk_batch_fn_static(examples): text = [np.array(item["text"], dtype=np.int64) for item in examples] diff --git a/paddlespeech/t2s/exps/starganv2_vc/train.py b/paddlespeech/t2s/exps/starganv2_vc/train.py new file mode 100644 index 00000000..529f1f3d --- /dev/null +++ b/paddlespeech/t2s/exps/starganv2_vc/train.py @@ -0,0 +1,259 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.optimizer import AdamW +from paddle.optimizer.lr import OneCycleLR +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import starganv2_vc_batch_fn +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.models.starganv2_vc import ASRCNN +from paddlespeech.t2s.models.starganv2_vc import Generator +from paddlespeech.t2s.models.starganv2_vc import JDCNet +from paddlespeech.t2s.models.starganv2_vc import MappingNetwork +from paddlespeech.t2s.models.starganv2_vc import StarGANv2VCEvaluator +from paddlespeech.t2s.models.starganv2_vc import StarGANv2VCUpdater +from paddlespeech.t2s.models.starganv2_vc import StyleEncoder +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer +from paddlespeech.utils.env import MODEL_HOME + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + world_size = paddle.distributed.get_world_size() + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + # to edit + fields = ["speech", "speech_lengths"] + converters = {"speech": np.load} + + collate_fn = starganv2_vc_batch_fn + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=fields, + converters=converters, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=fields, + converters=converters, ) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + + print("samplers done!") + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + shuffle=False, + drop_last=False, + batch_size=config.batch_size, + collate_fn=collate_fn, + num_workers=config.num_workers) + + print("dataloaders done!") + + # load model + model_version = '1.0' + uncompress_path = download_and_decompress(StarGANv2VC_source[model_version], + MODEL_HOME) + + generator = Generator(**config['generator_params']) + mapping_network = MappingNetwork(**config['mapping_network_params']) + style_encoder = StyleEncoder(**config['style_encoder_params']) + + # load pretrained model + jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz') + asr_model_dir = os.path.join(uncompress_path, 'asr.pdz') + + F0_model = JDCNet(num_class=1, seq_len=192) + F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params']) + F0_model.eval() + + asr_model = ASRCNN(**config['asr_params']) + asr_model.set_state_dict(paddle.load(asr_model_dir)['main_params']) + asr_model.eval() + + if world_size > 1: + generator = DataParallel(generator) + discriminator = DataParallel(discriminator) + print("models done!") + + lr_schedule_g = OneCycleLR(**config["generator_scheduler_params"]) + optimizer_g = AdamW( + learning_rate=lr_schedule_g, + parameters=generator.parameters(), + **config["generator_optimizer_params"]) + + lr_schedule_s = OneCycleLR(**config["style_encoder_scheduler_params"]) + optimizer_s = AdamW( + learning_rate=lr_schedule_s, + parameters=style_encoder.parameters(), + **config["style_encoder_optimizer_params"]) + + lr_schedule_m = OneCycleLR(**config["mapping_network_scheduler_params"]) + optimizer_m = AdamW( + learning_rate=lr_schedule_m, + parameters=mapping_network.parameters(), + **config["mapping_network_optimizer_params"]) + + lr_schedule_d = OneCycleLR(**config["discriminator_scheduler_params"]) + optimizer_d = AdamW( + learning_rate=lr_schedule_d, + parameters=discriminator.parameters(), + **config["discriminator_optimizer_params"]) + print("optimizers done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = StarGANv2VCUpdater( + models={ + "generator": generator, + "style_encoder": style_encoder, + "mapping_network": mapping_network, + "discriminator": discriminator, + "F0_model": F0_model, + "asr_model": asr_model, + }, + optimizers={ + "generator": optimizer_g, + "style_encoder": optimizer_s, + "mapping_network": optimizer_m, + "discriminator": optimizer_d, + }, + schedulers={ + "generator": lr_schedule_g, + "style_encoder": lr_schedule_s, + "mapping_network": lr_schedule_m, + "discriminator": lr_schedule_d, + }, + dataloader=train_dataloader, + g_loss_params=config.loss_params.g_loss, + d_loss_params=config.loss_params.d_loss, + adv_cls_epoch=config.loss_params.adv_cls_epoch, + con_reg_epoch=config.loss_params.con_reg_epoch, + output_dir=output_dir) + + evaluator = StarGANv2VCEvaluator( + models={ + "generator": generator, + "style_encoder": style_encoder, + "mapping_network": mapping_network, + "discriminator": discriminator, + "F0_model": F0_model, + "asr_model": asr_model, + }, + dataloader=dev_dataloader, + g_loss_params=config.loss_params.g_loss, + d_loss_params=config.loss_params.d_loss, + adv_cls_epoch=config.loss_params.adv_cls_epoch, + con_reg_epoch=config.loss_params.con_reg_epoch, + output_dir=output_dir) + + trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) + + if dist.get_rank() == 0: + trainer.extend(evaluator, trigger=(1, "epoch")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + trainer.extend( + Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) + print("Trainer Done!") + + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + + parser = argparse.ArgumentParser(description="Train a HiFiGAN model.") + parser.add_argument("--config", type=str, help="HiFiGAN config file.") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + + args = parser.parse_args() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.ngpu > 1: + dist.spawn(train_sp, (args, config), nprocs=args.ngpu) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/models/starganv2_vc/losses.py b/paddlespeech/t2s/models/starganv2_vc/losses.py index 8086a595..f9ff3927 100644 --- a/paddlespeech/t2s/models/starganv2_vc/losses.py +++ b/paddlespeech/t2s/models/starganv2_vc/losses.py @@ -11,29 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any +from typing import Dict + import paddle import paddle.nn.functional as F -from munch import Munch -from starganv2vc_paddle.transforms import build_transforms + +from .transforms import build_transforms # 这些都写到 updater 里 -def compute_d_loss(nets, - args, - x_real, - y_org, - y_trg, - z_trg=None, - x_ref=None, +def compute_d_loss(nets: Dict[str, Any], + x_real: paddle.Tensor, + y_org: paddle.Tensor, + y_trg: paddle.Tensor, + z_trg: paddle.Tensor=None, + x_ref: paddle.Tensor=None, use_r1_reg=True, use_adv_cls=False, - use_con_reg=False): - args = Munch(args) + use_con_reg=False, + lambda_reg: float=1., + lambda_adv_cls: float=0.1, + lambda_con_reg: float=10.): assert (z_trg is None) != (x_ref is None) # with real audios x_real.stop_gradient = False - out = nets.discriminator(x_real, y_org) + + out = nets['discriminator'](x_real, y_org) loss_real = adv_loss(out, 1) # R1 regularizaition (https://arxiv.org/abs/1801.04406v4) @@ -46,57 +51,60 @@ def compute_d_loss(nets, loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32) if use_con_reg: t = build_transforms() - out_aug = nets.discriminator(t(x_real).detach(), y_org) + out_aug = nets['discriminator'](t(x_real).detach(), y_org) loss_con_reg += F.smooth_l1_loss(out, out_aug) # with fake audios with paddle.no_grad(): if z_trg is not None: - s_trg = nets.mapping_network(z_trg, y_trg) + s_trg = nets['mapping_network'](z_trg, y_trg) else: # x_ref is not None - s_trg = nets.style_encoder(x_ref, y_trg) + s_trg = nets['style_encoder'](x_ref, y_trg) - F0 = nets.f0_model.get_feature_GAN(x_real) - x_fake = nets.generator(x_real, s_trg, masks=None, F0=F0) - out = nets.discriminator(x_fake, y_trg) + F0 = nets['F0_model'].get_feature_GAN(x_real) + x_fake = nets['generator'](x_real, s_trg, masks=None, F0=F0) + out = nets['discriminator'](x_fake, y_trg) loss_fake = adv_loss(out, 0) if use_con_reg: - out_aug = nets.discriminator(t(x_fake).detach(), y_trg) + out_aug = nets['discriminator'](t(x_fake).detach(), y_trg) loss_con_reg += F.smooth_l1_loss(out, out_aug) # adversarial classifier loss if use_adv_cls: - out_de = nets.discriminator.classifier(x_fake) + out_de = nets['discriminator'].classifier(x_fake) loss_real_adv_cls = F.cross_entropy(out_de[y_org != y_trg], y_org[y_org != y_trg]) if use_con_reg: - out_de_aug = nets.discriminator.classifier(t(x_fake).detach()) + out_de_aug = nets['discriminator'].classifier(t(x_fake).detach()) loss_con_reg += F.smooth_l1_loss(out_de, out_de_aug) else: loss_real_adv_cls = paddle.zeros([1]).mean() - loss = loss_real + loss_fake + args.lambda_reg * loss_reg + \ - args.lambda_adv_cls * loss_real_adv_cls + \ - args.lambda_con_reg * loss_con_reg + loss = loss_real + loss_fake + lambda_reg * loss_reg + \ + lambda_adv_cls * loss_real_adv_cls + \ + lambda_con_reg * loss_con_reg - return loss, Munch( - real=loss_real.item(), - fake=loss_fake.item(), - reg=loss_reg.item(), - real_adv_cls=loss_real_adv_cls.item(), - con_reg=loss_con_reg.item()) + return loss -def compute_g_loss(nets, - args, - x_real, - y_org, - y_trg, - z_trgs=None, - x_refs=None, - use_adv_cls=False): - args = Munch(args) +def compute_g_loss(nets: Dict[str, Any], + x_real: paddle.Tensor, + y_org: paddle.Tensor, + y_trg: paddle.Tensor, + z_trgs: paddle.Tensor=None, + x_refs: paddle.Tensor=None, + use_adv_cls: bool=False, + lambda_sty: float=1., + lambda_cyc: float=5., + lambda_ds: float=1., + lambda_norm: float=1., + lambda_asr: float=10., + lambda_f0: float=5., + lambda_f0_sty: float=0.1, + lambda_adv: float=2., + lambda_adv_cls: float=0.5, + norm_bias: float=0.5): assert (z_trgs is None) != (x_refs is None) if z_trgs is not None: @@ -106,37 +114,36 @@ def compute_g_loss(nets, # compute style vectors if z_trgs is not None: - s_trg = nets.mapping_network(z_trg, y_trg) + s_trg = nets['mapping_network'](z_trg, y_trg) else: - s_trg = nets.style_encoder(x_ref, y_trg) + s_trg = nets['style_encoder'](x_ref, y_trg) # compute ASR/F0 features (real) with paddle.no_grad(): - F0_real, GAN_F0_real, cyc_F0_real = nets.f0_model(x_real) - ASR_real = nets.asr_model.get_feature(x_real) + F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real) + ASR_real = nets['asr_model'].get_feature(x_real) # adversarial loss - x_fake = nets.generator(x_real, s_trg, masks=None, F0=GAN_F0_real) - out = nets.discriminator(x_fake, y_trg) + x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real) + out = nets['discriminator'](x_fake, y_trg) loss_adv = adv_loss(out, 1) # compute ASR/F0 features (fake) - F0_fake, GAN_F0_fake, _ = nets.f0_model(x_fake) - ASR_fake = nets.asr_model.get_feature(x_fake) + F0_fake, GAN_F0_fake, _ = nets['F0_model'](x_fake) + ASR_fake = nets['asr_model'].get_feature(x_fake) # norm consistency loss x_fake_norm = log_norm(x_fake) x_real_norm = log_norm(x_real) - loss_norm = (( - paddle.nn.ReLU()(paddle.abs(x_fake_norm - x_real_norm) - args.norm_bias) - )**2).mean() + tmp = paddle.abs(x_fake_norm - x_real_norm) - norm_bias + loss_norm = ((paddle.nn.ReLU()(tmp))**2).mean() # F0 loss loss_f0 = f0_loss(F0_fake, F0_real) # style F0 loss (style initialization) - if x_refs is not None and args.lambda_f0_sty > 0 and not use_adv_cls: - F0_sty, _, _ = nets.f0_model(x_ref) + if x_refs is not None and lambda_f0_sty > 0 and not use_adv_cls: + F0_sty, _, _ = nets['F0_model'](x_ref) loss_f0_sty = F.l1_loss( compute_mean_f0(F0_fake), compute_mean_f0(F0_sty)) else: @@ -146,61 +153,53 @@ def compute_g_loss(nets, loss_asr = F.smooth_l1_loss(ASR_fake, ASR_real) # style reconstruction loss - s_pred = nets.style_encoder(x_fake, y_trg) + s_pred = nets['style_encoder'](x_fake, y_trg) loss_sty = paddle.mean(paddle.abs(s_pred - s_trg)) # diversity sensitive loss if z_trgs is not None: - s_trg2 = nets.mapping_network(z_trg2, y_trg) + s_trg2 = nets['mapping_network'](z_trg2, y_trg) else: - s_trg2 = nets.style_encoder(x_ref2, y_trg) - x_fake2 = nets.generator(x_real, s_trg2, masks=None, F0=GAN_F0_real) + s_trg2 = nets['style_encoder'](x_ref2, y_trg) + x_fake2 = nets['generator'](x_real, s_trg2, masks=None, F0=GAN_F0_real) x_fake2 = x_fake2.detach() - _, GAN_F0_fake2, _ = nets.f0_model(x_fake2) + _, GAN_F0_fake2, _ = nets['F0_model'](x_fake2) loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2)) loss_ds += F.smooth_l1_loss(GAN_F0_fake, GAN_F0_fake2.detach()) # cycle-consistency loss - s_org = nets.style_encoder(x_real, y_org) - x_rec = nets.generator(x_fake, s_org, masks=None, F0=GAN_F0_fake) + s_org = nets['style_encoder'](x_real, y_org) + x_rec = nets['generator'](x_fake, s_org, masks=None, F0=GAN_F0_fake) loss_cyc = paddle.mean(paddle.abs(x_rec - x_real)) # F0 loss in cycle-consistency loss - if args.lambda_f0 > 0: - _, _, cyc_F0_rec = nets.f0_model(x_rec) + if lambda_f0 > 0: + _, _, cyc_F0_rec = nets['F0_model'](x_rec) loss_cyc += F.smooth_l1_loss(cyc_F0_rec, cyc_F0_real) - if args.lambda_asr > 0: - ASR_recon = nets.asr_model.get_feature(x_rec) + if lambda_asr > 0: + ASR_recon = nets['asr_model'].get_feature(x_rec) loss_cyc += F.smooth_l1_loss(ASR_recon, ASR_real) # adversarial classifier loss if use_adv_cls: - out_de = nets.discriminator.classifier(x_fake) + out_de = nets['discriminator'].classifier(x_fake) loss_adv_cls = F.cross_entropy(out_de[y_org != y_trg], y_trg[y_org != y_trg]) else: loss_adv_cls = paddle.zeros([1]).mean() - loss = args.lambda_adv * loss_adv + args.lambda_sty * loss_sty \ - - args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc\ - + args.lambda_norm * loss_norm \ - + args.lambda_asr * loss_asr \ - + args.lambda_f0 * loss_f0 \ - + args.lambda_f0_sty * loss_f0_sty \ - + args.lambda_adv_cls * loss_adv_cls - - return loss, Munch( - adv=loss_adv.item(), - sty=loss_sty.item(), - ds=loss_ds.item(), - cyc=loss_cyc.item(), - norm=loss_norm.item(), - asr=loss_asr.item(), - f0=loss_f0.item(), - adv_cls=loss_adv_cls.item()) + loss = lambda_adv * loss_adv + lambda_sty * loss_sty \ + - lambda_ds * loss_ds + lambda_cyc * loss_cyc \ + + lambda_norm * loss_norm \ + + lambda_asr * loss_asr \ + + lambda_f0 * loss_f0 \ + + lambda_f0_sty * loss_f0_sty \ + + lambda_adv_cls * loss_adv_cls + + return loss # for norm consistency loss -def log_norm(x, mean=-4, std=4, axis=2): +def log_norm(x: paddle.Tensor, mean: float=-4, std: float=4, axis: int=2): """ normalized log mel -> mel -> norm -> log(norm) """ @@ -209,7 +208,7 @@ def log_norm(x, mean=-4, std=4, axis=2): # for adversarial loss -def adv_loss(logits, target): +def adv_loss(logits: paddle.Tensor, target: float): assert target in [1, 0] if len(logits.shape) > 1: logits = logits.reshape([-1]) @@ -220,7 +219,7 @@ def adv_loss(logits, target): # for R1 regularization loss -def r1_reg(d_out, x_in): +def r1_reg(d_out: paddle.Tensor, x_in: paddle.Tensor): # zero-centered gradient penalty for real images batch_size = x_in.shape[0] grad_dout = paddle.grad( @@ -236,14 +235,14 @@ def r1_reg(d_out, x_in): # for F0 consistency loss -def compute_mean_f0(f0): +def compute_mean_f0(f0: paddle.Tensor): f0_mean = f0.mean(-1) f0_mean = f0_mean.expand((f0.shape[-1], f0_mean.shape[0])).transpose( (1, 0)) # (B, M) return f0_mean -def f0_loss(x_f0, y_f0): +def f0_loss(x_f0: paddle.Tensor, y_f0: paddle.Tensor): """ x.shape = (B, 1, M, L): predict y.shape = (B, 1, M, L): target diff --git a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py index 595add0a..09d4780e 100644 --- a/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py +++ b/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py @@ -11,3 +11,295 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import Any +from typing import Dict + +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator +from paddlespeech.t2s.training.reporter import report +from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater +from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class StarGANv2VCUpdater(StandardUpdater): + def __init__(self, + models: Dict[str, Layer], + optimizers: Dict[str, Optimizer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + g_loss_params: Dict[str, Any]={ + 'lambda_sty': 1., + 'lambda_cyc': 5., + 'lambda_ds': 1., + 'lambda_norm': 1., + 'lambda_asr': 10., + 'lambda_f0': 5., + 'lambda_f0_sty': 0.1, + 'lambda_adv': 2., + 'lambda_adv_cls': 0.5, + 'norm_bias': 0.5, + }, + d_loss_params: Dict[str, Any]={ + 'lambda_reg': 1., + 'lambda_adv_cls': 0.1, + 'lambda_con_reg': 10., + }, + adv_cls_epoch: int=50, + con_reg_epoch: int=30, + use_r1_reg: bool=False, + output_dir=None): + self.models = models + + self.optimizers = optimizers + self.optimizer_g = optimizers['optimizer_g'] + self.optimizer_s = optimizers['optimizer_s'] + self.optimizer_m = optimizers['optimizer_m'] + self.optimizer_d = optimizers['optimizer_d'] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_s = schedulers['style_encoder'] + self.scheduler_m = schedulers['mapping_network'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.g_loss_params = g_loss_params + self.d_loss_params = d_loss_params + + self.use_r1_reg = use_r1_reg + self.con_reg_epoch = con_reg_epoch + self.adv_cls_epoch = adv_cls_epoch + + self.state = UpdaterState(iteration=0, epoch=0) + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def zero_grad(self): + self.optimizer_d.clear_grad() + self.optimizer_g.clear_grad() + self.optimizer_m.clear_grad() + self.optimizer_s.clear_grad() + + def scheduler(self): + self.scheduler_d.step() + self.scheduler_g.step() + self.scheduler_m.step() + self.scheduler_s.step() + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + # parse batch + x_real = batch['x_real'] + y_org = batch['y_org'] + x_ref = batch['x_ref'] + x_ref2 = batch['x_ref2'] + y_trg = batch['y_trg'] + z_trg = batch['z_trg'] + z_trg2 = batch['z_trg2'] + + use_con_reg = (self.state.epoch >= self.con_reg_epoch) + use_adv_cls = (self.state.epoch >= self.adv_cls_epoch) + + # Discriminator loss + # train the discriminator (by random reference) + self.zero_grad() + random_d_loss = compute_d_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + z_trg=z_trg, + use_adv_cls=use_adv_cls, + use_con_reg=use_con_reg, + **self.d_loss_params) + random_d_loss.backward() + self.optimizer_d.step() + # train the discriminator (by target reference) + self.zero_grad() + target_d_loss = compute_d_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + x_ref=x_ref, + use_adv_cls=use_adv_cls, + use_con_reg=use_con_reg, + **self.d_loss_params) + target_d_loss.backward() + self.optimizer_d.step() + report("train/random_d_loss", float(random_d_loss)) + report("train/target_d_loss", float(target_d_loss)) + losses_dict["random_d_loss"] = float(random_d_loss) + losses_dict["target_d_loss"] = float(target_d_loss) + + # Generator + # train the generator (by random reference) + self.zero_grad() + random_g_loss = compute_g_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + z_trgs=[z_trg, z_trg2], + use_adv_cls=use_adv_cls, + **self.g_loss_params) + random_g_loss.backward() + self.optimizer_g.step() + self.optimizer_m.step() + self.optimizer_s.step() + + # train the generator (by target reference) + self.zero_grad() + target_g_loss = compute_g_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + x_refs=[x_ref, x_ref2], + use_adv_cls=use_adv_cls, + **self.g_loss_params) + target_g_loss.backward() + # 此处是否要 optimizer_g optimizer_m optimizer_s 都写上? + # 源码没写上后两个是否是疏忽? + self.optimizer_g.step() + # self.optimizer_m.step() + # self.optimizer_s.step() + report("train/random_g_loss", float(random_g_loss)) + report("train/target_g_loss", float(target_g_loss)) + losses_dict["random_g_loss"] = float(random_g_loss) + losses_dict["target_g_loss"] = float(target_g_loss) + + self.scheduler() + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class StarGANv2VCEvaluator(StandardEvaluator): + def __init__(self, + models: Dict[str, Layer], + dataloader: DataLoader, + g_loss_params: Dict[str, Any]={ + 'lambda_sty': 1., + 'lambda_cyc': 5., + 'lambda_ds': 1., + 'lambda_norm': 1., + 'lambda_asr': 10., + 'lambda_f0': 5., + 'lambda_f0_sty': 0.1, + 'lambda_adv': 2., + 'lambda_adv_cls': 0.5, + 'norm_bias': 0.5, + }, + d_loss_params: Dict[str, Any]={ + 'lambda_reg': 1., + 'lambda_adv_cls': 0.1, + 'lambda_con_reg': 10., + }, + adv_cls_epoch: int=50, + con_reg_epoch: int=30, + use_r1_reg: bool=False, + output_dir=None): + self.models = models + + self.dataloader = dataloader + + self.g_loss_params = g_loss_params + self.d_loss_params = d_loss_params + + self.use_r1_reg = use_r1_reg + self.con_reg_epoch = con_reg_epoch + self.adv_cls_epoch = adv_cls_epoch + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + # logging.debug("Evaluate: ") + self.msg = "Evaluate: " + losses_dict = {} + + x_real = batch['x_real'] + y_org = batch['y_org'] + x_ref = batch['x_ref'] + x_ref2 = batch['x_ref2'] + y_trg = batch['y_trg'] + z_trg = batch['z_trg'] + z_trg2 = batch['z_trg2'] + + # eval the discriminator + + random_d_loss = compute_d_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + z_trg=z_trg, + use_r1_reg=False, + use_adv_cls=use_adv_cls, + **self.d_loss_params) + + target_d_loss = compute_d_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + x_ref=x_ref, + use_r1_reg=False, + use_adv_cls=use_adv_cls, + **self.d_loss_params) + + report("eval/random_d_loss", float(random_d_loss)) + report("eval/target_d_loss", float(target_d_loss)) + losses_dict["random_d_loss"] = float(random_d_loss) + losses_dict["target_d_loss"] = float(target_d_loss) + + # eval the generator + + random_g_loss = compute_g_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + z_trgs=[z_trg, z_trg2], + use_adv_cls=use_adv_cls, + **self.g_loss_params) + + target_g_loss = compute_g_loss( + nets=self.models, + x_real=x_real, + y_org=y_org, + y_trg=y_trg, + x_refs=[x_ref, x_ref2], + use_adv_cls=use_adv_cls, + **self.g_loss_params) + + report("eval/random_g_loss", float(random_g_loss)) + report("eval/target_g_loss", float(target_g_loss)) + losses_dict["random_g_loss"] = float(random_g_loss) + losses_dict["target_g_loss"] = float(target_g_loss) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/paddlespeech/t2s/models/starganv2_vc/transforms.py b/paddlespeech/t2s/models/starganv2_vc/transforms.py new file mode 100644 index 00000000..d7586147 --- /dev/null +++ b/paddlespeech/t2s/models/starganv2_vc/transforms.py @@ -0,0 +1,143 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle import nn + + +## 1. RandomTimeStrech +class TimeStrech(nn.Layer): + def __init__(self, scale): + super().__init__() + self.scale = scale + + def forward(self, x: paddle.Tensor): + mel_size = x.shape[-1] + + x = F.interpolate( + x, + scale_factor=(1, self.scale), + align_corners=False, + mode='bilinear').squeeze() + + if x.shape[-1] < mel_size: + noise_length = (mel_size - x.shape[-1]) + random_pos = random.randint(0, x.shape[-1]) - noise_length + if random_pos < 0: + random_pos = 0 + noise = x[..., random_pos:random_pos + noise_length] + x = paddle.concat([x, noise], axis=-1) + else: + x = x[..., :mel_size] + + return x.unsqueeze(1) + + +## 2. PitchShift +class PitchShift(nn.Layer): + def __init__(self, shift): + super().__init__() + self.shift = shift + + def forward(self, x: paddle.Tensor): + if len(x.shape) == 2: + x = x.unsqueeze(0) + x = x.squeeze() + mel_size = x.shape[1] + shift_scale = (mel_size + self.shift) / mel_size + x = F.interpolate( + x.unsqueeze(1), + scale_factor=(shift_scale, 1.), + align_corners=False, + mode='bilinear').squeeze(1) + + x = x[:, :mel_size] + if x.shape[1] < mel_size: + pad_size = mel_size - x.shape[1] + x = paddle.cat( + [x, paddle.zeros(x.shape[0], pad_size, x.shape[2])], axis=1) + x = x.squeeze() + return x.unsqueeze(1) + + +## 3. ShiftBias +class ShiftBias(nn.Layer): + def __init__(self, bias): + super().__init__() + self.bias = bias + + def forward(self, x: paddle.Tensor): + return x + self.bias + + +## 4. Scaling +class SpectScaling(nn.Layer): + def __init__(self, scale): + super().__init__() + self.scale = scale + + def forward(self, x: paddle.Tensor): + return x * self.scale + + +## 5. Time Flip +class TimeFlip(nn.Layer): + def __init__(self, length): + super().__init__() + self.length = round(length) + + def forward(self, x: paddle.Tensor): + if self.length > 1: + start = np.random.randint(0, x.shape[-1] - self.length) + x_ret = x.clone() + x_ret[..., start:start + self.length] = paddle.flip( + x[..., start:start + self.length], axis=[-1]) + x = x_ret + return x + + +class PhaseShuffle2D(nn.Layer): + def __init__(self, n: int=2): + super().__init__() + self.n = n + self.random = random.Random(1) + + def forward(self, x: paddle.Tensor, move=None): + # x.size = (B, C, M, L) + if move is None: + move = self.random.randint(-self.n, self.n) + + if move == 0: + return x + else: + left = x[:, :, :, :move] + right = x[:, :, :, move:] + shuffled = paddle.concat([right, left], axis=3) + + return shuffled + + +def build_transforms(): + transforms = [ + lambda M: TimeStrech(1 + (np.random.random() - 0.5) * M * 0.2), + lambda M: SpectScaling(1 + (np.random.random() - 1) * M * 0.1), + lambda M: PhaseShuffle2D(192), + ] + N, M = len(transforms), np.random.random() + composed = nn.Sequential( + * [trans(M) for trans in np.random.choice(transforms, N)]) + return composed