You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/models/jets/jets_updater.py

438 lines
19 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generator module in JETS.
This code is based on https://github.com/imdanboy/jets.
"""
import logging
from typing import Dict
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.modules.nets_utils import get_segments
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class JETSUpdater(StandardUpdater):
def __init__(self,
model: Layer,
optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer],
schedulers: Dict[str, LRScheduler],
dataloader: DataLoader,
generator_train_start_steps: int=0,
discriminator_train_start_steps: int=100000,
lambda_adv: float=1.0,
lambda_mel: float=45.0,
lambda_feat_match: float=2.0,
lambda_var: float=1.0,
lambda_align: float=2.0,
generator_first: bool=False,
use_alignment_module: bool=False,
output_dir=None):
# it is designed to hold multiple models
# 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分
models = {"main": model}
self.models: Dict[str, Layer] = models
# self.model = model
self.model = model._layers if isinstance(model,
paddle.DataParallel) else model
self.optimizers = optimizers
self.optimizer_g: Optimizer = optimizers['generator']
self.optimizer_d: Optimizer = optimizers['discriminator']
self.criterions = criterions
self.criterion_mel = criterions['mel']
self.criterion_feat_match = criterions['feat_match']
self.criterion_gen_adv = criterions["gen_adv"]
self.criterion_dis_adv = criterions["dis_adv"]
self.criterion_var = criterions["var"]
self.criterion_forwardsum = criterions["forwardsum"]
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_d = schedulers['discriminator']
self.dataloader = dataloader
self.generator_train_start_steps = generator_train_start_steps
self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv
self.lambda_mel = lambda_mel
self.lambda_feat_match = lambda_feat_match
self.lambda_var = lambda_var
self.lambda_align = lambda_align
self.use_alignment_module = use_alignment_module
if generator_first:
self.turns = ["generator", "discriminator"]
else:
self.turns = ["discriminator", "generator"]
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
for turn in self.turns:
speech = batch["speech"]
speech = speech.unsqueeze(1)
text_lengths = batch["text_lengths"]
feats_lengths = batch["feats_lengths"]
outs = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
feats=batch["feats"],
feats_lengths=batch["feats_lengths"],
durations=batch["durations"],
durations_lengths=batch["durations_lengths"],
pitch=batch["pitch"],
energy=batch["energy"],
sids=batch.get("spk_id", None),
spembs=batch.get("spk_emb", None),
forward_generator=turn == "generator",
use_alignment_module=self.use_alignment_module)
# Generator
if turn == "generator":
# parse outputs
speech_hat_, bin_loss, log_p_attn, start_idxs, d_outs, ds, p_outs, ps, e_outs, es = outs
speech_ = get_segments(
x=speech,
start_idxs=start_idxs *
self.model.generator.upsample_factor,
segment_size=self.model.generator.segment_size *
self.model.generator.upsample_factor, )
# calculate discriminator outputs
p_hat = self.model.discriminator(speech_hat_)
with paddle.no_grad():
# do not store discriminator gradient in generator turn
p = self.model.discriminator(speech_)
# calculate losses
mel_loss = self.criterion_mel(speech_hat_, speech_)
adv_loss = self.criterion_gen_adv(p_hat)
feat_match_loss = self.criterion_feat_match(p_hat, p)
dur_loss, pitch_loss, energy_loss = self.criterion_var(
d_outs, ds, p_outs, ps, e_outs, es, text_lengths)
mel_loss = mel_loss * self.lambda_mel
adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match
g_loss = mel_loss + adv_loss + feat_match_loss
var_loss = (
dur_loss + pitch_loss + energy_loss) * self.lambda_var
gen_loss = g_loss + var_loss #+ align_loss
report("train/generator_loss", float(gen_loss))
report("train/generator_generator_loss", float(g_loss))
report("train/generator_variance_loss", float(var_loss))
report("train/generator_generator_mel_loss", float(mel_loss))
report("train/generator_generator_adv_loss", float(adv_loss))
report("train/generator_generator_feat_match_loss",
float(feat_match_loss))
report("train/generator_variance_dur_loss", float(dur_loss))
report("train/generator_variance_pitch_loss", float(pitch_loss))
report("train/generator_variance_energy_loss",
float(energy_loss))
losses_dict["generator_loss"] = float(gen_loss)
losses_dict["generator_generator_loss"] = float(g_loss)
losses_dict["generator_variance_loss"] = float(var_loss)
losses_dict["generator_generator_mel_loss"] = float(mel_loss)
losses_dict["generator_generator_adv_loss"] = float(adv_loss)
losses_dict["generator_generator_feat_match_loss"] = float(
feat_match_loss)
losses_dict["generator_variance_dur_loss"] = float(dur_loss)
losses_dict["generator_variance_pitch_loss"] = float(pitch_loss)
losses_dict["generator_variance_energy_loss"] = float(
energy_loss)
if self.use_alignment_module == True:
forwardsum_loss = self.criterion_forwardsum(
log_p_attn, text_lengths, feats_lengths)
align_loss = (
forwardsum_loss + bin_loss) * self.lambda_align
report("train/generator_alignment_loss", float(align_loss))
report("train/generator_alignment_forwardsum_loss",
float(forwardsum_loss))
report("train/generator_alignment_bin_loss",
float(bin_loss))
losses_dict["generator_alignment_loss"] = float(align_loss)
losses_dict["generator_alignment_forwardsum_loss"] = float(
forwardsum_loss)
losses_dict["generator_alignment_bin_loss"] = float(
bin_loss)
self.optimizer_g.clear_grad()
gen_loss.backward()
self.optimizer_g.step()
self.scheduler_g.step()
# reset cache
if self.model.reuse_cache_gen or not self.model.training:
self.model._cache = None
# Disctiminator
elif turn == "discriminator":
# parse outputs
speech_hat_, _, _, start_idxs, *_ = outs
speech_ = get_segments(
x=speech,
start_idxs=start_idxs *
self.model.generator.upsample_factor,
segment_size=self.model.generator.segment_size *
self.model.generator.upsample_factor, )
# calculate discriminator outputs
p_hat = self.model.discriminator(speech_hat_.detach())
p = self.model.discriminator(speech_)
# calculate losses
real_loss, fake_loss = self.criterion_dis_adv(p_hat, p)
dis_loss = real_loss + fake_loss
report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss))
report("train/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.optimizer_d.clear_grad()
dis_loss.backward()
self.optimizer_d.step()
self.scheduler_d.step()
# reset cache
if self.model.reuse_cache_dis or not self.model.training:
self.model._cache = None
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class JETSEvaluator(StandardEvaluator):
def __init__(self,
model,
criterions: Dict[str, Layer],
dataloader: DataLoader,
lambda_adv: float=1.0,
lambda_mel: float=45.0,
lambda_feat_match: float=2.0,
lambda_var: float=1.0,
lambda_align: float=2.0,
generator_first: bool=False,
use_alignment_module: bool=False,
output_dir=None):
# 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分
models = {"main": model}
self.models: Dict[str, Layer] = models
# self.model = model
self.model = model._layers if isinstance(model,
paddle.DataParallel) else model
self.criterions = criterions
self.criterion_mel = criterions['mel']
self.criterion_feat_match = criterions['feat_match']
self.criterion_gen_adv = criterions["gen_adv"]
self.criterion_dis_adv = criterions["dis_adv"]
self.criterion_var = criterions["var"]
self.criterion_forwardsum = criterions["forwardsum"]
self.dataloader = dataloader
self.lambda_adv = lambda_adv
self.lambda_mel = lambda_mel
self.lambda_feat_match = lambda_feat_match
self.lambda_var = lambda_var
self.lambda_align = lambda_align
self.use_alignment_module = use_alignment_module
if generator_first:
self.turns = ["generator", "discriminator"]
else:
self.turns = ["discriminator", "generator"]
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
# logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
for turn in self.turns:
speech = batch["speech"]
speech = speech.unsqueeze(1)
text_lengths = batch["text_lengths"]
feats_lengths = batch["feats_lengths"]
outs = self.model(
text=batch["text"],
text_lengths=batch["text_lengths"],
feats=batch["feats"],
feats_lengths=batch["feats_lengths"],
durations=batch["durations"],
durations_lengths=batch["durations_lengths"],
pitch=batch["pitch"],
energy=batch["energy"],
sids=batch.get("spk_id", None),
spembs=batch.get("spk_emb", None),
forward_generator=turn == "generator",
use_alignment_module=self.use_alignment_module)
# Generator
if turn == "generator":
# parse outputs
speech_hat_, bin_loss, log_p_attn, start_idxs, d_outs, ds, p_outs, ps, e_outs, es = outs
speech_ = get_segments(
x=speech,
start_idxs=start_idxs *
self.model.generator.upsample_factor,
segment_size=self.model.generator.segment_size *
self.model.generator.upsample_factor, )
# calculate discriminator outputs
p_hat = self.model.discriminator(speech_hat_)
with paddle.no_grad():
# do not store discriminator gradient in generator turn
p = self.model.discriminator(speech_)
# calculate losses
mel_loss = self.criterion_mel(speech_hat_, speech_)
adv_loss = self.criterion_gen_adv(p_hat)
feat_match_loss = self.criterion_feat_match(p_hat, p)
dur_loss, pitch_loss, energy_loss = self.criterion_var(
d_outs, ds, p_outs, ps, e_outs, es, text_lengths)
mel_loss = mel_loss * self.lambda_mel
adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match
g_loss = mel_loss + adv_loss + feat_match_loss
var_loss = (
dur_loss + pitch_loss + energy_loss) * self.lambda_var
gen_loss = g_loss + var_loss #+ align_loss
report("eval/generator_loss", float(gen_loss))
report("eval/generator_generator_loss", float(g_loss))
report("eval/generator_variance_loss", float(var_loss))
report("eval/generator_generator_mel_loss", float(mel_loss))
report("eval/generator_generator_adv_loss", float(adv_loss))
report("eval/generator_generator_feat_match_loss",
float(feat_match_loss))
report("eval/generator_variance_dur_loss", float(dur_loss))
report("eval/generator_variance_pitch_loss", float(pitch_loss))
report("eval/generator_variance_energy_loss",
float(energy_loss))
losses_dict["generator_loss"] = float(gen_loss)
losses_dict["generator_generator_loss"] = float(g_loss)
losses_dict["generator_variance_loss"] = float(var_loss)
losses_dict["generator_generator_mel_loss"] = float(mel_loss)
losses_dict["generator_generator_adv_loss"] = float(adv_loss)
losses_dict["generator_generator_feat_match_loss"] = float(
feat_match_loss)
losses_dict["generator_variance_dur_loss"] = float(dur_loss)
losses_dict["generator_variance_pitch_loss"] = float(pitch_loss)
losses_dict["generator_variance_energy_loss"] = float(
energy_loss)
if self.use_alignment_module == True:
forwardsum_loss = self.criterion_forwardsum(
log_p_attn, text_lengths, feats_lengths)
align_loss = (
forwardsum_loss + bin_loss) * self.lambda_align
report("eval/generator_alignment_loss", float(align_loss))
report("eval/generator_alignment_forwardsum_loss",
float(forwardsum_loss))
report("eval/generator_alignment_bin_loss", float(bin_loss))
losses_dict["generator_alignment_loss"] = float(align_loss)
losses_dict["generator_alignment_forwardsum_loss"] = float(
forwardsum_loss)
losses_dict["generator_alignment_bin_loss"] = float(
bin_loss)
# reset cache
if self.model.reuse_cache_gen or not self.model.training:
self.model._cache = None
# Disctiminator
elif turn == "discriminator":
# parse outputs
speech_hat_, _, _, start_idxs, *_ = outs
speech_ = get_segments(
x=speech,
start_idxs=start_idxs *
self.model.generator.upsample_factor,
segment_size=self.model.generator.segment_size *
self.model.generator.upsample_factor, )
# calculate discriminator outputs
p_hat = self.model.discriminator(speech_hat_.detach())
p = self.model.discriminator(speech_)
# calculate losses
real_loss, fake_loss = self.criterion_dis_adv(p_hat, p)
dis_loss = real_loss + fake_loss
report("eval/real_loss", float(real_loss))
report("eval/fake_loss", float(fake_loss))
report("eval/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
# reset cache
if self.model.reuse_cache_dis or not self.model.training:
self.model._cache = None
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)