add wavenet denoiser final conv initializer

pull/2868/head
HighCWu 3 years ago
parent 732f826871
commit e820352741

@ -40,7 +40,7 @@ class WaveNetDenoiser(nn.Layer):
layers (int, optional): layers (int, optional):
Number of residual blocks inside, by default 20 Number of residual blocks inside, by default 20
stacks (int, optional): stacks (int, optional):
The number of groups to split the residual blocks into, by default 4 The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially. Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional): residual_channels (int, optional):
Residual channel of the residual blocks, by default 256 Residual channel of the residual blocks, by default 256
@ -64,7 +64,7 @@ class WaveNetDenoiser(nn.Layer):
out_channels: int=80, out_channels: int=80,
kernel_size: int=3, kernel_size: int=3,
layers: int=20, layers: int=20,
stacks: int=4, stacks: int=5,
residual_channels: int=256, residual_channels: int=256,
gate_channels: int=512, gate_channels: int=512,
skip_channels: int=256, skip_channels: int=256,
@ -72,7 +72,7 @@ class WaveNetDenoiser(nn.Layer):
dropout: float=0., dropout: float=0.,
bias: bool=True, bias: bool=True,
use_weight_norm: bool=False, use_weight_norm: bool=False,
init_type: str="kaiming_uniform", ): init_type: str="kaiming_normal", ):
super().__init__() super().__init__()
# initialize parameters # initialize parameters
@ -118,18 +118,15 @@ class WaveNetDenoiser(nn.Layer):
bias=bias) bias=bias)
self.conv_layers.append(conv) self.conv_layers.append(conv)
final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
nn.initializer.Constant(0.0)(final_conv.weight)
self.last_conv_layers = nn.Sequential(nn.ReLU(), self.last_conv_layers = nn.Sequential(nn.ReLU(),
nn.Conv1D( nn.Conv1D(
skip_channels, skip_channels,
skip_channels, skip_channels,
1, 1,
bias_attr=True), bias_attr=True),
nn.ReLU(), nn.ReLU(), final_conv)
nn.Conv1D(
skip_channels,
out_channels,
1,
bias_attr=True))
if use_weight_norm: if use_weight_norm:
self.apply_weight_norm() self.apply_weight_norm()
@ -200,10 +197,6 @@ class GaussianDiffusion(nn.Layer):
Args: Args:
denoiser (Layer, optional): denoiser (Layer, optional):
The model used for denoising noises. The model used for denoising noises.
In fact, the denoiser model performs the operation
of producing a output with more noises from the noisy input.
Then we use the diffusion algorithm to calculate
the input with the output to get the denoised result.
num_train_timesteps (int, optional): num_train_timesteps (int, optional):
The number of timesteps between the noise and the real during training, by default 1000. The number of timesteps between the noise and the real during training, by default 1000.
beta_start (float, optional): beta_start (float, optional):
@ -233,7 +226,8 @@ class GaussianDiffusion(nn.Layer):
>>> def callback(index, timestep, num_timesteps, sample): >>> def callback(index, timestep, num_timesteps, sample):
>>> nonlocal pbar >>> nonlocal pbar
>>> if pbar is None: >>> if pbar is None:
>>> pbar = tqdm(total=num_timesteps-index) >>> pbar = tqdm(total=num_timesteps)
>>> pbar.update(index)
>>> pbar.update() >>> pbar.update()
>>> >>>
>>> return callback >>> return callback
@ -247,7 +241,7 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad(): >>> with paddle.no_grad():
>>> sample = diffusion.inference( >>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, x, >>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps, >>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type, >>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback()) >>> callback=create_progress_callback())
@ -262,7 +256,7 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad(): >>> with paddle.no_grad():
>>> sample = diffusion.inference( >>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, x_in, >>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps, >>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type, >>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback()) >>> callback=create_progress_callback())
@ -277,11 +271,11 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad(): >>> with paddle.no_grad():
>>> sample = diffusion.inference( >>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, None, >>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps, >>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type, >>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback()) >>> callback=create_progress_callback())
100%|| 25/25 [00:01<00:00, 19.75it/s] 100%|| 34/34 [00:01<00:00, 19.75it/s]
>>> >>>
>>> # ds=1000, K_step=100, scheduler=pndm, infer_step=50, from aux fs2 mel output >>> # ds=1000, K_step=100, scheduler=pndm, infer_step=50, from aux fs2 mel output
>>> ds = 1000 >>> ds = 1000
@ -292,11 +286,11 @@ class GaussianDiffusion(nn.Layer):
>>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step) >>> diffusion = GaussianDiffusion(denoiser, num_train_timesteps=ds, num_max_timesteps=K_step)
>>> with paddle.no_grad(): >>> with paddle.no_grad():
>>> sample = diffusion.inference( >>> sample = diffusion.inference(
>>> paddle.randn(x.shape), c, x, >>> paddle.randn(x.shape), c, ref_x=x_in,
>>> num_inference_steps=infer_steps, >>> num_inference_steps=infer_steps,
>>> scheduler_type=scheduler_type, >>> scheduler_type=scheduler_type,
>>> callback=create_progress_callback()) >>> callback=create_progress_callback())
100%|| 5/5 [00:00<00:00, 23.80it/s] 100%|| 14/14 [00:00<00:00, 23.80it/s]
""" """

Loading…
Cancel
Save