You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/models/starganv2_vc/starganv2_vc_updater.py

309 lines
10 KiB

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any
from typing import Dict
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.models.starganv2_vc.losses import compute_d_loss
from paddlespeech.t2s.models.starganv2_vc.losses import compute_g_loss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class StarGANv2VCUpdater(StandardUpdater):
def __init__(self,
models: Dict[str, Layer],
optimizers: Dict[str, Optimizer],
schedulers: Dict[str, LRScheduler],
dataloader: DataLoader,
g_loss_params: Dict[str, Any]={
'lambda_sty': 1.,
'lambda_cyc': 5.,
'lambda_ds': 1.,
'lambda_norm': 1.,
'lambda_asr': 10.,
'lambda_f0': 5.,
'lambda_f0_sty': 0.1,
'lambda_adv': 2.,
'lambda_adv_cls': 0.5,
'norm_bias': 0.5,
},
d_loss_params: Dict[str, Any]={
'lambda_reg': 1.,
'lambda_adv_cls': 0.1,
'lambda_con_reg': 10.,
},
adv_cls_epoch: int=50,
con_reg_epoch: int=30,
use_r1_reg: bool=False,
output_dir=None):
self.models = models
self.optimizers = optimizers
self.optimizer_g = optimizers['generator']
self.optimizer_s = optimizers['style_encoder']
self.optimizer_m = optimizers['mapping_network']
self.optimizer_d = optimizers['discriminator']
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_s = schedulers['style_encoder']
self.scheduler_m = schedulers['mapping_network']
self.scheduler_d = schedulers['discriminator']
self.dataloader = dataloader
self.g_loss_params = g_loss_params
self.d_loss_params = d_loss_params
self.use_r1_reg = use_r1_reg
self.con_reg_epoch = con_reg_epoch
self.adv_cls_epoch = adv_cls_epoch
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def zero_grad(self):
self.optimizer_d.clear_grad()
self.optimizer_g.clear_grad()
self.optimizer_m.clear_grad()
self.optimizer_s.clear_grad()
def scheduler(self):
self.scheduler_d.step()
self.scheduler_g.step()
self.scheduler_m.step()
self.scheduler_s.step()
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# parse batch
x_real = batch['x_real']
y_org = batch['y_org']
x_ref = batch['x_ref']
x_ref2 = batch['x_ref2']
y_trg = batch['y_trg']
z_trg = batch['z_trg']
z_trg2 = batch['z_trg2']
use_con_reg = (self.state.epoch >= self.con_reg_epoch)
use_adv_cls = (self.state.epoch >= self.adv_cls_epoch)
# Discriminator loss
# train the discriminator (by random reference)
self.zero_grad()
random_d_loss = compute_d_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
z_trg=z_trg,
use_adv_cls=use_adv_cls,
use_con_reg=use_con_reg,
**self.d_loss_params)
random_d_loss.backward()
self.optimizer_d.step()
# train the discriminator (by target reference)
self.zero_grad()
target_d_loss = compute_d_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
x_ref=x_ref,
use_adv_cls=use_adv_cls,
use_con_reg=use_con_reg,
**self.d_loss_params)
target_d_loss.backward()
self.optimizer_d.step()
report("train/random_d_loss", float(random_d_loss))
report("train/target_d_loss", float(target_d_loss))
losses_dict["random_d_loss"] = float(random_d_loss)
losses_dict["target_d_loss"] = float(target_d_loss)
# Generator
# train the generator (by random reference)
self.zero_grad()
random_g_loss = compute_g_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
z_trgs=[z_trg, z_trg2],
use_adv_cls=use_adv_cls,
**self.g_loss_params)
random_g_loss.backward()
self.optimizer_g.step()
self.optimizer_m.step()
self.optimizer_s.step()
# train the generator (by target reference)
self.zero_grad()
target_g_loss = compute_g_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
x_refs=[x_ref, x_ref2],
use_adv_cls=use_adv_cls,
**self.g_loss_params)
target_g_loss.backward()
# 此处是否要 optimizer_g optimizer_m optimizer_s 都写上?
# 源码没写上后两个是否是疏忽?
self.optimizer_g.step()
# self.optimizer_m.step()
# self.optimizer_s.step()
report("train/random_g_loss", float(random_g_loss))
report("train/target_g_loss", float(target_g_loss))
losses_dict["random_g_loss"] = float(random_g_loss)
losses_dict["target_g_loss"] = float(target_g_loss)
self.scheduler()
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class StarGANv2VCEvaluator(StandardEvaluator):
def __init__(self,
models: Dict[str, Layer],
dataloader: DataLoader,
g_loss_params: Dict[str, Any]={
'lambda_sty': 1.,
'lambda_cyc': 5.,
'lambda_ds': 1.,
'lambda_norm': 1.,
'lambda_asr': 10.,
'lambda_f0': 5.,
'lambda_f0_sty': 0.1,
'lambda_adv': 2.,
'lambda_adv_cls': 0.5,
'norm_bias': 0.5,
},
d_loss_params: Dict[str, Any]={
'lambda_reg': 1.,
'lambda_adv_cls': 0.1,
'lambda_con_reg': 10.,
},
adv_cls_epoch: int=50,
con_reg_epoch: int=30,
use_r1_reg: bool=False,
output_dir=None):
self.models = models
self.dataloader = dataloader
self.g_loss_params = g_loss_params
self.d_loss_params = d_loss_params
self.use_r1_reg = use_r1_reg
self.con_reg_epoch = con_reg_epoch
self.adv_cls_epoch = adv_cls_epoch
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
# logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
x_real = batch['x_real']
y_org = batch['y_org']
x_ref = batch['x_ref']
x_ref2 = batch['x_ref2']
y_trg = batch['y_trg']
z_trg = batch['z_trg']
z_trg2 = batch['z_trg2']
# eval the discriminator
random_d_loss = compute_d_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
z_trg=z_trg,
use_r1_reg=False,
use_adv_cls=use_adv_cls,
**self.d_loss_params)
target_d_loss = compute_d_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
x_ref=x_ref,
use_r1_reg=False,
use_adv_cls=use_adv_cls,
**self.d_loss_params)
report("eval/random_d_loss", float(random_d_loss))
report("eval/target_d_loss", float(target_d_loss))
losses_dict["random_d_loss"] = float(random_d_loss)
losses_dict["target_d_loss"] = float(target_d_loss)
# eval the generator
random_g_loss = compute_g_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
z_trgs=[z_trg, z_trg2],
use_adv_cls=use_adv_cls,
**self.g_loss_params)
target_g_loss = compute_g_loss(
nets=self.models,
x_real=x_real,
y_org=y_org,
y_trg=y_trg,
x_refs=[x_ref, x_ref2],
use_adv_cls=use_adv_cls,
**self.g_loss_params)
report("eval/random_g_loss", float(random_g_loss))
report("eval/target_g_loss", float(target_g_loss))
losses_dict["random_g_loss"] = float(random_g_loss)
losses_dict["target_g_loss"] = float(target_g_loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)