add linear norm

pull/3005/head
lym0302 3 years ago
parent 4ecc752e38
commit 9df1294935

@ -270,10 +270,10 @@ class DiffSinger(nn.Layer):
mel_fs2 = mel_fs2.unsqueeze(0).transpose((0, 2, 1)) mel_fs2 = mel_fs2.unsqueeze(0).transpose((0, 2, 1))
cond_fs2 = self.fs2.encoder_infer(text, note, note_dur, is_slur) cond_fs2 = self.fs2.encoder_infer(text, note, note_dur, is_slur)
cond_fs2 = cond_fs2.transpose((0, 2, 1)) cond_fs2 = cond_fs2.transpose((0, 2, 1))
# mel, _ = self.diffusion(mel_fs2, cond_fs2) mel, _ = self.diffusion(mel_fs2, cond_fs2)
noise = paddle.randn(mel_fs2.shape) # noise = paddle.randn(mel_fs2.shape)
mel = self.diffusion.inference( # mel = self.diffusion.inference(
noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100) # noise=noise, cond=cond_fs2, ref_x=mel_fs2, num_inference_steps=100)
mel = mel.transpose((0, 2, 1)) mel = mel.transpose((0, 2, 1))
return mel[0] return mel[0]

@ -17,6 +17,7 @@ from typing import Callable
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
import numpy as np
import paddle import paddle
import ppdiffusers import ppdiffusers
from paddle import nn from paddle import nn
@ -315,8 +316,46 @@ class GaussianDiffusion(nn.Layer):
beta_end=beta_end, beta_end=beta_end,
beta_schedule=beta_schedule) beta_schedule=beta_schedule)
self.num_max_timesteps = num_max_timesteps self.num_max_timesteps = num_max_timesteps
self.spec_min = paddle.to_tensor(
np.array([
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0,
-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0
]))
self.spec_max = paddle.to_tensor(
np.array([
-0.79453, -0.81116, -0.61631, -0.30679, -0.13863, -0.050652,
-0.11563, -0.10679, -0.091068, -0.062174, -0.075302, -0.072217,
-0.063815, -0.073299, 0.007361, -0.072508, -0.050234, -0.16534,
-0.26928, -0.20782, -0.20823, -0.11702, -0.070128, -0.065868,
-0.012675, 0.0015121, -0.089902, -0.21392, -0.23789, -0.28922,
-0.30405, -0.23029, -0.22088, -0.21542, -0.29367, -0.30137,
-0.38281, -0.4359, -0.28681, -0.46855, -0.57485, -0.47022,
-0.54266, -0.44848, -0.6412, -0.687, -0.6486, -0.76436,
-0.49971, -0.71068, -0.69724, -0.61487, -0.55843, -0.69773,
-0.57502, -0.70919, -0.82431, -0.84213, -0.90431, -0.8284,
-0.77945, -0.82758, -0.87699, -1.0532, -1.0766, -1.1198,
-1.0185, -0.98983, -1.0001, -1.0756, -1.0024, -1.0304, -1.0579,
-1.0188, -1.05, -1.0842, -1.0923, -1.1223, -1.2381, -1.6467
]))
def norm_spec(self, x):
"""
Linearly map x to [-1, 1]
Args:
x: [B, T, N]
"""
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None def denorm_spec(self, x):
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None, is_infer: bool=False,
) -> Tuple[paddle.Tensor, paddle.Tensor]: ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Generate random timesteps noised x. """Generate random timesteps noised x.
@ -333,6 +372,11 @@ class GaussianDiffusion(nn.Layer):
The noises which is added to the input. The noises which is added to the input.
""" """
# print("xxxxxxxxxxxxxxxx1: ", x, x.shape)
x = x.transpose((0, 2, 1))
x = self.norm_spec(x)
x = x.transpose((0, 2, 1))
print("xxxxxxxxxxxxxxxx2: ", x, x.shape)
noise_scheduler = self.noise_scheduler noise_scheduler = self.noise_scheduler
# Sample noise that we'll add to the mel-spectrograms # Sample noise that we'll add to the mel-spectrograms
@ -350,6 +394,13 @@ class GaussianDiffusion(nn.Layer):
y = self.denoiser(noisy_images, timesteps, cond) y = self.denoiser(noisy_images, timesteps, cond)
if is_infer:
y = y.transpose((0, 2, 1))
y = self.denorm_spec(y)
y = y.transpose((0, 2, 1))
# y = self.denorm_spec(y)
# then compute loss use output y and noisy target for prediction_type == "epsilon" # then compute loss use output y and noisy target for prediction_type == "epsilon"
return y, target return y, target
@ -360,7 +411,7 @@ class GaussianDiffusion(nn.Layer):
num_inference_steps: Optional[int]=1000, num_inference_steps: Optional[int]=1000,
strength: Optional[float]=None, strength: Optional[float]=None,
scheduler_type: Optional[str]="ddpm", scheduler_type: Optional[str]="ddpm",
clip_noise: Optional[bool]=True, clip_noise: Optional[bool]=False,
clip_noise_range: Optional[Tuple[float, float]]=(-1, 1), clip_noise_range: Optional[Tuple[float, float]]=(-1, 1),
callback: Optional[Callable[[int, int, int, paddle.Tensor], callback: Optional[Callable[[int, int, int, paddle.Tensor],
None]]=None, None]]=None,
@ -426,10 +477,12 @@ class GaussianDiffusion(nn.Layer):
scheduler.set_timesteps(num_inference_steps) scheduler.set_timesteps(num_inference_steps)
# prepare first noise variables # prepare first noise variables
import pdb;pdb.set_trace()
noisy_input = noise noisy_input = noise
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
if ref_x is not None: if ref_x is not None:
ref_x = ref_x.transpose((0, 2, 1))
ref_x = self.norm_spec(ref_x)
ref_x = ref_x.transpose((0, 2, 1))
init_timestep = None init_timestep = None
if strength is None or strength < 0. or strength > 1.: if strength is None or strength < 0. or strength > 1.:
strength = None strength = None
@ -445,8 +498,6 @@ class GaussianDiffusion(nn.Layer):
noisy_input = scheduler.add_noise( noisy_input = scheduler.add_noise(
ref_x, noise, timesteps[:1].tile([noise.shape[0]])) ref_x, noise, timesteps[:1].tile([noise.shape[0]]))
# denoising loop # denoising loop
denoised_output = noisy_input denoised_output = noisy_input
if clip_noise: if clip_noise:
@ -472,4 +523,10 @@ class GaussianDiffusion(nn.Layer):
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
callback(i, t, len(timesteps), denoised_output) callback(i, t, len(timesteps), denoised_output)
denoised_output = denoised_output.transpose((0, 2, 1))
denoised_output = self.denorm_spec(denoised_output)
denoised_output = denoised_output.transpose((0, 2, 1))
return denoised_output return denoised_output
Loading…
Cancel
Save