From bcd8e309ec3fade62971067de6d5607027c254e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=89=BE=E6=A2=A6?= Date: Thu, 9 Feb 2023 14:58:34 +0800 Subject: [PATCH] [TTS]Add diffusion noise clip to optimize sample result (#2902) * add diffusion module for training diffsinger * add wavenet denoiser final conv initializer * add diffusion noise clip to optimize sample result --- paddlespeech/t2s/modules/diffusion.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/paddlespeech/t2s/modules/diffusion.py b/paddlespeech/t2s/modules/diffusion.py index eb67ffb0..be684ce3 100644 --- a/paddlespeech/t2s/modules/diffusion.py +++ b/paddlespeech/t2s/modules/diffusion.py @@ -360,6 +360,8 @@ class GaussianDiffusion(nn.Layer): num_inference_steps: Optional[int]=1000, strength: Optional[float]=None, scheduler_type: Optional[str]="ddpm", + clip_noise: Optional[bool]=True, + clip_noise_range: Optional[Tuple[float, float]]=(-1, 1), callback: Optional[Callable[[int, int, int, paddle.Tensor], None]]=None, callback_steps: Optional[int]=1): @@ -380,6 +382,10 @@ class GaussianDiffusion(nn.Layer): scheduler_type (str, optional): Noise scheduler for generate noises. Choose a great scheduler can skip many denoising step, by default 'ddpm'. + clip_noise (bool, optional): + Whether to clip each denoised output, by default True. + clip_noise_range (tuple, optional): + denoised output min and max value range after clip, by default (-1, 1). callback (Callable[[int,int,int,Tensor], None], optional): Callback function during denoising steps. @@ -440,6 +446,9 @@ class GaussianDiffusion(nn.Layer): # denoising loop denoised_output = noisy_input + if clip_noise: + n_min, n_max = clip_noise_range + denoised_output = paddle.clip(denoised_output, n_min, n_max) num_warmup_steps = len( timesteps) - num_inference_steps * scheduler.order for i, t in enumerate(timesteps): @@ -451,6 +460,8 @@ class GaussianDiffusion(nn.Layer): # compute the previous noisy sample x_t -> x_t-1 denoised_output = scheduler.step(noise_pred, t, denoised_output).prev_sample + if clip_noise: + denoised_output = paddle.clip(denoised_output, n_min, n_max) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and