fix diffsinger loss target to noisy_mel

pull/3005/head
HighCWu 3 years ago
parent 84a22ffb93
commit def9d64f79

@ -236,8 +236,9 @@ class DiffSinger(nn.Layer):
cond_fs2 = cond_fs2.transpose((0, 2, 1)) cond_fs2 = cond_fs2.transpose((0, 2, 1))
# get the output(final mel) from diffusion module # get the output(final mel) from diffusion module
mel = self.diffusion(speech.transpose((0, 2, 1)), cond_fs2.detach()) mel, mel_ref = self.diffusion(
return mel[0], mel_masks speech.transpose((0, 2, 1)), cond_fs2.detach())
return mel, mel_ref, mel_masks
def inference( def inference(
self, self,
@ -271,7 +272,8 @@ class DiffSinger(nn.Layer):
cond_fs2 = cond_fs2.transpose((0, 2, 1)) cond_fs2 = cond_fs2.transpose((0, 2, 1))
# mel, _ = self.diffusion(mel_fs2, cond_fs2) # mel, _ = self.diffusion(mel_fs2, cond_fs2)
noise = paddle.randn(mel_fs2.shape) noise = paddle.randn(mel_fs2.shape)
mel = self.diffusion.inference(noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100) mel = self.diffusion.inference(
noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100)
mel = mel.transpose((0, 2, 1)) mel = mel.transpose((0, 2, 1))
return mel[0] return mel[0]

@ -16,7 +16,6 @@ from pathlib import Path
from typing import Dict from typing import Dict
import paddle import paddle
from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.nn import Layer from paddle.nn import Layer
@ -140,7 +139,7 @@ class DiffSingerUpdater(StandardUpdater):
for param in self.model.fs2.parameters(): for param in self.model.fs2.parameters():
param.trainable = False param.trainable = False
mel, mel_masks = self.model( mel, mel_ref, mel_masks = self.model(
text=batch["text"], text=batch["text"],
note=batch["note"], note=batch["note"],
note_dur=batch["note_dur"], note_dur=batch["note_dur"],
@ -156,9 +155,10 @@ class DiffSingerUpdater(StandardUpdater):
train_fs2=False, ) train_fs2=False, )
mel = mel.transpose((0, 2, 1)) mel = mel.transpose((0, 2, 1))
mel_ref = mel_ref.transpose((0, 2, 1))
mel_masks = mel_masks.transpose((0, 2, 1)) mel_masks = mel_masks.transpose((0, 2, 1))
l1_loss_ds = self.criterion_ds( l1_loss_ds = self.criterion_ds(
ref_mels=batch["speech"], ref_mels=mel_ref,
out_mels=mel, out_mels=mel,
mel_masks=mel_masks, ) mel_masks=mel_masks, )

Loading…
Cancel
Save