parent
93964960c6
commit
4b7786f2ed
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
MODEL=vits
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,172 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Stochastic duration predictor modules in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.models.vits.flow import ConvFlow
|
||||
from paddlespeech.t2s.models.vits.flow import DilatedDepthSeparableConv
|
||||
from paddlespeech.t2s.models.vits.flow import ElementwiseAffineFlow
|
||||
from paddlespeech.t2s.models.vits.flow import FlipFlow
|
||||
from paddlespeech.t2s.models.vits.flow import LogFlow
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Layer):
|
||||
"""Stochastic duration predictor module.
|
||||
This is a module of stochastic duration predictor described in `Conditional
|
||||
Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`: https://arxiv.org/abs/2106.06103
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int=192,
|
||||
kernel_size: int=3,
|
||||
dropout_rate: float=0.5,
|
||||
flows: int=4,
|
||||
dds_conv_layers: int=3,
|
||||
global_channels: int=-1, ):
|
||||
"""Initialize StochasticDurationPredictor module.
|
||||
Args:
|
||||
channels (int): Number of channels.
|
||||
kernel_size (int): Kernel size.
|
||||
dropout_rate (float): Dropout rate.
|
||||
flows (int): Number of flows.
|
||||
dds_conv_layers (int): Number of conv layers in DDS conv.
|
||||
global_channels (int): Number of global conditioning channels.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.pre = nn.Conv1D(channels, channels, 1)
|
||||
self.dds = DilatedDepthSeparableConv(
|
||||
channels,
|
||||
kernel_size,
|
||||
layers=dds_conv_layers,
|
||||
dropout_rate=dropout_rate, )
|
||||
self.proj = nn.Conv1D(channels, channels, 1)
|
||||
|
||||
self.log_flow = LogFlow()
|
||||
self.flows = nn.LayerList()
|
||||
self.flows.append(ElementwiseAffineFlow(2))
|
||||
for i in range(flows):
|
||||
self.flows.append(
|
||||
ConvFlow(
|
||||
2,
|
||||
channels,
|
||||
kernel_size,
|
||||
layers=dds_conv_layers, ))
|
||||
self.flows.append(FlipFlow())
|
||||
|
||||
self.post_pre = nn.Conv1D(1, channels, 1)
|
||||
self.post_dds = DilatedDepthSeparableConv(
|
||||
channels,
|
||||
kernel_size,
|
||||
layers=dds_conv_layers,
|
||||
dropout_rate=dropout_rate, )
|
||||
self.post_proj = nn.Conv1D(channels, channels, 1)
|
||||
self.post_flows = nn.LayerList()
|
||||
self.post_flows.append(ElementwiseAffineFlow(2))
|
||||
for i in range(flows):
|
||||
self.post_flows.append(
|
||||
ConvFlow(
|
||||
2,
|
||||
channels,
|
||||
kernel_size,
|
||||
layers=dds_conv_layers, ))
|
||||
self.post_flows.append(FlipFlow())
|
||||
|
||||
if global_channels > 0:
|
||||
self.global_conv = nn.Conv1D(global_channels, channels, 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
x_mask: paddle.Tensor,
|
||||
w: Optional[paddle.Tensor]=None,
|
||||
g: Optional[paddle.Tensor]=None,
|
||||
inverse: bool=False,
|
||||
noise_scale: float=1.0, ) -> paddle.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T_text).
|
||||
x_mask (Tensor): Mask tensor (B, 1, T_text).
|
||||
w (Optional[Tensor]): Duration tensor (B, 1, T_text).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
noise_scale (float): Noise scale value.
|
||||
Returns:
|
||||
Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
|
||||
If inverse, log-duration tensor (B, 1, T_text).
|
||||
"""
|
||||
# stop gradient
|
||||
# x = x.detach()
|
||||
x = self.pre(x)
|
||||
if g is not None:
|
||||
# stop gradient
|
||||
x = x + self.global_conv(g.detach())
|
||||
x = self.dds(x, x_mask)
|
||||
x = self.proj(x) * x_mask
|
||||
|
||||
if not inverse:
|
||||
assert w is not None, "w must be provided."
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_dds(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = (paddle.randn([paddle.shape(w)[0], 2, paddle.shape(w)[2]]) *
|
||||
x_mask)
|
||||
z_q = e_q
|
||||
logdet_tot_q = 0.0
|
||||
for i, flow in enumerate(self.post_flows):
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
logdet_tot_q += logdet_q
|
||||
z_u, z1 = paddle.split(z_q, [1, 1], 1)
|
||||
u = F.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += paddle.sum(
|
||||
(F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2])
|
||||
logq = (paddle.sum(-0.5 *
|
||||
(math.log(2 * math.pi) +
|
||||
(e_q**2)) * x_mask, [1, 2]) - logdet_tot_q)
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
logdet_tot += logdet
|
||||
z = paddle.concat([z0, z1], 1)
|
||||
for flow in self.flows:
|
||||
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = (paddle.sum(0.5 * (math.log(2 * math.pi) +
|
||||
(z**2)) * x_mask, [1, 2]) - logdet_tot)
|
||||
# (B,)
|
||||
return nll + logq
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
# remove a useless vflow
|
||||
flows = flows[:-2] + [flows[-1]]
|
||||
z = (paddle.randn([paddle.shape(x)[0], 2, paddle.shape(x)[2]]) *
|
||||
noise_scale)
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, inverse=inverse)
|
||||
z0, z1 = paddle.split(z, 2, axis=1)
|
||||
logw = z0
|
||||
return logw
|
@ -0,0 +1,316 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Basic Flow modules used in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.models.vits.transform import piecewise_rational_quadratic_transform
|
||||
|
||||
|
||||
class FlipFlow(nn.Layer):
|
||||
"""Flip flow module."""
|
||||
|
||||
def forward(self, x: paddle.Tensor, *args, inverse: bool=False, **kwargs
|
||||
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
Returns:
|
||||
Tensor: Flipped tensor (B, channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
"""
|
||||
x = paddle.flip(x, [1])
|
||||
if not inverse:
|
||||
logdet = paddle.zeros(paddle.shape(x)[0], dtype=x.dtype)
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class LogFlow(nn.Layer):
|
||||
"""Log flow module."""
|
||||
|
||||
def forward(self,
|
||||
x: paddle.Tensor,
|
||||
x_mask: paddle.Tensor,
|
||||
inverse: bool=False,
|
||||
eps: float=1e-5,
|
||||
**kwargs
|
||||
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T).
|
||||
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
eps (float): Epsilon for log.
|
||||
Returns:
|
||||
Tensor: Output tensor (B, channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
"""
|
||||
if not inverse:
|
||||
y = paddle.log(paddle.clip(x, min=eps)) * x_mask
|
||||
logdet = paddle.sum(-y, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = paddle.exp(x) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class ElementwiseAffineFlow(nn.Layer):
|
||||
"""Elementwise affine flow module."""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
"""Initialize ElementwiseAffineFlow module.
|
||||
Args:
|
||||
channels (int): Number of channels.
|
||||
"""
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
||||
m = paddle.zeros([channels, 1])
|
||||
self.m = paddle.create_parameter(
|
||||
shape=m.shape,
|
||||
dtype=str(m.numpy().dtype),
|
||||
default_initializer=paddle.nn.initializer.Assign(m))
|
||||
logs = paddle.zeros([channels, 1])
|
||||
self.logs = paddle.create_parameter(
|
||||
shape=logs.shape,
|
||||
dtype=str(logs.numpy().dtype),
|
||||
default_initializer=paddle.nn.initializer.Assign(logs))
|
||||
|
||||
def forward(self,
|
||||
x: paddle.Tensor,
|
||||
x_mask: paddle.Tensor,
|
||||
inverse: bool=False,
|
||||
**kwargs
|
||||
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T).
|
||||
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
Returns:
|
||||
Tensor: Output tensor (B, channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
"""
|
||||
if not inverse:
|
||||
y = self.m + paddle.exp(self.logs) * x
|
||||
y = y * x_mask
|
||||
logdet = paddle.sum(self.logs * x_mask, [1, 2])
|
||||
return y, logdet
|
||||
else:
|
||||
x = (x - self.m) * paddle.exp(-self.logs) * x_mask
|
||||
return x
|
||||
|
||||
|
||||
class Transpose(nn.Layer):
|
||||
"""Transpose module for paddle.nn.Sequential()."""
|
||||
|
||||
def __init__(self, dim1: int, dim2: int):
|
||||
"""Initialize Transpose module."""
|
||||
super().__init__()
|
||||
self.dim1 = dim1
|
||||
self.dim2 = dim2
|
||||
|
||||
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
|
||||
"""Transpose."""
|
||||
len_dim = len(x.shape)
|
||||
orig_perm = list(range(len_dim))
|
||||
new_perm = orig_perm[:]
|
||||
temp = new_perm[self.dim1]
|
||||
new_perm[self.dim1] = new_perm[self.dim2]
|
||||
new_perm[self.dim2] = temp
|
||||
|
||||
return paddle.transpose(x, new_perm)
|
||||
|
||||
|
||||
class DilatedDepthSeparableConv(nn.Layer):
|
||||
"""Dilated depth-separable conv module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
layers: int,
|
||||
dropout_rate: float=0.0,
|
||||
eps: float=1e-5, ):
|
||||
"""Initialize DilatedDepthSeparableConv module.
|
||||
Args:
|
||||
channels (int): Number of channels.
|
||||
kernel_size (int): Kernel size.
|
||||
layers (int): Number of layers.
|
||||
dropout_rate (float): Dropout rate.
|
||||
eps (float): Epsilon for layer norm.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.convs = nn.LayerList()
|
||||
for i in range(layers):
|
||||
dilation = kernel_size**i
|
||||
padding = (kernel_size * dilation - dilation) // 2
|
||||
self.convs.append(
|
||||
nn.Sequential(
|
||||
nn.Conv1D(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
groups=channels,
|
||||
dilation=dilation,
|
||||
padding=padding, ),
|
||||
Transpose(1, 2),
|
||||
nn.LayerNorm(channels, epsilon=eps),
|
||||
Transpose(1, 2),
|
||||
nn.GELU(),
|
||||
nn.Conv1D(
|
||||
channels,
|
||||
channels,
|
||||
1, ),
|
||||
Transpose(1, 2),
|
||||
nn.LayerNorm(channels, epsilon=eps),
|
||||
Transpose(1, 2),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout_rate), ))
|
||||
|
||||
def forward(self,
|
||||
x: paddle.Tensor,
|
||||
x_mask: paddle.Tensor,
|
||||
g: Optional[paddle.Tensor]=None) -> paddle.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T).
|
||||
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
Returns:
|
||||
Tensor: Output tensor (B, channels, T).
|
||||
"""
|
||||
if g is not None:
|
||||
x = x + g
|
||||
for f in self.convs:
|
||||
y = f(x * x_mask)
|
||||
x = x + y
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class ConvFlow(nn.Layer):
|
||||
"""Convolutional flow module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
hidden_channels: int,
|
||||
kernel_size: int,
|
||||
layers: int,
|
||||
bins: int=10,
|
||||
tail_bound: float=5.0, ):
|
||||
"""Initialize ConvFlow module.
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size.
|
||||
layers (int): Number of layers.
|
||||
bins (int): Number of bins.
|
||||
tail_bound (float): Tail bound value.
|
||||
"""
|
||||
super().__init__()
|
||||
self.half_channels = in_channels // 2
|
||||
self.hidden_channels = hidden_channels
|
||||
self.bins = bins
|
||||
self.tail_bound = tail_bound
|
||||
|
||||
self.input_conv = nn.Conv1D(
|
||||
self.half_channels,
|
||||
hidden_channels,
|
||||
1, )
|
||||
self.dds_conv = DilatedDepthSeparableConv(
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
layers,
|
||||
dropout_rate=0.0, )
|
||||
self.proj = nn.Conv1D(
|
||||
hidden_channels,
|
||||
self.half_channels * (bins * 3 - 1),
|
||||
1, )
|
||||
|
||||
# self.proj.weight.data.zero_()
|
||||
# self.proj.bias.data.zero_()
|
||||
|
||||
weight = paddle.zeros(paddle.shape(self.proj.weight))
|
||||
|
||||
self.proj.weight = paddle.create_parameter(
|
||||
shape=weight.shape,
|
||||
dtype=str(weight.numpy().dtype),
|
||||
default_initializer=paddle.nn.initializer.Assign(weight))
|
||||
|
||||
bias = paddle.zeros(paddle.shape(self.proj.bias))
|
||||
|
||||
self.proj.bias = paddle.create_parameter(
|
||||
shape=bias.shape,
|
||||
dtype=str(bias.numpy().dtype),
|
||||
default_initializer=paddle.nn.initializer.Assign(bias))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
x_mask: paddle.Tensor,
|
||||
g: Optional[paddle.Tensor]=None,
|
||||
inverse: bool=False,
|
||||
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, channels, T).
|
||||
x_mask (Tensor): Mask tensor (B, 1, T).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
Returns:
|
||||
Tensor: Output tensor (B, channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
"""
|
||||
xa, xb = x.split(2, 1)
|
||||
h = self.input_conv(xa)
|
||||
h = self.dds_conv(h, x_mask, g=g)
|
||||
# (B, half_channels * (bins * 3 - 1), T)
|
||||
h = self.proj(h) * x_mask
|
||||
|
||||
b, c, t = xa.shape
|
||||
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
|
||||
h = h.reshape([b, c, -1, t]).transpose([0, 1, 3, 2])
|
||||
|
||||
denom = math.sqrt(self.hidden_channels)
|
||||
unnorm_widths = h[..., :self.bins] / denom
|
||||
unnorm_heights = h[..., self.bins:2 * self.bins] / denom
|
||||
unnorm_derivatives = h[..., 2 * self.bins:]
|
||||
xb, logdet_abs = piecewise_rational_quadratic_transform(
|
||||
xb,
|
||||
unnorm_widths,
|
||||
unnorm_heights,
|
||||
unnorm_derivatives,
|
||||
inverse=inverse,
|
||||
tails="linear",
|
||||
tail_bound=self.tail_bound, )
|
||||
x = paddle.concat([xa, xb], 1) * x_mask
|
||||
logdet = paddle.sum(logdet_abs * x_mask, [1, 2])
|
||||
if not inverse:
|
||||
return x, logdet
|
||||
else:
|
||||
return x
|
@ -0,0 +1,551 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Generator module in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
import math
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.models.hifigan import HiFiGANGenerator
|
||||
from paddlespeech.t2s.models.vits.duration_predictor import StochasticDurationPredictor
|
||||
from paddlespeech.t2s.models.vits.posterior_encoder import PosteriorEncoder
|
||||
from paddlespeech.t2s.models.vits.residual_coupling import ResidualAffineCouplingBlock
|
||||
from paddlespeech.t2s.models.vits.text_encoder import TextEncoder
|
||||
from paddlespeech.t2s.modules.nets_utils import get_random_segments
|
||||
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
||||
|
||||
|
||||
class VITSGenerator(nn.Layer):
|
||||
"""Generator module in VITS.
|
||||
This is a module of VITS described in `Conditional Variational Autoencoder
|
||||
with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||
As text encoder, we use conformer architecture instead of the relative positional
|
||||
Transformer, which contains additional convolution layers.
|
||||
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocabs: int,
|
||||
aux_channels: int=513,
|
||||
hidden_channels: int=192,
|
||||
spks: Optional[int]=None,
|
||||
langs: Optional[int]=None,
|
||||
spk_embed_dim: Optional[int]=None,
|
||||
global_channels: int=-1,
|
||||
segment_size: int=32,
|
||||
text_encoder_attention_heads: int=2,
|
||||
text_encoder_ffn_expand: int=4,
|
||||
text_encoder_blocks: int=6,
|
||||
text_encoder_positionwise_layer_type: str="conv1d",
|
||||
text_encoder_positionwise_conv_kernel_size: int=1,
|
||||
text_encoder_positional_encoding_layer_type: str="rel_pos",
|
||||
text_encoder_self_attention_layer_type: str="rel_selfattn",
|
||||
text_encoder_activation_type: str="swish",
|
||||
text_encoder_normalize_before: bool=True,
|
||||
text_encoder_dropout_rate: float=0.1,
|
||||
text_encoder_positional_dropout_rate: float=0.0,
|
||||
text_encoder_attention_dropout_rate: float=0.0,
|
||||
text_encoder_conformer_kernel_size: int=7,
|
||||
use_macaron_style_in_text_encoder: bool=True,
|
||||
use_conformer_conv_in_text_encoder: bool=True,
|
||||
decoder_kernel_size: int=7,
|
||||
decoder_channels: int=512,
|
||||
decoder_upsample_scales: List[int]=[8, 8, 2, 2],
|
||||
decoder_upsample_kernel_sizes: List[int]=[16, 16, 4, 4],
|
||||
decoder_resblock_kernel_sizes: List[int]=[3, 7, 11],
|
||||
decoder_resblock_dilations: List[List[int]]=[[1, 3, 5], [1, 3, 5],
|
||||
[1, 3, 5]],
|
||||
use_weight_norm_in_decoder: bool=True,
|
||||
posterior_encoder_kernel_size: int=5,
|
||||
posterior_encoder_layers: int=16,
|
||||
posterior_encoder_stacks: int=1,
|
||||
posterior_encoder_base_dilation: int=1,
|
||||
posterior_encoder_dropout_rate: float=0.0,
|
||||
use_weight_norm_in_posterior_encoder: bool=True,
|
||||
flow_flows: int=4,
|
||||
flow_kernel_size: int=5,
|
||||
flow_base_dilation: int=1,
|
||||
flow_layers: int=4,
|
||||
flow_dropout_rate: float=0.0,
|
||||
use_weight_norm_in_flow: bool=True,
|
||||
use_only_mean_in_flow: bool=True,
|
||||
stochastic_duration_predictor_kernel_size: int=3,
|
||||
stochastic_duration_predictor_dropout_rate: float=0.5,
|
||||
stochastic_duration_predictor_flows: int=4,
|
||||
stochastic_duration_predictor_dds_conv_layers: int=3, ):
|
||||
"""Initialize VITS generator module.
|
||||
Args:
|
||||
vocabs (int): Input vocabulary size.
|
||||
aux_channels (int): Number of acoustic feature channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
spks (Optional[int]): Number of speakers. If set to > 1, assume that the
|
||||
sids will be provided as the input and use sid embedding layer.
|
||||
langs (Optional[int]): Number of languages. If set to > 1, assume that the
|
||||
lids will be provided as the input and use sid embedding layer.
|
||||
spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
|
||||
assume that spembs will be provided as the input.
|
||||
global_channels (int): Number of global conditioning channels.
|
||||
segment_size (int): Segment size for decoder.
|
||||
text_encoder_attention_heads (int): Number of heads in conformer block
|
||||
of text encoder.
|
||||
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
|
||||
of text encoder.
|
||||
text_encoder_blocks (int): Number of conformer blocks in text encoder.
|
||||
text_encoder_positionwise_layer_type (str): Position-wise layer type in
|
||||
conformer block of text encoder.
|
||||
text_encoder_positionwise_conv_kernel_size (int): Position-wise convolution
|
||||
kernel size in conformer block of text encoder. Only used when the
|
||||
above layer type is conv1d or conv1d-linear.
|
||||
text_encoder_positional_encoding_layer_type (str): Positional encoding layer
|
||||
type in conformer block of text encoder.
|
||||
text_encoder_self_attention_layer_type (str): Self-attention layer type in
|
||||
conformer block of text encoder.
|
||||
text_encoder_activation_type (str): Activation function type in conformer
|
||||
block of text encoder.
|
||||
text_encoder_normalize_before (bool): Whether to apply layer norm before
|
||||
self-attention in conformer block of text encoder.
|
||||
text_encoder_dropout_rate (float): Dropout rate in conformer block of
|
||||
text encoder.
|
||||
text_encoder_positional_dropout_rate (float): Dropout rate for positional
|
||||
encoding in conformer block of text encoder.
|
||||
text_encoder_attention_dropout_rate (float): Dropout rate for attention in
|
||||
conformer block of text encoder.
|
||||
text_encoder_conformer_kernel_size (int): Conformer conv kernel size. It
|
||||
will be used when only use_conformer_conv_in_text_encoder = True.
|
||||
use_macaron_style_in_text_encoder (bool): Whether to use macaron style FFN
|
||||
in conformer block of text encoder.
|
||||
use_conformer_conv_in_text_encoder (bool): Whether to use covolution in
|
||||
conformer block of text encoder.
|
||||
decoder_kernel_size (int): Decoder kernel size.
|
||||
decoder_channels (int): Number of decoder initial channels.
|
||||
decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
|
||||
decoder_upsample_kernel_sizes (List[int]): List of kernel size for
|
||||
upsampling layers in decoder.
|
||||
decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
|
||||
in decoder.
|
||||
decoder_resblock_dilations (List[List[int]]): List of list of dilations for
|
||||
resblocks in decoder.
|
||||
use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
|
||||
decoder.
|
||||
posterior_encoder_kernel_size (int): Posterior encoder kernel size.
|
||||
posterior_encoder_layers (int): Number of layers of posterior encoder.
|
||||
posterior_encoder_stacks (int): Number of stacks of posterior encoder.
|
||||
posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
|
||||
posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
|
||||
use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
|
||||
normalization in posterior encoder.
|
||||
flow_flows (int): Number of flows in flow.
|
||||
flow_kernel_size (int): Kernel size in flow.
|
||||
flow_base_dilation (int): Base dilation in flow.
|
||||
flow_layers (int): Number of layers in flow.
|
||||
flow_dropout_rate (float): Dropout rate in flow
|
||||
use_weight_norm_in_flow (bool): Whether to apply weight normalization in
|
||||
flow.
|
||||
use_only_mean_in_flow (bool): Whether to use only mean in flow.
|
||||
stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
|
||||
duration predictor.
|
||||
stochastic_duration_predictor_dropout_rate (float): Dropout rate in
|
||||
stochastic duration predictor.
|
||||
stochastic_duration_predictor_flows (int): Number of flows in stochastic
|
||||
duration predictor.
|
||||
stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
|
||||
layers in stochastic duration predictor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.segment_size = segment_size
|
||||
self.text_encoder = TextEncoder(
|
||||
vocabs=vocabs,
|
||||
attention_dim=hidden_channels,
|
||||
attention_heads=text_encoder_attention_heads,
|
||||
linear_units=hidden_channels * text_encoder_ffn_expand,
|
||||
blocks=text_encoder_blocks,
|
||||
positionwise_layer_type=text_encoder_positionwise_layer_type,
|
||||
positionwise_conv_kernel_size=text_encoder_positionwise_conv_kernel_size,
|
||||
positional_encoding_layer_type=text_encoder_positional_encoding_layer_type,
|
||||
self_attention_layer_type=text_encoder_self_attention_layer_type,
|
||||
activation_type=text_encoder_activation_type,
|
||||
normalize_before=text_encoder_normalize_before,
|
||||
dropout_rate=text_encoder_dropout_rate,
|
||||
positional_dropout_rate=text_encoder_positional_dropout_rate,
|
||||
attention_dropout_rate=text_encoder_attention_dropout_rate,
|
||||
conformer_kernel_size=text_encoder_conformer_kernel_size,
|
||||
use_macaron_style=use_macaron_style_in_text_encoder,
|
||||
use_conformer_conv=use_conformer_conv_in_text_encoder, )
|
||||
self.decoder = HiFiGANGenerator(
|
||||
in_channels=hidden_channels,
|
||||
out_channels=1,
|
||||
channels=decoder_channels,
|
||||
global_channels=global_channels,
|
||||
kernel_size=decoder_kernel_size,
|
||||
upsample_scales=decoder_upsample_scales,
|
||||
upsample_kernel_sizes=decoder_upsample_kernel_sizes,
|
||||
resblock_kernel_sizes=decoder_resblock_kernel_sizes,
|
||||
resblock_dilations=decoder_resblock_dilations,
|
||||
use_weight_norm=use_weight_norm_in_decoder, )
|
||||
self.posterior_encoder = PosteriorEncoder(
|
||||
in_channels=aux_channels,
|
||||
out_channels=hidden_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
kernel_size=posterior_encoder_kernel_size,
|
||||
layers=posterior_encoder_layers,
|
||||
stacks=posterior_encoder_stacks,
|
||||
base_dilation=posterior_encoder_base_dilation,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=posterior_encoder_dropout_rate,
|
||||
use_weight_norm=use_weight_norm_in_posterior_encoder, )
|
||||
self.flow = ResidualAffineCouplingBlock(
|
||||
in_channels=hidden_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
flows=flow_flows,
|
||||
kernel_size=flow_kernel_size,
|
||||
base_dilation=flow_base_dilation,
|
||||
layers=flow_layers,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=flow_dropout_rate,
|
||||
use_weight_norm=use_weight_norm_in_flow,
|
||||
use_only_mean=use_only_mean_in_flow, )
|
||||
# TODO: Add deterministic version as an option
|
||||
self.duration_predictor = StochasticDurationPredictor(
|
||||
channels=hidden_channels,
|
||||
kernel_size=stochastic_duration_predictor_kernel_size,
|
||||
dropout_rate=stochastic_duration_predictor_dropout_rate,
|
||||
flows=stochastic_duration_predictor_flows,
|
||||
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
|
||||
global_channels=global_channels, )
|
||||
|
||||
self.upsample_factor = int(np.prod(decoder_upsample_scales))
|
||||
self.spks = None
|
||||
if spks is not None and spks > 1:
|
||||
assert global_channels > 0
|
||||
self.spks = spks
|
||||
self.global_emb = nn.Embedding(spks, global_channels)
|
||||
self.spk_embed_dim = None
|
||||
if spk_embed_dim is not None and spk_embed_dim > 0:
|
||||
assert global_channels > 0
|
||||
self.spk_embed_dim = spk_embed_dim
|
||||
self.spemb_proj = nn.Linear(spk_embed_dim, global_channels)
|
||||
self.langs = None
|
||||
if langs is not None and langs > 1:
|
||||
assert global_channels > 0
|
||||
self.langs = langs
|
||||
self.lang_emb = nn.Embedding(langs, global_channels)
|
||||
|
||||
# delayed import
|
||||
from paddlespeech.t2s.models.vits.monotonic_align import maximum_path
|
||||
|
||||
self.maximum_path = maximum_path
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
feats: paddle.Tensor,
|
||||
feats_lengths: paddle.Tensor,
|
||||
sids: Optional[paddle.Tensor]=None,
|
||||
spembs: Optional[paddle.Tensor]=None,
|
||||
lids: Optional[paddle.Tensor]=None,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
|
||||
paddle.Tensor, paddle.Tensor,
|
||||
Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
|
||||
paddle.Tensor, paddle.Tensor, ], ]:
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
text (Tensor): Text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, aux_channels, T_feats).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
Returns:
|
||||
Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
|
||||
Tensor: Duration negative log-likelihood (NLL) tensor (B,).
|
||||
Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
|
||||
Tensor: Segments start index tensor (B,).
|
||||
Tensor: Text mask tensor (B, 1, T_text).
|
||||
Tensor: Feature mask tensor (B, 1, T_feats).
|
||||
tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
- Tensor: Posterior encoder hidden representation (B, H, T_feats).
|
||||
- Tensor: Flow hidden representation (B, H, T_feats).
|
||||
- Tensor: Expanded text encoder projected mean (B, H, T_feats).
|
||||
- Tensor: Expanded text encoder projected scale (B, H, T_feats).
|
||||
- Tensor: Posterior encoder projected mean (B, H, T_feats).
|
||||
- Tensor: Posterior encoder projected scale (B, H, T_feats).
|
||||
"""
|
||||
# forward text encoder
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
||||
|
||||
# calculate global conditioning
|
||||
g = None
|
||||
if self.spks is not None:
|
||||
# speaker one-hot vector embedding: (B, global_channels, 1)
|
||||
g = self.global_emb(paddle.reshape(sids, [-1])).unsqueeze(-1)
|
||||
if self.spk_embed_dim is not None:
|
||||
# pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
|
||||
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
|
||||
if g is None:
|
||||
g = g_
|
||||
else:
|
||||
g = g + g_
|
||||
if self.langs is not None:
|
||||
# language one-hot vector embedding: (B, global_channels, 1)
|
||||
g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1)
|
||||
if g is None:
|
||||
g = g_
|
||||
else:
|
||||
g = g + g_
|
||||
|
||||
# forward posterior encoder
|
||||
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(
|
||||
feats, feats_lengths, g=g)
|
||||
|
||||
# forward flow
|
||||
# (B, H, T_feats)
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
# monotonic alignment search
|
||||
with paddle.no_grad():
|
||||
# negative cross-entropy
|
||||
# (B, H, T_text)
|
||||
s_p_sq_r = paddle.exp(-2 * logs_p)
|
||||
# (B, 1, T_text)
|
||||
neg_x_ent_1 = paddle.sum(
|
||||
-0.5 * math.log(2 * math.pi) - logs_p,
|
||||
[1],
|
||||
keepdim=True, )
|
||||
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||
neg_x_ent_2 = paddle.matmul(
|
||||
-0.5 * (z_p**2).transpose([0, 2, 1]),
|
||||
s_p_sq_r, )
|
||||
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||
neg_x_ent_3 = paddle.matmul(
|
||||
z_p.transpose([0, 2, 1]),
|
||||
(m_p * s_p_sq_r), )
|
||||
# (B, 1, T_text)
|
||||
neg_x_ent_4 = paddle.sum(
|
||||
-0.5 * (m_p**2) * s_p_sq_r,
|
||||
[1],
|
||||
keepdim=True, )
|
||||
# (B, T_feats, T_text)
|
||||
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
|
||||
# (B, 1, T_feats, T_text)
|
||||
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
|
||||
-1)
|
||||
# monotonic attention weight: (B, 1, T_feats, T_text)
|
||||
attn = (self.maximum_path(
|
||||
neg_x_ent,
|
||||
attn_mask.squeeze(1), ).unsqueeze(1).detach())
|
||||
|
||||
# forward duration predictor
|
||||
# (B, 1, T_text)
|
||||
w = attn.sum(2)
|
||||
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
|
||||
dur_nll = dur_nll / paddle.sum(x_mask)
|
||||
|
||||
# expand the length to match with the feature sequence
|
||||
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||
m_p = paddle.matmul(attn.squeeze(1),
|
||||
m_p.transpose([0, 2, 1])).transpose([0, 2, 1])
|
||||
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||
logs_p = paddle.matmul(attn.squeeze(1),
|
||||
logs_p.transpose([0, 2, 1])).transpose([0, 2, 1])
|
||||
|
||||
# get random segments
|
||||
z_segments, z_start_idxs = get_random_segments(
|
||||
z,
|
||||
feats_lengths,
|
||||
self.segment_size, )
|
||||
|
||||
# forward decoder with random segments
|
||||
wav = self.decoder(z_segments, g=g)
|
||||
|
||||
return (wav, dur_nll, attn, z_start_idxs, x_mask, y_mask,
|
||||
(z, z_p, m_p, logs_p, m_q, logs_q), )
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
feats: Optional[paddle.Tensor]=None,
|
||||
feats_lengths: Optional[paddle.Tensor]=None,
|
||||
sids: Optional[paddle.Tensor]=None,
|
||||
spembs: Optional[paddle.Tensor]=None,
|
||||
lids: Optional[paddle.Tensor]=None,
|
||||
dur: Optional[paddle.Tensor]=None,
|
||||
noise_scale: float=0.667,
|
||||
noise_scale_dur: float=0.8,
|
||||
alpha: float=1.0,
|
||||
max_len: Optional[int]=None,
|
||||
use_teacher_forcing: bool=False,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
"""Run inference.
|
||||
Args:
|
||||
text (Tensor): Input text index tensor (B, T_text,).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
|
||||
skip the prediction of durations (i.e., teacher forcing).
|
||||
noise_scale (float): Noise scale parameter for flow.
|
||||
noise_scale_dur (float): Noise scale parameter for duration predictor.
|
||||
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||
max_len (Optional[int]): Maximum length of acoustic feature sequence.
|
||||
use_teacher_forcing (bool): Whether to use teacher forcing.
|
||||
Returns:
|
||||
Tensor: Generated waveform tensor (B, T_wav).
|
||||
Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
|
||||
Tensor: Duration tensor (B, T_text).
|
||||
"""
|
||||
# encoder
|
||||
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
|
||||
g = None
|
||||
if self.spks is not None:
|
||||
# (B, global_channels, 1)
|
||||
g = self.global_emb(paddle.reshape(sids, [-1])).unsqueeze(-1)
|
||||
if self.spk_embed_dim is not None:
|
||||
# (B, global_channels, 1)
|
||||
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
|
||||
if g is None:
|
||||
g = g_
|
||||
else:
|
||||
g = g + g_
|
||||
if self.langs is not None:
|
||||
# (B, global_channels, 1)
|
||||
g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1)
|
||||
if g is None:
|
||||
g = g_
|
||||
else:
|
||||
g = g + g_
|
||||
|
||||
if use_teacher_forcing:
|
||||
# forward posterior encoder
|
||||
z, m_q, logs_q, y_mask = self.posterior_encoder(
|
||||
feats, feats_lengths, g=g)
|
||||
|
||||
# forward flow
|
||||
# (B, H, T_feats)
|
||||
z_p = self.flow(z, y_mask, g=g)
|
||||
|
||||
# monotonic alignment search
|
||||
# (B, H, T_text)
|
||||
s_p_sq_r = paddle.exp(-2 * logs_p)
|
||||
# (B, 1, T_text)
|
||||
neg_x_ent_1 = paddle.sum(
|
||||
-0.5 * math.log(2 * math.pi) - logs_p,
|
||||
[1],
|
||||
keepdim=True, )
|
||||
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||
neg_x_ent_2 = paddle.matmul(
|
||||
-0.5 * (z_p**2).transpose([0, 2, 1]),
|
||||
s_p_sq_r, )
|
||||
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
|
||||
neg_x_ent_3 = paddle.matmul(
|
||||
z_p.transpose([0, 2, 1]),
|
||||
(m_p * s_p_sq_r), )
|
||||
# (B, 1, T_text)
|
||||
neg_x_ent_4 = paddle.sum(
|
||||
-0.5 * (m_p**2) * s_p_sq_r,
|
||||
[1],
|
||||
keepdim=True, )
|
||||
# (B, T_feats, T_text)
|
||||
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
|
||||
# (B, 1, T_feats, T_text)
|
||||
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
|
||||
-1)
|
||||
# monotonic attention weight: (B, 1, T_feats, T_text)
|
||||
attn = self.maximum_path(
|
||||
neg_x_ent,
|
||||
attn_mask.squeeze(1), ).unsqueeze(1)
|
||||
# (B, 1, T_text)
|
||||
dur = attn.sum(2)
|
||||
|
||||
# forward decoder with random segments
|
||||
wav = self.decoder(z * y_mask, g=g)
|
||||
else:
|
||||
# duration
|
||||
if dur is None:
|
||||
logw = self.duration_predictor(
|
||||
x,
|
||||
x_mask,
|
||||
g=g,
|
||||
inverse=True,
|
||||
noise_scale=noise_scale_dur, )
|
||||
w = paddle.exp(logw) * x_mask * alpha
|
||||
dur = paddle.ceil(w)
|
||||
y_lengths = paddle.cast(
|
||||
paddle.clip(paddle.sum(dur, [1, 2]), min=1), dtype='int64')
|
||||
y_mask = make_non_pad_mask(y_lengths).unsqueeze(1)
|
||||
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
|
||||
-1)
|
||||
attn = self._generate_path(dur, attn_mask)
|
||||
|
||||
# expand the length to match with the feature sequence
|
||||
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||
m_p = paddle.matmul(
|
||||
attn.squeeze(1),
|
||||
m_p.transpose([0, 2, 1]), ).transpose([0, 2, 1])
|
||||
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
|
||||
logs_p = paddle.matmul(
|
||||
attn.squeeze(1),
|
||||
logs_p.transpose([0, 2, 1]), ).transpose([0, 2, 1])
|
||||
|
||||
# decoder
|
||||
z_p = m_p + paddle.randn(
|
||||
paddle.shape(m_p)) * paddle.exp(logs_p) * noise_scale
|
||||
z = self.flow(z_p, y_mask, g=g, inverse=True)
|
||||
wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
|
||||
|
||||
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
|
||||
|
||||
def _generate_path(self, dur: paddle.Tensor,
|
||||
mask: paddle.Tensor) -> paddle.Tensor:
|
||||
"""Generate path a.k.a. monotonic attention.
|
||||
Args:
|
||||
dur (Tensor): Duration tensor (B, 1, T_text).
|
||||
mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
|
||||
Returns:
|
||||
Tensor: Path tensor (B, 1, T_feats, T_text).
|
||||
"""
|
||||
b, _, t_y, t_x = paddle.shape(mask)
|
||||
cum_dur = paddle.cumsum(dur, -1)
|
||||
cum_dur_flat = paddle.reshape(cum_dur, [b * t_x])
|
||||
|
||||
path = paddle.arange(t_y, dtype=dur.dtype)
|
||||
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
|
||||
path = paddle.reshape(path, [b, t_x, t_y])
|
||||
'''
|
||||
path will be like (t_x = 3, t_y = 5):
|
||||
[[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
|
||||
[1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
|
||||
[1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
|
||||
'''
|
||||
|
||||
path = paddle.cast(path, dtype='float32')
|
||||
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
|
||||
return path.unsqueeze(1).transpose([0, 1, 3, 2]) * mask
|
@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Maximum path calculation module.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from numba import njit
|
||||
from numba import prange
|
||||
|
||||
try:
|
||||
from .core import maximum_path_c
|
||||
|
||||
is_cython_avalable = True
|
||||
except ImportError:
|
||||
is_cython_avalable = False
|
||||
warnings.warn(
|
||||
"Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
|
||||
"If you want to use the cython version, please build it as follows: "
|
||||
"`cd paddlespeech/t2s/models/vits/monotonic_align; python setup.py build_ext --inplace`"
|
||||
)
|
||||
|
||||
|
||||
def maximum_path(neg_x_ent: paddle.Tensor,
|
||||
attn_mask: paddle.Tensor) -> paddle.Tensor:
|
||||
"""Calculate maximum path.
|
||||
|
||||
Args:
|
||||
neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
|
||||
attn_mask (Tensor): Attention mask (B, T_feats, T_text).
|
||||
|
||||
Returns:
|
||||
Tensor: Maximum path tensor (B, T_feats, T_text).
|
||||
|
||||
"""
|
||||
dtype = neg_x_ent.dtype
|
||||
neg_x_ent = neg_x_ent.numpy().astype(np.float32)
|
||||
path = np.zeros(neg_x_ent.shape, dtype=np.int32)
|
||||
t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
|
||||
t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
|
||||
if is_cython_avalable:
|
||||
maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
|
||||
else:
|
||||
maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
|
||||
|
||||
return paddle.cast(paddle.to_tensor(path), dtype=dtype)
|
||||
|
||||
|
||||
@njit
|
||||
def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
|
||||
"""Calculate a single maximum path with numba."""
|
||||
index = t_x - 1
|
||||
for y in range(t_y):
|
||||
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[y - 1, x]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.0
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[y - 1, x - 1]
|
||||
value[y, x] += max(v_prev, v_cur)
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[y, index] = 1
|
||||
if index != 0 and (index == y or
|
||||
value[y - 1, index] < value[y - 1, index - 1]):
|
||||
index = index - 1
|
||||
|
||||
|
||||
@njit(parallel=True)
|
||||
def maximum_path_numba(paths, values, t_ys, t_xs):
|
||||
"""Calculate batch maximum path with numba."""
|
||||
for i in prange(paths.shape[0]):
|
||||
maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])
|
@ -0,0 +1,62 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Maximum path calculation module with cython optimization.
|
||||
|
||||
This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
|
||||
|
||||
"""
|
||||
|
||||
cimport cython
|
||||
|
||||
from cython.parallel import prange
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
|
||||
cdef int x
|
||||
cdef int y
|
||||
cdef float v_prev
|
||||
cdef float v_cur
|
||||
cdef float tmp
|
||||
cdef int index = t_x - 1
|
||||
|
||||
for y in range(t_y):
|
||||
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
||||
if x == y:
|
||||
v_cur = max_neg_val
|
||||
else:
|
||||
v_cur = value[y - 1, x]
|
||||
if x == 0:
|
||||
if y == 0:
|
||||
v_prev = 0.0
|
||||
else:
|
||||
v_prev = max_neg_val
|
||||
else:
|
||||
v_prev = value[y - 1, x - 1]
|
||||
value[y, x] += max(v_prev, v_cur)
|
||||
|
||||
for y in range(t_y - 1, -1, -1):
|
||||
path[y, index] = 1
|
||||
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
|
||||
index = index - 1
|
||||
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
|
||||
cdef int b = paths.shape[0]
|
||||
cdef int i
|
||||
for i in prange(b, nogil=True):
|
||||
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
|
@ -0,0 +1,39 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Setup cython code."""
|
||||
from Cython.Build import cythonize
|
||||
from setuptools import Extension
|
||||
from setuptools import setup
|
||||
from setuptools.command.build_ext import build_ext as _build_ext
|
||||
|
||||
|
||||
class build_ext(_build_ext):
|
||||
"""Overwrite build_ext."""
|
||||
|
||||
def finalize_options(self):
|
||||
"""Prevent numpy from thinking it is still in its setup process."""
|
||||
_build_ext.finalize_options(self)
|
||||
__builtins__.__NUMPY_SETUP__ = False
|
||||
import numpy
|
||||
|
||||
self.include_dirs.append(numpy.get_include())
|
||||
|
||||
|
||||
exts = [Extension(
|
||||
name="core",
|
||||
sources=["core.pyx"], )]
|
||||
setup(
|
||||
name="monotonic_align",
|
||||
ext_modules=cythonize(exts, language_level=3),
|
||||
cmdclass={"build_ext": build_ext}, )
|
@ -0,0 +1,120 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Text encoder module in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
|
||||
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
||||
|
||||
|
||||
class PosteriorEncoder(nn.Layer):
|
||||
"""Posterior encoder module in VITS.
|
||||
|
||||
This is a module of posterior encoder described in `Conditional Variational
|
||||
Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||
|
||||
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int=513,
|
||||
out_channels: int=192,
|
||||
hidden_channels: int=192,
|
||||
kernel_size: int=5,
|
||||
layers: int=16,
|
||||
stacks: int=1,
|
||||
base_dilation: int=1,
|
||||
global_channels: int=-1,
|
||||
dropout_rate: float=0.0,
|
||||
bias: bool=True,
|
||||
use_weight_norm: bool=True, ):
|
||||
"""Initilialize PosteriorEncoder module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size in WaveNet.
|
||||
layers (int): Number of layers of WaveNet.
|
||||
stacks (int): Number of repeat stacking of WaveNet.
|
||||
base_dilation (int): Base dilation factor.
|
||||
global_channels (int): Number of global conditioning channels.
|
||||
dropout_rate (float): Dropout rate.
|
||||
bias (bool): Whether to use bias parameters in conv.
|
||||
use_weight_norm (bool): Whether to apply weight norm.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# define modules
|
||||
self.input_conv = nn.Conv1D(in_channels, hidden_channels, 1)
|
||||
self.encoder = WaveNet(
|
||||
in_channels=-1,
|
||||
out_channels=-1,
|
||||
kernel_size=kernel_size,
|
||||
layers=layers,
|
||||
stacks=stacks,
|
||||
base_dilation=base_dilation,
|
||||
residual_channels=hidden_channels,
|
||||
aux_channels=-1,
|
||||
gate_channels=hidden_channels * 2,
|
||||
skip_channels=hidden_channels,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=dropout_rate,
|
||||
bias=bias,
|
||||
use_weight_norm=use_weight_norm,
|
||||
use_first_conv=False,
|
||||
use_last_conv=False,
|
||||
scale_residual=False,
|
||||
scale_skip_connect=True, )
|
||||
self.proj = nn.Conv1D(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
x_lengths: paddle.Tensor,
|
||||
g: Optional[paddle.Tensor]=None
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T_feats).
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
|
||||
Tensor: Projected mean tensor (B, out_channels, T_feats).
|
||||
Tensor: Projected scale tensor (B, out_channels, T_feats).
|
||||
Tensor: Mask tensor for input tensor (B, 1, T_feats).
|
||||
|
||||
"""
|
||||
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
|
||||
x = self.input_conv(x) * x_mask
|
||||
x = self.encoder(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
m, logs = paddle.split(stats, 2, axis=1)
|
||||
z = (m + paddle.randn(paddle.shape(m)) * paddle.exp(logs)) * x_mask
|
||||
|
||||
return z, m, logs, x_mask
|
@ -0,0 +1,244 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Residual affine coupling modules in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.models.vits.flow import FlipFlow
|
||||
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
|
||||
|
||||
|
||||
class ResidualAffineCouplingBlock(nn.Layer):
|
||||
"""Residual affine coupling block module.
|
||||
|
||||
This is a module of residual affine coupling block, which used as "Flow" in
|
||||
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`_.
|
||||
|
||||
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int=192,
|
||||
hidden_channels: int=192,
|
||||
flows: int=4,
|
||||
kernel_size: int=5,
|
||||
base_dilation: int=1,
|
||||
layers: int=4,
|
||||
global_channels: int=-1,
|
||||
dropout_rate: float=0.0,
|
||||
use_weight_norm: bool=True,
|
||||
bias: bool=True,
|
||||
use_only_mean: bool=True, ):
|
||||
"""Initilize ResidualAffineCouplingBlock module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
flows (int): Number of flows.
|
||||
kernel_size (int): Kernel size for WaveNet.
|
||||
base_dilation (int): Base dilation factor for WaveNet.
|
||||
layers (int): Number of layers of WaveNet.
|
||||
stacks (int): Number of stacks of WaveNet.
|
||||
global_channels (int): Number of global channels.
|
||||
dropout_rate (float): Dropout rate.
|
||||
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
|
||||
bias (bool): Whether to use bias paramters in WaveNet.
|
||||
use_only_mean (bool): Whether to estimate only mean.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.flows = nn.LayerList()
|
||||
for i in range(flows):
|
||||
self.flows.append(
|
||||
ResidualAffineCouplingLayer(
|
||||
in_channels=in_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
kernel_size=kernel_size,
|
||||
base_dilation=base_dilation,
|
||||
layers=layers,
|
||||
stacks=1,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=dropout_rate,
|
||||
use_weight_norm=use_weight_norm,
|
||||
bias=bias,
|
||||
use_only_mean=use_only_mean, ))
|
||||
self.flows.append(FlipFlow())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
x_mask: paddle.Tensor,
|
||||
g: Optional[paddle.Tensor]=None,
|
||||
inverse: bool=False, ) -> paddle.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T).
|
||||
x_mask (Tensor): Length tensor (B, 1, T).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, in_channels, T).
|
||||
|
||||
"""
|
||||
if not inverse:
|
||||
for flow in self.flows:
|
||||
x, _ = flow(x, x_mask, g=g, inverse=inverse)
|
||||
else:
|
||||
for flow in reversed(self.flows):
|
||||
x = flow(x, x_mask, g=g, inverse=inverse)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAffineCouplingLayer(nn.Layer):
|
||||
"""Residual affine coupling layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int=192,
|
||||
hidden_channels: int=192,
|
||||
kernel_size: int=5,
|
||||
base_dilation: int=1,
|
||||
layers: int=5,
|
||||
stacks: int=1,
|
||||
global_channels: int=-1,
|
||||
dropout_rate: float=0.0,
|
||||
use_weight_norm: bool=True,
|
||||
bias: bool=True,
|
||||
use_only_mean: bool=True, ):
|
||||
"""Initialzie ResidualAffineCouplingLayer module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
hidden_channels (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size for WaveNet.
|
||||
base_dilation (int): Base dilation factor for WaveNet.
|
||||
layers (int): Number of layers of WaveNet.
|
||||
stacks (int): Number of stacks of WaveNet.
|
||||
global_channels (int): Number of global channels.
|
||||
dropout_rate (float): Dropout rate.
|
||||
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
|
||||
bias (bool): Whether to use bias paramters in WaveNet.
|
||||
use_only_mean (bool): Whether to estimate only mean.
|
||||
|
||||
"""
|
||||
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
|
||||
super().__init__()
|
||||
self.half_channels = in_channels // 2
|
||||
self.use_only_mean = use_only_mean
|
||||
|
||||
# define modules
|
||||
self.input_conv = nn.Conv1D(
|
||||
self.half_channels,
|
||||
hidden_channels,
|
||||
1, )
|
||||
self.encoder = WaveNet(
|
||||
in_channels=-1,
|
||||
out_channels=-1,
|
||||
kernel_size=kernel_size,
|
||||
layers=layers,
|
||||
stacks=stacks,
|
||||
base_dilation=base_dilation,
|
||||
residual_channels=hidden_channels,
|
||||
aux_channels=-1,
|
||||
gate_channels=hidden_channels * 2,
|
||||
skip_channels=hidden_channels,
|
||||
global_channels=global_channels,
|
||||
dropout_rate=dropout_rate,
|
||||
bias=bias,
|
||||
use_weight_norm=use_weight_norm,
|
||||
use_first_conv=False,
|
||||
use_last_conv=False,
|
||||
scale_residual=False,
|
||||
scale_skip_connect=True, )
|
||||
if use_only_mean:
|
||||
self.proj = nn.Conv1D(
|
||||
hidden_channels,
|
||||
self.half_channels,
|
||||
1, )
|
||||
else:
|
||||
self.proj = nn.Conv1D(
|
||||
hidden_channels,
|
||||
self.half_channels * 2,
|
||||
1, )
|
||||
# self.proj.weight.data.zero_()
|
||||
# self.proj.bias.data.zero_()
|
||||
|
||||
weight = paddle.zeros(paddle.shape(self.proj.weight))
|
||||
|
||||
self.proj.weight = paddle.create_parameter(
|
||||
shape=weight.shape,
|
||||
dtype=str(weight.numpy().dtype),
|
||||
default_initializer=paddle.nn.initializer.Assign(weight))
|
||||
|
||||
bias = paddle.zeros(paddle.shape(self.proj.bias))
|
||||
|
||||
self.proj.bias = paddle.create_parameter(
|
||||
shape=bias.shape,
|
||||
dtype=str(bias.numpy().dtype),
|
||||
default_initializer=paddle.nn.initializer.Assign(bias))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
x_mask: paddle.Tensor,
|
||||
g: Optional[paddle.Tensor]=None,
|
||||
inverse: bool=False,
|
||||
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T).
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
inverse (bool): Whether to inverse the flow.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, in_channels, T).
|
||||
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
|
||||
|
||||
"""
|
||||
xa, xb = paddle.split(x, 2, axis=1)
|
||||
h = self.input_conv(xa) * x_mask
|
||||
h = self.encoder(h, x_mask, g=g)
|
||||
stats = self.proj(h) * x_mask
|
||||
if not self.use_only_mean:
|
||||
m, logs = paddle.split(stats, 2, axis=1)
|
||||
else:
|
||||
m = stats
|
||||
logs = paddle.zeros(paddle.shape(m))
|
||||
|
||||
if not inverse:
|
||||
xb = m + xb * paddle.exp(logs) * x_mask
|
||||
x = paddle.concat([xa, xb], 1)
|
||||
logdet = paddle.sum(logs, [1, 2])
|
||||
return x, logdet
|
||||
else:
|
||||
xb = (xb - m) * paddle.exp(-logs) * x_mask
|
||||
x = paddle.concat([xa, xb], 1)
|
||||
return x
|
@ -0,0 +1,145 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Text encoder module in VITS.
|
||||
|
||||
This code is based on https://github.com/jaywalnut310/vits.
|
||||
|
||||
"""
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
||||
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder as Encoder
|
||||
|
||||
|
||||
class TextEncoder(nn.Layer):
|
||||
"""Text encoder module in VITS.
|
||||
|
||||
This is a module of text encoder described in `Conditional Variational Autoencoder
|
||||
with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||
|
||||
Instead of the relative positional Transformer, we use conformer architecture as
|
||||
the encoder module, which contains additional convolution layers.
|
||||
|
||||
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocabs: int,
|
||||
attention_dim: int=192,
|
||||
attention_heads: int=2,
|
||||
linear_units: int=768,
|
||||
blocks: int=6,
|
||||
positionwise_layer_type: str="conv1d",
|
||||
positionwise_conv_kernel_size: int=3,
|
||||
positional_encoding_layer_type: str="rel_pos",
|
||||
self_attention_layer_type: str="rel_selfattn",
|
||||
activation_type: str="swish",
|
||||
normalize_before: bool=True,
|
||||
use_macaron_style: bool=False,
|
||||
use_conformer_conv: bool=False,
|
||||
conformer_kernel_size: int=7,
|
||||
dropout_rate: float=0.1,
|
||||
positional_dropout_rate: float=0.0,
|
||||
attention_dropout_rate: float=0.0, ):
|
||||
"""Initialize TextEncoder module.
|
||||
|
||||
Args:
|
||||
vocabs (int): Vocabulary size.
|
||||
attention_dim (int): Attention dimension.
|
||||
attention_heads (int): Number of attention heads.
|
||||
linear_units (int): Number of linear units of positionwise layers.
|
||||
blocks (int): Number of encoder blocks.
|
||||
positionwise_layer_type (str): Positionwise layer type.
|
||||
positionwise_conv_kernel_size (int): Positionwise layer's kernel size.
|
||||
positional_encoding_layer_type (str): Positional encoding layer type.
|
||||
self_attention_layer_type (str): Self-attention layer type.
|
||||
activation_type (str): Activation function type.
|
||||
normalize_before (bool): Whether to apply LayerNorm before attention.
|
||||
use_macaron_style (bool): Whether to use macaron style components.
|
||||
use_conformer_conv (bool): Whether to use conformer conv layers.
|
||||
conformer_kernel_size (int): Conformer's conv kernel size.
|
||||
dropout_rate (float): Dropout rate.
|
||||
positional_dropout_rate (float): Dropout rate for positional encoding.
|
||||
attention_dropout_rate (float): Dropout rate for attention.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
# store for forward
|
||||
self.attention_dim = attention_dim
|
||||
|
||||
# define modules
|
||||
self.emb = nn.Embedding(vocabs, attention_dim)
|
||||
|
||||
dist = paddle.distribution.Normal(loc=0.0, scale=attention_dim**-0.5)
|
||||
w = dist.sample(self.emb.weight.shape)
|
||||
self.emb.weight.set_value(w)
|
||||
|
||||
self.encoder = Encoder(
|
||||
idim=-1,
|
||||
input_layer=None,
|
||||
attention_dim=attention_dim,
|
||||
attention_heads=attention_heads,
|
||||
linear_units=linear_units,
|
||||
num_blocks=blocks,
|
||||
dropout_rate=dropout_rate,
|
||||
positional_dropout_rate=positional_dropout_rate,
|
||||
attention_dropout_rate=attention_dropout_rate,
|
||||
normalize_before=normalize_before,
|
||||
positionwise_layer_type=positionwise_layer_type,
|
||||
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
|
||||
macaron_style=use_macaron_style,
|
||||
pos_enc_layer_type=positional_encoding_layer_type,
|
||||
selfattention_layer_type=self_attention_layer_type,
|
||||
activation_type=activation_type,
|
||||
use_cnn_module=use_conformer_conv,
|
||||
cnn_module_kernel=conformer_kernel_size, )
|
||||
self.proj = nn.Conv1D(attention_dim, attention_dim * 2, 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
x_lengths: paddle.Tensor,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input index tensor (B, T_text).
|
||||
x_lengths (Tensor): Length tensor (B,).
|
||||
|
||||
Returns:
|
||||
Tensor: Encoded hidden representation (B, attention_dim, T_text).
|
||||
Tensor: Projected mean tensor (B, attention_dim, T_text).
|
||||
Tensor: Projected scale tensor (B, attention_dim, T_text).
|
||||
Tensor: Mask tensor for input tensor (B, 1, T_text).
|
||||
|
||||
"""
|
||||
x = self.emb(x) * math.sqrt(self.attention_dim)
|
||||
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
|
||||
# encoder assume the channel last (B, T_text, attention_dim)
|
||||
# but mask shape shoud be (B, 1, T_text)
|
||||
x, _ = self.encoder(x, x_mask)
|
||||
|
||||
# convert the channel first (B, attention_dim, T_text)
|
||||
x = paddle.transpose(x, [0, 2, 1])
|
||||
stats = self.proj(x) * x_mask
|
||||
m, logs = paddle.split(stats, 2, axis=1)
|
||||
|
||||
return x, m, logs, x_mask
|
@ -0,0 +1,238 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Flow-related transformation.
|
||||
|
||||
This code is based on https://github.com/bayesiains/nflows.
|
||||
|
||||
"""
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
|
||||
from paddlespeech.t2s.modules.nets_utils import paddle_gather
|
||||
|
||||
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
||||
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
||||
DEFAULT_MIN_DERIVATIVE = 1e-3
|
||||
|
||||
|
||||
def piecewise_rational_quadratic_transform(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails=None,
|
||||
tail_bound=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
|
||||
if tails is None:
|
||||
spline_fn = rational_quadratic_spline
|
||||
spline_kwargs = {}
|
||||
else:
|
||||
spline_fn = unconstrained_rational_quadratic_spline
|
||||
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
||||
|
||||
outputs, logabsdet = spline_fn(
|
||||
inputs=inputs,
|
||||
unnormalized_widths=unnormalized_widths,
|
||||
unnormalized_heights=unnormalized_heights,
|
||||
unnormalized_derivatives=unnormalized_derivatives,
|
||||
inverse=inverse,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs)
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def mask_preprocess(x, mask):
|
||||
B, C, T, bins = paddle.shape(x)
|
||||
new_x = paddle.zeros([mask.sum(), bins])
|
||||
for i in range(bins):
|
||||
new_x[:, i] = x[:, :, :, i][mask]
|
||||
return new_x
|
||||
|
||||
|
||||
def unconstrained_rational_quadratic_spline(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
tails="linear",
|
||||
tail_bound=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
|
||||
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
||||
outside_interval_mask = ~inside_interval_mask
|
||||
|
||||
outputs = paddle.zeros(paddle.shape(inputs))
|
||||
logabsdet = paddle.zeros(paddle.shape(inputs))
|
||||
if tails == "linear":
|
||||
unnormalized_derivatives = F.pad(
|
||||
unnormalized_derivatives,
|
||||
pad=[0] * (len(unnormalized_derivatives.shape) - 1) * 2 + [1, 1])
|
||||
constant = np.log(np.exp(1 - min_derivative) - 1)
|
||||
unnormalized_derivatives[..., 0] = constant
|
||||
unnormalized_derivatives[..., -1] = constant
|
||||
|
||||
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
||||
logabsdet[outside_interval_mask] = 0
|
||||
else:
|
||||
raise RuntimeError("{} tails are not implemented.".format(tails))
|
||||
|
||||
unnormalized_widths = mask_preprocess(unnormalized_widths,
|
||||
inside_interval_mask)
|
||||
unnormalized_heights = mask_preprocess(unnormalized_heights,
|
||||
inside_interval_mask)
|
||||
unnormalized_derivatives = mask_preprocess(unnormalized_derivatives,
|
||||
inside_interval_mask)
|
||||
|
||||
(outputs[inside_interval_mask],
|
||||
logabsdet[inside_interval_mask], ) = rational_quadratic_spline(
|
||||
inputs=inputs[inside_interval_mask],
|
||||
unnormalized_widths=unnormalized_widths,
|
||||
unnormalized_heights=unnormalized_heights,
|
||||
unnormalized_derivatives=unnormalized_derivatives,
|
||||
inverse=inverse,
|
||||
left=-tail_bound,
|
||||
right=tail_bound,
|
||||
bottom=-tail_bound,
|
||||
top=tail_bound,
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative, )
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def rational_quadratic_spline(
|
||||
inputs,
|
||||
unnormalized_widths,
|
||||
unnormalized_heights,
|
||||
unnormalized_derivatives,
|
||||
inverse=False,
|
||||
left=0.0,
|
||||
right=1.0,
|
||||
bottom=0.0,
|
||||
top=1.0,
|
||||
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
||||
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
||||
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
|
||||
if paddle.min(inputs) < left or paddle.max(inputs) > right:
|
||||
raise ValueError("Input to a transform is not within its domain")
|
||||
|
||||
num_bins = unnormalized_widths.shape[-1]
|
||||
|
||||
if min_bin_width * num_bins > 1.0:
|
||||
raise ValueError("Minimal bin width too large for the number of bins")
|
||||
if min_bin_height * num_bins > 1.0:
|
||||
raise ValueError("Minimal bin height too large for the number of bins")
|
||||
|
||||
widths = F.softmax(unnormalized_widths, axis=-1)
|
||||
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
||||
cumwidths = paddle.cumsum(widths, axis=-1)
|
||||
cumwidths = F.pad(
|
||||
cumwidths,
|
||||
pad=[0] * (len(cumwidths.shape) - 1) * 2 + [1, 0],
|
||||
mode="constant",
|
||||
value=0.0)
|
||||
cumwidths = (right - left) * cumwidths + left
|
||||
cumwidths[..., 0] = left
|
||||
cumwidths[..., -1] = right
|
||||
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
||||
|
||||
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
||||
|
||||
heights = F.softmax(unnormalized_heights, axis=-1)
|
||||
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
||||
cumheights = paddle.cumsum(heights, axis=-1)
|
||||
cumheights = F.pad(
|
||||
cumheights,
|
||||
pad=[0] * (len(cumheights.shape) - 1) * 2 + [1, 0],
|
||||
mode="constant",
|
||||
value=0.0)
|
||||
cumheights = (top - bottom) * cumheights + bottom
|
||||
cumheights[..., 0] = bottom
|
||||
cumheights[..., -1] = top
|
||||
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
||||
|
||||
if inverse:
|
||||
bin_idx = _searchsorted(cumheights, inputs)[..., None]
|
||||
else:
|
||||
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
|
||||
input_cumwidths = paddle_gather(cumwidths, -1, bin_idx)[..., 0]
|
||||
input_bin_widths = paddle_gather(widths, -1, bin_idx)[..., 0]
|
||||
|
||||
input_cumheights = paddle_gather(cumheights, -1, bin_idx)[..., 0]
|
||||
delta = heights / widths
|
||||
input_delta = paddle_gather(delta, -1, bin_idx)[..., 0]
|
||||
|
||||
input_derivatives = paddle_gather(derivatives, -1, bin_idx)[..., 0]
|
||||
input_derivatives_plus_one = paddle_gather(derivatives[..., 1:], -1,
|
||||
bin_idx)[..., 0]
|
||||
|
||||
input_heights = paddle_gather(heights, -1, bin_idx)[..., 0]
|
||||
|
||||
if inverse:
|
||||
a = (inputs - input_cumheights) * (
|
||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
) + input_heights * (input_delta - input_derivatives)
|
||||
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
||||
input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
c = -input_delta * (inputs - input_cumheights)
|
||||
|
||||
discriminant = b.pow(2) - 4 * a * c
|
||||
assert (discriminant >= 0).all()
|
||||
|
||||
root = (2 * c) / (-b - paddle.sqrt(discriminant))
|
||||
outputs = root * input_bin_widths + input_cumwidths
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
) * theta_one_minus_theta)
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * root.pow(2) + 2 * input_delta *
|
||||
theta_one_minus_theta + input_derivatives * (1 - root).pow(2))
|
||||
logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log(
|
||||
denominator)
|
||||
|
||||
return outputs, -logabsdet
|
||||
else:
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (input_delta * theta.pow(2) +
|
||||
input_derivatives * theta_one_minus_theta)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
||||
) * theta_one_minus_theta)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * theta.pow(2) + 2 * input_delta *
|
||||
theta_one_minus_theta + input_derivatives * (1 - theta).pow(2))
|
||||
logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log(
|
||||
denominator)
|
||||
|
||||
return outputs, logabsdet
|
||||
|
||||
|
||||
def _searchsorted(bin_locations, inputs, eps=1e-6):
|
||||
bin_locations[..., -1] += eps
|
||||
return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1
|
@ -0,0 +1,573 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
"""VITS module"""
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator
|
||||
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator
|
||||
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator
|
||||
from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator
|
||||
from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator
|
||||
from paddlespeech.t2s.models.vits.generator import VITSGenerator
|
||||
from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss
|
||||
from paddlespeech.t2s.modules.losses import FeatureMatchLoss
|
||||
from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss
|
||||
from paddlespeech.t2s.modules.losses import KLDivergenceLoss
|
||||
from paddlespeech.t2s.modules.losses import MelSpectrogramLoss
|
||||
from paddlespeech.t2s.modules.nets_utils import get_segments
|
||||
|
||||
AVAILABLE_GENERATERS = {
|
||||
"vits_generator": VITSGenerator,
|
||||
}
|
||||
AVAILABLE_DISCRIMINATORS = {
|
||||
"hifigan_period_discriminator":
|
||||
HiFiGANPeriodDiscriminator,
|
||||
"hifigan_scale_discriminator":
|
||||
HiFiGANScaleDiscriminator,
|
||||
"hifigan_multi_period_discriminator":
|
||||
HiFiGANMultiPeriodDiscriminator,
|
||||
"hifigan_multi_scale_discriminator":
|
||||
HiFiGANMultiScaleDiscriminator,
|
||||
"hifigan_multi_scale_multi_period_discriminator":
|
||||
HiFiGANMultiScaleMultiPeriodDiscriminator,
|
||||
}
|
||||
|
||||
|
||||
class VITS(nn.Layer):
|
||||
"""VITS module (generator + discriminator).
|
||||
This is a module of VITS described in `Conditional Variational Autoencoder
|
||||
with Adversarial Learning for End-to-End Text-to-Speech`_.
|
||||
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
||||
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# generator related
|
||||
idim: int,
|
||||
odim: int,
|
||||
sampling_rate: int=22050,
|
||||
generator_type: str="vits_generator",
|
||||
generator_params: Dict[str, Any]={
|
||||
"hidden_channels": 192,
|
||||
"spks": None,
|
||||
"langs": None,
|
||||
"spk_embed_dim": None,
|
||||
"global_channels": -1,
|
||||
"segment_size": 32,
|
||||
"text_encoder_attention_heads": 2,
|
||||
"text_encoder_ffn_expand": 4,
|
||||
"text_encoder_blocks": 6,
|
||||
"text_encoder_positionwise_layer_type": "conv1d",
|
||||
"text_encoder_positionwise_conv_kernel_size": 1,
|
||||
"text_encoder_positional_encoding_layer_type": "rel_pos",
|
||||
"text_encoder_self_attention_layer_type": "rel_selfattn",
|
||||
"text_encoder_activation_type": "swish",
|
||||
"text_encoder_normalize_before": True,
|
||||
"text_encoder_dropout_rate": 0.1,
|
||||
"text_encoder_positional_dropout_rate": 0.0,
|
||||
"text_encoder_attention_dropout_rate": 0.0,
|
||||
"text_encoder_conformer_kernel_size": 7,
|
||||
"use_macaron_style_in_text_encoder": True,
|
||||
"use_conformer_conv_in_text_encoder": True,
|
||||
"decoder_kernel_size": 7,
|
||||
"decoder_channels": 512,
|
||||
"decoder_upsample_scales": [8, 8, 2, 2],
|
||||
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
|
||||
"decoder_resblock_kernel_sizes": [3, 7, 11],
|
||||
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"use_weight_norm_in_decoder": True,
|
||||
"posterior_encoder_kernel_size": 5,
|
||||
"posterior_encoder_layers": 16,
|
||||
"posterior_encoder_stacks": 1,
|
||||
"posterior_encoder_base_dilation": 1,
|
||||
"posterior_encoder_dropout_rate": 0.0,
|
||||
"use_weight_norm_in_posterior_encoder": True,
|
||||
"flow_flows": 4,
|
||||
"flow_kernel_size": 5,
|
||||
"flow_base_dilation": 1,
|
||||
"flow_layers": 4,
|
||||
"flow_dropout_rate": 0.0,
|
||||
"use_weight_norm_in_flow": True,
|
||||
"use_only_mean_in_flow": True,
|
||||
"stochastic_duration_predictor_kernel_size": 3,
|
||||
"stochastic_duration_predictor_dropout_rate": 0.5,
|
||||
"stochastic_duration_predictor_flows": 4,
|
||||
"stochastic_duration_predictor_dds_conv_layers": 3,
|
||||
},
|
||||
# discriminator related
|
||||
discriminator_type: str="hifigan_multi_scale_multi_period_discriminator",
|
||||
discriminator_params: Dict[str, Any]={
|
||||
"scales": 1,
|
||||
"scale_downsample_pooling": "AvgPool1D",
|
||||
"scale_downsample_pooling_params": {
|
||||
"kernel_size": 4,
|
||||
"stride": 2,
|
||||
"padding": 2,
|
||||
},
|
||||
"scale_discriminator_params": {
|
||||
"in_channels": 1,
|
||||
"out_channels": 1,
|
||||
"kernel_sizes": [15, 41, 5, 3],
|
||||
"channels": 128,
|
||||
"max_downsample_channels": 1024,
|
||||
"max_groups": 16,
|
||||
"bias": True,
|
||||
"downsample_scales": [2, 2, 4, 4, 1],
|
||||
"nonlinear_activation": "leakyrelu",
|
||||
"nonlinear_activation_params": {
|
||||
"negative_slope": 0.1
|
||||
},
|
||||
"use_weight_norm": True,
|
||||
"use_spectral_norm": False,
|
||||
},
|
||||
"follow_official_norm": False,
|
||||
"periods": [2, 3, 5, 7, 11],
|
||||
"period_discriminator_params": {
|
||||
"in_channels": 1,
|
||||
"out_channels": 1,
|
||||
"kernel_sizes": [5, 3],
|
||||
"channels": 32,
|
||||
"downsample_scales": [3, 3, 3, 3, 1],
|
||||
"max_downsample_channels": 1024,
|
||||
"bias": True,
|
||||
"nonlinear_activation": "leakyrelu",
|
||||
"nonlinear_activation_params": {
|
||||
"negative_slope": 0.1
|
||||
},
|
||||
"use_weight_norm": True,
|
||||
"use_spectral_norm": False,
|
||||
},
|
||||
},
|
||||
# loss related
|
||||
generator_adv_loss_params: Dict[str, Any]={
|
||||
"average_by_discriminators": False,
|
||||
"loss_type": "mse",
|
||||
},
|
||||
discriminator_adv_loss_params: Dict[str, Any]={
|
||||
"average_by_discriminators": False,
|
||||
"loss_type": "mse",
|
||||
},
|
||||
feat_match_loss_params: Dict[str, Any]={
|
||||
"average_by_discriminators": False,
|
||||
"average_by_layers": False,
|
||||
"include_final_outputs": True,
|
||||
},
|
||||
mel_loss_params: Dict[str, Any]={
|
||||
"fs": 22050,
|
||||
"fft_size": 1024,
|
||||
"hop_size": 256,
|
||||
"win_length": None,
|
||||
"window": "hann",
|
||||
"num_mels": 80,
|
||||
"fmin": 0,
|
||||
"fmax": None,
|
||||
"log_base": None,
|
||||
},
|
||||
lambda_adv: float=1.0,
|
||||
lambda_mel: float=45.0,
|
||||
lambda_feat_match: float=2.0,
|
||||
lambda_dur: float=1.0,
|
||||
lambda_kl: float=1.0,
|
||||
cache_generator_outputs: bool=True, ):
|
||||
"""Initialize VITS module.
|
||||
Args:
|
||||
idim (int): Input vocabrary size.
|
||||
odim (int): Acoustic feature dimension. The actual output channels will
|
||||
be 1 since VITS is the end-to-end text-to-wave model but for the
|
||||
compatibility odim is used to indicate the acoustic feature dimension.
|
||||
sampling_rate (int): Sampling rate, not used for the training but it will
|
||||
be referred in saving waveform during the inference.
|
||||
generator_type (str): Generator type.
|
||||
generator_params (Dict[str, Any]): Parameter dict for generator.
|
||||
discriminator_type (str): Discriminator type.
|
||||
discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
|
||||
generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator
|
||||
adversarial loss.
|
||||
discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for
|
||||
discriminator adversarial loss.
|
||||
feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss.
|
||||
mel_loss_params (Dict[str, Any]): Parameter dict for mel loss.
|
||||
lambda_adv (float): Loss scaling coefficient for adversarial loss.
|
||||
lambda_mel (float): Loss scaling coefficient for mel spectrogram loss.
|
||||
lambda_feat_match (float): Loss scaling coefficient for feat match loss.
|
||||
lambda_dur (float): Loss scaling coefficient for duration loss.
|
||||
lambda_kl (float): Loss scaling coefficient for KL divergence loss.
|
||||
cache_generator_outputs (bool): Whether to cache generator outputs.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
# define modules
|
||||
generator_class = AVAILABLE_GENERATERS[generator_type]
|
||||
if generator_type == "vits_generator":
|
||||
# NOTE: Update parameters for the compatibility.
|
||||
# The idim and odim is automatically decided from input data,
|
||||
# where idim represents #vocabularies and odim represents
|
||||
# the input acoustic feature dimension.
|
||||
generator_params.update(vocabs=idim, aux_channels=odim)
|
||||
self.generator = generator_class(
|
||||
**generator_params, )
|
||||
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
|
||||
self.discriminator = discriminator_class(
|
||||
**discriminator_params, )
|
||||
self.generator_adv_loss = GeneratorAdversarialLoss(
|
||||
**generator_adv_loss_params, )
|
||||
self.discriminator_adv_loss = DiscriminatorAdversarialLoss(
|
||||
**discriminator_adv_loss_params, )
|
||||
self.feat_match_loss = FeatureMatchLoss(
|
||||
**feat_match_loss_params, )
|
||||
self.mel_loss = MelSpectrogramLoss(
|
||||
**mel_loss_params, )
|
||||
self.kl_loss = KLDivergenceLoss()
|
||||
|
||||
# coefficients
|
||||
self.lambda_adv = lambda_adv
|
||||
self.lambda_mel = lambda_mel
|
||||
self.lambda_kl = lambda_kl
|
||||
self.lambda_feat_match = lambda_feat_match
|
||||
self.lambda_dur = lambda_dur
|
||||
|
||||
# cache
|
||||
self.cache_generator_outputs = cache_generator_outputs
|
||||
self._cache = None
|
||||
|
||||
# store sampling rate for saving wav file
|
||||
# (not used for the training)
|
||||
self.fs = sampling_rate
|
||||
|
||||
# store parameters for test compatibility
|
||||
self.spks = self.generator.spks
|
||||
self.langs = self.generator.langs
|
||||
self.spk_embed_dim = self.generator.spk_embed_dim
|
||||
|
||||
@property
|
||||
def require_raw_speech(self):
|
||||
"""Return whether or not speech is required."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def require_vocoder(self):
|
||||
"""Return whether or not vocoder is required."""
|
||||
return False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
feats: paddle.Tensor,
|
||||
feats_lengths: paddle.Tensor,
|
||||
sids: Optional[paddle.Tensor]=None,
|
||||
spembs: Optional[paddle.Tensor]=None,
|
||||
lids: Optional[paddle.Tensor]=None,
|
||||
forward_generator: bool=True, ) -> Dict[str, Any]:
|
||||
"""Perform generator forward.
|
||||
Args:
|
||||
text (Tensor): Text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
forward_generator (bool): Whether to forward generator.
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
- loss (Tensor): Loss scalar tensor.
|
||||
- stats (Dict[str, float]): Statistics to be monitored.
|
||||
- weight (Tensor): Weight tensor to summarize losses.
|
||||
- optim_idx (int): Optimizer index (0 for G and 1 for D).
|
||||
"""
|
||||
if forward_generator:
|
||||
return self._forward_generator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids, )
|
||||
else:
|
||||
return self._forward_discrminator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids, )
|
||||
|
||||
def _forward_generator(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
feats: paddle.Tensor,
|
||||
feats_lengths: paddle.Tensor,
|
||||
sids: Optional[paddle.Tensor]=None,
|
||||
spembs: Optional[paddle.Tensor]=None,
|
||||
lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]:
|
||||
"""Perform generator forward.
|
||||
Args:
|
||||
text (Tensor): Text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
* weight (Tensor): Weight tensor to summarize losses.
|
||||
* optim_idx (int): Optimizer index (0 for G and 1 for D).
|
||||
"""
|
||||
# setup
|
||||
batch_size = paddle.shape(text)[0]
|
||||
feats = feats.transpose([0, 2, 1])
|
||||
# speech = speech.unsqueeze(1)
|
||||
|
||||
# calculate generator outputs
|
||||
reuse_cache = True
|
||||
if not self.cache_generator_outputs or self._cache is None:
|
||||
reuse_cache = False
|
||||
outs = self.generator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids, )
|
||||
else:
|
||||
outs = self._cache
|
||||
|
||||
# store cache
|
||||
if self.training and self.cache_generator_outputs and not reuse_cache:
|
||||
self._cache = outs
|
||||
|
||||
return outs
|
||||
"""
|
||||
# parse outputs
|
||||
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
|
||||
_, z_p, m_p, logs_p, _, logs_q = outs_
|
||||
speech_ = get_segments(
|
||||
x=speech,
|
||||
start_idxs=start_idxs * self.generator.upsample_factor,
|
||||
segment_size=self.generator.segment_size *
|
||||
self.generator.upsample_factor, )
|
||||
|
||||
# calculate discriminator outputs
|
||||
p_hat = self.discriminator(speech_hat_)
|
||||
with paddle.no_grad():
|
||||
# do not store discriminator gradient in generator turn
|
||||
p = self.discriminator(speech_)
|
||||
|
||||
# calculate losses
|
||||
mel_loss = self.mel_loss(speech_hat_, speech_)
|
||||
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
|
||||
dur_loss = paddle.sum(dur_nll.float())
|
||||
adv_loss = self.generator_adv_loss(p_hat)
|
||||
feat_match_loss = self.feat_match_loss(p_hat, p)
|
||||
|
||||
mel_loss = mel_loss * self.lambda_mel
|
||||
kl_loss = kl_loss * self.lambda_kl
|
||||
dur_loss = dur_loss * self.lambda_dur
|
||||
adv_loss = adv_loss * self.lambda_adv
|
||||
feat_match_loss = feat_match_loss * self.lambda_feat_match
|
||||
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
|
||||
|
||||
stats = dict(
|
||||
generator_loss=loss.item(),
|
||||
generator_mel_loss=mel_loss.item(),
|
||||
generator_kl_loss=kl_loss.item(),
|
||||
generator_dur_loss=dur_loss.item(),
|
||||
generator_adv_loss=adv_loss.item(),
|
||||
generator_feat_match_loss=feat_match_loss.item(), )
|
||||
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"stats": stats,
|
||||
# "weight": weight,
|
||||
"optim_idx": 0, # needed for trainer
|
||||
}
|
||||
"""
|
||||
|
||||
def _forward_discrminator(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
feats: paddle.Tensor,
|
||||
feats_lengths: paddle.Tensor,
|
||||
sids: Optional[paddle.Tensor]=None,
|
||||
spembs: Optional[paddle.Tensor]=None,
|
||||
lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]:
|
||||
"""Perform discriminator forward.
|
||||
Args:
|
||||
text (Tensor): Text index tensor (B, T_text).
|
||||
text_lengths (Tensor): Text length tensor (B,).
|
||||
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
||||
feats_lengths (Tensor): Feature length tensor (B,).
|
||||
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
||||
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
||||
Returns:
|
||||
Dict[str, Any]:
|
||||
* loss (Tensor): Loss scalar tensor.
|
||||
* stats (Dict[str, float]): Statistics to be monitored.
|
||||
* weight (Tensor): Weight tensor to summarize losses.
|
||||
* optim_idx (int): Optimizer index (0 for G and 1 for D).
|
||||
"""
|
||||
# setup
|
||||
batch_size = paddle.shape(text)[0]
|
||||
feats = feats.transpose([0, 2, 1])
|
||||
# speech = speech.unsqueeze(1)
|
||||
|
||||
# calculate generator outputs
|
||||
reuse_cache = True
|
||||
if not self.cache_generator_outputs or self._cache is None:
|
||||
reuse_cache = False
|
||||
outs = self.generator(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids, )
|
||||
else:
|
||||
outs = self._cache
|
||||
|
||||
# store cache
|
||||
if self.cache_generator_outputs and not reuse_cache:
|
||||
self._cache = outs
|
||||
|
||||
return outs
|
||||
"""
|
||||
|
||||
# parse outputs
|
||||
speech_hat_, _, _, start_idxs, *_ = outs
|
||||
speech_ = get_segments(
|
||||
x=speech,
|
||||
start_idxs=start_idxs * self.generator.upsample_factor,
|
||||
segment_size=self.generator.segment_size *
|
||||
self.generator.upsample_factor, )
|
||||
|
||||
# calculate discriminator outputs
|
||||
p_hat = self.discriminator(speech_hat_.detach())
|
||||
p = self.discriminator(speech_)
|
||||
|
||||
# calculate losses
|
||||
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
|
||||
loss = real_loss + fake_loss
|
||||
|
||||
stats = dict(
|
||||
discriminator_loss=loss.item(),
|
||||
discriminator_real_loss=real_loss.item(),
|
||||
discriminator_fake_loss=fake_loss.item(), )
|
||||
|
||||
# reset cache
|
||||
if reuse_cache or not self.training:
|
||||
self._cache = None
|
||||
|
||||
return {
|
||||
"loss": loss,
|
||||
"stats": stats,
|
||||
# "weight": weight,
|
||||
"optim_idx": 1, # needed for trainer
|
||||
}
|
||||
"""
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
feats: Optional[paddle.Tensor]=None,
|
||||
sids: Optional[paddle.Tensor]=None,
|
||||
spembs: Optional[paddle.Tensor]=None,
|
||||
lids: Optional[paddle.Tensor]=None,
|
||||
durations: Optional[paddle.Tensor]=None,
|
||||
noise_scale: float=0.667,
|
||||
noise_scale_dur: float=0.8,
|
||||
alpha: float=1.0,
|
||||
max_len: Optional[int]=None,
|
||||
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
|
||||
"""Run inference.
|
||||
Args:
|
||||
text (Tensor): Input text index tensor (T_text,).
|
||||
feats (Tensor): Feature tensor (T_feats, aux_channels).
|
||||
sids (Tensor): Speaker index tensor (1,).
|
||||
spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
|
||||
lids (Tensor): Language index tensor (1,).
|
||||
durations (Tensor): Ground-truth duration tensor (T_text,).
|
||||
noise_scale (float): Noise scale value for flow.
|
||||
noise_scale_dur (float): Noise scale value for duration predictor.
|
||||
alpha (float): Alpha parameter to control the speed of generated speech.
|
||||
max_len (Optional[int]): Maximum length.
|
||||
use_teacher_forcing (bool): Whether to use teacher forcing.
|
||||
Returns:
|
||||
Dict[str, Tensor]:
|
||||
* wav (Tensor): Generated waveform tensor (T_wav,).
|
||||
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
|
||||
* duration (Tensor): Predicted duration tensor (T_text,).
|
||||
"""
|
||||
# setup
|
||||
text = text[None]
|
||||
text_lengths = paddle.to_tensor(paddle.shape(text)[1])
|
||||
# if sids is not None:
|
||||
# sids = sids.view(1)
|
||||
# if lids is not None:
|
||||
# lids = lids.view(1)
|
||||
if durations is not None:
|
||||
durations = paddle.reshape(durations, [1, 1, -1])
|
||||
|
||||
# inference
|
||||
if use_teacher_forcing:
|
||||
assert feats is not None
|
||||
feats = feats[None].transpose([0, 2, 1])
|
||||
feats_lengths = paddle.to_tensor([paddle.shape(feats)[2]])
|
||||
wav, att_w, dur = self.generator.inference(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
feats=feats,
|
||||
feats_lengths=feats_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
max_len=max_len,
|
||||
use_teacher_forcing=use_teacher_forcing, )
|
||||
else:
|
||||
wav, att_w, dur = self.generator.inference(
|
||||
text=text,
|
||||
text_lengths=text_lengths,
|
||||
sids=sids,
|
||||
spembs=spembs,
|
||||
lids=lids,
|
||||
dur=durations,
|
||||
noise_scale=noise_scale,
|
||||
noise_scale_dur=noise_scale_dur,
|
||||
alpha=alpha,
|
||||
max_len=max_len, )
|
||||
return dict(
|
||||
wav=paddle.reshape(wav, [-1]), att_w=att_w[0], duration=dur[0])
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,154 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class ResidualBlock(nn.Layer):
|
||||
"""Residual block module in WaveNet."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: int=3,
|
||||
residual_channels: int=64,
|
||||
gate_channels: int=128,
|
||||
skip_channels: int=64,
|
||||
aux_channels: int=80,
|
||||
global_channels: int=-1,
|
||||
dropout_rate: float=0.0,
|
||||
dilation: int=1,
|
||||
bias: bool=True,
|
||||
scale_residual: bool=False, ):
|
||||
"""Initialize ResidualBlock module.
|
||||
|
||||
Args:
|
||||
kernel_size (int): Kernel size of dilation convolution layer.
|
||||
residual_channels (int): Number of channels for residual connection.
|
||||
skip_channels (int): Number of channels for skip connection.
|
||||
aux_channels (int): Number of local conditioning channels.
|
||||
dropout (float): Dropout probability.
|
||||
dilation (int): Dilation factor.
|
||||
bias (bool): Whether to add bias parameter in convolution layers.
|
||||
scale_residual (bool): Whether to scale the residual outputs.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.dropout_rate = dropout_rate
|
||||
self.residual_channels = residual_channels
|
||||
self.skip_channels = skip_channels
|
||||
self.scale_residual = scale_residual
|
||||
|
||||
# check
|
||||
assert (
|
||||
kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
assert gate_channels % 2 == 0
|
||||
|
||||
# dilation conv
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
self.conv = nn.Conv1D(
|
||||
residual_channels,
|
||||
gate_channels,
|
||||
kernel_size,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias_attr=bias, )
|
||||
|
||||
# local conditioning
|
||||
if aux_channels > 0:
|
||||
self.conv1x1_aux = nn.Conv1D(
|
||||
aux_channels, gate_channels, kernel_size=1, bias_attr=False)
|
||||
else:
|
||||
self.conv1x1_aux = None
|
||||
|
||||
# global conditioning
|
||||
if global_channels > 0:
|
||||
self.conv1x1_glo = nn.Conv1D(
|
||||
global_channels, gate_channels, kernel_size=1, bias_attr=False)
|
||||
else:
|
||||
self.conv1x1_glo = None
|
||||
|
||||
# conv output is split into two groups
|
||||
gate_out_channels = gate_channels // 2
|
||||
|
||||
# NOTE: concat two convs into a single conv for the efficiency
|
||||
# (integrate res 1x1 + skip 1x1 convs)
|
||||
self.conv1x1_out = nn.Conv1D(
|
||||
gate_out_channels,
|
||||
residual_channels + skip_channels,
|
||||
kernel_size=1,
|
||||
bias_attr=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
x_mask: Optional[paddle.Tensor]=None,
|
||||
c: Optional[paddle.Tensor]=None,
|
||||
g: Optional[paddle.Tensor]=None,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, residual_channels, T).
|
||||
x_mask Optional[paddle.Tensor]: Mask tensor (B, 1, T).
|
||||
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
|
||||
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor for residual connection (B, residual_channels, T).
|
||||
Tensor: Output tensor for skip connection (B, skip_channels, T).
|
||||
|
||||
"""
|
||||
residual = x
|
||||
x = F.dropout(x, p=self.dropout_rate, training=self.training)
|
||||
x = self.conv(x)
|
||||
|
||||
# split into two part for gated activation
|
||||
splitdim = 1
|
||||
xa, xb = paddle.split(x, 2, axis=splitdim)
|
||||
|
||||
# local conditioning
|
||||
if c is not None:
|
||||
c = self.conv1x1_aux(c)
|
||||
ca, cb = paddle.split(c, 2, axis=splitdim)
|
||||
xa, xb = xa + ca, xb + cb
|
||||
|
||||
# global conditioning
|
||||
if g is not None:
|
||||
g = self.conv1x1_glo(g)
|
||||
ga, gb = paddle.split(g, 2, axis=splitdim)
|
||||
xa, xb = xa + ga, xb + gb
|
||||
|
||||
x = paddle.tanh(xa) * F.sigmoid(xb)
|
||||
|
||||
# residual + skip 1x1 conv
|
||||
x = self.conv1x1_out(x)
|
||||
if x_mask is not None:
|
||||
x = x * x_mask
|
||||
|
||||
# split integrated conv results
|
||||
x, s = paddle.split(
|
||||
x, [self.residual_channels, self.skip_channels], axis=1)
|
||||
|
||||
# for residual connection
|
||||
x = x + residual
|
||||
if self.scale_residual:
|
||||
x = x * math.sqrt(0.5)
|
||||
|
||||
return x, s
|
@ -0,0 +1,175 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.models.vits.wavenet.residual_block import ResidualBlock
|
||||
|
||||
|
||||
class WaveNet(nn.Layer):
|
||||
"""WaveNet with global conditioning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int=1,
|
||||
out_channels: int=1,
|
||||
kernel_size: int=3,
|
||||
layers: int=30,
|
||||
stacks: int=3,
|
||||
base_dilation: int=2,
|
||||
residual_channels: int=64,
|
||||
aux_channels: int=-1,
|
||||
gate_channels: int=128,
|
||||
skip_channels: int=64,
|
||||
global_channels: int=-1,
|
||||
dropout_rate: float=0.0,
|
||||
bias: bool=True,
|
||||
use_weight_norm: bool=True,
|
||||
use_first_conv: bool=False,
|
||||
use_last_conv: bool=False,
|
||||
scale_residual: bool=False,
|
||||
scale_skip_connect: bool=False, ):
|
||||
"""Initialize WaveNet module.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
kernel_size (int): Kernel size of dilated convolution.
|
||||
layers (int): Number of residual block layers.
|
||||
stacks (int): Number of stacks i.e., dilation cycles.
|
||||
base_dilation (int): Base dilation factor.
|
||||
residual_channels (int): Number of channels in residual conv.
|
||||
gate_channels (int): Number of channels in gated conv.
|
||||
skip_channels (int): Number of channels in skip conv.
|
||||
aux_channels (int): Number of channels for local conditioning feature.
|
||||
global_channels (int): Number of channels for global conditioning feature.
|
||||
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
|
||||
bias (bool): Whether to use bias parameter in conv layer.
|
||||
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
|
||||
be applied to all of the conv layers.
|
||||
use_first_conv (bool): Whether to use the first conv layers.
|
||||
use_last_conv (bool): Whether to use the last conv layers.
|
||||
scale_residual (bool): Whether to scale the residual outputs.
|
||||
scale_skip_connect (bool): Whether to scale the skip connection outputs.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.layers = layers
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
self.base_dilation = base_dilation
|
||||
self.use_first_conv = use_first_conv
|
||||
self.use_last_conv = use_last_conv
|
||||
self.scale_skip_connect = scale_skip_connect
|
||||
|
||||
# check the number of layers and stacks
|
||||
assert layers % stacks == 0
|
||||
layers_per_stack = layers // stacks
|
||||
|
||||
# define first convolution
|
||||
if self.use_first_conv:
|
||||
self.first_conv = nn.Conv1D(
|
||||
in_channels, residual_channels, kernel_size=1, bias_attr=True)
|
||||
|
||||
# define residual blocks
|
||||
self.conv_layers = nn.LayerList()
|
||||
for layer in range(layers):
|
||||
dilation = base_dilation**(layer % layers_per_stack)
|
||||
conv = ResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
residual_channels=residual_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=aux_channels,
|
||||
global_channels=global_channels,
|
||||
dilation=dilation,
|
||||
dropout_rate=dropout_rate,
|
||||
bias=bias,
|
||||
scale_residual=scale_residual, )
|
||||
self.conv_layers.append(conv)
|
||||
|
||||
# define output layers
|
||||
if self.use_last_conv:
|
||||
self.last_conv = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Conv1D(
|
||||
skip_channels, skip_channels, kernel_size=1,
|
||||
bias_attr=True),
|
||||
nn.ReLU(),
|
||||
nn.Conv1D(
|
||||
skip_channels, out_channels, kernel_size=1, bias_attr=True),
|
||||
)
|
||||
|
||||
# apply weight norm
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
x_mask: Optional[paddle.Tensor]=None,
|
||||
c: Optional[paddle.Tensor]=None,
|
||||
g: Optional[paddle.Tensor]=None, ) -> paddle.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
|
||||
(B, residual_channels, T).
|
||||
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
|
||||
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
|
||||
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
|
||||
(B, residual_channels, T).
|
||||
|
||||
"""
|
||||
# encode to hidden representation
|
||||
if self.use_first_conv:
|
||||
x = self.first_conv(x)
|
||||
|
||||
# residual block
|
||||
skips = 0.0
|
||||
for f in self.conv_layers:
|
||||
x, h = f(x, x_mask=x_mask, c=c, g=g)
|
||||
skips = skips + h
|
||||
x = skips
|
||||
if self.scale_skip_connect:
|
||||
x = x * math.sqrt(1.0 / len(self.conv_layers))
|
||||
|
||||
# apply final layers
|
||||
if self.use_last_conv:
|
||||
x = self.last_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
def apply_weight_norm(self):
|
||||
def _apply_weight_norm(layer):
|
||||
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
|
||||
nn.utils.weight_norm(layer)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
def _remove_weight_norm(layer):
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self.apply(_remove_weight_norm)
|
Loading…
Reference in new issue