# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Generator module in JETS. This code is based on https://github.com/imdanboy/jets. """ """JETS module""" import math from typing import Any from typing import Dict from typing import Optional import paddle from paddle import nn from typeguard import typechecked from paddlespeech.t2s.models.hifigan import HiFiGANMultiPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator from paddlespeech.t2s.models.jets.generator import JETSGenerator from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out from paddlespeech.utils.initialize import kaiming_uniform_ from paddlespeech.utils.initialize import normal_ from paddlespeech.utils.initialize import ones_ from paddlespeech.utils.initialize import uniform_ from paddlespeech.utils.initialize import zeros_ AVAILABLE_GENERATERS = { "jets_generator": JETSGenerator, } AVAILABLE_DISCRIMINATORS = { "hifigan_period_discriminator": HiFiGANPeriodDiscriminator, "hifigan_scale_discriminator": HiFiGANScaleDiscriminator, "hifigan_multi_period_discriminator": HiFiGANMultiPeriodDiscriminator, "hifigan_multi_scale_discriminator": HiFiGANMultiScaleDiscriminator, "hifigan_multi_scale_multi_period_discriminator": HiFiGANMultiScaleMultiPeriodDiscriminator, } class JETS(nn.Layer): """JETS module (generator + discriminator). This is a module of JETS described in `JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech`_. .. _`JETS: Jointly Training FastSpeech2 and HiFi-GAN for End to End Text to Speech Text-to-Speech`: https://arxiv.org/abs/2203.16852v1 """ @typechecked def __init__( self, # generator related idim: int, odim: int, sampling_rate: int=22050, generator_type: str="jets_generator", generator_params: Dict[str, Any]={ "adim": 256, "aheads": 2, "elayers": 4, "eunits": 1024, "dlayers": 4, "dunits": 1024, "positionwise_layer_type": "conv1d", "positionwise_conv_kernel_size": 1, "use_scaled_pos_enc": True, "use_batch_norm": True, "encoder_normalize_before": True, "decoder_normalize_before": True, "encoder_concat_after": False, "decoder_concat_after": False, "reduction_factor": 1, "encoder_type": "transformer", "decoder_type": "transformer", "transformer_enc_dropout_rate": 0.1, "transformer_enc_positional_dropout_rate": 0.1, "transformer_enc_attn_dropout_rate": 0.1, "transformer_dec_dropout_rate": 0.1, "transformer_dec_positional_dropout_rate": 0.1, "transformer_dec_attn_dropout_rate": 0.1, "conformer_rel_pos_type": "latest", "conformer_pos_enc_layer_type": "rel_pos", "conformer_self_attn_layer_type": "rel_selfattn", "conformer_activation_type": "swish", "use_macaron_style_in_conformer": True, "use_cnn_in_conformer": True, "zero_triu": False, "conformer_enc_kernel_size": 7, "conformer_dec_kernel_size": 31, "duration_predictor_layers": 2, "duration_predictor_chans": 384, "duration_predictor_kernel_size": 3, "duration_predictor_dropout_rate": 0.1, "energy_predictor_layers": 2, "energy_predictor_chans": 384, "energy_predictor_kernel_size": 3, "energy_predictor_dropout": 0.5, "energy_embed_kernel_size": 1, "energy_embed_dropout": 0.5, "stop_gradient_from_energy_predictor": False, "pitch_predictor_layers": 5, "pitch_predictor_chans": 384, "pitch_predictor_kernel_size": 5, "pitch_predictor_dropout": 0.5, "pitch_embed_kernel_size": 1, "pitch_embed_dropout": 0.5, "stop_gradient_from_pitch_predictor": True, "generator_out_channels": 1, "generator_channels": 512, "generator_global_channels": -1, "generator_kernel_size": 7, "generator_upsample_scales": [8, 8, 2, 2], "generator_upsample_kernel_sizes": [16, 16, 4, 4], "generator_resblock_kernel_sizes": [3, 7, 11], "generator_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "generator_use_additional_convs": True, "generator_bias": True, "generator_nonlinear_activation": "LeakyReLU", "generator_nonlinear_activation_params": { "negative_slope": 0.1 }, "generator_use_weight_norm": True, "segment_size": 64, "spks": -1, "langs": -1, "spk_embed_dim": None, "spk_embed_integration_type": "add", "use_gst": False, "gst_tokens": 10, "gst_heads": 4, "gst_conv_layers": 6, "gst_conv_chans_list": [32, 32, 64, 64, 128, 128], "gst_conv_kernel_size": 3, "gst_conv_stride": 2, "gst_gru_layers": 1, "gst_gru_units": 128, "init_type": "xavier_uniform", "init_enc_alpha": 1.0, "init_dec_alpha": 1.0, "use_masking": False, "use_weighted_masking": False, }, # discriminator related discriminator_type: str="hifigan_multi_scale_multi_period_discriminator", discriminator_params: Dict[str, Any]={ "scales": 1, "scale_downsample_pooling": "AvgPool1D", "scale_downsample_pooling_params": { "kernel_size": 4, "stride": 2, "padding": 2, }, "scale_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [15, 41, 5, 3], "channels": 128, "max_downsample_channels": 1024, "max_groups": 16, "bias": True, "downsample_scales": [2, 2, 4, 4, 1], "nonlinear_activation": "leakyrelu", "nonlinear_activation_params": { "negative_slope": 0.1 }, "use_weight_norm": True, "use_spectral_norm": False, }, "follow_official_norm": False, "periods": [2, 3, 5, 7, 11], "period_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [5, 3], "channels": 32, "downsample_scales": [3, 3, 3, 3, 1], "max_downsample_channels": 1024, "bias": True, "nonlinear_activation": "leakyrelu", "nonlinear_activation_params": { "negative_slope": 0.1 }, "use_weight_norm": True, "use_spectral_norm": False, }, }, cache_generator_outputs: bool=True, ): """Initialize JETS module. Args: idim (int): Input vocabrary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since JETS is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. generator_type (str): Generator type. generator_params (Dict[str, Any]): Parameter dict for generator. discriminator_type (str): Discriminator type. discriminator_params (Dict[str, Any]): Parameter dict for discriminator. cache_generator_outputs (bool): Whether to cache generator outputs. """ super().__init__() # define modules generator_class = AVAILABLE_GENERATERS[generator_type] if generator_type == "jets_generator": # NOTE: Update parameters for the compatibility. # The idim and odim is automatically decided from input data, # where idim represents #vocabularies and odim represents # the input acoustic feature dimension. generator_params.update(idim=idim, odim=odim) self.generator = generator_class( **generator_params, ) discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] self.discriminator = discriminator_class( **discriminator_params, ) # cache self.cache_generator_outputs = cache_generator_outputs self._cache = None # store sampling rate for saving wav file # (not used for the training) self.fs = sampling_rate # store parameters for test compatibility self.spks = self.generator.spks self.langs = self.generator.langs self.spk_embed_dim = self.generator.spk_embed_dim self.reuse_cache_gen = True self.reuse_cache_dis = True self.reset_parameters() self.generator._reset_parameters( init_type=generator_params["init_type"], init_enc_alpha=generator_params["init_enc_alpha"], init_dec_alpha=generator_params["init_dec_alpha"], ) def forward( self, text: paddle.Tensor, text_lengths: paddle.Tensor, feats: paddle.Tensor, feats_lengths: paddle.Tensor, durations: paddle.Tensor, durations_lengths: paddle.Tensor, pitch: paddle.Tensor, energy: paddle.Tensor, sids: Optional[paddle.Tensor]=None, spembs: Optional[paddle.Tensor]=None, lids: Optional[paddle.Tensor]=None, forward_generator: bool=True, use_alignment_module: bool=False, **kwargs, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). durations(Tensor(int64)): Batch of padded durations (B, Tmax). durations_lengths (Tensor): durations length tensor (B,). pitch(Tensor): Batch of padded token-averaged pitch (B, Tmax, 1). energy(Tensor): Batch of padded token-averaged energy (B, Tmax, 1). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). forward_generator (bool): Whether to forward generator. use_alignment_module (bool): Whether to use alignment module. Returns: """ if forward_generator: return self._forward_generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, pitch=pitch, energy=energy, sids=sids, spembs=spembs, lids=lids, use_alignment_module=use_alignment_module, ) else: return self._forward_discrminator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, pitch=pitch, energy=energy, sids=sids, spembs=spembs, lids=lids, use_alignment_module=use_alignment_module, ) def _forward_generator( self, text: paddle.Tensor, text_lengths: paddle.Tensor, feats: paddle.Tensor, feats_lengths: paddle.Tensor, durations: paddle.Tensor, durations_lengths: paddle.Tensor, pitch: paddle.Tensor, energy: paddle.Tensor, sids: Optional[paddle.Tensor]=None, spembs: Optional[paddle.Tensor]=None, lids: Optional[paddle.Tensor]=None, use_alignment_module: bool=False, **kwargs, ) -> Dict[str, Any]: """Perform generator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). durations(Tensor(int64)): Batch of padded durations (B, Tmax). durations_lengths (Tensor): durations length tensor (B,). pitch(Tensor): Batch of padded token-averaged pitch (B, Tmax, 1). energy(Tensor): Batch of padded token-averaged energy (B, Tmax, 1). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). use_alignment_module (bool): Whether to use alignment module. Returns: """ # setup # calculate generator outputs self.reuse_cache_gen = True if not self.cache_generator_outputs or self._cache is None: self.reuse_cache_gen = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, pitch=pitch, energy=energy, sids=sids, spembs=spembs, lids=lids, use_alignment_module=use_alignment_module, ) else: outs = self._cache # store cache if self.training and self.cache_generator_outputs and not self.reuse_cache_gen: self._cache = outs return outs def _forward_discrminator( self, text: paddle.Tensor, text_lengths: paddle.Tensor, feats: paddle.Tensor, feats_lengths: paddle.Tensor, durations: paddle.Tensor, durations_lengths: paddle.Tensor, pitch: paddle.Tensor, energy: paddle.Tensor, sids: Optional[paddle.Tensor]=None, spembs: Optional[paddle.Tensor]=None, lids: Optional[paddle.Tensor]=None, use_alignment_module: bool=False, **kwargs, ) -> Dict[str, Any]: """Perform discriminator forward. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). durations(Tensor(int64)): Batch of padded durations (B, Tmax). durations_lengths (Tensor): durations length tensor (B,). pitch(Tensor): Batch of padded token-averaged pitch (B, Tmax, 1). energy(Tensor): Batch of padded token-averaged energy (B, Tmax, 1). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). use_alignment_module (bool): Whether to use alignment module. Returns: """ # setup # calculate generator outputs self.reuse_cache_dis = True if not self.cache_generator_outputs or self._cache is None: self.reuse_cache_dis = False outs = self.generator( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, durations=durations, durations_lengths=durations_lengths, pitch=pitch, energy=energy, sids=sids, spembs=spembs, lids=lids, use_alignment_module=use_alignment_module, **kwargs, ) else: outs = self._cache # store cache if self.cache_generator_outputs and not self.reuse_cache_dis: self._cache = outs return outs def inference(self, text: paddle.Tensor, feats: Optional[paddle.Tensor]=None, pitch: Optional[paddle.Tensor]=None, energy: Optional[paddle.Tensor]=None, use_alignment_module: bool=False, **kwargs) -> Dict[str, paddle.Tensor]: """Run inference. Args: text (Tensor): Input text index tensor (T_text,). feats (Tensor): Feature tensor (T_feats, aux_channels). pitch (Tensor): Pitch tensor (T_feats, 1). energy (Tensor): Energy tensor (T_feats, 1). use_alignment_module (bool): Whether to use alignment module. Returns: Dict[str, Tensor]: * wav (Tensor): Generated waveform tensor (T_wav,). * duration (Tensor): Predicted duration tensor (T_text,). """ # setup text = text[None] text_lengths = paddle.to_tensor(paddle.shape(text)[1]) # inference if use_alignment_module: assert feats is not None feats = feats[None] feats_lengths = paddle.to_tensor(paddle.shape(feats)[1]) pitch = pitch[None] energy = energy[None] wav, dur = self.generator.inference( text=text, text_lengths=text_lengths, feats=feats, feats_lengths=feats_lengths, pitch=pitch, energy=energy, use_alignment_module=use_alignment_module, **kwargs) else: wav, dur = self.generator.inference( text=text, text_lengths=text_lengths, **kwargs, ) return dict(wav=paddle.reshape(wav, [-1]), duration=dur[0]) def reset_parameters(self): def _reset_parameters(module): if isinstance( module, (nn.Conv1D, nn.Conv1DTranspose, nn.Conv2D, nn.Conv2DTranspose)): kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) if fan_in != 0: bound = 1 / math.sqrt(fan_in) uniform_(module.bias, -bound, bound) if isinstance( module, (nn.BatchNorm1D, nn.BatchNorm2D, nn.GroupNorm, nn.LayerNorm)): ones_(module.weight) zeros_(module.bias) if isinstance(module, nn.Linear): kaiming_uniform_(module.weight, a=math.sqrt(5)) if module.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 uniform_(module.bias, -bound, bound) if isinstance(module, nn.Embedding): normal_(module.weight) if module._padding_idx is not None: with paddle.no_grad(): module.weight[module._padding_idx] = 0 self.apply(_reset_parameters) class JETSInference(nn.Layer): def __init__(self, model): super().__init__() self.acoustic_model = model def forward(self, text, sids=None): out = self.acoustic_model.inference(text) wav = out['wav'] return wav