You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/t2s/models/jets/jets.py

582 lines
22 KiB

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generator module in JETS.
This code is based on https://github.com/imdanboy/jets.
"""
"""JETS module"""
import math
from typing import Any
from typing import Dict
from typing import Optional
import paddle
from paddle import nn
from typeguard import typechecked
from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator
from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator
from paddlespeech.t2s.models.jets.generator import JETSGenerator
from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out
from paddlespeech.utils.initialize import kaiming_uniform_
from paddlespeech.utils.initialize import normal_
from paddlespeech.utils.initialize import ones_
from paddlespeech.utils.initialize import uniform_
from paddlespeech.utils.initialize import zeros_
AVAILABLE_GENERATERS = {
"jets_generator": JETSGenerator,
}
AVAILABLE_DISCRIMINATORS = {
"hifigan_period_discriminator":
HiFiGANPeriodDiscriminator,
"hifigan_scale_discriminator":
HiFiGANScaleDiscriminator,
"hifigan_multi_period_discriminator":
HiFiGANMultiPeriodDiscriminator,
"hifigan_multi_scale_discriminator":
HiFiGANMultiScaleDiscriminator,
"hifigan_multi_scale_multi_period_discriminator":
HiFiGANMultiScaleMultiPeriodDiscriminator,
}
class JETS(nn.Layer):
"""JETS module (generator + discriminator).
This is a module of JETS described in `JETS: Jointly Training FastSpeech2
and HiFi-GAN for End to End Text to Speech`_.
.. _`JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech
Text-to-Speech`: https://arxiv.org/abs/2203.16852v1
"""
@typechecked
def __init__(
self,
# generator related
idim: int,
odim: int,
sampling_rate: int=22050,
generator_type: str="jets_generator",
generator_params: Dict[str, Any]={
"adim": 256,
"aheads": 2,
"elayers": 4,
"eunits": 1024,
"dlayers": 4,
"dunits": 1024,
"positionwise_layer_type": "conv1d",
"positionwise_conv_kernel_size": 1,
"use_scaled_pos_enc": True,
"use_batch_norm": True,
"encoder_normalize_before": True,
"decoder_normalize_before": True,
"encoder_concat_after": False,
"decoder_concat_after": False,
"reduction_factor": 1,
"encoder_type": "transformer",
"decoder_type": "transformer",
"transformer_enc_dropout_rate": 0.1,
"transformer_enc_positional_dropout_rate": 0.1,
"transformer_enc_attn_dropout_rate": 0.1,
"transformer_dec_dropout_rate": 0.1,
"transformer_dec_positional_dropout_rate": 0.1,
"transformer_dec_attn_dropout_rate": 0.1,
"conformer_rel_pos_type": "latest",
"conformer_pos_enc_layer_type": "rel_pos",
"conformer_self_attn_layer_type": "rel_selfattn",
"conformer_activation_type": "swish",
"use_macaron_style_in_conformer": True,
"use_cnn_in_conformer": True,
"zero_triu": False,
"conformer_enc_kernel_size": 7,
"conformer_dec_kernel_size": 31,
"duration_predictor_layers": 2,
"duration_predictor_chans": 384,
"duration_predictor_kernel_size": 3,
"duration_predictor_dropout_rate": 0.1,
"energy_predictor_layers": 2,
"energy_predictor_chans": 384,
"energy_predictor_kernel_size": 3,
"energy_predictor_dropout": 0.5,
"energy_embed_kernel_size": 1,
"energy_embed_dropout": 0.5,
"stop_gradient_from_energy_predictor": False,
"pitch_predictor_layers": 5,
"pitch_predictor_chans": 384,
"pitch_predictor_kernel_size": 5,
"pitch_predictor_dropout": 0.5,
"pitch_embed_kernel_size": 1,
"pitch_embed_dropout": 0.5,
"stop_gradient_from_pitch_predictor": True,
"generator_out_channels": 1,
"generator_channels": 512,
"generator_global_channels": -1,
"generator_kernel_size": 7,
"generator_upsample_scales": [8, 8, 2, 2],
"generator_upsample_kernel_sizes": [16, 16, 4, 4],
"generator_resblock_kernel_sizes": [3, 7, 11],
"generator_resblock_dilations":
[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"generator_use_additional_convs": True,
"generator_bias": True,
"generator_nonlinear_activation": "LeakyReLU",
"generator_nonlinear_activation_params": {
"negative_slope": 0.1
},
"generator_use_weight_norm": True,
"segment_size": 64,
"spks": -1,
"langs": -1,
"spk_embed_dim": None,
"spk_embed_integration_type": "add",
"use_gst": False,
"gst_tokens": 10,
"gst_heads": 4,
"gst_conv_layers": 6,
"gst_conv_chans_list": [32, 32, 64, 64, 128, 128],
"gst_conv_kernel_size": 3,
"gst_conv_stride": 2,
"gst_gru_layers": 1,
"gst_gru_units": 128,
"init_type": "xavier_uniform",
"init_enc_alpha": 1.0,
"init_dec_alpha": 1.0,
"use_masking": False,
"use_weighted_masking": False,
},
# discriminator related
discriminator_type: str="hifigan_multi_scale_multi_period_discriminator",
discriminator_params: Dict[str, Any]={
"scales": 1,
"scale_downsample_pooling": "AvgPool1D",
"scale_downsample_pooling_params": {
"kernel_size": 4,
"stride": 2,
"padding": 2,
},
"scale_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [15, 41, 5, 3],
"channels": 128,
"max_downsample_channels": 1024,
"max_groups": 16,
"bias": True,
"downsample_scales": [2, 2, 4, 4, 1],
"nonlinear_activation": "leakyrelu",
"nonlinear_activation_params": {
"negative_slope": 0.1
},
"use_weight_norm": True,
"use_spectral_norm": False,
},
"follow_official_norm": False,
"periods": [2, 3, 5, 7, 11],
"period_discriminator_params": {
"in_channels": 1,
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 32,
"downsample_scales": [3, 3, 3, 3, 1],
"max_downsample_channels": 1024,
"bias": True,
"nonlinear_activation": "leakyrelu",
"nonlinear_activation_params": {
"negative_slope": 0.1
},
"use_weight_norm": True,
"use_spectral_norm": False,
},
},
cache_generator_outputs: bool=True, ):
"""Initialize JETS module.
Args:
idim (int):
Input vocabrary size.
odim (int):
Acoustic feature dimension. The actual output channels will
be 1 since JETS is the end-to-end text-to-wave model but for the
compatibility odim is used to indicate the acoustic feature dimension.
sampling_rate (int):
Sampling rate, not used for the training but it will
be referred in saving waveform during the inference.
generator_type (str):
Generator type.
generator_params (Dict[str, Any]):
Parameter dict for generator.
discriminator_type (str):
Discriminator type.
discriminator_params (Dict[str, Any]):
Parameter dict for discriminator.
cache_generator_outputs (bool):
Whether to cache generator outputs.
"""
super().__init__()
# define modules
generator_class = AVAILABLE_GENERATERS[generator_type]
if generator_type == "jets_generator":
# NOTE: Update parameters for the compatibility.
# The idim and odim is automatically decided from input data,
# where idim represents #vocabularies and odim represents
# the input acoustic feature dimension.
generator_params.update(idim=idim, odim=odim)
self.generator = generator_class(
**generator_params, )
discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type]
self.discriminator = discriminator_class(
**discriminator_params, )
# cache
self.cache_generator_outputs = cache_generator_outputs
self._cache = None
# store sampling rate for saving wav file
# (not used for the training)
self.fs = sampling_rate
# store parameters for test compatibility
self.spks = self.generator.spks
self.langs = self.generator.langs
self.spk_embed_dim = self.generator.spk_embed_dim
self.reuse_cache_gen = True
self.reuse_cache_dis = True
self.reset_parameters()
self.generator._reset_parameters(
init_type=generator_params["init_type"],
init_enc_alpha=generator_params["init_enc_alpha"],
init_dec_alpha=generator_params["init_dec_alpha"], )
def forward(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
durations: paddle.Tensor,
durations_lengths: paddle.Tensor,
pitch: paddle.Tensor,
energy: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
forward_generator: bool=True,
use_alignment_module: bool=False,
**kwargs, ) -> Dict[str, Any]:
"""Perform generator forward.
Args:
text (Tensor):
Text index tensor (B, T_text).
text_lengths (Tensor):
Text length tensor (B,).
feats (Tensor):
Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor):
Feature length tensor (B,).
durations(Tensor(int64)):
Batch of padded durations (B, Tmax).
durations_lengths (Tensor):
durations length tensor (B,).
pitch(Tensor):
Batch of padded token-averaged pitch (B, Tmax, 1).
energy(Tensor):
Batch of padded token-averaged energy (B, Tmax, 1).
sids (Optional[Tensor]):
Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]):
Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]):
Language index tensor (B,) or (B, 1).
forward_generator (bool):
Whether to forward generator.
use_alignment_module (bool):
Whether to use alignment module.
Returns:
"""
if forward_generator:
return self._forward_generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
pitch=pitch,
energy=energy,
sids=sids,
spembs=spembs,
lids=lids,
use_alignment_module=use_alignment_module, )
else:
return self._forward_discrminator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
pitch=pitch,
energy=energy,
sids=sids,
spembs=spembs,
lids=lids,
use_alignment_module=use_alignment_module, )
def _forward_generator(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
durations: paddle.Tensor,
durations_lengths: paddle.Tensor,
pitch: paddle.Tensor,
energy: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
use_alignment_module: bool=False,
**kwargs, ) -> Dict[str, Any]:
"""Perform generator forward.
Args:
text (Tensor):
Text index tensor (B, T_text).
text_lengths (Tensor):
Text length tensor (B,).
feats (Tensor):
Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor):
Feature length tensor (B,).
durations(Tensor(int64)):
Batch of padded durations (B, Tmax).
durations_lengths (Tensor):
durations length tensor (B,).
pitch(Tensor):
Batch of padded token-averaged pitch (B, Tmax, 1).
energy(Tensor):
Batch of padded token-averaged energy (B, Tmax, 1).
sids (Optional[Tensor]):
Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]):
Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]):
Language index tensor (B,) or (B, 1).
use_alignment_module (bool):
Whether to use alignment module.
Returns:
"""
# setup
# calculate generator outputs
self.reuse_cache_gen = True
if not self.cache_generator_outputs or self._cache is None:
self.reuse_cache_gen = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
pitch=pitch,
energy=energy,
sids=sids,
spembs=spembs,
lids=lids,
use_alignment_module=use_alignment_module, )
else:
outs = self._cache
# store cache
if self.training and self.cache_generator_outputs and not self.reuse_cache_gen:
self._cache = outs
return outs
def _forward_discrminator(
self,
text: paddle.Tensor,
text_lengths: paddle.Tensor,
feats: paddle.Tensor,
feats_lengths: paddle.Tensor,
durations: paddle.Tensor,
durations_lengths: paddle.Tensor,
pitch: paddle.Tensor,
energy: paddle.Tensor,
sids: Optional[paddle.Tensor]=None,
spembs: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None,
use_alignment_module: bool=False,
**kwargs, ) -> Dict[str, Any]:
"""Perform discriminator forward.
Args:
text (Tensor):
Text index tensor (B, T_text).
text_lengths (Tensor):
Text length tensor (B,).
feats (Tensor):
Feature tensor (B, T_feats, aux_channels).
feats_lengths (Tensor):
Feature length tensor (B,).
durations(Tensor(int64)):
Batch of padded durations (B, Tmax).
durations_lengths (Tensor):
durations length tensor (B,).
pitch(Tensor):
Batch of padded token-averaged pitch (B, Tmax, 1).
energy(Tensor):
Batch of padded token-averaged energy (B, Tmax, 1).
sids (Optional[Tensor]):
Speaker index tensor (B,) or (B, 1).
spembs (Optional[Tensor]):
Speaker embedding tensor (B, spk_embed_dim).
lids (Optional[Tensor]):
Language index tensor (B,) or (B, 1).
use_alignment_module (bool):
Whether to use alignment module.
Returns:
"""
# setup
# calculate generator outputs
self.reuse_cache_dis = True
if not self.cache_generator_outputs or self._cache is None:
self.reuse_cache_dis = False
outs = self.generator(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
durations=durations,
durations_lengths=durations_lengths,
pitch=pitch,
energy=energy,
sids=sids,
spembs=spembs,
lids=lids,
use_alignment_module=use_alignment_module,
**kwargs, )
else:
outs = self._cache
# store cache
if self.cache_generator_outputs and not self.reuse_cache_dis:
self._cache = outs
return outs
def inference(self,
text: paddle.Tensor,
feats: Optional[paddle.Tensor]=None,
pitch: Optional[paddle.Tensor]=None,
energy: Optional[paddle.Tensor]=None,
use_alignment_module: bool=False,
**kwargs) -> Dict[str, paddle.Tensor]:
"""Run inference.
Args:
text (Tensor):
Input text index tensor (T_text,).
feats (Tensor):
Feature tensor (T_feats, aux_channels).
pitch (Tensor):
Pitch tensor (T_feats, 1).
energy (Tensor):
Energy tensor (T_feats, 1).
use_alignment_module (bool):
Whether to use alignment module.
Returns:
Dict[str, Tensor]:
* wav (Tensor):
Generated waveform tensor (T_wav,).
* duration (Tensor):
Predicted duration tensor (T_text,).
"""
# setup
text = text[None]
text_lengths = paddle.to_tensor(paddle.shape(text)[1])
# inference
if use_alignment_module:
assert feats is not None
feats = feats[None]
feats_lengths = paddle.to_tensor(paddle.shape(feats)[1])
pitch = pitch[None]
energy = energy[None]
wav, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
feats=feats,
feats_lengths=feats_lengths,
pitch=pitch,
energy=energy,
use_alignment_module=use_alignment_module,
**kwargs)
else:
wav, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
**kwargs, )
return dict(wav=paddle.reshape(wav, [-1]), duration=dur[0])
def reset_parameters(self):
def _reset_parameters(module):
if isinstance(
module,
(nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)):
kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
uniform_(module.bias, -bound, bound)
if isinstance(
module,
(nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)):
ones_(module.weight)
zeros_(module.bias)
if isinstance(module, nn.Linear):
kaiming_uniform_(module.weight, a=math.sqrt(5))
if module.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(module.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
uniform_(module.bias, -bound, bound)
if isinstance(module, nn.Embedding):
normal_(module.weight)
if module._padding_idx is not None:
with paddle.no_grad():
module.weight[module._padding_idx] = 0
self.apply(_reset_parameters)
class JETSInference(nn.Layer):
def __init__(self, model):
super().__init__()
self.acoustic_model = model
def forward(self, text, sids=None):
out = self.acoustic_model.inference(text)
wav = out['wav']
return wav