Merge pull request #3 from HighCWu/diffsinger_tmp

fix diffsinger loss target to noisy_mel
pull/3005/head
liangym 3 years ago committed by GitHub
commit 9e8bd9f4e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -236,8 +236,9 @@ class DiffSinger(nn.Layer):
cond_fs2 = cond_fs2.transpose((0, 2, 1)) cond_fs2 = cond_fs2.transpose((0, 2, 1))
# get the output(final mel) from diffusion module # get the output(final mel) from diffusion module
mel = self.diffusion(speech.transpose((0, 2, 1)), cond_fs2.detach()) mel, mel_ref = self.diffusion(
return mel[0], mel_masks speech.transpose((0, 2, 1)), cond_fs2.detach())
return mel, mel_ref, mel_masks
def inference( def inference(
self, self,
@ -271,7 +272,8 @@ class DiffSinger(nn.Layer):
cond_fs2 = cond_fs2.transpose((0, 2, 1)) cond_fs2 = cond_fs2.transpose((0, 2, 1))
# mel, _ = self.diffusion(mel_fs2, cond_fs2) # mel, _ = self.diffusion(mel_fs2, cond_fs2)
noise = paddle.randn(mel_fs2.shape) noise = paddle.randn(mel_fs2.shape)
mel = self.diffusion.inference(noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100) mel = self.diffusion.inference(
noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100)
mel = mel.transpose((0, 2, 1)) mel = mel.transpose((0, 2, 1))
return mel[0] return mel[0]

@ -16,7 +16,6 @@ from pathlib import Path
from typing import Dict from typing import Dict
import paddle import paddle
from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.nn import Layer from paddle.nn import Layer
@ -140,7 +139,7 @@ class DiffSingerUpdater(StandardUpdater):
for param in self.model.fs2.parameters(): for param in self.model.fs2.parameters():
param.trainable = False param.trainable = False
mel, mel_masks = self.model( mel, mel_ref, mel_masks = self.model(
text=batch["text"], text=batch["text"],
note=batch["note"], note=batch["note"],
note_dur=batch["note_dur"], note_dur=batch["note_dur"],
@ -156,9 +155,10 @@ class DiffSingerUpdater(StandardUpdater):
train_fs2=False, ) train_fs2=False, )
mel = mel.transpose((0, 2, 1)) mel = mel.transpose((0, 2, 1))
mel_ref = mel_ref.transpose((0, 2, 1))
mel_masks = mel_masks.transpose((0, 2, 1)) mel_masks = mel_masks.transpose((0, 2, 1))
l1_loss_ds = self.criterion_ds( l1_loss_ds = self.criterion_ds(
ref_mels=batch["speech"], ref_mels=mel_ref,
out_mels=mel, out_mels=mel,
mel_masks=mel_masks, ) mel_masks=mel_masks, )
@ -238,6 +238,6 @@ class DiffSingerEvaluator(StandardEvaluator):
losses_dict["l1_loss_ds"] = float(l1_loss_ds) losses_dict["l1_loss_ds"] = float(l1_loss_ds)
losses_dict["loss_ds"] = float(loss_ds) losses_dict["loss_ds"] = float(loss_ds)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v) self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items()) for k, v in losses_dict.items())
self.logger.info(self.msg) self.logger.info(self.msg)

Loading…
Cancel
Save