fix loss bug

pull/3184/head
TianYuan 2 years ago
parent c54822936c
commit cb8c994058

@ -844,9 +844,7 @@ class StarGANv2VCCollateFn:
mel = [self.random_clip(item["mel"]) for item in examples] mel = [self.random_clip(item["mel"]) for item in examples]
ref_mel = [self.random_clip(item["ref_mel"]) for item in examples] ref_mel = [self.random_clip(item["ref_mel"]) for item in examples]
ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples] ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples]
print("mel[0].shape after batch_sequences:", mel[0].shape)
mel = batch_sequences(mel) mel = batch_sequences(mel)
print("mel.shape after batch_sequences:", mel.shape)
ref_mel = batch_sequences(ref_mel) ref_mel = batch_sequences(ref_mel)
ref_mel_2 = batch_sequences(ref_mel_2) ref_mel_2 = batch_sequences(ref_mel_2)

@ -19,9 +19,9 @@ import paddle.nn.functional as F
from .transforms import build_transforms from .transforms import build_transforms
# 这些都写到 updater 里 # 这些都写到 updater 里
def compute_d_loss( def compute_d_loss(
nets: Dict[str, Any], nets: Dict[str, Any],
x_real: paddle.Tensor, x_real: paddle.Tensor,
@ -121,10 +121,10 @@ def compute_g_loss(nets: Dict[str, Any],
s_trg = nets['style_encoder'](x_ref, y_trg) s_trg = nets['style_encoder'](x_ref, y_trg)
# compute ASR/F0 features (real) # compute ASR/F0 features (real)
with paddle.no_grad(): # 源码没有用 .eval(), 使用了 no_grad()
print("x_real.shape:", x_real.shape) # 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错
F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real) F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real)
ASR_real = nets['asr_model'].get_feature(x_real) ASR_real = nets['asr_model'].get_feature(x_real)
# adversarial loss # adversarial loss
x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real) x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real)

Loading…
Cancel
Save