From cb8c99405803bdcc2a7f91a27f3aeb567e313378 Mon Sep 17 00:00:00 2001 From: TianYuan Date: Sun, 23 Apr 2023 07:35:07 +0000 Subject: [PATCH] fix loss bug --- paddlespeech/t2s/datasets/am_batch_fn.py | 2 -- paddlespeech/t2s/models/starganv2_vc/losses.py | 10 +++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 85959aa25..fe5d977a5 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -844,9 +844,7 @@ class StarGANv2VCCollateFn: mel = [self.random_clip(item["mel"]) for item in examples] ref_mel = [self.random_clip(item["ref_mel"]) for item in examples] ref_mel_2 = [self.random_clip(item["ref_mel_2"]) for item in examples] - print("mel[0].shape after batch_sequences:", mel[0].shape) mel = batch_sequences(mel) - print("mel.shape after batch_sequences:", mel.shape) ref_mel = batch_sequences(ref_mel) ref_mel_2 = batch_sequences(ref_mel_2) diff --git a/paddlespeech/t2s/models/starganv2_vc/losses.py b/paddlespeech/t2s/models/starganv2_vc/losses.py index 145344676..d94c9342a 100644 --- a/paddlespeech/t2s/models/starganv2_vc/losses.py +++ b/paddlespeech/t2s/models/starganv2_vc/losses.py @@ -19,9 +19,9 @@ import paddle.nn.functional as F from .transforms import build_transforms - # 这些都写到 updater 里 + def compute_d_loss( nets: Dict[str, Any], x_real: paddle.Tensor, @@ -121,10 +121,10 @@ def compute_g_loss(nets: Dict[str, Any], s_trg = nets['style_encoder'](x_ref, y_trg) # compute ASR/F0 features (real) - with paddle.no_grad(): - print("x_real.shape:", x_real.shape) - F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real) - ASR_real = nets['asr_model'].get_feature(x_real) + # 源码没有用 .eval(), 使用了 no_grad() + # 我们使用了 .eval(), 开启 with paddle.no_grad() 会报错 + F0_real, GAN_F0_real, cyc_F0_real = nets['F0_model'](x_real) + ASR_real = nets['asr_model'].get_feature(x_real) # adversarial loss x_fake = nets['generator'](x_real, s_trg, masks=None, F0=GAN_F0_real)