You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/modules/wavenet_denoiser.py

186 lines
6.6 KiB

# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle import nn
from ppdiffusers.models.embeddings import Timesteps
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock
class WaveNetDenoiser(nn.Layer):
"""A Mel-Spectrogram Denoiser modified from WaveNet
Args:
in_channels (int, optional):
Number of channels of the input mel-spectrogram, by default 80
out_channels (int, optional):
Number of channels of the output mel-spectrogram, by default 80
kernel_size (int, optional):
Kernel size of the residual blocks inside, by default 3
layers (int, optional):
Number of residual blocks inside, by default 20
stacks (int, optional):
The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional):
Residual channel of the residual blocks, by default 256
gate_channels (int, optional):
Gate channel of the residual blocks, by default 512
skip_channels (int, optional):
Skip channel of the residual blocks, by default 256
aux_channels (int, optional):
Auxiliary channel of the residual blocks, by default 256
dropout (float, optional):
Dropout of the residual blocks, by default 0.
bias (bool, optional):
Whether to use bias in residual blocks, by default True
use_weight_norm (bool, optional):
Whether to use weight norm in all convolutions, by default False
"""
def __init__(
self,
in_channels: int=80,
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_normal", ):
super().__init__()
# initialize parameters
initialize(self, init_type)
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
assert layers % stacks == 0
layers_per_stack = layers // stacks
self.first_t_emb = nn.Sequential(
Timesteps(
residual_channels,
flip_sin_to_cos=False,
downscale_freq_shift=1),
nn.Linear(residual_channels, residual_channels * 4),
nn.Mish(), nn.Linear(residual_channels * 4, residual_channels))
self.t_emb_layers = nn.LayerList([
nn.Linear(residual_channels, residual_channels)
for _ in range(layers)
])
self.first_conv = nn.Conv1D(
in_channels, residual_channels, 1, bias_attr=True)
self.first_act = nn.ReLU()
self.conv_layers = nn.LayerList()
for layer in range(layers):
dilation = 2**(layer % layers_per_stack)
conv = WaveNetResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias)
self.conv_layers.append(conv)
final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
nn.initializer.Constant(0.0)(final_conv.weight)
self.last_conv_layers = nn.Sequential(nn.ReLU(),
nn.Conv1D(
skip_channels,
skip_channels,
1,
bias_attr=True),
nn.ReLU(), final_conv)
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x: paddle.Tensor, t: paddle.Tensor, c: paddle.Tensor):
"""Denoise mel-spectrogram.
Args:
x(Tensor):
Shape (B, C_in, T), The input mel-spectrogram.
t(Tensor):
Shape (B), The timestep input.
c(Tensor):
Shape (B, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output).
Returns:
Tensor: Shape (B, C_out, T), the pred noise.
"""
assert c.shape[-1] == x.shape[-1]
if t.shape[0] != x.shape[0]:
t = t.tile([x.shape[0]])
t_emb = self.first_t_emb(t)
t_embs = [
t_emb_layer(t_emb)[..., None] for t_emb_layer in self.t_emb_layers
]
x = self.first_conv(x)
x = self.first_act(x)
skips = 0
for f, t in zip(self.conv_layers, t_embs):
x = x + t
x, s = f(x, c)
skips += s
skips *= math.sqrt(1.0 / len(self.conv_layers))
x = self.last_conv_layers(skips)
return x
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)