|
|
@ -19,15 +19,18 @@ import paddle.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
from .transforms import build_transforms
|
|
|
|
from .transforms import build_transforms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 这些都写到 updater 里
|
|
|
|
# 这些都写到 updater 里
|
|
|
|
def compute_d_loss(nets: Dict[str, Any],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_d_loss(
|
|
|
|
|
|
|
|
nets: Dict[str, Any],
|
|
|
|
x_real: paddle.Tensor,
|
|
|
|
x_real: paddle.Tensor,
|
|
|
|
y_org: paddle.Tensor,
|
|
|
|
y_org: paddle.Tensor,
|
|
|
|
y_trg: paddle.Tensor,
|
|
|
|
y_trg: paddle.Tensor,
|
|
|
|
z_trg: paddle.Tensor=None,
|
|
|
|
z_trg: paddle.Tensor=None,
|
|
|
|
x_ref: paddle.Tensor=None,
|
|
|
|
x_ref: paddle.Tensor=None,
|
|
|
|
use_r1_reg: bool=True,
|
|
|
|
# TODO: should be True here, but r1_reg has some bug now
|
|
|
|
|
|
|
|
use_r1_reg: bool=False,
|
|
|
|
use_adv_cls: bool=False,
|
|
|
|
use_adv_cls: bool=False,
|
|
|
|
use_con_reg: bool=False,
|
|
|
|
use_con_reg: bool=False,
|
|
|
|
lambda_reg: float=1.,
|
|
|
|
lambda_reg: float=1.,
|
|
|
@ -39,15 +42,15 @@ def compute_d_loss(nets: Dict[str, Any],
|
|
|
|
x_real.stop_gradient = False
|
|
|
|
x_real.stop_gradient = False
|
|
|
|
out = nets['discriminator'](x_real, y_org)
|
|
|
|
out = nets['discriminator'](x_real, y_org)
|
|
|
|
loss_real = adv_loss(out, 1)
|
|
|
|
loss_real = adv_loss(out, 1)
|
|
|
|
|
|
|
|
|
|
|
|
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
|
|
|
|
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
|
|
|
|
if use_r1_reg:
|
|
|
|
if use_r1_reg:
|
|
|
|
loss_reg = r1_reg(out, x_real)
|
|
|
|
loss_reg = r1_reg(out, x_real)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
|
|
|
|
# loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
|
|
|
|
|
|
|
|
loss_reg = paddle.zeros([1])
|
|
|
|
|
|
|
|
|
|
|
|
# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
|
|
|
|
# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
|
|
|
|
loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32)
|
|
|
|
loss_con_reg = paddle.zeros([1])
|
|
|
|
if use_con_reg:
|
|
|
|
if use_con_reg:
|
|
|
|
t = build_transforms()
|
|
|
|
t = build_transforms()
|
|
|
|
out_aug = nets['discriminator'](t(x_real).detach(), y_org)
|
|
|
|
out_aug = nets['discriminator'](t(x_real).detach(), y_org)
|
|
|
@ -118,7 +121,8 @@ def compute_g_loss(nets: Dict[str, Any],
|
|
|
|
s_trg = nets['style_encoder'](x_ref, y_trg)
|
|
|
|
s_trg = nets['style_encoder'](x_ref, y_trg)
|
|
|
|
|
|
|
|
|
|
|
|
# compute ASR/F0 features (real)
|
|
|
|
# compute ASR/F0 features (real)
|
|
|
|
with paddle.no_grad():
|
|
|
|
# 源码没有用 .eval(), 使用了 no_grad()
|
|
|
|
|
|
|
|
# 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错
|
|
|
|
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
|
|
|
|
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
|
|
|
|
ASR_real = nets['asr_model'].get_feature(x_real)
|
|
|
|
ASR_real = nets['asr_model'].get_feature(x_real)
|
|
|
|
|
|
|
|
|
|
|
|