258 lines
8.7 KiB

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
import paddle
import paddle.nn.functional as F
from .transforms import build_transforms
# 这些都写到 updater 里
def compute_d_loss(
nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
# TODO: should be True here, but r1_reg has some bug now
use_r1_reg: bool=False,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):
assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False
out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1)
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
if use_r1_reg:
loss_reg = r1_reg(out, x_real)
else:
# loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
loss_reg = paddle.zeros([1])
# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
loss_con_reg = paddle.zeros([1])
if use_con_reg:
t = build_transforms()
out_aug = nets['discriminator'](t(x_real).detach(), y_org)
loss_con_reg += F.smooth_l1_loss(out, out_aug)
# with fake audios
with paddle.no_grad():
if z_trg is not None:
s_trg = nets['mapping_network'](z_trg, y_trg)
else: # x_ref is not None
s_trg = nets['style_encoder'](x_ref, y_trg)
F0 = nets['F0_model'].get_feature_GAN(x_real)
x_fake = nets['generator'](x_real, s_trg, masks=None, F0=F0)
out = nets['discriminator'](x_fake, y_trg)
loss_fake = adv_loss(out, 0)
if use_con_reg:
out_aug = nets['discriminator'](t(x_fake).detach(), y_trg)
loss_con_reg += F.smooth_l1_loss(out, out_aug)
# adversarial classifier loss
if use_adv_cls:
out_de = nets['discriminator'].classifier(x_fake)
loss_real_adv_cls = F.cross_entropy(out_de[y_org != y_trg],
y_org[y_org != y_trg])
if use_con_reg:
out_de_aug = nets['discriminator'].classifier(t(x_fake).detach())
loss_con_reg += F.smooth_l1_loss(out_de, out_de_aug)
else:
loss_real_adv_cls = paddle.zeros([1]).mean()
loss = loss_real + loss_fake + lambda_reg * loss_reg + \
lambda_adv_cls * loss_real_adv_cls + \
lambda_con_reg * loss_con_reg
return loss
def compute_g_loss(nets: Dict[str, Any],
x_real: paddle.Tensor,
y_org: paddle.Tensor,
y_trg: paddle.Tensor,
z_trgs: paddle.Tensor=None,
x_refs: paddle.Tensor=None,
use_adv_cls: bool=False,
lambda_sty: float=1.,
lambda_cyc: float=5.,
lambda_ds: float=1.,
lambda_norm: float=1.,
lambda_asr: float=10.,
lambda_f0: float=5.,
lambda_f0_sty: float=0.1,
lambda_adv: float=2.,
lambda_adv_cls: float=0.5,
norm_bias: float=0.5):
assert (z_trgs is None) != (x_refs is None)
if z_trgs is not None:
z_trg, z_trg2 = z_trgs
if x_refs is not None:
x_ref, x_ref2 = x_refs
# compute style vectors
if z_trgs is not None:
s_trg = nets['mapping_network'](z_trg, y_trg)
else:
s_trg = nets['style_encoder'](x_ref, y_trg)
# compute ASR/F0 features (real)
# 源码没有用 .eval(), 使用了 no_grad()
# 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real)
# adversarial loss
x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real)
out = nets['discriminator'](x_fake, y_trg)
loss_adv = adv_loss(out, 1)
# compute ASR/F0 features (fake)
F0_fake, GAN_F0_fake, _ = nets['F0_model'](x_fake)
ASR_fake = nets['asr_model'].get_feature(x_fake)
# norm consistency loss
x_fake_norm = log_norm(x_fake)
x_real_norm = log_norm(x_real)
tmp = paddle.abs(x_fake_norm - x_real_norm) - norm_bias
loss_norm = ((paddle.nn.ReLU()(tmp))**2).mean()
# F0 loss
loss_f0 = f0_loss(F0_fake, F0_real)
# style F0 loss (style initialization)
if x_refs is not None and lambda_f0_sty > 0 and not use_adv_cls:
F0_sty, _, _ = nets['F0_model'](x_ref)
loss_f0_sty = F.l1_loss(
compute_mean_f0(F0_fake), compute_mean_f0(F0_sty))
else:
loss_f0_sty = paddle.zeros([1]).mean()
# ASR loss
loss_asr = F.smooth_l1_loss(ASR_fake, ASR_real)
# style reconstruction loss
s_pred = nets['style_encoder'](x_fake, y_trg)
loss_sty = paddle.mean(paddle.abs(s_pred - s_trg))
# diversity sensitive loss
if z_trgs is not None:
s_trg2 = nets['mapping_network'](z_trg2, y_trg)
else:
s_trg2 = nets['style_encoder'](x_ref2, y_trg)
x_fake2 = nets['generator'](x_real, s_trg2, masks=None, F0=GAN_F0_real)
x_fake2 = x_fake2.detach()
_, GAN_F0_fake2, _ = nets['F0_model'](x_fake2)
loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2))
loss_ds += F.smooth_l1_loss(GAN_F0_fake, GAN_F0_fake2.detach())
# cycle-consistency loss
s_org = nets['style_encoder'](x_real, y_org)
x_rec = nets['generator'](x_fake, s_org, masks=None, F0=GAN_F0_fake)
loss_cyc = paddle.mean(paddle.abs(x_rec - x_real))
# F0 loss in cycle-consistency loss
if lambda_f0 > 0:
_, _, cyc_F0_rec = nets['F0_model'](x_rec)
loss_cyc += F.smooth_l1_loss(cyc_F0_rec, cyc_F0_real)
if lambda_asr > 0:
ASR_recon = nets['asr_model'].get_feature(x_rec)
loss_cyc += F.smooth_l1_loss(ASR_recon, ASR_real)
# adversarial classifier loss
if use_adv_cls:
out_de = nets['discriminator'].classifier(x_fake)
loss_adv_cls = F.cross_entropy(out_de[y_org != y_trg],
y_trg[y_org != y_trg])
else:
loss_adv_cls = paddle.zeros([1]).mean()
loss = lambda_adv * loss_adv + lambda_sty * loss_sty \
- lambda_ds * loss_ds + lambda_cyc * loss_cyc \
+ lambda_norm * loss_norm \
+ lambda_asr * loss_asr \
+ lambda_f0 * loss_f0 \
+ lambda_f0_sty * loss_f0_sty \
+ lambda_adv_cls * loss_adv_cls
return loss
# for norm consistency loss
def log_norm(x: paddle.Tensor, mean: float=-4, std: float=4, axis: int=2):
"""
normalized log mel -> mel -> norm -> log(norm)
"""
x = paddle.log(paddle.exp(x * std + mean).norm(axis=axis))
return x
# for adversarial loss
def adv_loss(logits: paddle.Tensor, target: float):
assert target in [1, 0]
if len(logits.shape) > 1:
logits = logits.reshape([-1])
targets = paddle.full_like(logits, fill_value=target)
logits = logits.clip(min=-10, max=10) # prevent nan
loss = F.binary_cross_entropy_with_logits(logits, targets)
return loss
# for R1 regularization loss
def r1_reg(d_out: paddle.Tensor, x_in: paddle.Tensor):
# zero-centered gradient penalty for real images
batch_size = x_in.shape[0]
grad_dout = paddle.grad(
outputs=d_out.sum(),
inputs=x_in,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
grad_dout2 = grad_dout.pow(2)
assert (grad_dout2.shape == x_in.shape)
reg = 0.5 * grad_dout2.reshape((batch_size, -1)).sum(1).mean(0)
return reg
# for F0 consistency loss
def compute_mean_f0(f0: paddle.Tensor):
f0_mean = f0.mean(-1)
f0_mean = f0_mean.expand((f0.shape[-1], f0_mean.shape[0])).transpose(
(1, 0)) # (B, M)
return f0_mean
def f0_loss(x_f0: paddle.Tensor, y_f0: paddle.Tensor):
"""
x.shape = (B, 1, M, L): predict
y.shape = (B, 1, M, L): target
"""
# compute the mean
x_mean = compute_mean_f0(x_f0)
y_mean = compute_mean_f0(y_f0)
loss = F.l1_loss(x_f0 / x_mean, y_f0 / y_mean)
return loss