From faaa75882be8bd8ea3eec67758a7fa1d2568748f Mon Sep 17 00:00:00 2001 From: longRookie Date: Tue, 7 Mar 2023 09:57:23 +0800 Subject: [PATCH] iSTFTNet implementation based on hifigan, not affect the function and execution of HIFIGAN --- examples/csmsc/voc5/conf/iSTFT.yaml | 174 +++++++++++++++++++++ paddlespeech/t2s/models/hifigan/hifigan.py | 34 +++- 2 files changed, 206 insertions(+), 2 deletions(-) create mode 100644 examples/csmsc/voc5/conf/iSTFT.yaml diff --git a/examples/csmsc/voc5/conf/iSTFT.yaml b/examples/csmsc/voc5/conf/iSTFT.yaml new file mode 100644 index 000000000..898cf1e68 --- /dev/null +++ b/examples/csmsc/voc5/conf/iSTFT.yaml @@ -0,0 +1,174 @@ +# This is the configuration file for CSMSC dataset. +# This configuration is based on HiFiGAN V1, which is an official configuration. +# But I found that the optimizer setting does not work well with my implementation. +# So I changed optimizer settings as follows: +# - AdamW -> Adam +# - betas: [0.8, 0.99] -> betas: [0.5, 0.9] +# - Scheduler: ExponentialLR -> MultiStepLR +# To match the shift size difference, the upsample scales is also modified from the original 256 shift setting. + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + istft: True # use iSTFTNet + post_n_fft: 48 # stft fft_num + gen_istft_hop_size: 12 # istft hop_length 原始为4 + gen_istft_n_fft: 48 # istft n_fft + in_channels: 80 # Number of input channels. + out_channels: 1 # Number of output channels. + channels: 512 # Number of initial channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + upsample_scales: [5,5] # Upsampling scales. + upsample_kernel_sizes: [10,10] # Kernel size for upsampling layers. + resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks. + resblock_dilations: # Dilations for residual blocks. + - [1, 3, 5] + - [1, 3, 5] + - [1, 3, 5] + use_additional_convs: True # Whether to use additional conv layer in residual blocks. + bias: True # Whether to use bias parameter in conv. + nonlinear_activation: "leakyrelu" # Nonlinear activation type. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + + + + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + scales: 3 # Number of multi-scale discriminator. + scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator. + scale_downsample_pooling_params: + kernel_size: 4 # Pooling kernel size. + stride: 2 # Pooling stride. + padding: 2 # Padding size. + scale_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. + channels: 128 # Initial number of channels. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + max_groups: 16 # Maximum number of groups in downsampling conv layers. + bias: True + downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: + negative_slope: 0.1 + follow_official_norm: True # Whether to follow the official norm setting. + periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. + period_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [5, 3] # List of kernel sizes. + channels: 32 # Initial number of channels. + downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + bias: True # Whether to use bias parameter in conv layer." + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + use_spectral_norm: False # Whether to apply spectral normalization. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: False # Whether to use multi-resolution STFT loss. +use_mel_loss: True # Whether to use Mel-spectrogram loss. +mel_loss_params: + fs: 24000 + fft_size: 2048 + hop_size: 300 + win_length: 1200 + window: "hann" + num_mels: 80 + fmin: 0 + fmax: 12000 + log_base: null +generator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +discriminator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +use_feat_match_loss: True +feat_match_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. + average_by_layers: False # Whether to average loss by #layers in each discriminator. + include_final_outputs: False # Whether to include final outputs in feat match loss calculation. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_aux: 45.0 # Loss balancing coefficient for STFT loss. +lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss. +lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 16 # Batch size. +batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 2.0e-4 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +generator_grad_norm: -1 # Generator's gradient norm. +discriminator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 2.0e-4 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_grad_norm: -1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +generator_train_start_steps: 1 # Number of steps to start to train discriminator. +discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. +train_max_steps: 50000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/paddlespeech/t2s/models/hifigan/hifigan.py b/paddlespeech/t2s/models/hifigan/hifigan.py index 7a01840e2..b6542c6cd 100644 --- a/paddlespeech/t2s/models/hifigan/hifigan.py +++ b/paddlespeech/t2s/models/hifigan/hifigan.py @@ -21,6 +21,7 @@ from typing import Optional import paddle import paddle.nn.functional as F from paddle import nn +import numpy as np from paddlespeech.t2s.modules.activation import get_activation from paddlespeech.t2s.modules.nets_utils import initialize @@ -47,7 +48,12 @@ class HiFiGANGenerator(nn.Layer): nonlinear_activation: str="leakyrelu", nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.1}, use_weight_norm: bool=True, - init_type: str="xavier_uniform", ): + init_type: str="xavier_uniform", + istft: bool = False, + post_n_fft: int=16, + gen_istft_hop_size: int=12, + gen_istft_n_fft: int=16, + ): """Initialize HiFiGANGenerator module. Args: in_channels (int): @@ -135,6 +141,13 @@ class HiFiGANGenerator(nn.Layer): 1, padding=(kernel_size - 1) // 2, ), nn.Tanh(), ) + self.istft = istft + if self.istft: + self.post_n_fft = post_n_fft + self.gen_istft_hop_size = gen_istft_hop_size + self.gen_istft_n_fft = gen_istft_n_fft + self.reflection_pad = nn.Pad1D(padding=[1,0], mode='reflect') + self.conv_post = nn.Conv1D(channels// (2**(i + 1)), self.post_n_fft + 2, 7, 1, padding=3) if global_channels > 0: self.global_conv = nn.Conv1D(global_channels, channels, 1) @@ -167,7 +180,24 @@ class HiFiGANGenerator(nn.Layer): for j in range(self.num_blocks): cs += self.blocks[i * self.num_blocks + j](c) c = cs / self.num_blocks - c = self.output_conv(c) + + if self.istft: + c = F.leaky_relu(c) + c = self.reflection_pad(c) + c = self.conv_post(c) + """ + Input of Exp operator, an N-D Tensor, with data type float32, float64 or float16. + https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/exp_en.html + Use Euler's formula to implement spec*paddle.exp(1j*phase) + """ + spec = paddle.exp(c[:,:self.post_n_fft // 2 + 1, :]) + phase = paddle.sin(c[:, self.post_n_fft // 2 + 1:, :]) + + c = paddle.complex(spec*(paddle.cos(phase)), spec*(paddle.sin(phase))) + c = paddle.signal.istft(c , self.gen_istft_n_fft, hop_length=self.gen_istft_hop_size, win_length=self.gen_istft_n_fft) + c = c.unsqueeze(1) + else: + c = self.output_conv(c) return c