# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Generator module in JETS. This code is based on https://github.com/imdanboy/jets. """ """JETS module""" import math from typing import Any from typing import Dict from typing import Optional import paddle from paddle import nn from typeguard import check_argument_types from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator from paddlespeech.t2s.models.jets.generator import JETSGenerator from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out from paddlespeech.utils.initialize import kaiming_uniform_ from paddlespeech.utils.initialize import normal_ from paddlespeech.utils.initialize import ones_ from paddlespeech.utils.initialize import uniform_ from paddlespeech.utils.initialize import zeros_ AVAILABLE_GENERATERS = { "jets_generator": JETSGenerator, } AVAILABLE_DISCRIMINATORS = { "hifigan_period_discriminator": HiFiGANPeriodDiscriminator, "hifigan_scale_discriminator": HiFiGANScaleDiscriminator, "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator, "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator, "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, } class JETS(nn.Layer): """JETS module (generator + discriminator). This is a module of JETS described in `JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech`_. .. _`JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech Text-to-Speech`: https://arxiv.org/abs/2203.16852v1 """ def __init__( self, # generator related idim: int, odim: int, sampling_rate: int=22050, generator_type: str="jets_generator", generator_params: Dict[str, Any]={ "adim": 256, "aheads": 2, "elayers": 4, "eunits": 1024, "dlayers": 4, "dunits": 1024, "positionwise_layer_type": "conv1d", "positionwise_conv_kernel_size": 1, "use_scaled_pos_enc": True, "use_batch_norm": True, "encoder_normalize_before": True, "decoder_normalize_before": True, "encoder_concat_after": False, "decoder_concat_after": False, "reduction_factor": 1, "encoder_type": "transformer", "decoder_type": "transformer", "transformer_enc_dropout_rate": 0.1, "transformer_enc_positional_dropout_rate": 0.1, "transformer_enc_attn_dropout_rate": 0.1, "transformer_dec_dropout_rate": 0.1, "transformer_dec_positional_dropout_rate": 0.1, "transformer_dec_attn_dropout_rate": 0.1, "conformer_rel_pos_type": "latest", "conformer_pos_enc_layer_type": "rel_pos", "conformer_self_attn_layer_type": "rel_selfattn", "conformer_activation_type": "swish", "use_macaron_style_in_conformer": True, "use_cnn_in_conformer": True, "zero_triu": False, "conformer_enc_kernel_size": 7, "conformer_dec_kernel_size": 31, "duration_predictor_layers": 2, "duration_predictor_chans": 384, "duration_predictor_kernel_size": 3, "duration_predictor_dropout_rate": 0.1, "energy_predictor_layers": 2, "energy_predictor_chans": 384, "energy_predictor_kernel_size": 3, "energy_predictor_dropout": 0.5, "energy_embed_kernel_size": 1, "energy_embed_dropout": 0.5, "stop_gradient_from_energy_predictor": False, "pitch_predictor_layers": 5, "pitch_predictor_chans": 384, "pitch_predictor_kernel_size": 5, "pitch_predictor_dropout": 0.5, "pitch_embed_kernel_size": 1, "pitch_embed_dropout": 0.5, "stop_gradient_from_pitch_predictor": True, "generator_out_channels": 1, "generator_channels": 512, "generator_global_channels": -1, "generator_kernel_size": 7, "generator_upsample_scales": [8, 8, 2, 2], "generator_upsample_kernel_sizes": [16, 16, 4, 4], "generator_resblock_kernel_sizes": [3, 7, 11], "generator_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "generator_use_additional_convs": True, "generator_bias": True, "generator_nonlinear_activation": "LeakyReLU", "generator_nonlinear_activation_params": { "negative_slope": 0.1 }, "generator_use_weight_norm": True, "segment_size": 64, "spks": -1, "langs": -1, "spk_embed_dim": None, "spk_embed_integration_type": "add", "use_gst": False, "gst_tokens": 10, "gst_heads": 4, "gst_conv_layers": 6, "gst_conv_chans_list": [32, 32, 64, 64, 128, 128], "gst_conv_kernel_size": 3, "gst_conv_stride": 2, "gst_gru_layers": 1, "gst_gru_units": 128, "init_type": "xavier_uniform", "init_enc_alpha": 1.0, "init_dec_alpha": 1.0, "use_masking": False, "use_weighted_masking": False, }, # discriminator related discriminator_type: str="hifigan_multi_scale_multi_period_discriminator", discriminator_params: Dict[str, Any]={ "scales": 1, "scale_downsample_pooling": "AvgPool1D", "scale_downsample_pooling_params": { "kernel_size": 4, "stride": 2, "padding": 2, }, "scale_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [15, 41, 5, 3], "channels": 128, "max_downsample_channels": 1024, "max_groups": 16, "bias": True, "downsample_scales": [2, 2, 4, 4, 1], "nonlinear_activation": "leakyrelu", "nonlinear_activation_params": { "negative_slope": 0.1 }, "use_weight_norm": True, "use_spectral_norm": False, }, "follow_official_norm": False, "periods": [2, 3, 5, 7, 11], "period_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [5, 3], "channels": 32, "downsample_scales": [3, 3, 3, 3, 1], "max_downsample_channels": 1024, "bias": True, "nonlinear_activation": "leakyrelu", "nonlinear_activation_params": { "negative_slope": 0.1 }, "use_weight_norm": True, "use_spectral_norm": False, }, }, cache_generator_outputs: bool=True, ): """Initialize JETS module. Args: idim (int): Input vocabrary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since JETS is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. generator_type (str): Generator type. generator_params (Dict[str, Any]): Parameter dict for generator. discriminator_type (str): Discriminator type. discriminator_params (Dict[str, Any]): Parameter dict for discriminator. cache_generator_outputs (bool): Whether to cache generator outputs. """ assert check_argument_types() super().__init__() # define modules generator_class = AVAILABLE_GENERATERS[generator_type] if generator_type == "jets_generator": # NOTE: Update parameters for the compatibility. # The idim and odim is automatically decided from input data, # where idim represents #vocabularies and odim represents # the input acoustic feature dimension. generator_params.update(idim=idim, odim=odim) self.generator = generator_class( **generator_params, ) discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] self.discriminator = discriminator_class( **discriminator_params, ) # cache self.cache_generator_outputs = cache_generator_outputs self._cache = None # store sampling rate for saving wav file # (not used for the training) self.fs = sampling_rate # store parameters for test compatibility self.spks = self.generator.spks self.langs = self.generator.langs self.spk_embed_dim = self.generator.spk_embed_dim self.reuse_cache_gen = True self.reuse_cache_dis = True self.reset_parameters() self.generator._reset_parameters( init_type=generator_params["init_type"], init_enc_alpha=generator_params["init_enc_alpha"], init_dec_alpha=generator_params["init_dec_alpha"], ) def forward( self, text: paddle.Tensor, text_lengths: paddle.Tensor, feats: paddle.Tensor, feats_lengths: paddle.Tensor, durations: paddle.Tensor, durations_lengths: paddle.Tensor, pitch: paddle.Tensor, energy: paddle.Tensor, sids: Optional[paddle.Tensor]=None, spembs: Optional[paddle.Tensor]=None, lids: Optional[paddle.Tensor]=None, forward_generator: bool=True, use_alignment_module: bool=False, **kwargs, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). durations(Tensor(int64)): Batch of padded durations (B, Tmax). durations_lengths (Tensor): durations length tensor (B,). pitch(Tensor): Batch of padded token-averaged pitch (B, Tmax, 1). energy(Tensor): Batch of padded token-averaged energy (B, Tmax, 1). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). forward_generator (bool): Whether to forward generator. use_alignment_module (bool): Whether to use alignment module. Returns: """ if forward_generator: return self._forward_generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, pitch=pitch, energy=energy, sids=sids, spembs=spembs, lids=lids, use_alignment_module=use_alignment_module, ) else: return self._forward_discrminator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, pitch=pitch, energy=energy, sids=sids, spembs=spembs, lids=lids, use_alignment_module=use_alignment_module, ) def _forward_generator( self, text: paddle.Tensor, text_lengths: paddle.Tensor, feats: paddle.Tensor, feats_lengths: paddle.Tensor, durations: paddle.Tensor, durations_lengths: paddle.Tensor, pitch: paddle.Tensor, energy: paddle.Tensor, sids: Optional[paddle.Tensor]=None, spembs: Optional[paddle.Tensor]=None, lids: Optional[paddle.Tensor]=None, use_alignment_module: bool=False, **kwargs, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). durations(Tensor(int64)): Batch of padded durations (B, Tmax). durations_lengths (Tensor): durations length tensor (B,). pitch(Tensor): Batch of padded token-averaged pitch (B, Tmax, 1). energy(Tensor): Batch of padded token-averaged energy (B, Tmax, 1). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). use_alignment_module (bool): Whether to use alignment module. Returns: """ # setup # calculate generator outputs self.reuse_cache_gen = True if not self.cache_generator_outputs or self._cache is None: self.reuse_cache_gen = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, pitch=pitch, energy=energy, sids=sids, spembs=spembs, lids=lids, use_alignment_module=use_alignment_module, ) else: outs = self._cache # store cache if self.training and self.cache_generator_outputs and not self.reuse_cache_gen: self._cache = outs return outs def _forward_discrminator( self, text: paddle.Tensor, text_lengths: paddle.Tensor, feats: paddle.Tensor, feats_lengths: paddle.Tensor, durations: paddle.Tensor, durations_lengths: paddle.Tensor, pitch: paddle.Tensor, energy: paddle.Tensor, sids: Optional[paddle.Tensor]=None, spembs: Optional[paddle.Tensor]=None, lids: Optional[paddle.Tensor]=None, use_alignment_module: bool=False, **kwargs, ) -> Dict[str, Any]: """Perform discriminator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). durations(Tensor(int64)): Batch of padded durations (B, Tmax). durations_lengths (Tensor): durations length tensor (B,). pitch(Tensor): Batch of padded token-averaged pitch (B, Tmax, 1). energy(Tensor): Batch of padded token-averaged energy (B, Tmax, 1). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). use_alignment_module (bool): Whether to use alignment module. Returns: """ # setup # calculate generator outputs self.reuse_cache_dis = True if not self.cache_generator_outputs or self._cache is None: self.reuse_cache_dis = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, pitch=pitch, energy=energy, sids=sids, spembs=spembs, lids=lids, use_alignment_module=use_alignment_module, **kwargs, ) else: outs = self._cache # store cache if self.cache_generator_outputs and not self.reuse_cache_dis: self._cache = outs return outs def inference(self, text: paddle.Tensor, feats: Optional[paddle.Tensor]=None, pitch: Optional[paddle.Tensor]=None, energy: Optional[paddle.Tensor]=None, use_alignment_module: bool=False, **kwargs) -> Dict[str, paddle.Tensor]: """Run inference. Args: text (Tensor): Input text index tensor (T_text,). feats (Tensor): Feature tensor (T_feats, aux_channels). pitch (Tensor): Pitch tensor (T_feats, 1). energy (Tensor): Energy tensor (T_feats, 1). use_alignment_module (bool): Whether to use alignment module. Returns: Dict[str, Tensor]: * wav (Tensor): Generated waveform tensor (T_wav,). * duration (Tensor): Predicted duration tensor (T_text,). """ # setup text = text[None] text_lengths = paddle.to_tensor(paddle.shape(text)[1]) # inference if use_alignment_module: assert feats is not None feats = feats[None] feats_lengths = paddle.to_tensor(paddle.shape(feats)[1]) pitch = pitch[None] energy = energy[None] wav, dur = self.generator.inference( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, pitch=pitch, energy=energy, use_alignment_module=use_alignment_module, **kwargs) else: wav, dur = self.generator.inference( text=text, text_lengths=text_lengths, **kwargs, ) return dict(wav=paddle.reshape(wav, [-1]), duration=dur[0]) def reset_parameters(self): def _reset_parameters(module): if isinstance( module, (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)): kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) if fan_in != 0: bound = 1 / math.sqrt(fan_in) uniform_(module.bias, -bound, bound) if isinstance( module, (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)): ones_(module.weight) zeros_(module.bias) if isinstance(module, nn.Linear): kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 uniform_(module.bias, -bound, bound) if isinstance(module, nn.Embedding): normal_(module.weight) if module._padding_idx is not None: with paddle.no_grad(): module.weight[module._padding_idx] = 0 self.apply(_reset_parameters) class JETSInference(nn.Layer): def __init__(self, model): super().__init__() self.acoustic_model = model def forward(self, text, sids=None): out = self.acoustic_model.inference(text) wav = out['wav'] return wav