You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
413 lines
17 KiB
413 lines
17 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# Modified from espnet(https://github.com/espnet/espnet)
|
|
"""VITS module"""
|
|
from typing import Any
|
|
from typing import Dict
|
|
from typing import Optional
|
|
|
|
import paddle
|
|
from paddle import nn
|
|
from typeguard import check_argument_types
|
|
|
|
from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator
|
|
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator
|
|
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator
|
|
from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator
|
|
from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator
|
|
from paddlespeech.t2s.models.vits.generator import VITSGenerator
|
|
from paddlespeech.t2s.modules.nets_utils import initialize
|
|
|
|
AVAILABLE_GENERATERS = {
|
|
"vits_generator": VITSGenerator,
|
|
}
|
|
AVAILABLE_DISCRIMINATORS = {
|
|
"hifigan_period_discriminator":
|
|
HiFiGANPeriodDiscriminator,
|
|
"hifigan_scale_discriminator":
|
|
HiFiGANScaleDiscriminator,
|
|
"hifigan_multi_period_discriminator":
|
|
HiFiGANMultiPeriodDiscriminator,
|
|
"hifigan_multi_scale_discriminator":
|
|
HiFiGANMultiScaleDiscriminator,
|
|
"hifigan_multi_scale_multi_period_discriminator":
|
|
HiFiGANMultiScaleMultiPeriodDiscriminator,
|
|
}
|
|
|
|
|
|
class VITS(nn.Layer):
|
|
"""VITS module (generator + discriminator).
|
|
This is a module of VITS described in `Conditional Variational Autoencoder
|
|
with Adversarial Learning for End-to-End Text-to-Speech`_.
|
|
.. _`Conditional Variational Autoencoder with Adversarial Learning for End-to-End
|
|
Text-to-Speech`: https://arxiv.org/abs/2006.04558
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
# generator related
|
|
idim: int,
|
|
odim: int,
|
|
sampling_rate: int=22050,
|
|
generator_type: str="vits_generator",
|
|
generator_params: Dict[str, Any]={
|
|
"hidden_channels": 192,
|
|
"spks": None,
|
|
"langs": None,
|
|
"spk_embed_dim": None,
|
|
"global_channels": -1,
|
|
"segment_size": 32,
|
|
"text_encoder_attention_heads": 2,
|
|
"text_encoder_ffn_expand": 4,
|
|
"text_encoder_blocks": 6,
|
|
"text_encoder_positionwise_layer_type": "conv1d",
|
|
"text_encoder_positionwise_conv_kernel_size": 1,
|
|
"text_encoder_positional_encoding_layer_type": "rel_pos",
|
|
"text_encoder_self_attention_layer_type": "rel_selfattn",
|
|
"text_encoder_activation_type": "swish",
|
|
"text_encoder_normalize_before": True,
|
|
"text_encoder_dropout_rate": 0.1,
|
|
"text_encoder_positional_dropout_rate": 0.0,
|
|
"text_encoder_attention_dropout_rate": 0.0,
|
|
"text_encoder_conformer_kernel_size": 7,
|
|
"use_macaron_style_in_text_encoder": True,
|
|
"use_conformer_conv_in_text_encoder": True,
|
|
"decoder_kernel_size": 7,
|
|
"decoder_channels": 512,
|
|
"decoder_upsample_scales": [8, 8, 2, 2],
|
|
"decoder_upsample_kernel_sizes": [16, 16, 4, 4],
|
|
"decoder_resblock_kernel_sizes": [3, 7, 11],
|
|
"decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
"use_weight_norm_in_decoder": True,
|
|
"posterior_encoder_kernel_size": 5,
|
|
"posterior_encoder_layers": 16,
|
|
"posterior_encoder_stacks": 1,
|
|
"posterior_encoder_base_dilation": 1,
|
|
"posterior_encoder_dropout_rate": 0.0,
|
|
"use_weight_norm_in_posterior_encoder": True,
|
|
"flow_flows": 4,
|
|
"flow_kernel_size": 5,
|
|
"flow_base_dilation": 1,
|
|
"flow_layers": 4,
|
|
"flow_dropout_rate": 0.0,
|
|
"use_weight_norm_in_flow": True,
|
|
"use_only_mean_in_flow": True,
|
|
"stochastic_duration_predictor_kernel_size": 3,
|
|
"stochastic_duration_predictor_dropout_rate": 0.5,
|
|
"stochastic_duration_predictor_flows": 4,
|
|
"stochastic_duration_predictor_dds_conv_layers": 3,
|
|
},
|
|
# discriminator related
|
|
discriminator_type: str="hifigan_multi_scale_multi_period_discriminator",
|
|
discriminator_params: Dict[str, Any]={
|
|
"scales": 1,
|
|
"scale_downsample_pooling": "AvgPool1D",
|
|
"scale_downsample_pooling_params": {
|
|
"kernel_size": 4,
|
|
"stride": 2,
|
|
"padding": 2,
|
|
},
|
|
"scale_discriminator_params": {
|
|
"in_channels": 1,
|
|
"out_channels": 1,
|
|
"kernel_sizes": [15, 41, 5, 3],
|
|
"channels": 128,
|
|
"max_downsample_channels": 1024,
|
|
"max_groups": 16,
|
|
"bias": True,
|
|
"downsample_scales": [2, 2, 4, 4, 1],
|
|
"nonlinear_activation": "leakyrelu",
|
|
"nonlinear_activation_params": {
|
|
"negative_slope": 0.1
|
|
},
|
|
"use_weight_norm": True,
|
|
"use_spectral_norm": False,
|
|
},
|
|
"follow_official_norm": False,
|
|
"periods": [2, 3, 5, 7, 11],
|
|
"period_discriminator_params": {
|
|
"in_channels": 1,
|
|
"out_channels": 1,
|
|
"kernel_sizes": [5, 3],
|
|
"channels": 32,
|
|
"downsample_scales": [3, 3, 3, 3, 1],
|
|
"max_downsample_channels": 1024,
|
|
"bias": True,
|
|
"nonlinear_activation": "leakyrelu",
|
|
"nonlinear_activation_params": {
|
|
"negative_slope": 0.1
|
|
},
|
|
"use_weight_norm": True,
|
|
"use_spectral_norm": False,
|
|
},
|
|
},
|
|
cache_generator_outputs: bool=True,
|
|
init_type: str="xavier_uniform", ):
|
|
"""Initialize VITS module.
|
|
Args:
|
|
idim (int): Input vocabrary size.
|
|
odim (int): Acoustic feature dimension. The actual output channels will
|
|
be 1 since VITS is the end-to-end text-to-wave model but for the
|
|
compatibility odim is used to indicate the acoustic feature dimension.
|
|
sampling_rate (int): Sampling rate, not used for the training but it will
|
|
be referred in saving waveform during the inference.
|
|
generator_type (str): Generator type.
|
|
generator_params (Dict[str, Any]): Parameter dict for generator.
|
|
discriminator_type (str): Discriminator type.
|
|
discriminator_params (Dict[str, Any]): Parameter dict for discriminator.
|
|
cache_generator_outputs (bool): Whether to cache generator outputs.
|
|
"""
|
|
assert check_argument_types()
|
|
super().__init__()
|
|
|
|
# initialize parameters
|
|
initialize(self, init_type)
|
|
|
|
# define modules
|
|
generator_class = AVAILABLE_GENERATERS[generator_type]
|
|
if generator_type == "vits_generator":
|
|
# NOTE: Update parameters for the compatibility.
|
|
# The idim and odim is automatically decided from input data,
|
|
# where idim represents #vocabularies and odim represents
|
|
# the input acoustic feature dimension.
|
|
generator_params.update(vocabs=idim, aux_channels=odim)
|
|
self.generator = generator_class(
|
|
**generator_params, )
|
|
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
|
|
self.discriminator = discriminator_class(
|
|
**discriminator_params, )
|
|
|
|
nn.initializer.set_global_initializer(None)
|
|
|
|
# cache
|
|
self.cache_generator_outputs = cache_generator_outputs
|
|
self._cache = None
|
|
|
|
# store sampling rate for saving wav file
|
|
# (not used for the training)
|
|
self.fs = sampling_rate
|
|
|
|
# store parameters for test compatibility
|
|
self.spks = self.generator.spks
|
|
self.langs = self.generator.langs
|
|
self.spk_embed_dim = self.generator.spk_embed_dim
|
|
|
|
self.reuse_cache_gen = True
|
|
self.reuse_cache_dis = True
|
|
|
|
def forward(
|
|
self,
|
|
text: paddle.Tensor,
|
|
text_lengths: paddle.Tensor,
|
|
feats: paddle.Tensor,
|
|
feats_lengths: paddle.Tensor,
|
|
sids: Optional[paddle.Tensor]=None,
|
|
spembs: Optional[paddle.Tensor]=None,
|
|
lids: Optional[paddle.Tensor]=None,
|
|
forward_generator: bool=True, ) -> Dict[str, Any]:
|
|
"""Perform generator forward.
|
|
Args:
|
|
text (Tensor): Text index tensor (B, T_text).
|
|
text_lengths (Tensor): Text length tensor (B,).
|
|
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
|
feats_lengths (Tensor): Feature length tensor (B,).
|
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
|
forward_generator (bool): Whether to forward generator.
|
|
Returns:
|
|
Dict[str, Any]:
|
|
- loss (Tensor): Loss scalar tensor.
|
|
- stats (Dict[str, float]): Statistics to be monitored.
|
|
- weight (Tensor): Weight tensor to summarize losses.
|
|
- optim_idx (int): Optimizer index (0 for G and 1 for D).
|
|
"""
|
|
if forward_generator:
|
|
return self._forward_generator(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
feats=feats,
|
|
feats_lengths=feats_lengths,
|
|
sids=sids,
|
|
spembs=spembs,
|
|
lids=lids, )
|
|
else:
|
|
return self._forward_discrminator(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
feats=feats,
|
|
feats_lengths=feats_lengths,
|
|
sids=sids,
|
|
spembs=spembs,
|
|
lids=lids, )
|
|
|
|
def _forward_generator(
|
|
self,
|
|
text: paddle.Tensor,
|
|
text_lengths: paddle.Tensor,
|
|
feats: paddle.Tensor,
|
|
feats_lengths: paddle.Tensor,
|
|
sids: Optional[paddle.Tensor]=None,
|
|
spembs: Optional[paddle.Tensor]=None,
|
|
lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]:
|
|
"""Perform generator forward.
|
|
Args:
|
|
text (Tensor): Text index tensor (B, T_text).
|
|
text_lengths (Tensor): Text length tensor (B,).
|
|
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
|
feats_lengths (Tensor): Feature length tensor (B,).
|
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
|
Returns:
|
|
|
|
"""
|
|
# setup
|
|
feats = feats.transpose([0, 2, 1])
|
|
|
|
# calculate generator outputs
|
|
self.reuse_cache_gen = True
|
|
if not self.cache_generator_outputs or self._cache is None:
|
|
self.reuse_cache_gen = False
|
|
outs = self.generator(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
feats=feats,
|
|
feats_lengths=feats_lengths,
|
|
sids=sids,
|
|
spembs=spembs,
|
|
lids=lids, )
|
|
else:
|
|
outs = self._cache
|
|
|
|
# store cache
|
|
if self.training and self.cache_generator_outputs and not self.reuse_cache_gen:
|
|
self._cache = outs
|
|
|
|
return outs
|
|
|
|
def _forward_discrminator(
|
|
self,
|
|
text: paddle.Tensor,
|
|
text_lengths: paddle.Tensor,
|
|
feats: paddle.Tensor,
|
|
feats_lengths: paddle.Tensor,
|
|
sids: Optional[paddle.Tensor]=None,
|
|
spembs: Optional[paddle.Tensor]=None,
|
|
lids: Optional[paddle.Tensor]=None, ) -> Dict[str, Any]:
|
|
"""Perform discriminator forward.
|
|
Args:
|
|
text (Tensor): Text index tensor (B, T_text).
|
|
text_lengths (Tensor): Text length tensor (B,).
|
|
feats (Tensor): Feature tensor (B, T_feats, aux_channels).
|
|
feats_lengths (Tensor): Feature length tensor (B,).
|
|
sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
|
|
spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
|
|
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
|
|
Returns:
|
|
|
|
"""
|
|
# setup
|
|
feats = feats.transpose([0, 2, 1])
|
|
|
|
# calculate generator outputs
|
|
self.reuse_cache_dis = True
|
|
if not self.cache_generator_outputs or self._cache is None:
|
|
self.reuse_cache_dis = False
|
|
outs = self.generator(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
feats=feats,
|
|
feats_lengths=feats_lengths,
|
|
sids=sids,
|
|
spembs=spembs,
|
|
lids=lids, )
|
|
else:
|
|
outs = self._cache
|
|
|
|
# store cache
|
|
if self.cache_generator_outputs and not self.reuse_cache_dis:
|
|
self._cache = outs
|
|
|
|
return outs
|
|
|
|
def inference(
|
|
self,
|
|
text: paddle.Tensor,
|
|
feats: Optional[paddle.Tensor]=None,
|
|
sids: Optional[paddle.Tensor]=None,
|
|
spembs: Optional[paddle.Tensor]=None,
|
|
lids: Optional[paddle.Tensor]=None,
|
|
durations: Optional[paddle.Tensor]=None,
|
|
noise_scale: float=0.667,
|
|
noise_scale_dur: float=0.8,
|
|
alpha: float=1.0,
|
|
max_len: Optional[int]=None,
|
|
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
|
|
"""Run inference.
|
|
Args:
|
|
text (Tensor): Input text index tensor (T_text,).
|
|
feats (Tensor): Feature tensor (T_feats, aux_channels).
|
|
sids (Tensor): Speaker index tensor (1,).
|
|
spembs (Optional[Tensor]): Speaker embedding tensor (spk_embed_dim,).
|
|
lids (Tensor): Language index tensor (1,).
|
|
durations (Tensor): Ground-truth duration tensor (T_text,).
|
|
noise_scale (float): Noise scale value for flow.
|
|
noise_scale_dur (float): Noise scale value for duration predictor.
|
|
alpha (float): Alpha parameter to control the speed of generated speech.
|
|
max_len (Optional[int]): Maximum length.
|
|
use_teacher_forcing (bool): Whether to use teacher forcing.
|
|
Returns:
|
|
Dict[str, Tensor]:
|
|
* wav (Tensor): Generated waveform tensor (T_wav,).
|
|
* att_w (Tensor): Monotonic attention weight tensor (T_feats, T_text).
|
|
* duration (Tensor): Predicted duration tensor (T_text,).
|
|
"""
|
|
# setup
|
|
text = text[None]
|
|
text_lengths = paddle.to_tensor(paddle.shape(text)[1])
|
|
|
|
if durations is not None:
|
|
durations = paddle.reshape(durations, [1, 1, -1])
|
|
|
|
# inference
|
|
if use_teacher_forcing:
|
|
assert feats is not None
|
|
feats = feats[None].transpose([0, 2, 1])
|
|
feats_lengths = paddle.to_tensor([paddle.shape(feats)[2]])
|
|
wav, att_w, dur = self.generator.inference(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
feats=feats,
|
|
feats_lengths=feats_lengths,
|
|
sids=sids,
|
|
spembs=spembs,
|
|
lids=lids,
|
|
max_len=max_len,
|
|
use_teacher_forcing=use_teacher_forcing, )
|
|
else:
|
|
wav, att_w, dur = self.generator.inference(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
sids=sids,
|
|
spembs=spembs,
|
|
lids=lids,
|
|
dur=durations,
|
|
noise_scale=noise_scale,
|
|
noise_scale_dur=noise_scale_dur,
|
|
alpha=alpha,
|
|
max_len=max_len, )
|
|
return dict(
|
|
wav=paddle.reshape(wav, [-1]), att_w=att_w[0], duration=dur[0])
|