add vits network scripts, test=tts

pull/1855/head
TianYuan 2 years ago
parent 93964960c6
commit 4b7786f2ed

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=vits
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}

@ -20,15 +20,14 @@ from scipy.interpolate import interp1d
class LogMelFBank():
def __init__(self,
sr=24000,
n_fft=2048,
hop_length=300,
win_length=None,
window="hann",
n_mels=80,
fmin=80,
fmax=7600,
eps=1e-10):
sr: int=24000,
n_fft: int=2048,
hop_length: int=300,
win_length: int=None,
window: str="hann",
n_mels: int=80,
fmin: int=80,
fmax: int=7600):
self.sr = sr
# stft
self.n_fft = n_fft
@ -54,7 +53,7 @@ class LogMelFBank():
fmax=self.fmax)
return mel_filter
def _stft(self, wav):
def _stft(self, wav: np.ndarray):
D = librosa.core.stft(
wav,
n_fft=self.n_fft,
@ -65,11 +64,11 @@ class LogMelFBank():
pad_mode=self.pad_mode)
return D
def _spectrogram(self, wav):
def _spectrogram(self, wav: np.ndarray):
D = self._stft(wav)
return np.abs(D)
def _mel_spectrogram(self, wav):
def _mel_spectrogram(self, wav: np.ndarray):
S = self._spectrogram(wav)
mel = np.dot(self.mel_filter, S)
return mel
@ -90,14 +89,18 @@ class LogMelFBank():
class Pitch():
def __init__(self, sr=24000, hop_length=300, f0min=80, f0max=7600):
def __init__(self,
sr: int=24000,
hop_length: int=300,
f0min: int=80,
f0max: int=7600):
self.sr = sr
self.hop_length = hop_length
self.f0min = f0min
self.f0max = f0max
def _convert_to_continuous_f0(self, f0: np.array) -> np.array:
def _convert_to_continuous_f0(self, f0: np.ndarray) -> np.ndarray:
if (f0 == 0).all():
print("All frames seems to be unvoiced.")
return f0
@ -120,9 +123,9 @@ class Pitch():
return f0
def _calculate_f0(self,
input: np.array,
use_continuous_f0=True,
use_log_f0=True) -> np.array:
input: np.ndarray,
use_continuous_f0: bool=True,
use_log_f0: bool=True) -> np.ndarray:
input = input.astype(np.float)
frame_period = 1000 * self.hop_length / self.sr
f0, timeaxis = pyworld.dio(
@ -139,7 +142,8 @@ class Pitch():
f0[nonzero_idxs] = np.log(f0[nonzero_idxs])
return f0.reshape(-1)
def _average_by_duration(self, input: np.array, d: np.array) -> np.array:
def _average_by_duration(self, input: np.ndarray,
d: np.ndarray) -> np.ndarray:
d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant')
arr_list = []
for start, end in zip(d_cumsum[:-1], d_cumsum[1:]):
@ -154,11 +158,11 @@ class Pitch():
return arr_list
def get_pitch(self,
wav,
use_continuous_f0=True,
use_log_f0=True,
use_token_averaged_f0=True,
duration=None):
wav: np.ndarray,
use_continuous_f0: bool=True,
use_log_f0: bool=True,
use_token_averaged_f0: bool=True,
duration: np.ndarray=None):
f0 = self._calculate_f0(wav, use_continuous_f0, use_log_f0)
if use_token_averaged_f0 and duration is not None:
f0 = self._average_by_duration(f0, duration)
@ -167,13 +171,13 @@ class Pitch():
class Energy():
def __init__(self,
sr=24000,
n_fft=2048,
hop_length=300,
win_length=None,
window="hann",
center=True,
pad_mode="reflect"):
sr: int=24000,
n_fft: int=2048,
hop_length: int=300,
win_length: int=None,
window: str="hann",
center: bool=True,
pad_mode: str="reflect"):
self.sr = sr
self.n_fft = n_fft
@ -183,7 +187,7 @@ class Energy():
self.center = center
self.pad_mode = pad_mode
def _stft(self, wav):
def _stft(self, wav: np.ndarray):
D = librosa.core.stft(
wav,
n_fft=self.n_fft,
@ -194,7 +198,7 @@ class Energy():
pad_mode=self.pad_mode)
return D
def _calculate_energy(self, input):
def _calculate_energy(self, input: np.ndarray):
input = input.astype(np.float32)
input_stft = self._stft(input)
input_power = np.abs(input_stft)**2
@ -203,7 +207,8 @@ class Energy():
np.sum(input_power, axis=0), a_min=1.0e-10, a_max=float('inf')))
return energy
def _average_by_duration(self, input: np.array, d: np.array) -> np.array:
def _average_by_duration(self, input: np.ndarray,
d: np.ndarray) -> np.ndarray:
d_cumsum = np.pad(d.cumsum(0), (1, 0), 'constant')
arr_list = []
for start, end in zip(d_cumsum[:-1], d_cumsum[1:]):
@ -214,8 +219,49 @@ class Energy():
arr_list = np.expand_dims(np.array(arr_list), 0).T
return arr_list
def get_energy(self, wav, use_token_averaged_energy=True, duration=None):
def get_energy(self,
wav: np.ndarray,
use_token_averaged_energy: bool=True,
duration: np.ndarray=None):
energy = self._calculate_energy(wav)
if use_token_averaged_energy and duration is not None:
energy = self._average_by_duration(energy, duration)
return energy
class LinearSpectrogram():
def __init__(
self,
n_fft: int=1024,
win_length: int=None,
hop_length: int=256,
window: str="hann",
center: bool=True, ):
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.center = center
self.n_fft = n_fft
self.pad_mode = "reflect"
def _stft(self, wav: np.ndarray):
D = librosa.core.stft(
wav,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode=self.pad_mode)
return D
def _spectrogram(self, wav: np.ndarray):
D = self._stft(wav)
return np.abs(D)
def get_linear_spectrogram(self, wav: np.ndarray):
linear_spectrogram = self._spectrogram(wav)
linear_spectrogram = np.clip(
linear_spectrogram, a_min=1e-10, a_max=float("inf"))
return linear_spectrogram.T

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -195,7 +195,7 @@ class Frontend():
new_initials.append(initials[i])
return new_initials, new_finals
def _p2id(self, phonemes: List[str]) -> np.array:
def _p2id(self, phonemes: List[str]) -> np.ndarray:
# replace unk phone with sp
phonemes = [
phn if phn in self.vocab_phones else "sp" for phn in phonemes
@ -203,7 +203,7 @@ class Frontend():
phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)
def _t2id(self, tones: List[str]) -> np.array:
def _t2id(self, tones: List[str]) -> np.ndarray:
# replace unk phone with sp
tones = [tone if tone in self.vocab_tones else "0" for tone in tones]
tone_ids = [self.vocab_tones[item] for item in tones]

@ -16,6 +16,7 @@ import copy
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
import paddle
import paddle.nn.functional as F
@ -34,6 +35,7 @@ class HiFiGANGenerator(nn.Layer):
in_channels: int=80,
out_channels: int=1,
channels: int=512,
global_channels: int=-1,
kernel_size: int=7,
upsample_scales: List[int]=(8, 8, 2, 2),
upsample_kernel_sizes: List[int]=(16, 16, 4, 4),
@ -51,6 +53,7 @@ class HiFiGANGenerator(nn.Layer):
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
channels (int): Number of hidden representation channels.
global_channels (int): Number of global conditioning channels.
kernel_size (int): Kernel size of initial and final conv layer.
upsample_scales (list): List of upsampling scales.
upsample_kernel_sizes (list): List of kernel sizes for upsampling layers.
@ -119,6 +122,9 @@ class HiFiGANGenerator(nn.Layer):
padding=(kernel_size - 1) // 2, ),
nn.Tanh(), )
if global_channels > 0:
self.global_conv = nn.Conv1D(global_channels, channels, 1)
nn.initializer.set_global_initializer(None)
# apply weight norm
@ -128,15 +134,18 @@ class HiFiGANGenerator(nn.Layer):
# reset parameters
self.reset_parameters()
def forward(self, c):
def forward(self, c, g: Optional[paddle.Tensor]=None):
"""Calculate forward propagation.
Args:
c (Tensor): Input tensor (B, in_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T).
"""
c = self.input_conv(c)
if g is not None:
c = c + self.global_conv(g)
for i in range(self.num_upsamples):
c = self.upsamples[i](c)
# initialize
@ -187,16 +196,19 @@ class HiFiGANGenerator(nn.Layer):
self.apply(_remove_weight_norm)
def inference(self, c):
def inference(self, c, g: Optional[paddle.Tensor]=None):
"""Perform inference.
Args:
c (Tensor): Input tensor (T, in_channels).
normalize_before (bool): Whether to perform normalization.
g (Optional[Tensor]): Global conditioning tensor (global_channels, 1).
Returns:
Tensor:
Output tensor (T ** prod(upsample_scales), out_channels).
"""
c = self.forward(c.transpose([1, 0]).unsqueeze(0))
if g is not None:
g = g.unsqueeze(0)
c = self.forward(c.transpose([1, 0]).unsqueeze(0), g=g)
return c.squeeze(0).transpose([1, 0])

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,172 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stochastic duration predictor modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.t2s.models.vits.flow import ConvFlow
from paddlespeech.t2s.models.vits.flow import DilatedDepthSeparableConv
from paddlespeech.t2s.models.vits.flow import ElementwiseAffineFlow
from paddlespeech.t2s.models.vits.flow import FlipFlow
from paddlespeech.t2s.models.vits.flow import LogFlow
class StochasticDurationPredictor(nn.Layer):
"""Stochastic duration predictor module.
This is a module of stochastic duration predictor described in `Conditional
Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2106.06103
"""
def __init__(
self,
channels: int=192,
kernel_size: int=3,
dropout_rate: float=0.5,
flows: int=4,
dds_conv_layers: int=3,
global_channels: int=-1, ):
"""Initialize StochasticDurationPredictor module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
dropout_rate (float): Dropout rate.
flows (int): Number of flows.
dds_conv_layers (int): Number of conv layers in DDS conv.
global_channels (int): Number of global conditioning channels.
"""
super().__init__()
self.pre = nn.Conv1D(channels, channels, 1)
self.dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate, )
self.proj = nn.Conv1D(channels, channels, 1)
self.log_flow = LogFlow()
self.flows = nn.LayerList()
self.flows.append(ElementwiseAffineFlow(2))
for i in range(flows):
self.flows.append(
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers, ))
self.flows.append(FlipFlow())
self.post_pre = nn.Conv1D(1, channels, 1)
self.post_dds = DilatedDepthSeparableConv(
channels,
kernel_size,
layers=dds_conv_layers,
dropout_rate=dropout_rate, )
self.post_proj = nn.Conv1D(channels, channels, 1)
self.post_flows = nn.LayerList()
self.post_flows.append(ElementwiseAffineFlow(2))
for i in range(flows):
self.post_flows.append(
ConvFlow(
2,
channels,
kernel_size,
layers=dds_conv_layers, ))
self.post_flows.append(FlipFlow())
if global_channels > 0:
self.global_conv = nn.Conv1D(global_channels, channels, 1)
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
w: Optional[paddle.Tensor]=None,
g: Optional[paddle.Tensor]=None,
inverse: bool=False,
noise_scale: float=1.0, ) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T_text).
x_mask (Tensor): Mask tensor (B, 1, T_text).
w (Optional[Tensor]): Duration tensor (B, 1, T_text).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1)
inverse (bool): Whether to inverse the flow.
noise_scale (float): Noise scale value.
Returns:
Tensor: If not inverse, negative log-likelihood (NLL) tensor (B,).
If inverse, log-duration tensor (B, 1, T_text).
"""
# stop gradient
# x = x.detach()
x = self.pre(x)
if g is not None:
# stop gradient
x = x + self.global_conv(g.detach())
x = self.dds(x, x_mask)
x = self.proj(x) * x_mask
if not inverse:
assert w is not None, "w must be provided."
h_w = self.post_pre(w)
h_w = self.post_dds(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (paddle.randn([paddle.shape(w)[0], 2, paddle.shape(w)[2]]) *
x_mask)
z_q = e_q
logdet_tot_q = 0.0
for i, flow in enumerate(self.post_flows):
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
logdet_tot_q += logdet_q
z_u, z1 = paddle.split(z_q, [1, 1], 1)
u = F.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += paddle.sum(
(F.log_sigmoid(z_u) + F.log_sigmoid(-z_u)) * x_mask, [1, 2])
logq = (paddle.sum(-0.5 *
(math.log(2 * math.pi) +
(e_q**2)) * x_mask, [1, 2]) - logdet_tot_q)
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
logdet_tot += logdet
z = paddle.concat([z0, z1], 1)
for flow in self.flows:
z, logdet = flow(z, x_mask, g=x, inverse=inverse)
logdet_tot = logdet_tot + logdet
nll = (paddle.sum(0.5 * (math.log(2 * math.pi) +
(z**2)) * x_mask, [1, 2]) - logdet_tot)
# (B,)
return nll + logq
else:
flows = list(reversed(self.flows))
# remove a useless vflow
flows = flows[:-2] + [flows[-1]]
z = (paddle.randn([paddle.shape(x)[0], 2, paddle.shape(x)[2]]) *
noise_scale)
for flow in flows:
z = flow(z, x_mask, g=x, inverse=inverse)
z0, z1 = paddle.split(z, 2, axis=1)
logw = z0
return logw

@ -0,0 +1,316 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Basic Flow modules used in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Optional
from typing import Tuple
from typing import Union
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.transform import piecewise_rational_quadratic_transform
class FlipFlow(nn.Layer):
"""Flip flow module."""
def forward(self, x: paddle.Tensor, *args, inverse: bool=False, **kwargs
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Flipped tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
x = paddle.flip(x, [1])
if not inverse:
logdet = paddle.zeros(paddle.shape(x)[0], dtype=x.dtype)
return x, logdet
else:
return x
class LogFlow(nn.Layer):
"""Log flow module."""
def forward(self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
inverse: bool=False,
eps: float=1e-5,
**kwargs
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
inverse (bool): Whether to inverse the flow.
eps (float): Epsilon for log.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = paddle.log(paddle.clip(x, min=eps)) * x_mask
logdet = paddle.sum(-y, [1, 2])
return y, logdet
else:
x = paddle.exp(x) * x_mask
return x
class ElementwiseAffineFlow(nn.Layer):
"""Elementwise affine flow module."""
def __init__(self, channels: int):
"""Initialize ElementwiseAffineFlow module.
Args:
channels (int): Number of channels.
"""
super().__init__()
self.channels = channels
m = paddle.zeros([channels, 1])
self.m = paddle.create_parameter(
shape=m.shape,
dtype=str(m.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(m))
logs = paddle.zeros([channels, 1])
self.logs = paddle.create_parameter(
shape=logs.shape,
dtype=str(logs.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(logs))
def forward(self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
inverse: bool=False,
**kwargs
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
if not inverse:
y = self.m + paddle.exp(self.logs) * x
y = y * x_mask
logdet = paddle.sum(self.logs * x_mask, [1, 2])
return y, logdet
else:
x = (x - self.m) * paddle.exp(-self.logs) * x_mask
return x
class Transpose(nn.Layer):
"""Transpose module for paddle.nn.Sequential()."""
def __init__(self, dim1: int, dim2: int):
"""Initialize Transpose module."""
super().__init__()
self.dim1 = dim1
self.dim2 = dim2
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
"""Transpose."""
len_dim = len(x.shape)
orig_perm = list(range(len_dim))
new_perm = orig_perm[:]
temp = new_perm[self.dim1]
new_perm[self.dim1] = new_perm[self.dim2]
new_perm[self.dim2] = temp
return paddle.transpose(x, new_perm)
class DilatedDepthSeparableConv(nn.Layer):
"""Dilated depth-separable conv module."""
def __init__(
self,
channels: int,
kernel_size: int,
layers: int,
dropout_rate: float=0.0,
eps: float=1e-5, ):
"""Initialize DilatedDepthSeparableConv module.
Args:
channels (int): Number of channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
dropout_rate (float): Dropout rate.
eps (float): Epsilon for layer norm.
"""
super().__init__()
self.convs = nn.LayerList()
for i in range(layers):
dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2
self.convs.append(
nn.Sequential(
nn.Conv1D(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=padding, ),
Transpose(1, 2),
nn.LayerNorm(channels, epsilon=eps),
Transpose(1, 2),
nn.GELU(),
nn.Conv1D(
channels,
channels,
1, ),
Transpose(1, 2),
nn.LayerNorm(channels, epsilon=eps),
Transpose(1, 2),
nn.GELU(),
nn.Dropout(dropout_rate), ))
def forward(self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, channels, T).
"""
if g is not None:
x = x + g
for f in self.convs:
y = f(x * x_mask)
x = x + y
return x * x_mask
class ConvFlow(nn.Layer):
"""Convolutional flow module."""
def __init__(
self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
layers: int,
bins: int=10,
tail_bound: float=5.0, ):
"""Initialize ConvFlow module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size.
layers (int): Number of layers.
bins (int): Number of bins.
tail_bound (float): Tail bound value.
"""
super().__init__()
self.half_channels = in_channels // 2
self.hidden_channels = hidden_channels
self.bins = bins
self.tail_bound = tail_bound
self.input_conv = nn.Conv1D(
self.half_channels,
hidden_channels,
1, )
self.dds_conv = DilatedDepthSeparableConv(
hidden_channels,
kernel_size,
layers,
dropout_rate=0.0, )
self.proj = nn.Conv1D(
hidden_channels,
self.half_channels * (bins * 3 - 1),
1, )
# self.proj.weight.data.zero_()
# self.proj.bias.data.zero_()
weight = paddle.zeros(paddle.shape(self.proj.weight))
self.proj.weight = paddle.create_parameter(
shape=weight.shape,
dtype=str(weight.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(weight))
bias = paddle.zeros(paddle.shape(self.proj.bias))
self.proj.bias = paddle.create_parameter(
shape=bias.shape,
dtype=str(bias.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(bias))
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None,
inverse: bool=False,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, channels, T).
x_mask (Tensor): Mask tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = x.split(2, 1)
h = self.input_conv(xa)
h = self.dds_conv(h, x_mask, g=g)
# (B, half_channels * (bins * 3 - 1), T)
h = self.proj(h) * x_mask
b, c, t = xa.shape
# (B, half_channels, bins * 3 - 1, T) -> (B, half_channels, T, bins * 3 - 1)
h = h.reshape([b, c, -1, t]).transpose([0, 1, 3, 2])
denom = math.sqrt(self.hidden_channels)
unnorm_widths = h[..., :self.bins] / denom
unnorm_heights = h[..., self.bins:2 * self.bins] / denom
unnorm_derivatives = h[..., 2 * self.bins:]
xb, logdet_abs = piecewise_rational_quadratic_transform(
xb,
unnorm_widths,
unnorm_heights,
unnorm_derivatives,
inverse=inverse,
tails="linear",
tail_bound=self.tail_bound, )
x = paddle.concat([xa, xb], 1) * x_mask
logdet = paddle.sum(logdet_abs * x_mask, [1, 2])
if not inverse:
return x, logdet
else:
return x

@ -0,0 +1,551 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generator module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.t2s.models.hifigan import HiFiGANGenerator
from paddlespeech.t2s.models.vits.duration_predictor import StochasticDurationPredictor
from paddlespeech.t2s.models.vits.posterior_encoder import PosteriorEncoder
from paddlespeech.t2s.models.vits.residual_coupling import ResidualAffineCouplingBlock
from paddlespeech.t2s.models.vits.text_encoder import TextEncoder
from paddlespeech.t2s.modules.nets_utils import get_random_segments
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
class VITSGenerator(nn.Layer):
"""Generator module in VITS.
This is a module of VITS described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`_.
As text encoder, we use conformer architecture instead of the relative positional
Transformer, which contains additional convolution layers.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
vocabs: int,
aux_channels: int=513,
hidden_channels: int=192,
spks: Optional[int]=None,
langs: Optional[int]=None,
spk_embed_dim: Optional[int]=None,
global_channels: int=-1,
segment_size: int=32,
text_encoder_attention_heads: int=2,
text_encoder_ffn_expand: int=4,
text_encoder_blocks: int=6,
text_encoder_positionwise_layer_type: str="conv1d",
text_encoder_positionwise_conv_kernel_size: int=1,
text_encoder_positional_encoding_layer_type: str="rel_pos",
text_encoder_self_attention_layer_type: str="rel_selfattn",
text_encoder_activation_type: str="swish",
text_encoder_normalize_before: bool=True,
text_encoder_dropout_rate: float=0.1,
text_encoder_positional_dropout_rate: float=0.0,
text_encoder_attention_dropout_rate: float=0.0,
text_encoder_conformer_kernel_size: int=7,
use_macaron_style_in_text_encoder: bool=True,
use_conformer_conv_in_text_encoder: bool=True,
decoder_kernel_size: int=7,
decoder_channels: int=512,
decoder_upsample_scales: List[int]=[8, 8, 2, 2],
decoder_upsample_kernel_sizes: List[int]=[16, 16, 4, 4],
decoder_resblock_kernel_sizes: List[int]=[3, 7, 11],
decoder_resblock_dilations: List[List[int]]=[[1, 3, 5], [1, 3, 5],
[1, 3, 5]],
use_weight_norm_in_decoder: bool=True,
posterior_encoder_kernel_size: int=5,
posterior_encoder_layers: int=16,
posterior_encoder_stacks: int=1,
posterior_encoder_base_dilation: int=1,
posterior_encoder_dropout_rate: float=0.0,
use_weight_norm_in_posterior_encoder: bool=True,
flow_flows: int=4,
flow_kernel_size: int=5,
flow_base_dilation: int=1,
flow_layers: int=4,
flow_dropout_rate: float=0.0,
use_weight_norm_in_flow: bool=True,
use_only_mean_in_flow: bool=True,
stochastic_duration_predictor_kernel_size: int=3,
stochastic_duration_predictor_dropout_rate: float=0.5,
stochastic_duration_predictor_flows: int=4,
stochastic_duration_predictor_dds_conv_layers: int=3, ):
"""Initialize VITS generator module.
Args:
vocabs (int): Input vocabulary size.
aux_channels (int): Number of acoustic feature channels.
hidden_channels (int): Number of hidden channels.
spks (Optional[int]): Number of speakers. If set to > 1, assume that the
sids will be provided as the input and use sid embedding layer.
langs (Optional[int]): Number of languages. If set to > 1, assume that the
lids will be provided as the input and use sid embedding layer.
spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
assume that spembs will be provided as the input.
global_channels (int): Number of global conditioning channels.
segment_size (int): Segment size for decoder.
text_encoder_attention_heads (int): Number of heads in conformer block
of text encoder.
text_encoder_ffn_expand (int): Expansion ratio of FFN in conformer block
of text encoder.
text_encoder_blocks (int): Number of conformer blocks in text encoder.
text_encoder_positionwise_layer_type (str): Position-wise layer type in
conformer block of text encoder.
text_encoder_positionwise_conv_kernel_size (int): Position-wise convolution
kernel size in conformer block of text encoder. Only used when the
above layer type is conv1d or conv1d-linear.
text_encoder_positional_encoding_layer_type (str): Positional encoding layer
type in conformer block of text encoder.
text_encoder_self_attention_layer_type (str): Self-attention layer type in
conformer block of text encoder.
text_encoder_activation_type (str): Activation function type in conformer
block of text encoder.
text_encoder_normalize_before (bool): Whether to apply layer norm before
self-attention in conformer block of text encoder.
text_encoder_dropout_rate (float): Dropout rate in conformer block of
text encoder.
text_encoder_positional_dropout_rate (float): Dropout rate for positional
encoding in conformer block of text encoder.
text_encoder_attention_dropout_rate (float): Dropout rate for attention in
conformer block of text encoder.
text_encoder_conformer_kernel_size (int): Conformer conv kernel size. It
will be used when only use_conformer_conv_in_text_encoder = True.
use_macaron_style_in_text_encoder (bool): Whether to use macaron style FFN
in conformer block of text encoder.
use_conformer_conv_in_text_encoder (bool): Whether to use covolution in
conformer block of text encoder.
decoder_kernel_size (int): Decoder kernel size.
decoder_channels (int): Number of decoder initial channels.
decoder_upsample_scales (List[int]): List of upsampling scales in decoder.
decoder_upsample_kernel_sizes (List[int]): List of kernel size for
upsampling layers in decoder.
decoder_resblock_kernel_sizes (List[int]): List of kernel size for resblocks
in decoder.
decoder_resblock_dilations (List[List[int]]): List of list of dilations for
resblocks in decoder.
use_weight_norm_in_decoder (bool): Whether to apply weight normalization in
decoder.
posterior_encoder_kernel_size (int): Posterior encoder kernel size.
posterior_encoder_layers (int): Number of layers of posterior encoder.
posterior_encoder_stacks (int): Number of stacks of posterior encoder.
posterior_encoder_base_dilation (int): Base dilation of posterior encoder.
posterior_encoder_dropout_rate (float): Dropout rate for posterior encoder.
use_weight_norm_in_posterior_encoder (bool): Whether to apply weight
normalization in posterior encoder.
flow_flows (int): Number of flows in flow.
flow_kernel_size (int): Kernel size in flow.
flow_base_dilation (int): Base dilation in flow.
flow_layers (int): Number of layers in flow.
flow_dropout_rate (float): Dropout rate in flow
use_weight_norm_in_flow (bool): Whether to apply weight normalization in
flow.
use_only_mean_in_flow (bool): Whether to use only mean in flow.
stochastic_duration_predictor_kernel_size (int): Kernel size in stochastic
duration predictor.
stochastic_duration_predictor_dropout_rate (float): Dropout rate in
stochastic duration predictor.
stochastic_duration_predictor_flows (int): Number of flows in stochastic
duration predictor.
stochastic_duration_predictor_dds_conv_layers (int): Number of DDS conv
layers in stochastic duration predictor.
"""
super().__init__()
self.segment_size = segment_size
self.text_encoder = TextEncoder(
vocabs=vocabs,
attention_dim=hidden_channels,
attention_heads=text_encoder_attention_heads,
linear_units=hidden_channels * text_encoder_ffn_expand,
blocks=text_encoder_blocks,
positionwise_layer_type=text_encoder_positionwise_layer_type,
positionwise_conv_kernel_size=text_encoder_positionwise_conv_kernel_size,
positional_encoding_layer_type=text_encoder_positional_encoding_layer_type,
self_attention_layer_type=text_encoder_self_attention_layer_type,
activation_type=text_encoder_activation_type,
normalize_before=text_encoder_normalize_before,
dropout_rate=text_encoder_dropout_rate,
positional_dropout_rate=text_encoder_positional_dropout_rate,
attention_dropout_rate=text_encoder_attention_dropout_rate,
conformer_kernel_size=text_encoder_conformer_kernel_size,
use_macaron_style=use_macaron_style_in_text_encoder,
use_conformer_conv=use_conformer_conv_in_text_encoder, )
self.decoder = HiFiGANGenerator(
in_channels=hidden_channels,
out_channels=1,
channels=decoder_channels,
global_channels=global_channels,
kernel_size=decoder_kernel_size,
upsample_scales=decoder_upsample_scales,
upsample_kernel_sizes=decoder_upsample_kernel_sizes,
resblock_kernel_sizes=decoder_resblock_kernel_sizes,
resblock_dilations=decoder_resblock_dilations,
use_weight_norm=use_weight_norm_in_decoder, )
self.posterior_encoder = PosteriorEncoder(
in_channels=aux_channels,
out_channels=hidden_channels,
hidden_channels=hidden_channels,
kernel_size=posterior_encoder_kernel_size,
layers=posterior_encoder_layers,
stacks=posterior_encoder_stacks,
base_dilation=posterior_encoder_base_dilation,
global_channels=global_channels,
dropout_rate=posterior_encoder_dropout_rate,
use_weight_norm=use_weight_norm_in_posterior_encoder, )
self.flow = ResidualAffineCouplingBlock(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
flows=flow_flows,
kernel_size=flow_kernel_size,
base_dilation=flow_base_dilation,
layers=flow_layers,
global_channels=global_channels,
dropout_rate=flow_dropout_rate,
use_weight_norm=use_weight_norm_in_flow,
use_only_mean=use_only_mean_in_flow, )
# TODO: Add deterministic version as an option
self.duration_predictor = StochasticDurationPredictor(
channels=hidden_channels,
kernel_size=stochastic_duration_predictor_kernel_size,
dropout_rate=stochastic_duration_predictor_dropout_rate,
flows=stochastic_duration_predictor_flows,
dds_conv_layers=stochastic_duration_predictor_dds_conv_layers,
global_channels=global_channels, )
self.upsample_factor = int(np.prod(decoder_upsample_scales))
self.spks = None
if spks is not None and spks > 1:
assert global_channels > 0
self.spks = spks
self.global_emb = nn.Embedding(spks, global_channels)
self.spk_embed_dim = None
if spk_embed_dim is not None and spk_embed_dim > 0:
assert global_channels > 0
self.spk_embed_dim = spk_embed_dim
self.spemb_proj = nn.Linear(spk_embed_dim, global_channels)
self.langs = None
if langs is not None and langs > 1:
assert global_channels > 0
self.langs = langs
self.lang_emb = nn.Embedding(langs, global_channels)
# delayed import
from paddlespeech.t2s.models.vits.monotonic_align import maximum_path
self.maximum_path = maximum_path
def forward(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor, paddle.Tensor,
Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor, paddle.Tensor, ], ]:
"""Calculate forward propagation.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, aux_channels, T_feats).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Tensor: Waveform tensor (B, 1, segment_size * upsample_factor).
Tensor: Duration negative log-likelihood (NLL) tensor (B,).
Tensor: Monotonic attention weight tensor (B, 1, T_feats, T_text).
Tensor: Segments start index tensor (B,).
Tensor: Text mask tensor (B, 1, T_text).
Tensor: Feature mask tensor (B, 1, T_feats).
tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
- Tensor: Posterior encoder hidden representation (B, H, T_feats).
- Tensor: Flow hidden representation (B, H, T_feats).
- Tensor: Expanded text encoder projected mean (B, H, T_feats).
- Tensor: Expanded text encoder projected scale (B, H, T_feats).
- Tensor: Posterior encoder projected mean (B, H, T_feats).
- Tensor: Posterior encoder projected scale (B, H, T_feats).
"""
# forward text encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
# calculate global conditioning
g = None
if self.spks is not None:
# speaker one-hot vector embedding: (B, global_channels, 1)
g = self.global_emb(paddle.reshape(sids, [-1])).unsqueeze(-1)
if self.spk_embed_dim is not None:
# pretreined speaker embedding, e.g., X-vector (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if self.langs is not None:
# language one-hot vector embedding: (B, global_channels, 1)
g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(
feats, feats_lengths, g=g)
# forward flow
# (B, H, T_feats)
z_p = self.flow(z, y_mask, g=g)
# monotonic alignment search
with paddle.no_grad():
# negative cross-entropy
# (B, H, T_text)
s_p_sq_r = paddle.exp(-2 * logs_p)
# (B, 1, T_text)
neg_x_ent_1 = paddle.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
[1],
keepdim=True, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_2 = paddle.matmul(
-0.5 * (z_p**2).transpose([0, 2, 1]),
s_p_sq_r, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_3 = paddle.matmul(
z_p.transpose([0, 2, 1]),
(m_p * s_p_sq_r), )
# (B, 1, T_text)
neg_x_ent_4 = paddle.sum(
-0.5 * (m_p**2) * s_p_sq_r,
[1],
keepdim=True, )
# (B, T_feats, T_text)
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
# (B, 1, T_feats, T_text)
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
-1)
# monotonic attention weight: (B, 1, T_feats, T_text)
attn = (self.maximum_path(
neg_x_ent,
attn_mask.squeeze(1), ).unsqueeze(1).detach())
# forward duration predictor
# (B, 1, T_text)
w = attn.sum(2)
dur_nll = self.duration_predictor(x, x_mask, w=w, g=g)
dur_nll = dur_nll / paddle.sum(x_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = paddle.matmul(attn.squeeze(1),
m_p.transpose([0, 2, 1])).transpose([0, 2, 1])
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
logs_p = paddle.matmul(attn.squeeze(1),
logs_p.transpose([0, 2, 1])).transpose([0, 2, 1])
# get random segments
z_segments, z_start_idxs = get_random_segments(
z,
feats_lengths,
self.segment_size, )
# forward decoder with random segments
wav = self.decoder(z_segments, g=g)
return (wav, dur_nll, attn, z_start_idxs, x_mask, y_mask,
(z, z_p, m_p, logs_p, m_q, logs_q), )
def inference(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: Optional[paddle.Tensor]=None,
feats_lengths: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
dur: Optional[paddle.Tensor]=None,
noise_scale: float=0.667,
noise_scale_dur: float=0.8,
alpha: float=1.0,
max_len: Optional[int]=None,
use_teacher_forcing: bool=False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Run inference.
Args:
text (Tensor): Input text index tensor (B, T_text,).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
dur (Optional[Tensor]): Ground-truth duration (B, T_text,). If provided,
skip the prediction of durations (i.e., teacher forcing).
noise_scale (float): Noise scale parameter for flow.
noise_scale_dur (float): Noise scale parameter for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length of acoustic feature sequence.
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
Tensor: Generated waveform tensor (B, T_wav).
Tensor: Monotonic attention weight tensor (B, T_feats, T_text).
Tensor: Duration tensor (B, T_text).
"""
# encoder
x, m_p, logs_p, x_mask = self.text_encoder(text, text_lengths)
g = None
if self.spks is not None:
# (B, global_channels, 1)
g = self.global_emb(paddle.reshape(sids, [-1])).unsqueeze(-1)
if self.spk_embed_dim is not None:
# (B, global_channels, 1)
g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if self.langs is not None:
# (B, global_channels, 1)
g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1)
if g is None:
g = g_
else:
g = g + g_
if use_teacher_forcing:
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(
feats, feats_lengths, g=g)
# forward flow
# (B, H, T_feats)
z_p = self.flow(z, y_mask, g=g)
# monotonic alignment search
# (B, H, T_text)
s_p_sq_r = paddle.exp(-2 * logs_p)
# (B, 1, T_text)
neg_x_ent_1 = paddle.sum(
-0.5 * math.log(2 * math.pi) - logs_p,
[1],
keepdim=True, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_2 = paddle.matmul(
-0.5 * (z_p**2).transpose([0, 2, 1]),
s_p_sq_r, )
# (B, T_feats, H) x (B, H, T_text) = (B, T_feats, T_text)
neg_x_ent_3 = paddle.matmul(
z_p.transpose([0, 2, 1]),
(m_p * s_p_sq_r), )
# (B, 1, T_text)
neg_x_ent_4 = paddle.sum(
-0.5 * (m_p**2) * s_p_sq_r,
[1],
keepdim=True, )
# (B, T_feats, T_text)
neg_x_ent = neg_x_ent_1 + neg_x_ent_2 + neg_x_ent_3 + neg_x_ent_4
# (B, 1, T_feats, T_text)
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
-1)
# monotonic attention weight: (B, 1, T_feats, T_text)
attn = self.maximum_path(
neg_x_ent,
attn_mask.squeeze(1), ).unsqueeze(1)
# (B, 1, T_text)
dur = attn.sum(2)
# forward decoder with random segments
wav = self.decoder(z * y_mask, g=g)
else:
# duration
if dur is None:
logw = self.duration_predictor(
x,
x_mask,
g=g,
inverse=True,
noise_scale=noise_scale_dur, )
w = paddle.exp(logw) * x_mask * alpha
dur = paddle.ceil(w)
y_lengths = paddle.cast(
paddle.clip(paddle.sum(dur, [1, 2]), min=1), dtype='int64')
y_mask = make_non_pad_mask(y_lengths).unsqueeze(1)
attn_mask = paddle.unsqueeze(x_mask, 2) * paddle.unsqueeze(y_mask,
-1)
attn = self._generate_path(dur, attn_mask)
# expand the length to match with the feature sequence
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
m_p = paddle.matmul(
attn.squeeze(1),
m_p.transpose([0, 2, 1]), ).transpose([0, 2, 1])
# (B, T_feats, T_text) x (B, T_text, H) -> (B, H, T_feats)
logs_p = paddle.matmul(
attn.squeeze(1),
logs_p.transpose([0, 2, 1]), ).transpose([0, 2, 1])
# decoder
z_p = m_p + paddle.randn(
paddle.shape(m_p)) * paddle.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=g, inverse=True)
wav = self.decoder((z * y_mask)[:, :, :max_len], g=g)
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
def _generate_path(self, dur: paddle.Tensor,
mask: paddle.Tensor) -> paddle.Tensor:
"""Generate path a.k.a. monotonic attention.
Args:
dur (Tensor): Duration tensor (B, 1, T_text).
mask (Tensor): Attention mask tensor (B, 1, T_feats, T_text).
Returns:
Tensor: Path tensor (B, 1, T_feats, T_text).
"""
b, _, t_y, t_x = paddle.shape(mask)
cum_dur = paddle.cumsum(dur, -1)
cum_dur_flat = paddle.reshape(cum_dur, [b * t_x])
path = paddle.arange(t_y, dtype=dur.dtype)
path = path.unsqueeze(0) < cum_dur_flat.unsqueeze(1)
path = paddle.reshape(path, [b, t_x, t_y])
'''
path will be like (t_x = 3, t_y = 5):
[[[1., 1., 0., 0., 0.], [[[1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0.], --> [0., 0., 1., 1., 0.],
[1., 1., 1., 1., 1.]]] [0., 0., 0., 0., 1.]]]
'''
path = paddle.cast(path, dtype='float32')
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
return path.unsqueeze(1).transpose([0, 1, 3, 2]) * mask

@ -0,0 +1,94 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Maximum path calculation module.
This code is based on https://github.com/jaywalnut310/vits.
"""
import warnings
import numpy as np
import paddle
from numba import njit
from numba import prange
try:
from .core import maximum_path_c
is_cython_avalable = True
except ImportError:
is_cython_avalable = False
warnings.warn(
"Cython version is not available. Fallback to 'EXPERIMETAL' numba version. "
"If you want to use the cython version, please build it as follows: "
"`cd paddlespeech/t2s/models/vits/monotonic_align; python setup.py build_ext --inplace`"
)
def maximum_path(neg_x_ent: paddle.Tensor,
attn_mask: paddle.Tensor) -> paddle.Tensor:
"""Calculate maximum path.
Args:
neg_x_ent (Tensor): Negative X entropy tensor (B, T_feats, T_text).
attn_mask (Tensor): Attention mask (B, T_feats, T_text).
Returns:
Tensor: Maximum path tensor (B, T_feats, T_text).
"""
dtype = neg_x_ent.dtype
neg_x_ent = neg_x_ent.numpy().astype(np.float32)
path = np.zeros(neg_x_ent.shape, dtype=np.int32)
t_t_max = attn_mask.sum(1)[:, 0].cpu().numpy().astype(np.int32)
t_s_max = attn_mask.sum(2)[:, 0].cpu().numpy().astype(np.int32)
if is_cython_avalable:
maximum_path_c(path, neg_x_ent, t_t_max, t_s_max)
else:
maximum_path_numba(path, neg_x_ent, t_t_max, t_s_max)
return paddle.cast(paddle.to_tensor(path), dtype=dtype)
@njit
def maximum_path_each_numba(path, value, t_y, t_x, max_neg_val=-np.inf):
"""Calculate a single maximum path with numba."""
index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or
value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@njit(parallel=True)
def maximum_path_numba(paths, values, t_ys, t_xs):
"""Calculate batch maximum path with numba."""
for i in prange(paths.shape[0]):
maximum_path_each_numba(paths[i], values[i], t_ys[i], t_xs[i])

@ -0,0 +1,62 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Maximum path calculation module with cython optimization.
This code is copied from https://github.com/jaywalnut310/vits and modifed code format.
"""
cimport cython
from cython.parallel import prange
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void maximum_path_each(int[:, ::1] path, float[:, ::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
cdef int x
cdef int y
cdef float v_prev
cdef float v_cur
cdef float tmp
cdef int index = t_x - 1
for y in range(t_y):
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
if x == y:
v_cur = max_neg_val
else:
v_cur = value[y - 1, x]
if x == 0:
if y == 0:
v_prev = 0.0
else:
v_prev = max_neg_val
else:
v_prev = value[y - 1, x - 1]
value[y, x] += max(v_prev, v_cur)
for y in range(t_y - 1, -1, -1):
path[y, index] = 1
if index != 0 and (index == y or value[y - 1, index] < value[y - 1, index - 1]):
index = index - 1
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef void maximum_path_c(int[:, :, ::1] paths, float[:, :, ::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
cdef int b = paths.shape[0]
cdef int i
for i in prange(b, nogil=True):
maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])

@ -0,0 +1,39 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup cython code."""
from Cython.Build import cythonize
from setuptools import Extension
from setuptools import setup
from setuptools.command.build_ext import build_ext as _build_ext
class build_ext(_build_ext):
"""Overwrite build_ext."""
def finalize_options(self):
"""Prevent numpy from thinking it is still in its setup process."""
_build_ext.finalize_options(self)
__builtins__.__NUMPY_SETUP__ = False
import numpy
self.include_dirs.append(numpy.get_include())
exts = [Extension(
name="core",
sources=["core.pyx"], )]
setup(
name="monotonic_align",
ext_modules=cythonize(exts, language_level=3),
cmdclass={"build_ext": build_ext}, )

@ -0,0 +1,120 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text encoder module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional
from typing import Tuple
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
class PosteriorEncoder(nn.Layer):
"""Posterior encoder module in VITS.
This is a module of posterior encoder described in `Conditional Variational
Autoencoder with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int=513,
out_channels: int=192,
hidden_channels: int=192,
kernel_size: int=5,
layers: int=16,
stacks: int=1,
base_dilation: int=1,
global_channels: int=-1,
dropout_rate: float=0.0,
bias: bool=True,
use_weight_norm: bool=True, ):
"""Initilialize PosteriorEncoder module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size in WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of repeat stacking of WaveNet.
base_dilation (int): Base dilation factor.
global_channels (int): Number of global conditioning channels.
dropout_rate (float): Dropout rate.
bias (bool): Whether to use bias parameters in conv.
use_weight_norm (bool): Whether to apply weight norm.
"""
super().__init__()
# define modules
self.input_conv = nn.Conv1D(in_channels, hidden_channels, 1)
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True, )
self.proj = nn.Conv1D(hidden_channels, out_channels * 2, 1)
def forward(
self,
x: paddle.Tensor,
x_lengths: paddle.Tensor,
g: Optional[paddle.Tensor]=None
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_feats).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Encoded hidden representation tensor (B, out_channels, T_feats).
Tensor: Projected mean tensor (B, out_channels, T_feats).
Tensor: Projected scale tensor (B, out_channels, T_feats).
Tensor: Mask tensor for input tensor (B, 1, T_feats).
"""
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
x = self.input_conv(x) * x_mask
x = self.encoder(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = paddle.split(stats, 2, axis=1)
z = (m + paddle.randn(paddle.shape(m)) * paddle.exp(logs)) * x_mask
return z, m, logs, x_mask

@ -0,0 +1,244 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Residual affine coupling modules in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
from typing import Optional
from typing import Tuple
from typing import Union
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.flow import FlipFlow
from paddlespeech.t2s.models.vits.wavenet.wavenet import WaveNet
class ResidualAffineCouplingBlock(nn.Layer):
"""Residual affine coupling block module.
This is a module of residual affine coupling block, which used as "Flow" in
`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
in_channels: int=192,
hidden_channels: int=192,
flows: int=4,
kernel_size: int=5,
base_dilation: int=1,
layers: int=4,
global_channels: int=-1,
dropout_rate: float=0.0,
use_weight_norm: bool=True,
bias: bool=True,
use_only_mean: bool=True, ):
"""Initilize ResidualAffineCouplingBlock module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
flows (int): Number of flows.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
super().__init__()
self.flows = nn.LayerList()
for i in range(flows):
self.flows.append(
ResidualAffineCouplingLayer(
in_channels=in_channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
base_dilation=base_dilation,
layers=layers,
stacks=1,
global_channels=global_channels,
dropout_rate=dropout_rate,
use_weight_norm=use_weight_norm,
bias=bias,
use_only_mean=use_only_mean, ))
self.flows.append(FlipFlow())
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None,
inverse: bool=False, ) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_mask (Tensor): Length tensor (B, 1, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
"""
if not inverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, inverse=inverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, inverse=inverse)
return x
class ResidualAffineCouplingLayer(nn.Layer):
"""Residual affine coupling layer."""
def __init__(
self,
in_channels: int=192,
hidden_channels: int=192,
kernel_size: int=5,
base_dilation: int=1,
layers: int=5,
stacks: int=1,
global_channels: int=-1,
dropout_rate: float=0.0,
use_weight_norm: bool=True,
bias: bool=True,
use_only_mean: bool=True, ):
"""Initialzie ResidualAffineCouplingLayer module.
Args:
in_channels (int): Number of input channels.
hidden_channels (int): Number of hidden channels.
kernel_size (int): Kernel size for WaveNet.
base_dilation (int): Base dilation factor for WaveNet.
layers (int): Number of layers of WaveNet.
stacks (int): Number of stacks of WaveNet.
global_channels (int): Number of global channels.
dropout_rate (float): Dropout rate.
use_weight_norm (bool): Whether to use weight normalization in WaveNet.
bias (bool): Whether to use bias paramters in WaveNet.
use_only_mean (bool): Whether to estimate only mean.
"""
assert in_channels % 2 == 0, "in_channels should be divisible by 2"
super().__init__()
self.half_channels = in_channels // 2
self.use_only_mean = use_only_mean
# define modules
self.input_conv = nn.Conv1D(
self.half_channels,
hidden_channels,
1, )
self.encoder = WaveNet(
in_channels=-1,
out_channels=-1,
kernel_size=kernel_size,
layers=layers,
stacks=stacks,
base_dilation=base_dilation,
residual_channels=hidden_channels,
aux_channels=-1,
gate_channels=hidden_channels * 2,
skip_channels=hidden_channels,
global_channels=global_channels,
dropout_rate=dropout_rate,
bias=bias,
use_weight_norm=use_weight_norm,
use_first_conv=False,
use_last_conv=False,
scale_residual=False,
scale_skip_connect=True, )
if use_only_mean:
self.proj = nn.Conv1D(
hidden_channels,
self.half_channels,
1, )
else:
self.proj = nn.Conv1D(
hidden_channels,
self.half_channels * 2,
1, )
# self.proj.weight.data.zero_()
# self.proj.bias.data.zero_()
weight = paddle.zeros(paddle.shape(self.proj.weight))
self.proj.weight = paddle.create_parameter(
shape=weight.shape,
dtype=str(weight.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(weight))
bias = paddle.zeros(paddle.shape(self.proj.bias))
self.proj.bias = paddle.create_parameter(
shape=bias.shape,
dtype=str(bias.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(bias))
def forward(
self,
x: paddle.Tensor,
x_mask: paddle.Tensor,
g: Optional[paddle.Tensor]=None,
inverse: bool=False,
) -> Union[paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T).
x_lengths (Tensor): Length tensor (B,).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
inverse (bool): Whether to inverse the flow.
Returns:
Tensor: Output tensor (B, in_channels, T).
Tensor: Log-determinant tensor for NLL (B,) if not inverse.
"""
xa, xb = paddle.split(x, 2, axis=1)
h = self.input_conv(xa) * x_mask
h = self.encoder(h, x_mask, g=g)
stats = self.proj(h) * x_mask
if not self.use_only_mean:
m, logs = paddle.split(stats, 2, axis=1)
else:
m = stats
logs = paddle.zeros(paddle.shape(m))
if not inverse:
xb = m + xb * paddle.exp(logs) * x_mask
x = paddle.concat([xa, xb], 1)
logdet = paddle.sum(logs, [1, 2])
return x, logdet
else:
xb = (xb - m) * paddle.exp(-logs) * x_mask
x = paddle.concat([xa, xb], 1)
return x

@ -0,0 +1,145 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text encoder module in VITS.
This code is based on https://github.com/jaywalnut310/vits.
"""
import math
from typing import Tuple
import paddle
from paddle import nn
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.transformer.encoder import ConformerEncoder as Encoder
class TextEncoder(nn.Layer):
"""Text encoder module in VITS.
This is a module of text encoder described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`_.
Instead of the relative positional Transformer, we use conformer architecture as
the encoder module, which contains additional convolution layers.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
vocabs: int,
attention_dim: int=192,
attention_heads: int=2,
linear_units: int=768,
blocks: int=6,
positionwise_layer_type: str="conv1d",
positionwise_conv_kernel_size: int=3,
positional_encoding_layer_type: str="rel_pos",
self_attention_layer_type: str="rel_selfattn",
activation_type: str="swish",
normalize_before: bool=True,
use_macaron_style: bool=False,
use_conformer_conv: bool=False,
conformer_kernel_size: int=7,
dropout_rate: float=0.1,
positional_dropout_rate: float=0.0,
attention_dropout_rate: float=0.0, ):
"""Initialize TextEncoder module.
Args:
vocabs (int): Vocabulary size.
attention_dim (int): Attention dimension.
attention_heads (int): Number of attention heads.
linear_units (int): Number of linear units of positionwise layers.
blocks (int): Number of encoder blocks.
positionwise_layer_type (str): Positionwise layer type.
positionwise_conv_kernel_size (int): Positionwise layer's kernel size.
positional_encoding_layer_type (str): Positional encoding layer type.
self_attention_layer_type (str): Self-attention layer type.
activation_type (str): Activation function type.
normalize_before (bool): Whether to apply LayerNorm before attention.
use_macaron_style (bool): Whether to use macaron style components.
use_conformer_conv (bool): Whether to use conformer conv layers.
conformer_kernel_size (int): Conformer's conv kernel size.
dropout_rate (float): Dropout rate.
positional_dropout_rate (float): Dropout rate for positional encoding.
attention_dropout_rate (float): Dropout rate for attention.
"""
super().__init__()
# store for forward
self.attention_dim = attention_dim
# define modules
self.emb = nn.Embedding(vocabs, attention_dim)
dist = paddle.distribution.Normal(loc=0.0, scale=attention_dim**-0.5)
w = dist.sample(self.emb.weight.shape)
self.emb.weight.set_value(w)
self.encoder = Encoder(
idim=-1,
input_layer=None,
attention_dim=attention_dim,
attention_heads=attention_heads,
linear_units=linear_units,
num_blocks=blocks,
dropout_rate=dropout_rate,
positional_dropout_rate=positional_dropout_rate,
attention_dropout_rate=attention_dropout_rate,
normalize_before=normalize_before,
positionwise_layer_type=positionwise_layer_type,
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
macaron_style=use_macaron_style,
pos_enc_layer_type=positional_encoding_layer_type,
selfattention_layer_type=self_attention_layer_type,
activation_type=activation_type,
use_cnn_module=use_conformer_conv,
cnn_module_kernel=conformer_kernel_size, )
self.proj = nn.Conv1D(attention_dim, attention_dim * 2, 1)
def forward(
self,
x: paddle.Tensor,
x_lengths: paddle.Tensor,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input index tensor (B, T_text).
x_lengths (Tensor): Length tensor (B,).
Returns:
Tensor: Encoded hidden representation (B, attention_dim, T_text).
Tensor: Projected mean tensor (B, attention_dim, T_text).
Tensor: Projected scale tensor (B, attention_dim, T_text).
Tensor: Mask tensor for input tensor (B, 1, T_text).
"""
x = self.emb(x) * math.sqrt(self.attention_dim)
x_mask = make_non_pad_mask(x_lengths).unsqueeze(1)
# encoder assume the channel last (B, T_text, attention_dim)
# but mask shape shoud be (B, 1, T_text)
x, _ = self.encoder(x, x_mask)
# convert the channel first (B, attention_dim, T_text)
x = paddle.transpose(x, [0, 2, 1])
stats = self.proj(x) * x_mask
m, logs = paddle.split(stats, 2, axis=1)
return x, m, logs, x_mask

@ -0,0 +1,238 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flow-related transformation.
This code is based on https://github.com/bayesiains/nflows.
"""
import numpy as np
import paddle
from paddle.nn import functional as F
from paddlespeech.t2s.modules.nets_utils import paddle_gather
DEFAULT_MIN_BIN_WIDTH = 1e-3
DEFAULT_MIN_BIN_HEIGHT = 1e-3
DEFAULT_MIN_DERIVATIVE = 1e-3
def piecewise_rational_quadratic_transform(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
else:
spline_fn = unconstrained_rational_quadratic_spline
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
outputs, logabsdet = spline_fn(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs)
return outputs, logabsdet
def mask_preprocess(x, mask):
B, C, T, bins = paddle.shape(x)
new_x = paddle.zeros([mask.sum(), bins])
for i in range(bins):
new_x[:, i] = x[:, :, :, i][mask]
return new_x
def unconstrained_rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = paddle.zeros(paddle.shape(inputs))
logabsdet = paddle.zeros(paddle.shape(inputs))
if tails == "linear":
unnormalized_derivatives = F.pad(
unnormalized_derivatives,
pad=[0] * (len(unnormalized_derivatives.shape) - 1) * 2 + [1, 1])
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
unnormalized_widths = mask_preprocess(unnormalized_widths,
inside_interval_mask)
unnormalized_heights = mask_preprocess(unnormalized_heights,
inside_interval_mask)
unnormalized_derivatives = mask_preprocess(unnormalized_derivatives,
inside_interval_mask)
(outputs[inside_interval_mask],
logabsdet[inside_interval_mask], ) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative, )
return outputs, logabsdet
def rational_quadratic_spline(
inputs,
unnormalized_widths,
unnormalized_heights,
unnormalized_derivatives,
inverse=False,
left=0.0,
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
if paddle.min(inputs) < left or paddle.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, axis=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = paddle.cumsum(widths, axis=-1)
cumwidths = F.pad(
cumwidths,
pad=[0] * (len(cumwidths.shape) - 1) * 2 + [1, 0],
mode="constant",
value=0.0)
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
heights = F.softmax(unnormalized_heights, axis=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = paddle.cumsum(heights, axis=-1)
cumheights = F.pad(
cumheights,
pad=[0] * (len(cumheights.shape) - 1) * 2 + [1, 0],
mode="constant",
value=0.0)
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top
heights = cumheights[..., 1:] - cumheights[..., :-1]
if inverse:
bin_idx = _searchsorted(cumheights, inputs)[..., None]
else:
bin_idx = _searchsorted(cumwidths, inputs)[..., None]
input_cumwidths = paddle_gather(cumwidths, -1, bin_idx)[..., 0]
input_bin_widths = paddle_gather(widths, -1, bin_idx)[..., 0]
input_cumheights = paddle_gather(cumheights, -1, bin_idx)[..., 0]
delta = heights / widths
input_delta = paddle_gather(delta, -1, bin_idx)[..., 0]
input_derivatives = paddle_gather(derivatives, -1, bin_idx)[..., 0]
input_derivatives_plus_one = paddle_gather(derivatives[..., 1:], -1,
bin_idx)[..., 0]
input_heights = paddle_gather(heights, -1, bin_idx)[..., 0]
if inverse:
a = (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta
) + input_heights * (input_delta - input_derivatives)
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
input_derivatives + input_derivatives_plus_one - 2 * input_delta)
c = -input_delta * (inputs - input_cumheights)
discriminant = b.pow(2) - 4 * a * c
assert (discriminant >= 0).all()
root = (2 * c) / (-b - paddle.sqrt(discriminant))
outputs = root * input_bin_widths + input_cumwidths
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta
) * theta_one_minus_theta)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2) + 2 * input_delta *
theta_one_minus_theta + input_derivatives * (1 - root).pow(2))
logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log(
denominator)
return outputs, -logabsdet
else:
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (input_delta * theta.pow(2) +
input_derivatives * theta_one_minus_theta)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta
) * theta_one_minus_theta)
outputs = input_cumheights + numerator / denominator
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * theta.pow(2) + 2 * input_delta *
theta_one_minus_theta + input_derivatives * (1 - theta).pow(2))
logabsdet = paddle.log(derivative_numerator) - 2 * paddle.log(
denominator)
return outputs, logabsdet
def _searchsorted(bin_locations, inputs, eps=1e-6):
bin_locations[..., -1] += eps
return paddle.sum(inputs[..., None] >= bin_locations, axis=-1) - 1

@ -0,0 +1,573 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""VITS module"""
from typing import Any
from typing import Dict
from typing import Optional
import paddle
from paddle import nn
from typeguard import check_argument_types
from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator
from paddlespeech.t2s.models.vits.generator import VITSGenerator
from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss
from paddlespeech.t2s.modules.losses import FeatureMatchLoss
from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss
from paddlespeech.t2s.modules.losses import KLDivergenceLoss
from paddlespeech.t2s.modules.losses import MelSpectrogramLoss
from paddlespeech.t2s.modules.nets_utils import get_segments
AVAILABLE_GENERATERS = {
"vits_generator": VITSGenerator,
}
AVAILABLE_DISCRIMINATORS = {
"hifigan_period_discriminator":
HiFiGANPeriodDiscriminator,
"hifigan_scale_discriminator":
HiFiGANScaleDiscriminator,
"hifigan_multi_period_discriminator":
HiFiGANMultiPeriodDiscriminator,
"hifigan_multi_scale_discriminator":
HiFiGANMultiScaleDiscriminator,
"hifigan_multi_scale_multi_period_discriminator":
HiFiGANMultiScaleMultiPeriodDiscriminator,
}
class VITS(nn.Layer):
"""VITS module (generator + discriminator).
This is a module of VITS described in `Conditional Variational Autoencoder
with Adversarial Learning for End-to-End Text-to-Speech`_.
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
Text-to-Speech`: https://arxiv.org/abs/2006.04558
"""
def __init__(
self,
# generator related
idim: int,
odim: int,
sampling_rate: int=22050,
generator_type: str="vits_generator",
generator_params: Dict[str, Any]={
"hidden_channels": 192,
"spks": None,
"langs": None,
"spk_embed_dim": None,
"global_channels": -1,
"segment_size": 32,
"text_encoder_attention_heads": 2,
"text_encoder_ffn_expand": 4,
"text_encoder_blocks": 6,
"text_encoder_positionwise_layer_type": "conv1d",
"text_encoder_positionwise_conv_kernel_size": 1,
"text_encoder_positional_encoding_layer_type": "rel_pos",
"text_encoder_self_attention_layer_type": "rel_selfattn",
"text_encoder_activation_type": "swish",
"text_encoder_normalize_before": True,
"text_encoder_dropout_rate": 0.1,
"text_encoder_positional_dropout_rate": 0.0,
"text_encoder_attention_dropout_rate": 0.0,
"text_encoder_conformer_kernel_size": 7,
"use_macaron_style_in_text_encoder": True,
"use_conformer_conv_in_text_encoder": True,
"decoder_kernel_size": 7,
"decoder_channels": 512,
"decoder_upsample_scales": [8, 8, 2, 2],
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
"decoder_resblock_kernel_sizes": [3, 7, 11],
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"use_weight_norm_in_decoder": True,
"posterior_encoder_kernel_size": 5,
"posterior_encoder_layers": 16,
"posterior_encoder_stacks": 1,
"posterior_encoder_base_dilation": 1,
"posterior_encoder_dropout_rate": 0.0,
"use_weight_norm_in_posterior_encoder": True,
"flow_flows": 4,
"flow_kernel_size": 5,
"flow_base_dilation": 1,
"flow_layers": 4,
"flow_dropout_rate": 0.0,
"use_weight_norm_in_flow": True,
"use_only_mean_in_flow": True,
"stochastic_duration_predictor_kernel_size": 3,
"stochastic_duration_predictor_dropout_rate": 0.5,
"stochastic_duration_predictor_flows": 4,
"stochastic_duration_predictor_dds_conv_layers": 3,
},
# discriminator related
discriminator_type: str="hifigan_multi_scale_multi_period_discriminator",
discriminator_params: Dict[str, Any]={
"scales": 1,
"scale_downsample_pooling": "AvgPool1D",
"scale_downsample_pooling_params": {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
"scale_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "leakyrelu",
"nonlinear_activation_params": {
"negative_slope": 0.1
},
"use_weight_norm": True,
"use_spectral_norm": False,
},
"follow_official_norm": False,
"periods": [2, 3, 5, 7, 11],
"period_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "leakyrelu",
"nonlinear_activation_params": {
"negative_slope": 0.1
},
"use_weight_norm": True,
"use_spectral_norm": False,
},
},
# loss related
generator_adv_loss_params: Dict[str, Any]={
"average_by_discriminators": False,
"loss_type": "mse",
},
discriminator_adv_loss_params: Dict[str, Any]={
"average_by_discriminators": False,
"loss_type": "mse",
},
feat_match_loss_params: Dict[str, Any]={
"average_by_discriminators": False,
"average_by_layers": False,
"include_final_outputs": True,
},
mel_loss_params: Dict[str, Any]={
"fs": 22050,
"fft_size": 1024,
"hop_size": 256,
"win_length": None,
"window": "hann",
"num_mels": 80,
"fmin": 0,
"fmax": None,
"log_base": None,
},
lambda_adv: float=1.0,
lambda_mel: float=45.0,
lambda_feat_match: float=2.0,
lambda_dur: float=1.0,
lambda_kl: float=1.0,
cache_generator_outputs: bool=True, ):
"""Initialize VITS module.
Args:
idim (int): Input vocabrary size.
odim (int): Acoustic feature dimension. The actual output channels will
be 1 since VITS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int): Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
generator_type (str): Generator type.
generator_params (Dict[str, Any]): Parameter dict for generator.
discriminator_type (str): Discriminator type.
discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator
adversarial loss.
discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for
discriminator adversarial loss.
feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss.
mel_loss_params (Dict[str, Any]): Parameter dict for mel loss.
lambda_adv (float): Loss scaling coefficient for adversarial loss.
lambda_mel (float): Loss scaling coefficient for mel spectrogram loss.
lambda_feat_match (float): Loss scaling coefficient for feat match loss.
lambda_dur (float): Loss scaling coefficient for duration loss.
lambda_kl (float): Loss scaling coefficient for KL divergence loss.
cache_generator_outputs (bool): Whether to cache generator outputs.
"""
assert check_argument_types()
super().__init__()
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "vits_generator":
# NOTE: Update parameters for the compatibility.
# The idim and odim is automatically decided from input data,
# where idim represents #vocabularies and odim represents
# the input acoustic feature dimension.
generator_params.update(vocabs=idim, aux_channels=odim)
self.generator = generator_class(
**generator_params, )
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
self.discriminator = discriminator_class(
**discriminator_params, )
self.generator_adv_loss = GeneratorAdversarialLoss(
**generator_adv_loss_params, )
self.discriminator_adv_loss = DiscriminatorAdversarialLoss(
**discriminator_adv_loss_params, )
self.feat_match_loss = FeatureMatchLoss(
**feat_match_loss_params, )
self.mel_loss = MelSpectrogramLoss(
**mel_loss_params, )
self.kl_loss = KLDivergenceLoss()
# coefficients
self.lambda_adv = lambda_adv
self.lambda_mel = lambda_mel
self.lambda_kl = lambda_kl
self.lambda_feat_match = lambda_feat_match
self.lambda_dur = lambda_dur
# cache
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
# store sampling rate for saving wav file
# (not used for the training)
self.fs = sampling_rate
# store parameters for test compatibility
self.spks = self.generator.spks
self.langs = self.generator.langs
self.spk_embed_dim = self.generator.spk_embed_dim
@property
def require_raw_speech(self):
"""Return whether or not speech is required."""
return True
@property
def require_vocoder(self):
"""Return whether or not vocoder is required."""
return False
def forward(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
forward_generator: bool=True, ) -> Dict[str, Any]:
"""Perform generator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
forward_generator (bool): Whether to forward generator.
Returns:
Dict[str, Any]:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
if forward_generator:
return self._forward_generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids, )
else:
return self._forward_discrminator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids, )
def _forward_generator(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]:
"""Perform generator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Dict[str, Any]:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
* weight (Tensor): Weight tensor to summarize losses.
* optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
# setup
batch_size = paddle.shape(text)[0]
feats = feats.transpose([0, 2, 1])
# speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids, )
else:
outs = self._cache
# store cache
if self.training and self.cache_generator_outputs and not reuse_cache:
self._cache = outs
return outs
"""
# parse outputs
speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs
_, z_p, m_p, logs_p, _, logs_q = outs_
speech_ = get_segments(
x=speech,
start_idxs=start_idxs * self.generator.upsample_factor,
segment_size=self.generator.segment_size *
self.generator.upsample_factor, )
# calculate discriminator outputs
p_hat = self.discriminator(speech_hat_)
with paddle.no_grad():
# do not store discriminator gradient in generator turn
p = self.discriminator(speech_)
# calculate losses
mel_loss = self.mel_loss(speech_hat_, speech_)
kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
dur_loss = paddle.sum(dur_nll.float())
adv_loss = self.generator_adv_loss(p_hat)
feat_match_loss = self.feat_match_loss(p_hat, p)
mel_loss = mel_loss * self.lambda_mel
kl_loss = kl_loss * self.lambda_kl
dur_loss = dur_loss * self.lambda_dur
adv_loss = adv_loss * self.lambda_adv
feat_match_loss = feat_match_loss * self.lambda_feat_match
loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss
stats = dict(
generator_loss=loss.item(),
generator_mel_loss=mel_loss.item(),
generator_kl_loss=kl_loss.item(),
generator_dur_loss=dur_loss.item(),
generator_adv_loss=adv_loss.item(),
generator_feat_match_loss=feat_match_loss.item(), )
# reset cache
if reuse_cache or not self.training:
self._cache = None
return {
"loss": loss,
"stats": stats,
# "weight": weight,
"optim_idx": 0, # needed for trainer
}
"""
def _forward_discrminator(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]:
"""Perform discriminator forward.
Args:
text (Tensor): Text index tensor (B, T_text).
text_lengths (Tensor): Text length tensor (B,).
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor): Feature length tensor (B,).
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Dict[str, Any]:
* loss (Tensor): Loss scalar tensor.
* stats (Dict[str, float]): Statistics to be monitored.
* weight (Tensor): Weight tensor to summarize losses.
* optim_idx (int): Optimizer index (0 for G and 1 for D).
"""
# setup
batch_size = paddle.shape(text)[0]
feats = feats.transpose([0, 2, 1])
# speech = speech.unsqueeze(1)
# calculate generator outputs
reuse_cache = True
if not self.cache_generator_outputs or self._cache is None:
reuse_cache = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids, )
else:
outs = self._cache
# store cache
if self.cache_generator_outputs and not reuse_cache:
self._cache = outs
return outs
"""
# parse outputs
speech_hat_, _, _, start_idxs, *_ = outs
speech_ = get_segments(
x=speech,
start_idxs=start_idxs * self.generator.upsample_factor,
segment_size=self.generator.segment_size *
self.generator.upsample_factor, )
# calculate discriminator outputs
p_hat = self.discriminator(speech_hat_.detach())
p = self.discriminator(speech_)
# calculate losses
real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p)
loss = real_loss + fake_loss
stats = dict(
discriminator_loss=loss.item(),
discriminator_real_loss=real_loss.item(),
discriminator_fake_loss=fake_loss.item(), )
# reset cache
if reuse_cache or not self.training:
self._cache = None
return {
"loss": loss,
"stats": stats,
# "weight": weight,
"optim_idx": 1, # needed for trainer
}
"""
def inference(
self,
text: paddle.Tensor,
feats: Optional[paddle.Tensor]=None,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
durations: Optional[paddle.Tensor]=None,
noise_scale: float=0.667,
noise_scale_dur: float=0.8,
alpha: float=1.0,
max_len: Optional[int]=None,
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
"""Run inference.
Args:
text (Tensor): Input text index tensor (T_text,).
feats (Tensor): Feature tensor (T_feats, aux_channels).
sids (Tensor): Speaker index tensor (1,).
spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
lids (Tensor): Language index tensor (1,).
durations (Tensor): Ground-truth duration tensor (T_text,).
noise_scale (float): Noise scale value for flow.
noise_scale_dur (float): Noise scale value for duration predictor.
alpha (float): Alpha parameter to control the speed of generated speech.
max_len (Optional[int]): Maximum length.
use_teacher_forcing (bool): Whether to use teacher forcing.
Returns:
Dict[str, Tensor]:
* wav (Tensor): Generated waveform tensor (T_wav,).
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
* duration (Tensor): Predicted duration tensor (T_text,).
"""
# setup
text = text[None]
text_lengths = paddle.to_tensor(paddle.shape(text)[1])
# if sids is not None:
# sids = sids.view(1)
# if lids is not None:
# lids = lids.view(1)
if durations is not None:
durations = paddle.reshape(durations, [1, 1, -1])
# inference
if use_teacher_forcing:
assert feats is not None
feats = feats[None].transpose([0, 2, 1])
feats_lengths = paddle.to_tensor([paddle.shape(feats)[2]])
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
sids=sids,
spembs=spembs,
lids=lids,
max_len=max_len,
use_teacher_forcing=use_teacher_forcing, )
else:
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
sids=sids,
spembs=spembs,
lids=lids,
dur=durations,
noise_scale=noise_scale,
noise_scale_dur=noise_scale_dur,
alpha=alpha,
max_len=max_len, )
return dict(
wav=paddle.reshape(wav, [-1]), att_w=att_w[0], duration=dur[0])

@ -0,0 +1,13 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,154 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import math
from typing import Optional
from typing import Tuple
import paddle
import paddle.nn.functional as F
from paddle import nn
class ResidualBlock(nn.Layer):
"""Residual block module in WaveNet."""
def __init__(
self,
kernel_size: int=3,
residual_channels: int=64,
gate_channels: int=128,
skip_channels: int=64,
aux_channels: int=80,
global_channels: int=-1,
dropout_rate: float=0.0,
dilation: int=1,
bias: bool=True,
scale_residual: bool=False, ):
"""Initialize ResidualBlock module.
Args:
kernel_size (int): Kernel size of dilation convolution layer.
residual_channels (int): Number of channels for residual connection.
skip_channels (int): Number of channels for skip connection.
aux_channels (int): Number of local conditioning channels.
dropout (float): Dropout probability.
dilation (int): Dilation factor.
bias (bool): Whether to add bias parameter in convolution layers.
scale_residual (bool): Whether to scale the residual outputs.
"""
super().__init__()
self.dropout_rate = dropout_rate
self.residual_channels = residual_channels
self.skip_channels = skip_channels
self.scale_residual = scale_residual
# check
assert (
kernel_size - 1) % 2 == 0, "Not support even number kernel size."
assert gate_channels % 2 == 0
# dilation conv
padding = (kernel_size - 1) // 2 * dilation
self.conv = nn.Conv1D(
residual_channels,
gate_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias_attr=bias, )
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = nn.Conv1D(
aux_channels, gate_channels, kernel_size=1, bias_attr=False)
else:
self.conv1x1_aux = None
# global conditioning
if global_channels > 0:
self.conv1x1_glo = nn.Conv1D(
global_channels, gate_channels, kernel_size=1, bias_attr=False)
else:
self.conv1x1_glo = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
# NOTE: concat two convs into a single conv for the efficiency
# (integrate res 1x1 + skip 1x1 convs)
self.conv1x1_out = nn.Conv1D(
gate_out_channels,
residual_channels + skip_channels,
kernel_size=1,
bias_attr=bias)
def forward(
self,
x: paddle.Tensor,
x_mask: Optional[paddle.Tensor]=None,
c: Optional[paddle.Tensor]=None,
g: Optional[paddle.Tensor]=None,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, residual_channels, T).
x_mask Optional[paddle.Tensor]: Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning tensor (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning tensor (B, global_channels, 1).
Returns:
Tensor: Output tensor for residual connection (B, residual_channels, T).
Tensor: Output tensor for skip connection (B, skip_channels, T).
"""
residual = x
x = F.dropout(x, p=self.dropout_rate, training=self.training)
x = self.conv(x)
# split into two part for gated activation
splitdim = 1
xa, xb = paddle.split(x, 2, axis=splitdim)
# local conditioning
if c is not None:
c = self.conv1x1_aux(c)
ca, cb = paddle.split(c, 2, axis=splitdim)
xa, xb = xa + ca, xb + cb
# global conditioning
if g is not None:
g = self.conv1x1_glo(g)
ga, gb = paddle.split(g, 2, axis=splitdim)
xa, xb = xa + ga, xb + gb
x = paddle.tanh(xa) * F.sigmoid(xb)
# residual + skip 1x1 conv
x = self.conv1x1_out(x)
if x_mask is not None:
x = x * x_mask
# split integrated conv results
x, s = paddle.split(
x, [self.residual_channels, self.skip_channels], axis=1)
# for residual connection
x = x + residual
if self.scale_residual:
x = x * math.sqrt(0.5)
return x, s

@ -0,0 +1,175 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
import math
from typing import Optional
import paddle
from paddle import nn
from paddlespeech.t2s.models.vits.wavenet.residual_block import ResidualBlock
class WaveNet(nn.Layer):
"""WaveNet with global conditioning."""
def __init__(
self,
in_channels: int=1,
out_channels: int=1,
kernel_size: int=3,
layers: int=30,
stacks: int=3,
base_dilation: int=2,
residual_channels: int=64,
aux_channels: int=-1,
gate_channels: int=128,
skip_channels: int=64,
global_channels: int=-1,
dropout_rate: float=0.0,
bias: bool=True,
use_weight_norm: bool=True,
use_first_conv: bool=False,
use_last_conv: bool=False,
scale_residual: bool=False,
scale_skip_connect: bool=False, ):
"""Initialize WaveNet module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size of dilated convolution.
layers (int): Number of residual block layers.
stacks (int): Number of stacks i.e., dilation cycles.
base_dilation (int): Base dilation factor.
residual_channels (int): Number of channels in residual conv.
gate_channels (int): Number of channels in gated conv.
skip_channels (int): Number of channels in skip conv.
aux_channels (int): Number of channels for local conditioning feature.
global_channels (int): Number of channels for global conditioning feature.
dropout_rate (float): Dropout rate. 0.0 means no dropout applied.
bias (bool): Whether to use bias parameter in conv layer.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_first_conv (bool): Whether to use the first conv layers.
use_last_conv (bool): Whether to use the last conv layers.
scale_residual (bool): Whether to scale the residual outputs.
scale_skip_connect (bool): Whether to scale the skip connection outputs.
"""
super().__init__()
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
self.base_dilation = base_dilation
self.use_first_conv = use_first_conv
self.use_last_conv = use_last_conv
self.scale_skip_connect = scale_skip_connect
# check the number of layers and stacks
assert layers % stacks == 0
layers_per_stack = layers // stacks
# define first convolution
if self.use_first_conv:
self.first_conv = nn.Conv1D(
in_channels, residual_channels, kernel_size=1, bias_attr=True)
# define residual blocks
self.conv_layers = nn.LayerList()
for layer in range(layers):
dilation = base_dilation**(layer % layers_per_stack)
conv = ResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
global_channels=global_channels,
dilation=dilation,
dropout_rate=dropout_rate,
bias=bias,
scale_residual=scale_residual, )
self.conv_layers.append(conv)
# define output layers
if self.use_last_conv:
self.last_conv = nn.Sequential(
nn.ReLU(),
nn.Conv1D(
skip_channels, skip_channels, kernel_size=1,
bias_attr=True),
nn.ReLU(),
nn.Conv1D(
skip_channels, out_channels, kernel_size=1, bias_attr=True),
)
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
def forward(
self,
x: paddle.Tensor,
x_mask: Optional[paddle.Tensor]=None,
c: Optional[paddle.Tensor]=None,
g: Optional[paddle.Tensor]=None, ) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T) if use_first_conv else
(B, residual_channels, T).
x_mask (Optional[Tensor]): Mask tensor (B, 1, T).
c (Optional[Tensor]): Local conditioning features (B, aux_channels, T).
g (Optional[Tensor]): Global conditioning features (B, global_channels, 1).
Returns:
Tensor: Output tensor (B, out_channels, T) if use_last_conv else
(B, residual_channels, T).
"""
# encode to hidden representation
if self.use_first_conv:
x = self.first_conv(x)
# residual block
skips = 0.0
for f in self.conv_layers:
x, h = f(x, x_mask=x_mask, c=c, g=g)
skips = skips + h
x = skips
if self.scale_skip_connect:
x = x * math.sqrt(1.0 / len(self.conv_layers))
# apply final layers
if self.use_last_conv:
x = self.last_conv(x)
return x
def apply_weight_norm(self):
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)

@ -1006,3 +1006,40 @@ class FeatureMatchLoss(nn.Layer):
feat_match_loss /= i + 1
return feat_match_loss
# loss for VITS
class KLDivergenceLoss(nn.Layer):
"""KL divergence loss."""
def forward(
self,
z_p: paddle.Tensor,
logs_q: paddle.Tensor,
m_p: paddle.Tensor,
logs_p: paddle.Tensor,
z_mask: paddle.Tensor,
) -> paddle.Tensor:
"""Calculate KL divergence loss.
Args:
z_p (Tensor): Flow hidden representation (B, H, T_feats).
logs_q (Tensor): Posterior encoder projected scale (B, H, T_feats).
m_p (Tensor): Expanded text encoder projected mean (B, H, T_feats).
logs_p (Tensor): Expanded text encoder projected scale (B, H, T_feats).
z_mask (Tensor): Mask tensor (B, 1, T_feats).
Returns:
Tensor: KL divergence loss.
"""
z_p = paddle.cast(z_p, 'float32')
logs_q = paddle.cast(logs_q, 'float32')
m_p = paddle.cast(m_p, 'float32')
logs_p = paddle.cast(logs_p, 'float32')
z_mask = paddle.cast(z_mask, 'float32')
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p) ** 2) * paddle.exp(-2.0 * logs_p)
kl = paddle.sum(kl * z_mask)
loss = kl / paddle.sum(z_mask)
return loss

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from typing import Tuple
import paddle
from paddle import nn
from typeguard import check_argument_types
@ -129,3 +131,66 @@ def initialize(model: nn.Layer, init: str):
nn.initializer.Constant())
else:
raise ValueError("Unknown initialization: " + init)
# for VITS
def get_random_segments(
x: paddle.paddle,
x_lengths: paddle.Tensor,
segment_size: int, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Get random segments.
Args:
x (Tensor): Input tensor (B, C, T).
x_lengths (Tensor): Length tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
Tensor: Start index tensor (B,).
"""
b, c, t = paddle.shape(x)
max_start_idx = x_lengths - segment_size
start_idxs = paddle.cast(paddle.rand([b]) * max_start_idx, 'int64')
segments = get_segments(x, start_idxs, segment_size)
return segments, start_idxs
def get_segments(
x: paddle.Tensor,
start_idxs: paddle.Tensor,
segment_size: int, ) -> paddle.Tensor:
"""Get segments.
Args:
x (Tensor): Input tensor (B, C, T).
start_idxs (Tensor): Start index tensor (B,).
segment_size (int): Segment size.
Returns:
Tensor: Segmented tensor (B, C, segment_size).
"""
b, c, t = paddle.shape(x)
segments = paddle.zeros([b, c, segment_size], dtype=x.dtype)
for i, start_idx in enumerate(start_idxs):
segments[i] = x[i, :, start_idx:start_idx + segment_size]
return segments
# see https://github.com/PaddlePaddle/X2Paddle/blob/develop/docs/pytorch_project_convertor/API_docs/ops/torch.gather.md
def paddle_gather(x, dim, index):
index_shape = index.shape
index_flatten = index.flatten()
if dim < 0:
dim = len(x.shape) + dim
nd_index = []
for k in range(len(x.shape)):
if k == dim:
nd_index.append(index_flatten)
else:
reshape_shape = [1] * len(x.shape)
reshape_shape[k] = x.shape[k]
x_arange = paddle.arange(x.shape[k], dtype=index.dtype)
x_arange = x_arange.reshape(reshape_shape)
dim_index = paddle.expand(x_arange, index_shape).flatten()
nd_index.append(dim_index)
ind2 = paddle.transpose(paddle.stack(nd_index), [1, 0]).astype("int64")
paddle_out = paddle.gather_nd(x, ind2).reshape(index_shape)
return paddle_out

Loading…
Cancel
Save