add typehint

pull/3182/head
TianYuan 2 years ago
parent ec13243ff4
commit b523701867

@ -27,9 +27,9 @@ def compute_d_loss(nets: Dict[str, Any],
y_trg: paddle.Tensor,
z_trg: paddle.Tensor=None,
x_ref: paddle.Tensor=None,
use_r1_reg=True,
use_adv_cls=False,
use_con_reg=False,
use_r1_reg: bool=True,
use_adv_cls: bool=False,
use_con_reg: bool=False,
lambda_reg: float=1.,
lambda_adv_cls: float=0.1,
lambda_con_reg: float=10.):
@ -37,7 +37,6 @@ def compute_d_loss(nets: Dict[str, Any],
assert (z_trg is None) != (x_ref is None)
# with real audios
x_real.stop_gradient = False
out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1)

Loading…
Cancel
Save