You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
276 lines
8.9 KiB
276 lines
8.9 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Residual affine coupling modules in VITS.
|
|
|
|
This code is based on https://github.com/jaywalnut310/vits.
|
|
|
|
"""
|
|
from typing import Optional
|
|
from typing import Tuple
|
|
from typing import Union
|
|
|
|
import paddle
|
|
from paddle import nn
|
|
|
|
from paddlespeech.t2s.models.vits.flow import FlipFlow
|
|
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
|
|
|
|
|
|
class ResidualAffineCouplingBlock(nn.Layer):
|
|
"""Residual affine coupling block module.
|
|
|
|
This is a module of residual affine coupling block, which used as "Flow" in
|
|
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
|
Text-to-Speech`_.
|
|
|
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
|
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int=192,
|
|
hidden_channels: int=192,
|
|
flows: int=4,
|
|
kernel_size: int=5,
|
|
base_dilation: int=1,
|
|
layers: int=4,
|
|
global_channels: int=-1,
|
|
dropout_rate: float=0.0,
|
|
use_weight_norm: bool=True,
|
|
bias: bool=True,
|
|
use_only_mean: bool=True, ):
|
|
"""Initilize ResidualAffineCouplingBlock module.
|
|
|
|
Args:
|
|
in_channels (int):
|
|
Number of input channels.
|
|
hidden_channels (int):
|
|
Number of hidden channels.
|
|
flows (int):
|
|
Number of flows.
|
|
kernel_size (int):
|
|
Kernel size for WaveNet.
|
|
base_dilation (int):
|
|
Base dilation factor for WaveNet.
|
|
layers (int):
|
|
Number of layers of WaveNet.
|
|
stacks (int):
|
|
Number of stacks of WaveNet.
|
|
global_channels (int):
|
|
Number of global channels.
|
|
dropout_rate (float):
|
|
Dropout rate.
|
|
use_weight_norm (bool):
|
|
Whether to use weight normalization in WaveNet.
|
|
bias (bool):
|
|
Whether to use bias paramters in WaveNet.
|
|
use_only_mean (bool):
|
|
Whether to estimate only mean.
|
|
|
|
"""
|
|
super().__init__()
|
|
|
|
self.flows = nn.LayerList()
|
|
for i in range(flows):
|
|
self.flows.append(
|
|
ResidualAffineCouplingLayer(
|
|
in_channels=in_channels,
|
|
hidden_channels=hidden_channels,
|
|
kernel_size=kernel_size,
|
|
base_dilation=base_dilation,
|
|
layers=layers,
|
|
stacks=1,
|
|
global_channels=global_channels,
|
|
dropout_rate=dropout_rate,
|
|
use_weight_norm=use_weight_norm,
|
|
bias=bias,
|
|
use_only_mean=use_only_mean, ))
|
|
self.flows.append(FlipFlow())
|
|
|
|
def forward(
|
|
self,
|
|
x: paddle.Tensor,
|
|
x_mask: paddle.Tensor,
|
|
g: Optional[paddle.Tensor]=None,
|
|
inverse: bool=False, ) -> paddle.Tensor:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
x (Tensor):
|
|
Input tensor (B, in_channels, T).
|
|
x_mask (Tensor):
|
|
Length tensor (B, 1, T).
|
|
g (Optional[Tensor]):
|
|
Global conditioning tensor (B, global_channels, 1).
|
|
inverse (bool):
|
|
Whether to inverse the flow.
|
|
|
|
Returns:
|
|
Tensor: Output tensor (B, in_channels, T).
|
|
|
|
"""
|
|
if not inverse:
|
|
for flow in self.flows:
|
|
x, _ = flow(x, x_mask, g=g, inverse=inverse)
|
|
else:
|
|
for flow in reversed(self.flows):
|
|
x = flow(x, x_mask, g=g, inverse=inverse)
|
|
return x
|
|
|
|
|
|
class ResidualAffineCouplingLayer(nn.Layer):
|
|
"""Residual affine coupling layer."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int=192,
|
|
hidden_channels: int=192,
|
|
kernel_size: int=5,
|
|
base_dilation: int=1,
|
|
layers: int=5,
|
|
stacks: int=1,
|
|
global_channels: int=-1,
|
|
dropout_rate: float=0.0,
|
|
use_weight_norm: bool=True,
|
|
bias: bool=True,
|
|
use_only_mean: bool=True, ):
|
|
"""Initialzie ResidualAffineCouplingLayer module.
|
|
|
|
Args:
|
|
in_channels (int):
|
|
Number of input channels.
|
|
hidden_channels (int):
|
|
Number of hidden channels.
|
|
kernel_size (int):
|
|
Kernel size for WaveNet.
|
|
base_dilation (int):
|
|
Base dilation factor for WaveNet.
|
|
layers (int):
|
|
Number of layers of WaveNet.
|
|
stacks (int):
|
|
Number of stacks of WaveNet.
|
|
global_channels (int):
|
|
Number of global channels.
|
|
dropout_rate (float):
|
|
Dropout rate.
|
|
use_weight_norm (bool):
|
|
Whether to use weight normalization in WaveNet.
|
|
bias (bool):
|
|
Whether to use bias paramters in WaveNet.
|
|
use_only_mean (bool):
|
|
Whether to estimate only mean.
|
|
|
|
"""
|
|
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
|
|
super().__init__()
|
|
self.half_channels = in_channels // 2
|
|
self.use_only_mean = use_only_mean
|
|
|
|
# define modules
|
|
self.input_conv = nn.Conv1D(
|
|
self.half_channels,
|
|
hidden_channels,
|
|
1, )
|
|
self.encoder = WaveNet(
|
|
in_channels=-1,
|
|
out_channels=-1,
|
|
kernel_size=kernel_size,
|
|
layers=layers,
|
|
stacks=stacks,
|
|
base_dilation=base_dilation,
|
|
residual_channels=hidden_channels,
|
|
aux_channels=-1,
|
|
gate_channels=hidden_channels * 2,
|
|
skip_channels=hidden_channels,
|
|
global_channels=global_channels,
|
|
dropout_rate=dropout_rate,
|
|
bias=bias,
|
|
use_weight_norm=use_weight_norm,
|
|
use_first_conv=False,
|
|
use_last_conv=False,
|
|
scale_residual=False,
|
|
scale_skip_connect=True, )
|
|
if use_only_mean:
|
|
self.proj = nn.Conv1D(
|
|
hidden_channels,
|
|
self.half_channels,
|
|
1, )
|
|
else:
|
|
self.proj = nn.Conv1D(
|
|
hidden_channels,
|
|
self.half_channels * 2,
|
|
1, )
|
|
|
|
weight = paddle.zeros(paddle.shape(self.proj.weight))
|
|
|
|
self.proj.weight = paddle.create_parameter(
|
|
shape=weight.shape,
|
|
dtype=str(weight.numpy().dtype),
|
|
default_initializer=paddle.nn.initializer.Assign(weight))
|
|
|
|
bias = paddle.zeros(paddle.shape(self.proj.bias))
|
|
|
|
self.proj.bias = paddle.create_parameter(
|
|
shape=bias.shape,
|
|
dtype=str(bias.numpy().dtype),
|
|
default_initializer=paddle.nn.initializer.Assign(bias))
|
|
|
|
def forward(
|
|
self,
|
|
x: paddle.Tensor,
|
|
x_mask: paddle.Tensor,
|
|
g: Optional[paddle.Tensor]=None,
|
|
inverse: bool=False,
|
|
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
|
"""Calculate forward propagation.
|
|
|
|
Args:
|
|
x (Tensor):
|
|
Input tensor (B, in_channels, T).
|
|
x_lengths (Tensor):
|
|
Length tensor (B,).
|
|
g (Optional[Tensor]):
|
|
Global conditioning tensor (B, global_channels, 1).
|
|
inverse (bool):
|
|
Whether to inverse the flow.
|
|
|
|
Returns:
|
|
Tensor:
|
|
Output tensor (B, in_channels, T).
|
|
Tensor:
|
|
Log-determinant tensor for NLL (B,) if not inverse.
|
|
|
|
"""
|
|
xa, xb = paddle.split(x, 2, axis=1)
|
|
h = self.input_conv(xa) * x_mask
|
|
h = self.encoder(h, x_mask, g=g)
|
|
stats = self.proj(h) * x_mask
|
|
if not self.use_only_mean:
|
|
m, logs = paddle.split(stats, 2, axis=1)
|
|
else:
|
|
m = stats
|
|
logs = paddle.zeros(paddle.shape(m))
|
|
|
|
if not inverse:
|
|
xb = m + xb * paddle.exp(logs) * x_mask
|
|
x = paddle.concat([xa, xb], 1)
|
|
logdet = paddle.sum(logs, [1, 2])
|
|
return x, logdet
|
|
else:
|
|
xb = (xb - m) * paddle.exp(-logs) * x_mask
|
|
x = paddle.concat([xa, xb], 1)
|
|
return x
|