You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
583 lines
22 KiB
583 lines
22 KiB
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Generator module in JETS.
|
|
|
|
This code is based on https://github.com/imdanboy/jets.
|
|
|
|
"""
|
|
"""JETS module"""
|
|
import math
|
|
from typing import Any
|
|
from typing import Dict
|
|
from typing import Optional
|
|
|
|
import paddle
|
|
from paddle import nn
|
|
from typeguard import check_argument_types
|
|
|
|
from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator
|
|
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator
|
|
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator
|
|
from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator
|
|
from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator
|
|
from paddlespeech.t2s.models.jets.generator import JETSGenerator
|
|
from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out
|
|
from paddlespeech.utils.initialize import kaiming_uniform_
|
|
from paddlespeech.utils.initialize import normal_
|
|
from paddlespeech.utils.initialize import ones_
|
|
from paddlespeech.utils.initialize import uniform_
|
|
from paddlespeech.utils.initialize import zeros_
|
|
|
|
AVAILABLE_GENERATERS = {
|
|
"jets_generator": JETSGenerator,
|
|
}
|
|
AVAILABLE_DISCRIMINATORS = {
|
|
"hifigan_period_discriminator":
|
|
HiFiGANPeriodDiscriminator,
|
|
"hifigan_scale_discriminator":
|
|
HiFiGANScaleDiscriminator,
|
|
"hifigan_multi_period_discriminator":
|
|
HiFiGANMultiPeriodDiscriminator,
|
|
"hifigan_multi_scale_discriminator":
|
|
HiFiGANMultiScaleDiscriminator,
|
|
"hifigan_multi_scale_multi_period_discriminator":
|
|
HiFiGANMultiScaleMultiPeriodDiscriminator,
|
|
}
|
|
|
|
|
|
class JETS(nn.Layer):
|
|
"""JETS module (generator + discriminator).
|
|
This is a module of JETS described in `JETS: Jointly Training FastSpeech2
|
|
and HiFi-GAN for End to End Text to Speech`_.
|
|
.. _`JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech
|
|
Text-to-Speech`: https://arxiv.org/abs/2203.16852v1
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
# generator related
|
|
idim: int,
|
|
odim: int,
|
|
sampling_rate: int=22050,
|
|
generator_type: str="jets_generator",
|
|
generator_params: Dict[str, Any]={
|
|
"adim": 256,
|
|
"aheads": 2,
|
|
"elayers": 4,
|
|
"eunits": 1024,
|
|
"dlayers": 4,
|
|
"dunits": 1024,
|
|
"positionwise_layer_type": "conv1d",
|
|
"positionwise_conv_kernel_size": 1,
|
|
"use_scaled_pos_enc": True,
|
|
"use_batch_norm": True,
|
|
"encoder_normalize_before": True,
|
|
"decoder_normalize_before": True,
|
|
"encoder_concat_after": False,
|
|
"decoder_concat_after": False,
|
|
"reduction_factor": 1,
|
|
"encoder_type": "transformer",
|
|
"decoder_type": "transformer",
|
|
"transformer_enc_dropout_rate": 0.1,
|
|
"transformer_enc_positional_dropout_rate": 0.1,
|
|
"transformer_enc_attn_dropout_rate": 0.1,
|
|
"transformer_dec_dropout_rate": 0.1,
|
|
"transformer_dec_positional_dropout_rate": 0.1,
|
|
"transformer_dec_attn_dropout_rate": 0.1,
|
|
"conformer_rel_pos_type": "latest",
|
|
"conformer_pos_enc_layer_type": "rel_pos",
|
|
"conformer_self_attn_layer_type": "rel_selfattn",
|
|
"conformer_activation_type": "swish",
|
|
"use_macaron_style_in_conformer": True,
|
|
"use_cnn_in_conformer": True,
|
|
"zero_triu": False,
|
|
"conformer_enc_kernel_size": 7,
|
|
"conformer_dec_kernel_size": 31,
|
|
"duration_predictor_layers": 2,
|
|
"duration_predictor_chans": 384,
|
|
"duration_predictor_kernel_size": 3,
|
|
"duration_predictor_dropout_rate": 0.1,
|
|
"energy_predictor_layers": 2,
|
|
"energy_predictor_chans": 384,
|
|
"energy_predictor_kernel_size": 3,
|
|
"energy_predictor_dropout": 0.5,
|
|
"energy_embed_kernel_size": 1,
|
|
"energy_embed_dropout": 0.5,
|
|
"stop_gradient_from_energy_predictor": False,
|
|
"pitch_predictor_layers": 5,
|
|
"pitch_predictor_chans": 384,
|
|
"pitch_predictor_kernel_size": 5,
|
|
"pitch_predictor_dropout": 0.5,
|
|
"pitch_embed_kernel_size": 1,
|
|
"pitch_embed_dropout": 0.5,
|
|
"stop_gradient_from_pitch_predictor": True,
|
|
"generator_out_channels": 1,
|
|
"generator_channels": 512,
|
|
"generator_global_channels": -1,
|
|
"generator_kernel_size": 7,
|
|
"generator_upsample_scales": [8, 8, 2, 2],
|
|
"generator_upsample_kernel_sizes": [16, 16, 4, 4],
|
|
"generator_resblock_kernel_sizes": [3, 7, 11],
|
|
"generator_resblock_dilations":
|
|
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
"generator_use_additional_convs": True,
|
|
"generator_bias": True,
|
|
"generator_nonlinear_activation": "LeakyReLU",
|
|
"generator_nonlinear_activation_params": {
|
|
"negative_slope": 0.1
|
|
},
|
|
"generator_use_weight_norm": True,
|
|
"segment_size": 64,
|
|
"spks": -1,
|
|
"langs": -1,
|
|
"spk_embed_dim": None,
|
|
"spk_embed_integration_type": "add",
|
|
"use_gst": False,
|
|
"gst_tokens": 10,
|
|
"gst_heads": 4,
|
|
"gst_conv_layers": 6,
|
|
"gst_conv_chans_list": [32, 32, 64, 64, 128, 128],
|
|
"gst_conv_kernel_size": 3,
|
|
"gst_conv_stride": 2,
|
|
"gst_gru_layers": 1,
|
|
"gst_gru_units": 128,
|
|
"init_type": "xavier_uniform",
|
|
"init_enc_alpha": 1.0,
|
|
"init_dec_alpha": 1.0,
|
|
"use_masking": False,
|
|
"use_weighted_masking": False,
|
|
},
|
|
# discriminator related
|
|
discriminator_type: str="hifigan_multi_scale_multi_period_discriminator",
|
|
discriminator_params: Dict[str, Any]={
|
|
"scales": 1,
|
|
"scale_downsample_pooling": "AvgPool1D",
|
|
"scale_downsample_pooling_params": {
|
|
"kernel_size": 4,
|
|
"stride": 2,
|
|
"padding": 2,
|
|
},
|
|
"scale_discriminator_params": {
|
|
"in_channels": 1,
|
|
"out_channels": 1,
|
|
"kernel_sizes": [15, 41, 5, 3],
|
|
"channels": 128,
|
|
"max_downsample_channels": 1024,
|
|
"max_groups": 16,
|
|
"bias": True,
|
|
"downsample_scales": [2, 2, 4, 4, 1],
|
|
"nonlinear_activation": "leakyrelu",
|
|
"nonlinear_activation_params": {
|
|
"negative_slope": 0.1
|
|
},
|
|
"use_weight_norm": True,
|
|
"use_spectral_norm": False,
|
|
},
|
|
"follow_official_norm": False,
|
|
"periods": [2, 3, 5, 7, 11],
|
|
"period_discriminator_params": {
|
|
"in_channels": 1,
|
|
"out_channels": 1,
|
|
"kernel_sizes": [5, 3],
|
|
"channels": 32,
|
|
"downsample_scales": [3, 3, 3, 3, 1],
|
|
"max_downsample_channels": 1024,
|
|
"bias": True,
|
|
"nonlinear_activation": "leakyrelu",
|
|
"nonlinear_activation_params": {
|
|
"negative_slope": 0.1
|
|
},
|
|
"use_weight_norm": True,
|
|
"use_spectral_norm": False,
|
|
},
|
|
},
|
|
cache_generator_outputs: bool=True, ):
|
|
"""Initialize JETS module.
|
|
Args:
|
|
idim (int):
|
|
Input vocabrary size.
|
|
odim (int):
|
|
Acoustic feature dimension. The actual output channels will
|
|
be 1 since JETS is the end-to-end text-to-wave model but for the
|
|
compatibility odim is used to indicate the acoustic feature dimension.
|
|
sampling_rate (int):
|
|
Sampling rate, not used for the training but it will
|
|
be referred in saving waveform during the inference.
|
|
generator_type (str):
|
|
Generator type.
|
|
generator_params (Dict[str, Any]):
|
|
Parameter dict for generator.
|
|
discriminator_type (str):
|
|
Discriminator type.
|
|
discriminator_params (Dict[str, Any]):
|
|
Parameter dict for discriminator.
|
|
cache_generator_outputs (bool):
|
|
Whether to cache generator outputs.
|
|
"""
|
|
assert check_argument_types()
|
|
super().__init__()
|
|
|
|
# define modules
|
|
generator_class = AVAILABLE_GENERATERS[generator_type]
|
|
if generator_type == "jets_generator":
|
|
# NOTE: Update parameters for the compatibility.
|
|
# The idim and odim is automatically decided from input data,
|
|
# where idim represents #vocabularies and odim represents
|
|
# the input acoustic feature dimension.
|
|
generator_params.update(idim=idim, odim=odim)
|
|
self.generator = generator_class(
|
|
**generator_params, )
|
|
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
|
|
self.discriminator = discriminator_class(
|
|
**discriminator_params, )
|
|
|
|
# cache
|
|
self.cache_generator_outputs = cache_generator_outputs
|
|
self._cache = None
|
|
|
|
# store sampling rate for saving wav file
|
|
# (not used for the training)
|
|
self.fs = sampling_rate
|
|
|
|
# store parameters for test compatibility
|
|
self.spks = self.generator.spks
|
|
self.langs = self.generator.langs
|
|
self.spk_embed_dim = self.generator.spk_embed_dim
|
|
|
|
self.reuse_cache_gen = True
|
|
self.reuse_cache_dis = True
|
|
|
|
self.reset_parameters()
|
|
self.generator._reset_parameters(
|
|
init_type=generator_params["init_type"],
|
|
init_enc_alpha=generator_params["init_enc_alpha"],
|
|
init_dec_alpha=generator_params["init_dec_alpha"], )
|
|
|
|
def forward(
|
|
self,
|
|
text: paddle.Tensor,
|
|
text_lengths: paddle.Tensor,
|
|
feats: paddle.Tensor,
|
|
feats_lengths: paddle.Tensor,
|
|
durations: paddle.Tensor,
|
|
durations_lengths: paddle.Tensor,
|
|
pitch: paddle.Tensor,
|
|
energy: paddle.Tensor,
|
|
sids: Optional[paddle.Tensor]=None,
|
|
spembs: Optional[paddle.Tensor]=None,
|
|
lids: Optional[paddle.Tensor]=None,
|
|
forward_generator: bool=True,
|
|
use_alignment_module: bool=False,
|
|
**kwargs,
|
|
) -> Dict[str, Any]:
|
|
"""Perform generator forward.
|
|
Args:
|
|
text (Tensor):
|
|
Text index tensor (B, T_text).
|
|
text_lengths (Tensor):
|
|
Text length tensor (B,).
|
|
feats (Tensor):
|
|
Feature tensor (B, T_feats, aux_channels).
|
|
feats_lengths (Tensor):
|
|
Feature length tensor (B,).
|
|
durations(Tensor(int64)):
|
|
Batch of padded durations (B, Tmax).
|
|
durations_lengths (Tensor):
|
|
durations length tensor (B,).
|
|
pitch(Tensor):
|
|
Batch of padded token-averaged pitch (B, Tmax, 1).
|
|
energy(Tensor):
|
|
Batch of padded token-averaged energy (B, Tmax, 1).
|
|
sids (Optional[Tensor]):
|
|
Speaker index tensor (B,) or (B, 1).
|
|
spembs (Optional[Tensor]):
|
|
Speaker embedding tensor (B, spk_embed_dim).
|
|
lids (Optional[Tensor]):
|
|
Language index tensor (B,) or (B, 1).
|
|
forward_generator (bool):
|
|
Whether to forward generator.
|
|
use_alignment_module (bool):
|
|
Whether to use alignment module.
|
|
Returns:
|
|
|
|
"""
|
|
if forward_generator:
|
|
return self._forward_generator(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
feats=feats,
|
|
feats_lengths=feats_lengths,
|
|
durations=durations,
|
|
durations_lengths=durations_lengths,
|
|
pitch=pitch,
|
|
energy=energy,
|
|
sids=sids,
|
|
spembs=spembs,
|
|
lids=lids,
|
|
use_alignment_module=use_alignment_module, )
|
|
else:
|
|
return self._forward_discrminator(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
feats=feats,
|
|
feats_lengths=feats_lengths,
|
|
durations=durations,
|
|
durations_lengths=durations_lengths,
|
|
pitch=pitch,
|
|
energy=energy,
|
|
sids=sids,
|
|
spembs=spembs,
|
|
lids=lids,
|
|
use_alignment_module=use_alignment_module, )
|
|
|
|
def _forward_generator(
|
|
self,
|
|
text: paddle.Tensor,
|
|
text_lengths: paddle.Tensor,
|
|
feats: paddle.Tensor,
|
|
feats_lengths: paddle.Tensor,
|
|
durations: paddle.Tensor,
|
|
durations_lengths: paddle.Tensor,
|
|
pitch: paddle.Tensor,
|
|
energy: paddle.Tensor,
|
|
sids: Optional[paddle.Tensor]=None,
|
|
spembs: Optional[paddle.Tensor]=None,
|
|
lids: Optional[paddle.Tensor]=None,
|
|
use_alignment_module: bool=False,
|
|
**kwargs, ) -> Dict[str, Any]:
|
|
"""Perform generator forward.
|
|
Args:
|
|
text (Tensor):
|
|
Text index tensor (B, T_text).
|
|
text_lengths (Tensor):
|
|
Text length tensor (B,).
|
|
feats (Tensor):
|
|
Feature tensor (B, T_feats, aux_channels).
|
|
feats_lengths (Tensor):
|
|
Feature length tensor (B,).
|
|
durations(Tensor(int64)):
|
|
Batch of padded durations (B, Tmax).
|
|
durations_lengths (Tensor):
|
|
durations length tensor (B,).
|
|
pitch(Tensor):
|
|
Batch of padded token-averaged pitch (B, Tmax, 1).
|
|
energy(Tensor):
|
|
Batch of padded token-averaged energy (B, Tmax, 1).
|
|
sids (Optional[Tensor]):
|
|
Speaker index tensor (B,) or (B, 1).
|
|
spembs (Optional[Tensor]):
|
|
Speaker embedding tensor (B, spk_embed_dim).
|
|
lids (Optional[Tensor]):
|
|
Language index tensor (B,) or (B, 1).
|
|
use_alignment_module (bool):
|
|
Whether to use alignment module.
|
|
Returns:
|
|
|
|
"""
|
|
# setup
|
|
# calculate generator outputs
|
|
self.reuse_cache_gen = True
|
|
if not self.cache_generator_outputs or self._cache is None:
|
|
self.reuse_cache_gen = False
|
|
outs = self.generator(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
feats=feats,
|
|
feats_lengths=feats_lengths,
|
|
durations=durations,
|
|
durations_lengths=durations_lengths,
|
|
pitch=pitch,
|
|
energy=energy,
|
|
sids=sids,
|
|
spembs=spembs,
|
|
lids=lids,
|
|
use_alignment_module=use_alignment_module, )
|
|
else:
|
|
outs = self._cache
|
|
|
|
# store cache
|
|
if self.training and self.cache_generator_outputs and not self.reuse_cache_gen:
|
|
self._cache = outs
|
|
|
|
return outs
|
|
|
|
def _forward_discrminator(
|
|
self,
|
|
text: paddle.Tensor,
|
|
text_lengths: paddle.Tensor,
|
|
feats: paddle.Tensor,
|
|
feats_lengths: paddle.Tensor,
|
|
durations: paddle.Tensor,
|
|
durations_lengths: paddle.Tensor,
|
|
pitch: paddle.Tensor,
|
|
energy: paddle.Tensor,
|
|
sids: Optional[paddle.Tensor]=None,
|
|
spembs: Optional[paddle.Tensor]=None,
|
|
lids: Optional[paddle.Tensor]=None,
|
|
use_alignment_module: bool=False,
|
|
**kwargs, ) -> Dict[str, Any]:
|
|
"""Perform discriminator forward.
|
|
Args:
|
|
text (Tensor):
|
|
Text index tensor (B, T_text).
|
|
text_lengths (Tensor):
|
|
Text length tensor (B,).
|
|
feats (Tensor):
|
|
Feature tensor (B, T_feats, aux_channels).
|
|
feats_lengths (Tensor):
|
|
Feature length tensor (B,).
|
|
durations(Tensor(int64)):
|
|
Batch of padded durations (B, Tmax).
|
|
durations_lengths (Tensor):
|
|
durations length tensor (B,).
|
|
pitch(Tensor):
|
|
Batch of padded token-averaged pitch (B, Tmax, 1).
|
|
energy(Tensor):
|
|
Batch of padded token-averaged energy (B, Tmax, 1).
|
|
sids (Optional[Tensor]):
|
|
Speaker index tensor (B,) or (B, 1).
|
|
spembs (Optional[Tensor]):
|
|
Speaker embedding tensor (B, spk_embed_dim).
|
|
lids (Optional[Tensor]):
|
|
Language index tensor (B,) or (B, 1).
|
|
use_alignment_module (bool):
|
|
Whether to use alignment module.
|
|
Returns:
|
|
|
|
"""
|
|
# setup
|
|
# calculate generator outputs
|
|
self.reuse_cache_dis = True
|
|
if not self.cache_generator_outputs or self._cache is None:
|
|
self.reuse_cache_dis = False
|
|
outs = self.generator(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
feats=feats,
|
|
feats_lengths=feats_lengths,
|
|
durations=durations,
|
|
durations_lengths=durations_lengths,
|
|
pitch=pitch,
|
|
energy=energy,
|
|
sids=sids,
|
|
spembs=spembs,
|
|
lids=lids,
|
|
use_alignment_module=use_alignment_module,
|
|
**kwargs, )
|
|
else:
|
|
outs = self._cache
|
|
|
|
# store cache
|
|
if self.cache_generator_outputs and not self.reuse_cache_dis:
|
|
self._cache = outs
|
|
|
|
return outs
|
|
|
|
def inference(self,
|
|
text: paddle.Tensor,
|
|
feats: Optional[paddle.Tensor]=None,
|
|
pitch: Optional[paddle.Tensor]=None,
|
|
energy: Optional[paddle.Tensor]=None,
|
|
use_alignment_module: bool=False,
|
|
**kwargs) -> Dict[str, paddle.Tensor]:
|
|
"""Run inference.
|
|
Args:
|
|
text (Tensor):
|
|
Input text index tensor (T_text,).
|
|
feats (Tensor):
|
|
Feature tensor (T_feats, aux_channels).
|
|
pitch (Tensor):
|
|
Pitch tensor (T_feats, 1).
|
|
energy (Tensor):
|
|
Energy tensor (T_feats, 1).
|
|
use_alignment_module (bool):
|
|
Whether to use alignment module.
|
|
Returns:
|
|
Dict[str, Tensor]:
|
|
* wav (Tensor):
|
|
Generated waveform tensor (T_wav,).
|
|
* duration (Tensor):
|
|
Predicted duration tensor (T_text,).
|
|
"""
|
|
# setup
|
|
text = text[None]
|
|
text_lengths = paddle.to_tensor(paddle.shape(text)[1])
|
|
|
|
# inference
|
|
if use_alignment_module:
|
|
assert feats is not None
|
|
feats = feats[None]
|
|
feats_lengths = paddle.to_tensor(paddle.shape(feats)[1])
|
|
pitch = pitch[None]
|
|
energy = energy[None]
|
|
wav, dur = self.generator.inference(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
feats=feats,
|
|
feats_lengths=feats_lengths,
|
|
pitch=pitch,
|
|
energy=energy,
|
|
use_alignment_module=use_alignment_module,
|
|
**kwargs)
|
|
else:
|
|
wav, dur = self.generator.inference(
|
|
text=text,
|
|
text_lengths=text_lengths,
|
|
**kwargs, )
|
|
return dict(wav=paddle.reshape(wav, [-1]), duration=dur[0])
|
|
|
|
def reset_parameters(self):
|
|
def _reset_parameters(module):
|
|
if isinstance(
|
|
module,
|
|
(nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)):
|
|
kaiming_uniform_(module.weight, a=math.sqrt(5))
|
|
if module.bias is not None:
|
|
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
|
|
if fan_in != 0:
|
|
bound = 1 / math.sqrt(fan_in)
|
|
uniform_(module.bias, -bound, bound)
|
|
|
|
if isinstance(
|
|
module,
|
|
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)):
|
|
ones_(module.weight)
|
|
zeros_(module.bias)
|
|
|
|
if isinstance(module, nn.Linear):
|
|
kaiming_uniform_(module.weight, a=math.sqrt(5))
|
|
if module.bias is not None:
|
|
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
uniform_(module.bias, -bound, bound)
|
|
|
|
if isinstance(module, nn.Embedding):
|
|
normal_(module.weight)
|
|
if module._padding_idx is not None:
|
|
with paddle.no_grad():
|
|
module.weight[module._padding_idx] = 0
|
|
|
|
self.apply(_reset_parameters)
|
|
|
|
|
|
class JETSInference(nn.Layer):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.acoustic_model = model
|
|
|
|
def forward(self, text, sids=None):
|
|
out = self.acoustic_model.inference(text)
|
|
wav = out['wav']
|
|
return wav
|