[TTS]add starganv2 vc trainer (#3143)

* add starganv2 vc trainer

* fix StarGANv2VCUpdater and losses

* fix StarGANv2VCEvaluator

* add some typehint
pull/3155/head
TianYuan 1 year ago committed by GitHub
parent 54ef90fcec
commit 72aa19c32c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,22 +1,123 @@
generator_params:
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
# 其实没用上,其实用的是 16000
sr: 24000
n_fft: 2048
win_length: 1200
hop_length: 300
n_mels: 80
###########################################################
# MODEL SETTING #
###########################################################
generator_params:
dim_in: 64
style_dim: 64
max_conv_dim: 512
w_hpf: 0
F0_channel: 256
mapping_network_params:
mapping_network_params:
num_domains: 20 # num of speakers in StarGANv2
latent_dim: 16
style_dim: 64 # same as style_dim in generator_params
hidden_dim: 512 # same as max_conv_dim in generator_params
style_encoder_params:
style_encoder_params:
dim_in: 64 # same as dim_in in generator_params
style_dim: 64 # same as style_dim in generator_params
num_domains: 20 # same as num_domains in generator_params
max_conv_dim: 512 # same as max_conv_dim in generator_params
discriminator_params:
discriminator_params:
dim_in: 64 # same as dim_in in generator_params
num_domains: 20 # same as num_domains in mapping_network_params
max_conv_dim: 512 # same as max_conv_dim in generator_params
n_repeat: 4
asr_params:
input_dim: 80
hidden_dim: 256
n_token: 80
token_embedding_dim: 256
###########################################################
# ADVERSARIAL LOSS SETTING #
###########################################################
loss_params:
g_loss:
lambda_sty: 1.
lambda_cyc: 5.
lambda_ds: 1.
lambda_norm: 1.
lambda_asr: 10.
lambda_f0: 5.
lambda_f0_sty: 0.1
lambda_adv: 2.
lambda_adv_cls: 0.5
norm_bias: 0.5
d_loss:
lambda_reg: 1.
lambda_adv_cls: 0.1
lambda_con_reg: 10.
adv_cls_epoch: 50
con_reg_epoch: 30
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 5 # Batch size.
num_workers: 2 # Number of workers in DataLoader.
###########################################################
# OPTIMIZER & SCHEDULER SETTING #
###########################################################
generator_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
generator_scheduler_params:
max_learning_rate: 2e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
style_encoder_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
style_encoder_scheduler_params:
max_learning_rate: 2e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
mapping_network_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
mapping_network_scheduler_params:
max_learning_rate: 2e-6
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-6
discriminator_optimizer_params:
beta1: 0.0
beta2: 0.99
weight_decay: 1e-4
epsilon: 1e-9
discriminator_scheduler_params:
max_learning_rate: 2e-4
phase_pct: 0.0
divide_factor: 1
total_steps: 200000 # train_max_steps
end_learning_rate: 2e-4
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 150
num_snapshots: 5
seed: 1

@ -114,7 +114,7 @@ def erniesat_batch_fn(examples,
]
span_bdy = paddle.to_tensor(span_bdy)
# dual_mask 的是混合中英时候同时 mask 语音和文本
# dual_mask 的是混合中英时候同时 mask 语音和文本
# ernie sat 在实现跨语言的时候都 mask 了
if text_masking:
masked_pos, text_masked_pos = phones_text_masking(
@ -153,7 +153,7 @@ def erniesat_batch_fn(examples,
batch = {
"text": text,
"speech": speech,
# need to generate
# need to generate
"masked_pos": masked_pos,
"speech_mask": speech_mask,
"text_mask": text_mask,
@ -415,10 +415,13 @@ def fastspeech2_multi_spk_batch_fn(examples):
def diffsinger_single_spk_batch_fn(examples):
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", \
# "speech", "speech_lengths", "durations", "pitch", "energy"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
note = [np.array(item["note"], dtype=np.int64) for item in examples]
note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples]
note_dur = [
np.array(item["note_dur"], dtype=np.float32) for item in examples
]
is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
@ -471,10 +474,13 @@ def diffsinger_single_spk_batch_fn(examples):
def diffsinger_multi_spk_batch_fn(examples):
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"]
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", \
# "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
note = [np.array(item["note"], dtype=np.int64) for item in examples]
note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples]
note_dur = [
np.array(item["note_dur"], dtype=np.float32) for item in examples
]
is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
@ -663,6 +669,20 @@ def vits_multi_spk_batch_fn(examples):
return batch
# 未完成
def starganv2_vc_batch_fn(examples):
batch = {
"x_real": None,
"y_org": None,
"x_ref": None,
"x_ref2": None,
"y_trg": None,
"z_trg": None,
"z_trg2": None,
}
return batch
# for PaddleSlim
def fastspeech2_single_spk_batch_fn_static(examples):
text = [np.array(item["text"], dtype=np.int64) for item in examples]

@ -0,0 +1,259 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import shutil
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import AdamW
from paddle.optimizer.lr import OneCycleLR
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import starganv2_vc_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.starganv2_vc import ASRCNN
from paddlespeech.t2s.models.starganv2_vc import Generator
from paddlespeech.t2s.models.starganv2_vc import JDCNet
from paddlespeech.t2s.models.starganv2_vc import MappingNetwork
from paddlespeech.t2s.models.starganv2_vc import StarGANv2VCEvaluator
from paddlespeech.t2s.models.starganv2_vc import StarGANv2VCUpdater
from paddlespeech.t2s.models.starganv2_vc import StyleEncoder
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer
from paddlespeech.utils.env import MODEL_HOME
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
world_size = paddle.distributed.get_world_size()
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
paddle.set_device("cpu")
else:
paddle.set_device("gpu")
if world_size > 1:
paddle.distributed.init_parallel_env()
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
# to edit
fields = ["speech", "speech_lengths"]
converters = {"speech": np.load}
collate_fn = starganv2_vc_batch_fn
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=fields,
converters=converters, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=fields,
converters=converters, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
print("samplers done!")
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=collate_fn,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
shuffle=False,
drop_last=False,
batch_size=config.batch_size,
collate_fn=collate_fn,
num_workers=config.num_workers)
print("dataloaders done!")
# load model
model_version = '1.0'
uncompress_path = download_and_decompress(StarGANv2VC_source[model_version],
MODEL_HOME)
generator = Generator(**config['generator_params'])
mapping_network = MappingNetwork(**config['mapping_network_params'])
style_encoder = StyleEncoder(**config['style_encoder_params'])
# load pretrained model
jdc_model_dir = os.path.join(uncompress_path, 'jdcnet.pdz')
asr_model_dir = os.path.join(uncompress_path, 'asr.pdz')
F0_model = JDCNet(num_class=1, seq_len=192)
F0_model.set_state_dict(paddle.load(jdc_model_dir)['main_params'])
F0_model.eval()
asr_model = ASRCNN(**config['asr_params'])
asr_model.set_state_dict(paddle.load(asr_model_dir)['main_params'])
asr_model.eval()
if world_size > 1:
generator = DataParallel(generator)
discriminator = DataParallel(discriminator)
print("models done!")
lr_schedule_g = OneCycleLR(**config["generator_scheduler_params"])
optimizer_g = AdamW(
learning_rate=lr_schedule_g,
parameters=generator.parameters(),
**config["generator_optimizer_params"])
lr_schedule_s = OneCycleLR(**config["style_encoder_scheduler_params"])
optimizer_s = AdamW(
learning_rate=lr_schedule_s,
parameters=style_encoder.parameters(),
**config["style_encoder_optimizer_params"])
lr_schedule_m = OneCycleLR(**config["mapping_network_scheduler_params"])
optimizer_m = AdamW(
learning_rate=lr_schedule_m,
parameters=mapping_network.parameters(),
**config["mapping_network_optimizer_params"])
lr_schedule_d = OneCycleLR(**config["discriminator_scheduler_params"])
optimizer_d = AdamW(
learning_rate=lr_schedule_d,
parameters=discriminator.parameters(),
**config["discriminator_optimizer_params"])
print("optimizers done!")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if dist.get_rank() == 0:
config_name = args.config.split("/")[-1]
# copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name)
updater = StarGANv2VCUpdater(
models={
"generator": generator,
"style_encoder": style_encoder,
"mapping_network": mapping_network,
"discriminator": discriminator,
"F0_model": F0_model,
"asr_model": asr_model,
},
optimizers={
"generator": optimizer_g,
"style_encoder": optimizer_s,
"mapping_network": optimizer_m,
"discriminator": optimizer_d,
},
schedulers={
"generator": lr_schedule_g,
"style_encoder": lr_schedule_s,
"mapping_network": lr_schedule_m,
"discriminator": lr_schedule_d,
},
dataloader=train_dataloader,
g_loss_params=config.loss_params.g_loss,
d_loss_params=config.loss_params.d_loss,
adv_cls_epoch=config.loss_params.adv_cls_epoch,
con_reg_epoch=config.loss_params.con_reg_epoch,
output_dir=output_dir)
evaluator = StarGANv2VCEvaluator(
models={
"generator": generator,
"style_encoder": style_encoder,
"mapping_network": mapping_network,
"discriminator": discriminator,
"F0_model": F0_model,
"asr_model": asr_model,
},
dataloader=dev_dataloader,
g_loss_params=config.loss_params.g_loss,
d_loss_params=config.loss_params.d_loss,
adv_cls_epoch=config.loss_params.adv_cls_epoch,
con_reg_epoch=config.loss_params.con_reg_epoch,
output_dir=output_dir)
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
print("Trainer Done!")
trainer.run()
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a HiFiGAN model.")
parser.add_argument("--config", type=str, help="HiFiGAN config file.")
parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
# dispatch
if args.ngpu > 1:
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
else:
train_sp(args, config)
if __name__ == "__main__":
main()

@ -11,29 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
import paddle
import paddle.nn.functional as F
from munch import Munch
from starganv2vc_paddle.transforms import build_transforms
from .transforms import build_transforms
# 这些都写到 updater 里
def compute_d_loss(nets,
args,
x_real,
y_org,
y_trg,
z_trg=None,
x_ref=None,
def compute_d_loss(nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
use_r1_reg=True,
use_adv_cls=False,
use_con_reg=False):
args = Munch(args)
use_con_reg=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):
assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False
out = nets.discriminator(x_real, y_org)
out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1)
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
@ -46,57 +51,60 @@ def compute_d_loss(nets,
loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32)
if use_con_reg:
t = build_transforms()
out_aug = nets.discriminator(t(x_real).detach(), y_org)
out_aug = nets['discriminator'](t(x_real).detach(), y_org)
loss_con_reg += F.smooth_l1_loss(out, out_aug)
# with fake audios
with paddle.no_grad():
if z_trg is not None:
s_trg = nets.mapping_network(z_trg, y_trg)
s_trg = nets['mapping_network'](z_trg, y_trg)
else: # x_ref is not None
s_trg = nets.style_encoder(x_ref, y_trg)
s_trg = nets['style_encoder'](x_ref, y_trg)
F0 = nets.f0_model.get_feature_GAN(x_real)
x_fake = nets.generator(x_real, s_trg, masks=None, F0=F0)
out = nets.discriminator(x_fake, y_trg)
F0 = nets['F0_model'].get_feature_GAN(x_real)
x_fake = nets['generator'](x_real, s_trg, masks=None, F0=F0)
out = nets['discriminator'](x_fake, y_trg)
loss_fake = adv_loss(out, 0)
if use_con_reg:
out_aug = nets.discriminator(t(x_fake).detach(), y_trg)
out_aug = nets['discriminator'](t(x_fake).detach(), y_trg)
loss_con_reg += F.smooth_l1_loss(out, out_aug)
# adversarial classifier loss
if use_adv_cls:
out_de = nets.discriminator.classifier(x_fake)
out_de = nets['discriminator'].classifier(x_fake)
loss_real_adv_cls = F.cross_entropy(out_de[y_org != y_trg],
y_org[y_org != y_trg])
if use_con_reg:
out_de_aug = nets.discriminator.classifier(t(x_fake).detach())
out_de_aug = nets['discriminator'].classifier(t(x_fake).detach())
loss_con_reg += F.smooth_l1_loss(out_de, out_de_aug)
else:
loss_real_adv_cls = paddle.zeros([1]).mean()
loss = loss_real + loss_fake + args.lambda_reg * loss_reg + \
args.lambda_adv_cls * loss_real_adv_cls + \
args.lambda_con_reg * loss_con_reg
loss = loss_real + loss_fake + lambda_reg * loss_reg + \
lambda_adv_cls * loss_real_adv_cls + \
lambda_con_reg * loss_con_reg
return loss, Munch(
real=loss_real.item(),
fake=loss_fake.item(),
reg=loss_reg.item(),
real_adv_cls=loss_real_adv_cls.item(),
con_reg=loss_con_reg.item())
return loss
def compute_g_loss(nets,
args,
x_real,
y_org,
y_trg,
z_trgs=None,
x_refs=None,
use_adv_cls=False):
args = Munch(args)
def compute_g_loss(nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trgs: paddle.Tensor=None,
x_refs: paddle.Tensor=None,
use_adv_cls: bool=False,
lambda_sty: float=1.,
lambda_cyc: float=5.,
lambda_ds: float=1.,
lambda_norm: float=1.,
lambda_asr: float=10.,
lambda_f0: float=5.,
lambda_f0_sty: float=0.1,
lambda_adv: float=2.,
lambda_adv_cls: float=0.5,
norm_bias: float=0.5):
assert (z_trgs is None) != (x_refs is None)
if z_trgs is not None:
@ -106,37 +114,36 @@ def compute_g_loss(nets,
# compute style vectors
if z_trgs is not None:
s_trg = nets.mapping_network(z_trg, y_trg)
s_trg = nets['mapping_network'](z_trg, y_trg)
else:
s_trg = nets.style_encoder(x_ref, y_trg)
s_trg = nets['style_encoder'](x_ref, y_trg)
# compute ASR/F0 features (real)
with paddle.no_grad():
F0_real, GAN_F0_real, cyc_F0_real = nets.f0_model(x_real)
ASR_real = nets.asr_model.get_feature(x_real)
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real)
# adversarial loss
x_fake = nets.generator(x_real, s_trg, masks=None, F0=GAN_F0_real)
out = nets.discriminator(x_fake, y_trg)
x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real)
out = nets['discriminator'](x_fake, y_trg)
loss_adv = adv_loss(out, 1)
# compute ASR/F0 features (fake)
F0_fake, GAN_F0_fake, _ = nets.f0_model(x_fake)
ASR_fake = nets.asr_model.get_feature(x_fake)
F0_fake, GAN_F0_fake, _ = nets['F0_model'](x_fake)
ASR_fake = nets['asr_model'].get_feature(x_fake)
# norm consistency loss
x_fake_norm = log_norm(x_fake)
x_real_norm = log_norm(x_real)
loss_norm = ((
paddle.nn.ReLU()(paddle.abs(x_fake_norm - x_real_norm) - args.norm_bias)
)**2).mean()
tmp = paddle.abs(x_fake_norm - x_real_norm) - norm_bias
loss_norm = ((paddle.nn.ReLU()(tmp))**2).mean()
# F0 loss
loss_f0 = f0_loss(F0_fake, F0_real)
# style F0 loss (style initialization)
if x_refs is not None and args.lambda_f0_sty > 0 and not use_adv_cls:
F0_sty, _, _ = nets.f0_model(x_ref)
if x_refs is not None and lambda_f0_sty > 0 and not use_adv_cls:
F0_sty, _, _ = nets['F0_model'](x_ref)
loss_f0_sty = F.l1_loss(
compute_mean_f0(F0_fake), compute_mean_f0(F0_sty))
else:
@ -146,61 +153,53 @@ def compute_g_loss(nets,
loss_asr = F.smooth_l1_loss(ASR_fake, ASR_real)
# style reconstruction loss
s_pred = nets.style_encoder(x_fake, y_trg)
s_pred = nets['style_encoder'](x_fake, y_trg)
loss_sty = paddle.mean(paddle.abs(s_pred - s_trg))
# diversity sensitive loss
if z_trgs is not None:
s_trg2 = nets.mapping_network(z_trg2, y_trg)
s_trg2 = nets['mapping_network'](z_trg2, y_trg)
else:
s_trg2 = nets.style_encoder(x_ref2, y_trg)
x_fake2 = nets.generator(x_real, s_trg2, masks=None, F0=GAN_F0_real)
s_trg2 = nets['style_encoder'](x_ref2, y_trg)
x_fake2 = nets['generator'](x_real, s_trg2, masks=None, F0=GAN_F0_real)
x_fake2 = x_fake2.detach()
_, GAN_F0_fake2, _ = nets.f0_model(x_fake2)
_, GAN_F0_fake2, _ = nets['F0_model'](x_fake2)
loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2))
loss_ds += F.smooth_l1_loss(GAN_F0_fake, GAN_F0_fake2.detach())
# cycle-consistency loss
s_org = nets.style_encoder(x_real, y_org)
x_rec = nets.generator(x_fake, s_org, masks=None, F0=GAN_F0_fake)
s_org = nets['style_encoder'](x_real, y_org)
x_rec = nets['generator'](x_fake, s_org, masks=None, F0=GAN_F0_fake)
loss_cyc = paddle.mean(paddle.abs(x_rec - x_real))
# F0 loss in cycle-consistency loss
if args.lambda_f0 > 0:
_, _, cyc_F0_rec = nets.f0_model(x_rec)
if lambda_f0 > 0:
_, _, cyc_F0_rec = nets['F0_model'](x_rec)
loss_cyc += F.smooth_l1_loss(cyc_F0_rec, cyc_F0_real)
if args.lambda_asr > 0:
ASR_recon = nets.asr_model.get_feature(x_rec)
if lambda_asr > 0:
ASR_recon = nets['asr_model'].get_feature(x_rec)
loss_cyc += F.smooth_l1_loss(ASR_recon, ASR_real)
# adversarial classifier loss
if use_adv_cls:
out_de = nets.discriminator.classifier(x_fake)
out_de = nets['discriminator'].classifier(x_fake)
loss_adv_cls = F.cross_entropy(out_de[y_org != y_trg],
y_trg[y_org != y_trg])
else:
loss_adv_cls = paddle.zeros([1]).mean()
loss = args.lambda_adv * loss_adv + args.lambda_sty * loss_sty \
- args.lambda_ds * loss_ds + args.lambda_cyc * loss_cyc\
+ args.lambda_norm * loss_norm \
+ args.lambda_asr * loss_asr \
+ args.lambda_f0 * loss_f0 \
+ args.lambda_f0_sty * loss_f0_sty \
+ args.lambda_adv_cls * loss_adv_cls
return loss, Munch(
adv=loss_adv.item(),
sty=loss_sty.item(),
ds=loss_ds.item(),
cyc=loss_cyc.item(),
norm=loss_norm.item(),
asr=loss_asr.item(),
f0=loss_f0.item(),
adv_cls=loss_adv_cls.item())
loss = lambda_adv * loss_adv + lambda_sty * loss_sty \
- lambda_ds * loss_ds + lambda_cyc * loss_cyc \
+ lambda_norm * loss_norm \
+ lambda_asr * loss_asr \
+ lambda_f0 * loss_f0 \
+ lambda_f0_sty * loss_f0_sty \
+ lambda_adv_cls * loss_adv_cls
return loss
# for norm consistency loss
def log_norm(x, mean=-4, std=4, axis=2):
def log_norm(x: paddle.Tensor, mean: float=-4, std: float=4, axis: int=2):
"""
normalized log mel -> mel -> norm -> log(norm)
"""
@ -209,7 +208,7 @@ def log_norm(x, mean=-4, std=4, axis=2):
# for adversarial loss
def adv_loss(logits, target):
def adv_loss(logits: paddle.Tensor, target: float):
assert target in [1, 0]
if len(logits.shape) > 1:
logits = logits.reshape([-1])
@ -220,7 +219,7 @@ def adv_loss(logits, target):
# for R1 regularization loss
def r1_reg(d_out, x_in):
def r1_reg(d_out: paddle.Tensor, x_in: paddle.Tensor):
# zero-centered gradient penalty for real images
batch_size = x_in.shape[0]
grad_dout = paddle.grad(
@ -236,14 +235,14 @@ def r1_reg(d_out, x_in):
# for F0 consistency loss
def compute_mean_f0(f0):
def compute_mean_f0(f0: paddle.Tensor):
f0_mean = f0.mean(-1)
f0_mean = f0_mean.expand((f0.shape[-1], f0_mean.shape[0])).transpose(
(1, 0)) # (B, M)
return f0_mean
def f0_loss(x_f0, y_f0):
def f0_loss(x_f0: paddle.Tensor, y_f0: paddle.Tensor):
"""
x.shape = (B, 1, M, L): predict
y.shape = (B, 1, M, L): target

@ -11,3 +11,295 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
from typing import Dict
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class StarGANv2VCUpdater(StandardUpdater):
def __init__(self,
models: Dict[str, Layer],
optimizers: Dict[str, Optimizer],
schedulers: Dict[str, LRScheduler],
dataloader: DataLoader,
g_loss_params: Dict[str, Any]={
'lambda_sty': 1.,
'lambda_cyc': 5.,
'lambda_ds': 1.,
'lambda_norm': 1.,
'lambda_asr': 10.,
'lambda_f0': 5.,
'lambda_f0_sty': 0.1,
'lambda_adv': 2.,
'lambda_adv_cls': 0.5,
'norm_bias': 0.5,
},
d_loss_params: Dict[str, Any]={
'lambda_reg': 1.,
'lambda_adv_cls': 0.1,
'lambda_con_reg': 10.,
},
adv_cls_epoch: int=50,
con_reg_epoch: int=30,
use_r1_reg: bool=False,
output_dir=None):
self.models = models
self.optimizers = optimizers
self.optimizer_g = optimizers['optimizer_g']
self.optimizer_s = optimizers['optimizer_s']
self.optimizer_m = optimizers['optimizer_m']
self.optimizer_d = optimizers['optimizer_d']
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_s = schedulers['style_encoder']
self.scheduler_m = schedulers['mapping_network']
self.scheduler_d = schedulers['discriminator']
self.dataloader = dataloader
self.g_loss_params = g_loss_params
self.d_loss_params = d_loss_params
self.use_r1_reg = use_r1_reg
self.con_reg_epoch = con_reg_epoch
self.adv_cls_epoch = adv_cls_epoch
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def zero_grad(self):
self.optimizer_d.clear_grad()
self.optimizer_g.clear_grad()
self.optimizer_m.clear_grad()
self.optimizer_s.clear_grad()
def scheduler(self):
self.scheduler_d.step()
self.scheduler_g.step()
self.scheduler_m.step()
self.scheduler_s.step()
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# parse batch
x_real = batch['x_real']
y_org = batch['y_org']
x_ref = batch['x_ref']
x_ref2 = batch['x_ref2']
y_trg = batch['y_trg']
z_trg = batch['z_trg']
z_trg2 = batch['z_trg2']
use_con_reg = (self.state.epoch >= self.con_reg_epoch)
use_adv_cls = (self.state.epoch >= self.adv_cls_epoch)
# Discriminator loss
# train the discriminator (by random reference)
self.zero_grad()
random_d_loss = compute_d_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
z_trg=z_trg,
use_adv_cls=use_adv_cls,
use_con_reg=use_con_reg,
**self.d_loss_params)
random_d_loss.backward()
self.optimizer_d.step()
# train the discriminator (by target reference)
self.zero_grad()
target_d_loss = compute_d_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
x_ref=x_ref,
use_adv_cls=use_adv_cls,
use_con_reg=use_con_reg,
**self.d_loss_params)
target_d_loss.backward()
self.optimizer_d.step()
report("train/random_d_loss", float(random_d_loss))
report("train/target_d_loss", float(target_d_loss))
losses_dict["random_d_loss"] = float(random_d_loss)
losses_dict["target_d_loss"] = float(target_d_loss)
# Generator
# train the generator (by random reference)
self.zero_grad()
random_g_loss = compute_g_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
z_trgs=[z_trg, z_trg2],
use_adv_cls=use_adv_cls,
**self.g_loss_params)
random_g_loss.backward()
self.optimizer_g.step()
self.optimizer_m.step()
self.optimizer_s.step()
# train the generator (by target reference)
self.zero_grad()
target_g_loss = compute_g_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
x_refs=[x_ref, x_ref2],
use_adv_cls=use_adv_cls,
**self.g_loss_params)
target_g_loss.backward()
# 此处是否要 optimizer_g optimizer_m optimizer_s 都写上?
# 源码没写上后两个是否是疏忽?
self.optimizer_g.step()
# self.optimizer_m.step()
# self.optimizer_s.step()
report("train/random_g_loss", float(random_g_loss))
report("train/target_g_loss", float(target_g_loss))
losses_dict["random_g_loss"] = float(random_g_loss)
losses_dict["target_g_loss"] = float(target_g_loss)
self.scheduler()
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class StarGANv2VCEvaluator(StandardEvaluator):
def __init__(self,
models: Dict[str, Layer],
dataloader: DataLoader,
g_loss_params: Dict[str, Any]={
'lambda_sty': 1.,
'lambda_cyc': 5.,
'lambda_ds': 1.,
'lambda_norm': 1.,
'lambda_asr': 10.,
'lambda_f0': 5.,
'lambda_f0_sty': 0.1,
'lambda_adv': 2.,
'lambda_adv_cls': 0.5,
'norm_bias': 0.5,
},
d_loss_params: Dict[str, Any]={
'lambda_reg': 1.,
'lambda_adv_cls': 0.1,
'lambda_con_reg': 10.,
},
adv_cls_epoch: int=50,
con_reg_epoch: int=30,
use_r1_reg: bool=False,
output_dir=None):
self.models = models
self.dataloader = dataloader
self.g_loss_params = g_loss_params
self.d_loss_params = d_loss_params
self.use_r1_reg = use_r1_reg
self.con_reg_epoch = con_reg_epoch
self.adv_cls_epoch = adv_cls_epoch
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
# logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
x_real = batch['x_real']
y_org = batch['y_org']
x_ref = batch['x_ref']
x_ref2 = batch['x_ref2']
y_trg = batch['y_trg']
z_trg = batch['z_trg']
z_trg2 = batch['z_trg2']
# eval the discriminator
random_d_loss = compute_d_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
z_trg=z_trg,
use_r1_reg=False,
use_adv_cls=use_adv_cls,
**self.d_loss_params)
target_d_loss = compute_d_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
x_ref=x_ref,
use_r1_reg=False,
use_adv_cls=use_adv_cls,
**self.d_loss_params)
report("eval/random_d_loss", float(random_d_loss))
report("eval/target_d_loss", float(target_d_loss))
losses_dict["random_d_loss"] = float(random_d_loss)
losses_dict["target_d_loss"] = float(target_d_loss)
# eval the generator
random_g_loss = compute_g_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
z_trgs=[z_trg, z_trg2],
use_adv_cls=use_adv_cls,
**self.g_loss_params)
target_g_loss = compute_g_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
x_refs=[x_ref, x_ref2],
use_adv_cls=use_adv_cls,
**self.g_loss_params)
report("eval/random_g_loss", float(random_g_loss))
report("eval/target_g_loss", float(target_g_loss))
losses_dict["random_g_loss"] = float(random_g_loss)
losses_dict["target_g_loss"] = float(target_g_loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

@ -0,0 +1,143 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import nn
## 1. RandomTimeStrech
class TimeStrech(nn.Layer):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, x: paddle.Tensor):
mel_size = x.shape[-1]
x = F.interpolate(
x,
scale_factor=(1, self.scale),
align_corners=False,
mode='bilinear').squeeze()
if x.shape[-1] < mel_size:
noise_length = (mel_size - x.shape[-1])
random_pos = random.randint(0, x.shape[-1]) - noise_length
if random_pos < 0:
random_pos = 0
noise = x[..., random_pos:random_pos + noise_length]
x = paddle.concat([x, noise], axis=-1)
else:
x = x[..., :mel_size]
return x.unsqueeze(1)
## 2. PitchShift
class PitchShift(nn.Layer):
def __init__(self, shift):
super().__init__()
self.shift = shift
def forward(self, x: paddle.Tensor):
if len(x.shape) == 2:
x = x.unsqueeze(0)
x = x.squeeze()
mel_size = x.shape[1]
shift_scale = (mel_size + self.shift) / mel_size
x = F.interpolate(
x.unsqueeze(1),
scale_factor=(shift_scale, 1.),
align_corners=False,
mode='bilinear').squeeze(1)
x = x[:, :mel_size]
if x.shape[1] < mel_size:
pad_size = mel_size - x.shape[1]
x = paddle.cat(
[x, paddle.zeros(x.shape[0], pad_size, x.shape[2])], axis=1)
x = x.squeeze()
return x.unsqueeze(1)
## 3. ShiftBias
class ShiftBias(nn.Layer):
def __init__(self, bias):
super().__init__()
self.bias = bias
def forward(self, x: paddle.Tensor):
return x + self.bias
## 4. Scaling
class SpectScaling(nn.Layer):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, x: paddle.Tensor):
return x * self.scale
## 5. Time Flip
class TimeFlip(nn.Layer):
def __init__(self, length):
super().__init__()
self.length = round(length)
def forward(self, x: paddle.Tensor):
if self.length > 1:
start = np.random.randint(0, x.shape[-1] - self.length)
x_ret = x.clone()
x_ret[..., start:start + self.length] = paddle.flip(
x[..., start:start + self.length], axis=[-1])
x = x_ret
return x
class PhaseShuffle2D(nn.Layer):
def __init__(self, n: int=2):
super().__init__()
self.n = n
self.random = random.Random(1)
def forward(self, x: paddle.Tensor, move=None):
# x.size = (B, C, M, L)
if move is None:
move = self.random.randint(-self.n, self.n)
if move == 0:
return x
else:
left = x[:, :, :, :move]
right = x[:, :, :, move:]
shuffled = paddle.concat([right, left], axis=3)
return shuffled
def build_transforms():
transforms = [
lambda M: TimeStrech(1 + (np.random.random() - 0.5) * M * 0.2),
lambda M: SpectScaling(1 + (np.random.random() - 1) * M * 0.1),
lambda M: PhaseShuffle2D(192),
]
N, M = len(transforms), np.random.random()
composed = nn.Sequential(
* [trans(M) for trans in np.random.choice(transforms, N)])
return composed
Loading…
Cancel
Save