|
|
@ -19,9 +19,9 @@ import paddle.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
from .transforms import build_transforms
|
|
|
|
from .transforms import build_transforms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 这些都写到 updater 里
|
|
|
|
# 这些都写到 updater 里
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_d_loss(
|
|
|
|
def compute_d_loss(
|
|
|
|
nets: Dict[str, Any],
|
|
|
|
nets: Dict[str, Any],
|
|
|
|
x_real: paddle.Tensor,
|
|
|
|
x_real: paddle.Tensor,
|
|
|
@ -121,8 +121,8 @@ def compute_g_loss(nets: Dict[str, Any],
|
|
|
|
s_trg = nets['style_encoder'](x_ref, y_trg)
|
|
|
|
s_trg = nets['style_encoder'](x_ref, y_trg)
|
|
|
|
|
|
|
|
|
|
|
|
# compute ASR/F0 features (real)
|
|
|
|
# compute ASR/F0 features (real)
|
|
|
|
with paddle.no_grad():
|
|
|
|
# 源码没有用 .eval(), 使用了 no_grad()
|
|
|
|
print("x_real.shape:", x_real.shape)
|
|
|
|
# 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错
|
|
|
|
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
|
|
|
|
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
|
|
|
|
ASR_real = nets['asr_model'].get_feature(x_real)
|
|
|
|
ASR_real = nets['asr_model'].get_feature(x_real)
|
|
|
|
|
|
|
|
|
|
|
|