fix diffsinger loss target to noisy_mel

pull/3005/head
HighCWu 3 years ago
parent 84a22ffb93
commit def9d64f79

@ -236,8 +236,9 @@ class DiffSinger(nn.Layer):
cond_fs2 = cond_fs2.transpose((0, 2, 1))
# get the output(final mel) from diffusion module
mel = self.diffusion(speech.transpose((0, 2, 1)), cond_fs2.detach())
return mel[0], mel_masks
mel, mel_ref = self.diffusion(
speech.transpose((0, 2, 1)), cond_fs2.detach())
return mel, mel_ref, mel_masks
def inference(
self,
@ -271,7 +272,8 @@ class DiffSinger(nn.Layer):
cond_fs2 = cond_fs2.transpose((0, 2, 1))
# mel, _ = self.diffusion(mel_fs2, cond_fs2)
noise = paddle.randn(mel_fs2.shape)
mel = self.diffusion.inference(noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100)
mel = self.diffusion.inference(
noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100)
mel = mel.transpose((0, 2, 1))
return mel[0]

@ -16,7 +16,6 @@ from pathlib import Path
from typing import Dict
import paddle
from paddle import DataParallel
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
@ -110,8 +109,8 @@ class DiffSingerUpdater(StandardUpdater):
olens=olens,
spk_logits=spk_logits,
spk_ids=spk_id, )
loss_fs2 = l1_loss_fs2 + ssim_loss_fs2 + duration_loss + pitch_loss + energy_loss
loss_fs2 = l1_loss_fs2 + ssim_loss_fs2 + duration_loss + pitch_loss + energy_loss
self.optimizer_fs2.clear_grad()
loss_fs2.backward()
@ -140,7 +139,7 @@ class DiffSingerUpdater(StandardUpdater):
for param in self.model.fs2.parameters():
param.trainable = False
mel, mel_masks = self.model(
mel, mel_ref, mel_masks = self.model(
text=batch["text"],
note=batch["note"],
note_dur=batch["note_dur"],
@ -156,9 +155,10 @@ class DiffSingerUpdater(StandardUpdater):
train_fs2=False, )
mel = mel.transpose((0, 2, 1))
mel_ref = mel_ref.transpose((0, 2, 1))
mel_masks = mel_masks.transpose((0, 2, 1))
l1_loss_ds = self.criterion_ds(
ref_mels=batch["speech"],
ref_mels=mel_ref,
out_mels=mel,
mel_masks=mel_masks, )
@ -238,6 +238,6 @@ class DiffSingerEvaluator(StandardEvaluator):
losses_dict["l1_loss_ds"] = float(l1_loss_ds)
losses_dict["loss_ds"] = float(loss_ds)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
for k, v in losses_dict.items())
self.logger.info(self.msg)

Loading…
Cancel
Save