add linear norm

pull/3005/head
lym0302 3 years ago
parent 4ecc752e38
commit 9df1294935

@ -270,10 +270,10 @@ class DiffSinger(nn.Layer):
mel_fs2 = mel_fs2.unsqueeze(0).transpose((0, 2, 1))
cond_fs2 = self.fs2.encoder_infer(text, note, note_dur, is_slur)
cond_fs2 = cond_fs2.transpose((0, 2, 1))
# mel, _ = self.diffusion(mel_fs2, cond_fs2)
noise = paddle.randn(mel_fs2.shape)
mel = self.diffusion.inference(
noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100)
mel, _ = self.diffusion(mel_fs2, cond_fs2)
# noise = paddle.randn(mel_fs2.shape)
# mel = self.diffusion.inference(
# noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100)
mel = mel.transpose((0, 2, 1))
return mel[0]

@ -17,6 +17,7 @@ from typing import Callable
from typing import Optional
from typing import Tuple
import numpy as np
import paddle
import ppdiffusers
from paddle import nn
@ -315,8 +316,46 @@ class GaussianDiffusion(nn.Layer):
beta_end=beta_end,
beta_schedule=beta_schedule)
self.num_max_timesteps = num_max_timesteps
self.spec_min = paddle.to_tensor(
np.array([
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0
]))
self.spec_max = paddle.to_tensor(
np.array([
-0.79453, -0.81116, -0.61631, -0.30679, -0.13863, -0.050652,
-0.11563, -0.10679, -0.091068, -0.062174, -0.075302, -0.072217,
-0.063815, -0.073299, 0.007361, -0.072508, -0.050234, -0.16534,
-0.26928, -0.20782, -0.20823, -0.11702, -0.070128, -0.065868,
-0.012675, 0.0015121, -0.089902, -0.21392, -0.23789, -0.28922,
-0.30405, -0.23029, -0.22088, -0.21542, -0.29367, -0.30137,
-0.38281, -0.4359, -0.28681, -0.46855, -0.57485, -0.47022,
-0.54266, -0.44848, -0.6412, -0.687, -0.6486, -0.76436,
-0.49971, -0.71068, -0.69724, -0.61487, -0.55843, -0.69773,
-0.57502, -0.70919, -0.82431, -0.84213, -0.90431, -0.8284,
-0.77945, -0.82758, -0.87699, -1.0532, -1.0766, -1.1198,
-1.0185, -0.98983, -1.0001, -1.0756, -1.0024, -1.0304, -1.0579,
-1.0188, -1.05, -1.0842, -1.0923, -1.1223, -1.2381, -1.6467
]))
def norm_spec(self, x):
"""
Linearly map x to [-1, 1]
Args:
x: [B, T, N]
"""
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None
def denorm_spec(self, x):
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None, is_infer: bool=False,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Generate random timesteps noised x.
@ -333,6 +372,11 @@ class GaussianDiffusion(nn.Layer):
The noises which is added to the input.
"""
# print("xxxxxxxxxxxxxxxx1: ", x, x.shape)
x = x.transpose((0, 2, 1))
x = self.norm_spec(x)
x = x.transpose((0, 2, 1))
print("xxxxxxxxxxxxxxxx2: ", x, x.shape)
noise_scheduler = self.noise_scheduler
# Sample noise that we'll add to the mel-spectrograms
@ -349,6 +393,13 @@ class GaussianDiffusion(nn.Layer):
noisy_images = noise_scheduler.add_noise(x, noise, timesteps)
y = self.denoiser(noisy_images, timesteps, cond)
if is_infer:
y = y.transpose((0, 2, 1))
y = self.denorm_spec(y)
y = y.transpose((0, 2, 1))
# y = self.denorm_spec(y)
# then compute loss use output y and noisy target for prediction_type == "epsilon"
return y, target
@ -360,7 +411,7 @@ class GaussianDiffusion(nn.Layer):
num_inference_steps: Optional[int]=1000,
strength: Optional[float]=None,
scheduler_type: Optional[str]="ddpm",
clip_noise: Optional[bool]=True,
clip_noise: Optional[bool]=False,
clip_noise_range: Optional[Tuple[float, float]]=(-1, 1),
callback: Optional[Callable[[int, int, int, paddle.Tensor],
None]]=None,
@ -426,10 +477,12 @@ class GaussianDiffusion(nn.Layer):
scheduler.set_timesteps(num_inference_steps)
# prepare first noise variables
import pdb;pdb.set_trace()
noisy_input = noise
timesteps = scheduler.timesteps
if ref_x is not None:
if ref_x is not None:
ref_x = ref_x.transpose((0, 2, 1))
ref_x = self.norm_spec(ref_x)
ref_x = ref_x.transpose((0, 2, 1))
init_timestep = None
if strength is None or strength < 0. or strength > 1.:
strength = None
@ -445,8 +498,6 @@ class GaussianDiffusion(nn.Layer):
noisy_input = scheduler.add_noise(
ref_x, noise, timesteps[:1].tile([noise.shape[0]]))
# denoising loop
denoised_output = noisy_input
if clip_noise:
@ -471,5 +522,11 @@ class GaussianDiffusion(nn.Layer):
(i + 1) % scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
callback(i, t, len(timesteps), denoised_output)
denoised_output = denoised_output.transpose((0, 2, 1))
denoised_output = self.denorm_spec(denoised_output)
denoised_output = denoised_output.transpose((0, 2, 1))
return denoised_output
return denoised_output

Loading…
Cancel
Save