|
|
|
@ -21,33 +21,35 @@ from .transforms import build_transforms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 这些都写到 updater 里
|
|
|
|
|
def compute_d_loss(nets: Dict[str, Any],
|
|
|
|
|
x_real: paddle.Tensor,
|
|
|
|
|
y_org: paddle.Tensor,
|
|
|
|
|
y_trg: paddle.Tensor,
|
|
|
|
|
z_trg: paddle.Tensor=None,
|
|
|
|
|
x_ref: paddle.Tensor=None,
|
|
|
|
|
use_r1_reg: bool=True,
|
|
|
|
|
use_adv_cls: bool=False,
|
|
|
|
|
use_con_reg: bool=False,
|
|
|
|
|
lambda_reg: float=1.,
|
|
|
|
|
lambda_adv_cls: float=0.1,
|
|
|
|
|
lambda_con_reg: float=10.):
|
|
|
|
|
def compute_d_loss(
|
|
|
|
|
nets: Dict[str, Any],
|
|
|
|
|
x_real: paddle.Tensor,
|
|
|
|
|
y_org: paddle.Tensor,
|
|
|
|
|
y_trg: paddle.Tensor,
|
|
|
|
|
z_trg: paddle.Tensor=None,
|
|
|
|
|
x_ref: paddle.Tensor=None,
|
|
|
|
|
# TODO: should be True here, but r1_reg has some bug now
|
|
|
|
|
use_r1_reg: bool=False,
|
|
|
|
|
use_adv_cls: bool=False,
|
|
|
|
|
use_con_reg: bool=False,
|
|
|
|
|
lambda_reg: float=1.,
|
|
|
|
|
lambda_adv_cls: float=0.1,
|
|
|
|
|
lambda_con_reg: float=10.):
|
|
|
|
|
|
|
|
|
|
assert (z_trg is None) != (x_ref is None)
|
|
|
|
|
# with real audios
|
|
|
|
|
x_real.stop_gradient = False
|
|
|
|
|
out = nets['discriminator'](x_real, y_org)
|
|
|
|
|
loss_real = adv_loss(out, 1)
|
|
|
|
|
|
|
|
|
|
# R1 regularizaition (https://arxiv.org/abs/1801.04406v4)
|
|
|
|
|
if use_r1_reg:
|
|
|
|
|
loss_reg = r1_reg(out, x_real)
|
|
|
|
|
else:
|
|
|
|
|
loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
|
|
|
|
|
# loss_reg = paddle.to_tensor([0.], dtype=paddle.float32)
|
|
|
|
|
loss_reg = paddle.zeros([1])
|
|
|
|
|
|
|
|
|
|
# consistency regularization (bCR-GAN: https://arxiv.org/abs/2002.04724)
|
|
|
|
|
loss_con_reg = paddle.to_tensor([0.], dtype=paddle.float32)
|
|
|
|
|
loss_con_reg = paddle.zeros([1])
|
|
|
|
|
if use_con_reg:
|
|
|
|
|
t = build_transforms()
|
|
|
|
|
out_aug = nets['discriminator'](t(x_real).detach(), y_org)
|
|
|
|
@ -119,6 +121,7 @@ def compute_g_loss(nets: Dict[str, Any],
|
|
|
|
|
|
|
|
|
|
# compute ASR/F0 features (real)
|
|
|
|
|
with paddle.no_grad():
|
|
|
|
|
print("x_real.shape:", x_real.shape)
|
|
|
|
|
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
|
|
|
|
|
ASR_real = nets['asr_model'].get_feature(x_real)
|
|
|
|
|
|
|
|
|
|