You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
438 lines
19 KiB
438 lines
19 KiB
2 years ago
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""Generator module in JETS.
|
||
|
|
||
|
This code is based on https://github.com/imdanboy/jets.
|
||
|
|
||
|
"""
|
||
|
import logging
|
||
|
from typing import Dict
|
||
|
|
||
|
import paddle
|
||
|
from paddle import distributed as dist
|
||
|
from paddle.io import DataLoader
|
||
|
from paddle.nn import Layer
|
||
|
from paddle.optimizer import Optimizer
|
||
|
from paddle.optimizer.lr import LRScheduler
|
||
|
|
||
|
from paddlespeech.t2s.modules.nets_utils import get_segments
|
||
|
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
|
||
|
from paddlespeech.t2s.training.reporter import report
|
||
|
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
|
||
|
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
|
||
|
|
||
|
logging.basicConfig(
|
||
|
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||
|
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||
|
logger = logging.getLogger(__name__)
|
||
|
logger.setLevel(logging.INFO)
|
||
|
|
||
|
|
||
|
class JETSUpdater(StandardUpdater):
|
||
|
def __init__(self,
|
||
|
model: Layer,
|
||
|
optimizers: Dict[str, Optimizer],
|
||
|
criterions: Dict[str, Layer],
|
||
|
schedulers: Dict[str, LRScheduler],
|
||
|
dataloader: DataLoader,
|
||
|
generator_train_start_steps: int=0,
|
||
|
discriminator_train_start_steps: int=100000,
|
||
|
lambda_adv: float=1.0,
|
||
|
lambda_mel: float=45.0,
|
||
|
lambda_feat_match: float=2.0,
|
||
|
lambda_var: float=1.0,
|
||
|
lambda_align: float=2.0,
|
||
|
generator_first: bool=False,
|
||
|
use_alignment_module: bool=False,
|
||
|
output_dir=None):
|
||
|
# it is designed to hold multiple models
|
||
|
# 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分
|
||
|
models = {"main": model}
|
||
|
self.models: Dict[str, Layer] = models
|
||
|
# self.model = model
|
||
|
|
||
|
self.model = model._layers if isinstance(model,
|
||
|
paddle.DataParallel) else model
|
||
|
|
||
|
self.optimizers = optimizers
|
||
|
self.optimizer_g: Optimizer = optimizers['generator']
|
||
|
self.optimizer_d: Optimizer = optimizers['discriminator']
|
||
|
|
||
|
self.criterions = criterions
|
||
|
self.criterion_mel = criterions['mel']
|
||
|
self.criterion_feat_match = criterions['feat_match']
|
||
|
self.criterion_gen_adv = criterions["gen_adv"]
|
||
|
self.criterion_dis_adv = criterions["dis_adv"]
|
||
|
self.criterion_var = criterions["var"]
|
||
|
self.criterion_forwardsum = criterions["forwardsum"]
|
||
|
|
||
|
self.schedulers = schedulers
|
||
|
self.scheduler_g = schedulers['generator']
|
||
|
self.scheduler_d = schedulers['discriminator']
|
||
|
|
||
|
self.dataloader = dataloader
|
||
|
|
||
|
self.generator_train_start_steps = generator_train_start_steps
|
||
|
self.discriminator_train_start_steps = discriminator_train_start_steps
|
||
|
|
||
|
self.lambda_adv = lambda_adv
|
||
|
self.lambda_mel = lambda_mel
|
||
|
self.lambda_feat_match = lambda_feat_match
|
||
|
self.lambda_var = lambda_var
|
||
|
self.lambda_align = lambda_align
|
||
|
|
||
|
self.use_alignment_module = use_alignment_module
|
||
|
|
||
|
if generator_first:
|
||
|
self.turns = ["generator", "discriminator"]
|
||
|
else:
|
||
|
self.turns = ["discriminator", "generator"]
|
||
|
|
||
|
self.state = UpdaterState(iteration=0, epoch=0)
|
||
|
self.train_iterator = iter(self.dataloader)
|
||
|
|
||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||
|
logger.addHandler(self.filehandler)
|
||
|
self.logger = logger
|
||
|
self.msg = ""
|
||
|
|
||
|
def update_core(self, batch):
|
||
|
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||
|
losses_dict = {}
|
||
|
|
||
|
for turn in self.turns:
|
||
|
speech = batch["speech"]
|
||
|
speech = speech.unsqueeze(1)
|
||
|
text_lengths = batch["text_lengths"]
|
||
|
feats_lengths = batch["feats_lengths"]
|
||
|
outs = self.model(
|
||
|
text=batch["text"],
|
||
|
text_lengths=batch["text_lengths"],
|
||
|
feats=batch["feats"],
|
||
|
feats_lengths=batch["feats_lengths"],
|
||
|
durations=batch["durations"],
|
||
|
durations_lengths=batch["durations_lengths"],
|
||
|
pitch=batch["pitch"],
|
||
|
energy=batch["energy"],
|
||
|
sids=batch.get("spk_id", None),
|
||
|
spembs=batch.get("spk_emb", None),
|
||
|
forward_generator=turn == "generator",
|
||
|
use_alignment_module=self.use_alignment_module)
|
||
|
# Generator
|
||
|
if turn == "generator":
|
||
|
# parse outputs
|
||
|
speech_hat_, bin_loss, log_p_attn, start_idxs, d_outs, ds, p_outs, ps, e_outs, es = outs
|
||
|
speech_ = get_segments(
|
||
|
x=speech,
|
||
|
start_idxs=start_idxs *
|
||
|
self.model.generator.upsample_factor,
|
||
|
segment_size=self.model.generator.segment_size *
|
||
|
self.model.generator.upsample_factor, )
|
||
|
|
||
|
# calculate discriminator outputs
|
||
|
p_hat = self.model.discriminator(speech_hat_)
|
||
|
with paddle.no_grad():
|
||
|
# do not store discriminator gradient in generator turn
|
||
|
p = self.model.discriminator(speech_)
|
||
|
|
||
|
# calculate losses
|
||
|
mel_loss = self.criterion_mel(speech_hat_, speech_)
|
||
|
|
||
|
adv_loss = self.criterion_gen_adv(p_hat)
|
||
|
feat_match_loss = self.criterion_feat_match(p_hat, p)
|
||
|
dur_loss, pitch_loss, energy_loss = self.criterion_var(
|
||
|
d_outs, ds, p_outs, ps, e_outs, es, text_lengths)
|
||
|
|
||
|
mel_loss = mel_loss * self.lambda_mel
|
||
|
adv_loss = adv_loss * self.lambda_adv
|
||
|
feat_match_loss = feat_match_loss * self.lambda_feat_match
|
||
|
g_loss = mel_loss + adv_loss + feat_match_loss
|
||
|
var_loss = (
|
||
|
dur_loss + pitch_loss + energy_loss) * self.lambda_var
|
||
|
|
||
|
gen_loss = g_loss + var_loss #+ align_loss
|
||
|
|
||
|
report("train/generator_loss", float(gen_loss))
|
||
|
report("train/generator_generator_loss", float(g_loss))
|
||
|
report("train/generator_variance_loss", float(var_loss))
|
||
|
report("train/generator_generator_mel_loss", float(mel_loss))
|
||
|
report("train/generator_generator_adv_loss", float(adv_loss))
|
||
|
report("train/generator_generator_feat_match_loss",
|
||
|
float(feat_match_loss))
|
||
|
report("train/generator_variance_dur_loss", float(dur_loss))
|
||
|
report("train/generator_variance_pitch_loss", float(pitch_loss))
|
||
|
report("train/generator_variance_energy_loss",
|
||
|
float(energy_loss))
|
||
|
|
||
|
losses_dict["generator_loss"] = float(gen_loss)
|
||
|
losses_dict["generator_generator_loss"] = float(g_loss)
|
||
|
losses_dict["generator_variance_loss"] = float(var_loss)
|
||
|
losses_dict["generator_generator_mel_loss"] = float(mel_loss)
|
||
|
losses_dict["generator_generator_adv_loss"] = float(adv_loss)
|
||
|
losses_dict["generator_generator_feat_match_loss"] = float(
|
||
|
feat_match_loss)
|
||
|
losses_dict["generator_variance_dur_loss"] = float(dur_loss)
|
||
|
losses_dict["generator_variance_pitch_loss"] = float(pitch_loss)
|
||
|
losses_dict["generator_variance_energy_loss"] = float(
|
||
|
energy_loss)
|
||
|
|
||
|
if self.use_alignment_module == True:
|
||
|
forwardsum_loss = self.criterion_forwardsum(
|
||
|
log_p_attn, text_lengths, feats_lengths)
|
||
|
align_loss = (
|
||
|
forwardsum_loss + bin_loss) * self.lambda_align
|
||
|
report("train/generator_alignment_loss", float(align_loss))
|
||
|
report("train/generator_alignment_forwardsum_loss",
|
||
|
float(forwardsum_loss))
|
||
|
report("train/generator_alignment_bin_loss",
|
||
|
float(bin_loss))
|
||
|
losses_dict["generator_alignment_loss"] = float(align_loss)
|
||
|
losses_dict["generator_alignment_forwardsum_loss"] = float(
|
||
|
forwardsum_loss)
|
||
|
losses_dict["generator_alignment_bin_loss"] = float(
|
||
|
bin_loss)
|
||
|
|
||
|
self.optimizer_g.clear_grad()
|
||
|
gen_loss.backward()
|
||
|
|
||
|
self.optimizer_g.step()
|
||
|
self.scheduler_g.step()
|
||
|
|
||
|
# reset cache
|
||
|
if self.model.reuse_cache_gen or not self.model.training:
|
||
|
self.model._cache = None
|
||
|
|
||
|
# Disctiminator
|
||
|
elif turn == "discriminator":
|
||
|
# parse outputs
|
||
|
speech_hat_, _, _, start_idxs, *_ = outs
|
||
|
speech_ = get_segments(
|
||
|
x=speech,
|
||
|
start_idxs=start_idxs *
|
||
|
self.model.generator.upsample_factor,
|
||
|
segment_size=self.model.generator.segment_size *
|
||
|
self.model.generator.upsample_factor, )
|
||
|
|
||
|
# calculate discriminator outputs
|
||
|
p_hat = self.model.discriminator(speech_hat_.detach())
|
||
|
p = self.model.discriminator(speech_)
|
||
|
|
||
|
# calculate losses
|
||
|
real_loss, fake_loss = self.criterion_dis_adv(p_hat, p)
|
||
|
dis_loss = real_loss + fake_loss
|
||
|
|
||
|
report("train/real_loss", float(real_loss))
|
||
|
report("train/fake_loss", float(fake_loss))
|
||
|
report("train/discriminator_loss", float(dis_loss))
|
||
|
losses_dict["real_loss"] = float(real_loss)
|
||
|
losses_dict["fake_loss"] = float(fake_loss)
|
||
|
losses_dict["discriminator_loss"] = float(dis_loss)
|
||
|
|
||
|
self.optimizer_d.clear_grad()
|
||
|
dis_loss.backward()
|
||
|
|
||
|
self.optimizer_d.step()
|
||
|
self.scheduler_d.step()
|
||
|
|
||
|
# reset cache
|
||
|
if self.model.reuse_cache_dis or not self.model.training:
|
||
|
self.model._cache = None
|
||
|
|
||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||
|
for k, v in losses_dict.items())
|
||
|
|
||
|
|
||
|
class JETSEvaluator(StandardEvaluator):
|
||
|
def __init__(self,
|
||
|
model,
|
||
|
criterions: Dict[str, Layer],
|
||
|
dataloader: DataLoader,
|
||
|
lambda_adv: float=1.0,
|
||
|
lambda_mel: float=45.0,
|
||
|
lambda_feat_match: float=2.0,
|
||
|
lambda_var: float=1.0,
|
||
|
lambda_align: float=2.0,
|
||
|
generator_first: bool=False,
|
||
|
use_alignment_module: bool=False,
|
||
|
output_dir=None):
|
||
|
# 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分
|
||
|
models = {"main": model}
|
||
|
self.models: Dict[str, Layer] = models
|
||
|
# self.model = model
|
||
|
self.model = model._layers if isinstance(model,
|
||
|
paddle.DataParallel) else model
|
||
|
|
||
|
self.criterions = criterions
|
||
|
self.criterion_mel = criterions['mel']
|
||
|
self.criterion_feat_match = criterions['feat_match']
|
||
|
self.criterion_gen_adv = criterions["gen_adv"]
|
||
|
self.criterion_dis_adv = criterions["dis_adv"]
|
||
|
self.criterion_var = criterions["var"]
|
||
|
self.criterion_forwardsum = criterions["forwardsum"]
|
||
|
|
||
|
self.dataloader = dataloader
|
||
|
|
||
|
self.lambda_adv = lambda_adv
|
||
|
self.lambda_mel = lambda_mel
|
||
|
self.lambda_feat_match = lambda_feat_match
|
||
|
self.lambda_var = lambda_var
|
||
|
self.lambda_align = lambda_align
|
||
|
self.use_alignment_module = use_alignment_module
|
||
|
|
||
|
if generator_first:
|
||
|
self.turns = ["generator", "discriminator"]
|
||
|
else:
|
||
|
self.turns = ["discriminator", "generator"]
|
||
|
|
||
|
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||
|
self.filehandler = logging.FileHandler(str(log_file))
|
||
|
logger.addHandler(self.filehandler)
|
||
|
self.logger = logger
|
||
|
self.msg = ""
|
||
|
|
||
|
def evaluate_core(self, batch):
|
||
|
# logging.debug("Evaluate: ")
|
||
|
self.msg = "Evaluate: "
|
||
|
losses_dict = {}
|
||
|
|
||
|
for turn in self.turns:
|
||
|
speech = batch["speech"]
|
||
|
speech = speech.unsqueeze(1)
|
||
|
text_lengths = batch["text_lengths"]
|
||
|
feats_lengths = batch["feats_lengths"]
|
||
|
outs = self.model(
|
||
|
text=batch["text"],
|
||
|
text_lengths=batch["text_lengths"],
|
||
|
feats=batch["feats"],
|
||
|
feats_lengths=batch["feats_lengths"],
|
||
|
durations=batch["durations"],
|
||
|
durations_lengths=batch["durations_lengths"],
|
||
|
pitch=batch["pitch"],
|
||
|
energy=batch["energy"],
|
||
|
sids=batch.get("spk_id", None),
|
||
|
spembs=batch.get("spk_emb", None),
|
||
|
forward_generator=turn == "generator",
|
||
|
use_alignment_module=self.use_alignment_module)
|
||
|
# Generator
|
||
|
if turn == "generator":
|
||
|
# parse outputs
|
||
|
speech_hat_, bin_loss, log_p_attn, start_idxs, d_outs, ds, p_outs, ps, e_outs, es = outs
|
||
|
speech_ = get_segments(
|
||
|
x=speech,
|
||
|
start_idxs=start_idxs *
|
||
|
self.model.generator.upsample_factor,
|
||
|
segment_size=self.model.generator.segment_size *
|
||
|
self.model.generator.upsample_factor, )
|
||
|
|
||
|
# calculate discriminator outputs
|
||
|
p_hat = self.model.discriminator(speech_hat_)
|
||
|
with paddle.no_grad():
|
||
|
# do not store discriminator gradient in generator turn
|
||
|
p = self.model.discriminator(speech_)
|
||
|
|
||
|
# calculate losses
|
||
|
mel_loss = self.criterion_mel(speech_hat_, speech_)
|
||
|
|
||
|
adv_loss = self.criterion_gen_adv(p_hat)
|
||
|
feat_match_loss = self.criterion_feat_match(p_hat, p)
|
||
|
dur_loss, pitch_loss, energy_loss = self.criterion_var(
|
||
|
d_outs, ds, p_outs, ps, e_outs, es, text_lengths)
|
||
|
|
||
|
mel_loss = mel_loss * self.lambda_mel
|
||
|
adv_loss = adv_loss * self.lambda_adv
|
||
|
feat_match_loss = feat_match_loss * self.lambda_feat_match
|
||
|
g_loss = mel_loss + adv_loss + feat_match_loss
|
||
|
var_loss = (
|
||
|
dur_loss + pitch_loss + energy_loss) * self.lambda_var
|
||
|
|
||
|
gen_loss = g_loss + var_loss #+ align_loss
|
||
|
|
||
|
report("eval/generator_loss", float(gen_loss))
|
||
|
report("eval/generator_generator_loss", float(g_loss))
|
||
|
report("eval/generator_variance_loss", float(var_loss))
|
||
|
report("eval/generator_generator_mel_loss", float(mel_loss))
|
||
|
report("eval/generator_generator_adv_loss", float(adv_loss))
|
||
|
report("eval/generator_generator_feat_match_loss",
|
||
|
float(feat_match_loss))
|
||
|
report("eval/generator_variance_dur_loss", float(dur_loss))
|
||
|
report("eval/generator_variance_pitch_loss", float(pitch_loss))
|
||
|
report("eval/generator_variance_energy_loss",
|
||
|
float(energy_loss))
|
||
|
|
||
|
losses_dict["generator_loss"] = float(gen_loss)
|
||
|
losses_dict["generator_generator_loss"] = float(g_loss)
|
||
|
losses_dict["generator_variance_loss"] = float(var_loss)
|
||
|
losses_dict["generator_generator_mel_loss"] = float(mel_loss)
|
||
|
losses_dict["generator_generator_adv_loss"] = float(adv_loss)
|
||
|
losses_dict["generator_generator_feat_match_loss"] = float(
|
||
|
feat_match_loss)
|
||
|
losses_dict["generator_variance_dur_loss"] = float(dur_loss)
|
||
|
losses_dict["generator_variance_pitch_loss"] = float(pitch_loss)
|
||
|
losses_dict["generator_variance_energy_loss"] = float(
|
||
|
energy_loss)
|
||
|
|
||
|
if self.use_alignment_module == True:
|
||
|
forwardsum_loss = self.criterion_forwardsum(
|
||
|
log_p_attn, text_lengths, feats_lengths)
|
||
|
align_loss = (
|
||
|
forwardsum_loss + bin_loss) * self.lambda_align
|
||
|
report("eval/generator_alignment_loss", float(align_loss))
|
||
|
report("eval/generator_alignment_forwardsum_loss",
|
||
|
float(forwardsum_loss))
|
||
|
report("eval/generator_alignment_bin_loss", float(bin_loss))
|
||
|
losses_dict["generator_alignment_loss"] = float(align_loss)
|
||
|
losses_dict["generator_alignment_forwardsum_loss"] = float(
|
||
|
forwardsum_loss)
|
||
|
losses_dict["generator_alignment_bin_loss"] = float(
|
||
|
bin_loss)
|
||
|
|
||
|
# reset cache
|
||
|
if self.model.reuse_cache_gen or not self.model.training:
|
||
|
self.model._cache = None
|
||
|
|
||
|
# Disctiminator
|
||
|
elif turn == "discriminator":
|
||
|
# parse outputs
|
||
|
speech_hat_, _, _, start_idxs, *_ = outs
|
||
|
speech_ = get_segments(
|
||
|
x=speech,
|
||
|
start_idxs=start_idxs *
|
||
|
self.model.generator.upsample_factor,
|
||
|
segment_size=self.model.generator.segment_size *
|
||
|
self.model.generator.upsample_factor, )
|
||
|
|
||
|
# calculate discriminator outputs
|
||
|
p_hat = self.model.discriminator(speech_hat_.detach())
|
||
|
p = self.model.discriminator(speech_)
|
||
|
|
||
|
# calculate losses
|
||
|
real_loss, fake_loss = self.criterion_dis_adv(p_hat, p)
|
||
|
dis_loss = real_loss + fake_loss
|
||
|
|
||
|
report("eval/real_loss", float(real_loss))
|
||
|
report("eval/fake_loss", float(fake_loss))
|
||
|
report("eval/discriminator_loss", float(dis_loss))
|
||
|
losses_dict["real_loss"] = float(real_loss)
|
||
|
losses_dict["fake_loss"] = float(fake_loss)
|
||
|
losses_dict["discriminator_loss"] = float(dis_loss)
|
||
|
|
||
|
# reset cache
|
||
|
if self.model.reuse_cache_dis or not self.model.training:
|
||
|
self.model._cache = None
|
||
|
|
||
|
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||
|
for k, v in losses_dict.items())
|
||
|
self.logger.info(self.msg)
|