From df37798598e8f32475892af819377101ace6d0a5 Mon Sep 17 00:00:00 2001 From: longRookie <68834517+longRookie@users.noreply.github.com> Date: Mon, 10 Apr 2023 11:16:13 +0800 Subject: [PATCH] =?UTF-8?q?[TTS]=E3=80=90Hackathon=20+=20No.190=E3=80=91?= =?UTF-8?q?=20+=20=E6=A8=A1=E5=9E=8B=E5=A4=8D=E7=8E=B0=EF=BC=9AiSTFTNet=20?= =?UTF-8?q?(#3006)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * iSTFTNet implementation based on hifigan, not affect the function and execution of HIFIGAN * modify the comment in iSTFT.yaml * add the comments in hifigan * iSTFTNet implementation based on hifigan, not affect the function and execution of HIFIGAN * modify the comment in iSTFT.yaml * add the comments in hifigan * add iSTFTNet.md * modify the format of iSTFTNet.md * modify iSTFT.yaml and hifigan.py * Format code using pre-commit * modify hifigan.py,delete the unused self.istft_layer_id , move the self.output_conv behind else, change conv_post to output_conv * update iSTFTNet_csmsc_ckpt.zip download link * modify iSTFTNet.md * modify hifigan.py and iSTFT.yaml * modify iSTFTNet.md --- examples/csmsc/voc5/conf/iSTFT.yaml | 174 +++++++++++++++++++++ examples/csmsc/voc5/iSTFTNet.md | 145 +++++++++++++++++ paddlespeech/t2s/models/hifigan/hifigan.py | 82 ++++++++-- 3 files changed, 389 insertions(+), 12 deletions(-) create mode 100644 examples/csmsc/voc5/conf/iSTFT.yaml create mode 100644 examples/csmsc/voc5/iSTFTNet.md diff --git a/examples/csmsc/voc5/conf/iSTFT.yaml b/examples/csmsc/voc5/conf/iSTFT.yaml new file mode 100644 index 00000000..06677d79 --- /dev/null +++ b/examples/csmsc/voc5/conf/iSTFT.yaml @@ -0,0 +1,174 @@ +# This is the configuration file for CSMSC dataset. +# This configuration is based on HiFiGAN V1, which is an official configuration. +# But I found that the optimizer setting does not work well with my implementation. +# So I changed optimizer settings as follows: +# - AdamW -> Adam +# - betas: [0.8, 0.99] -> betas: [0.5, 0.9] +# - Scheduler: ExponentialLR -> MultiStepLR +# To match the shift size difference, the upsample scales is also modified from the original 256 shift setting. + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + use_istft: True # Use iSTFTNet. + istft_layer_id: 2 # Use istft after istft_layer_id layers of upsample layer if use_istft=True. + n_fft: 2048 # FFT size (samples) in feature extraction. + win_length: 1200 # Window length (samples) in feature extraction. + in_channels: 80 # Number of input channels. + out_channels: 1 # Number of output channels. + channels: 512 # Number of initial channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + upsample_scales: [5, 5, 4, 3] # Upsampling scales. + upsample_kernel_sizes: [10, 10, 8, 6] # Kernel size for upsampling layers. + resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks. + resblock_dilations: # Dilations for residual blocks. + - [1, 3, 5] + - [1, 3, 5] + - [1, 3, 5] + use_additional_convs: True # Whether to use additional conv layer in residual blocks. + bias: True # Whether to use bias parameter in conv. + nonlinear_activation: "leakyrelu" # Nonlinear activation type. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + + + + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + scales: 3 # Number of multi-scale discriminator. + scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator. + scale_downsample_pooling_params: + kernel_size: 4 # Pooling kernel size. + stride: 2 # Pooling stride. + padding: 2 # Padding size. + scale_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. + channels: 128 # Initial number of channels. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + max_groups: 16 # Maximum number of groups in downsampling conv layers. + bias: True + downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: + negative_slope: 0.1 + follow_official_norm: True # Whether to follow the official norm setting. + periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. + period_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [5, 3] # List of kernel sizes. + channels: 32 # Initial number of channels. + downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + bias: True # Whether to use bias parameter in conv layer." + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + use_spectral_norm: False # Whether to apply spectral normalization. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: False # Whether to use multi-resolution STFT loss. +use_mel_loss: True # Whether to use Mel-spectrogram loss. +mel_loss_params: + fs: 24000 + fft_size: 2048 + hop_size: 300 + win_length: 1200 + window: "hann" + num_mels: 80 + fmin: 0 + fmax: 12000 + log_base: null +generator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +discriminator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +use_feat_match_loss: True +feat_match_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. + average_by_layers: False # Whether to average loss by #layers in each discriminator. + include_final_outputs: False # Whether to include final outputs in feat match loss calculation. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_aux: 45.0 # Loss balancing coefficient for STFT loss. +lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss. +lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 16 # Batch size. +batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 2 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 2.0e-4 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +generator_grad_norm: -1 # Generator's gradient norm. +discriminator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 2.0e-4 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_grad_norm: -1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +generator_train_start_steps: 1 # Number of steps to start to train discriminator. +discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. +train_max_steps: 2500000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/examples/csmsc/voc5/iSTFTNet.md b/examples/csmsc/voc5/iSTFTNet.md new file mode 100644 index 00000000..8f121938 --- /dev/null +++ b/examples/csmsc/voc5/iSTFTNet.md @@ -0,0 +1,145 @@ +# iSTFTNet with CSMSC + +This example contains code used to train a [iSTFTNet](https://arxiv.org/abs/2203.02395) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). + +## Dataset +### Download and Extract +Download CSMSC from it's [official website](https://test.data-baker.com/data/index/TNtts/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/BZNSYP`. + +### Get MFA Result and Extract +We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence at the edge of audio. +You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/BZNSYP`. +Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`. +Run the command below to +1. **source path**. +2. preprocess the dataset. +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. +```bash +./run.sh +``` +You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset. +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Data Preprocessing +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│ ├── norm +│ └── raw +├── test +│ ├── norm +│ └── raw +└── train + ├── norm + ├── raw + └── feats_stats.npy +``` +The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance. + +### Model Training +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +`./local/train.sh` calls `${BIN_DIR}/train.py`. +Here's the complete help message. + +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] + +Train a HiFiGAN model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG HiFiGAN config file. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. +``` + +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/iSTFT.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory. +4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +### Synthesizing +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG] + [--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA] + [--output-dir OUTPUT_DIR] [--ngpu NGPU] + +Synthesize with GANVocoder. + +optional arguments: + -h, --help show this help message and exit + --generator-type GENERATOR_TYPE + type of GANVocoder, should in {pwgan, mb_melgan, + style_melgan, } now + --config CONFIG GANVocoder config file. + --checkpoint CHECKPOINT + snapshot to load. + --test-metadata TEST_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. +``` + +1. `--config` config file. You should use the same config with which the model is trained. +2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory. +3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory. +4. `--output-dir` is the directory to save the synthesized audio files. +5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +## Pretrained Models + +The pretrained model can be downloaded here: + +- [iSTFTNet_csmsc_ckpt.zip](https://pan.baidu.com/s/1SNDlRWOGOcbbrKf5w-TJaA?pwd=r1e5) + +iSTFTNet checkpoint contains files listed below. + +```text +iSTFTNet_csmsc_ckpt +├── iSTFT.yaml                  # config used to train iSTFTNet +├── feats_stats.npy               # statistics used to normalize spectrogram when training hifigan +└── snapshot_iter_50000.pdz     # generator parameters of hifigan +``` + +A Comparison between iSTFTNet and Hifigan +| Model | Step | eval/generator_loss | eval/mel_loss | eval/feature_matching_loss | rtf | +|:--------:|:--------------:|:-------------------:|:-------------:|:--------------------------:| :---: | +| hifigan | 1(gpu) x 50000 | 13.989 | 0.14683 | 1.3484 | 0.01767 | +| istftNet | 1(gpu) x 50000 | 13.319 | 0.14818 | 1.1069 | 0.01069 | + +> Rtf is tested on the CSMSC test dataset, and the test environment is aistudio v100 16G 1GPU, the test command is `./run.sh --stage 2 --stop-stage 2` + +The pretained hifigan model int the comparison can be downloaded here: + +- [hifigan_csmsc_ckpt.zip](https://pan.baidu.com/s/1pGY6RYV7yEB_5hRI_JoWig?pwd=tcaj) + +## Acknowledgement + +We adapted some code from https://github.com/rishikksh20/iSTFTNet-pytorch.git. diff --git a/paddlespeech/t2s/models/hifigan/hifigan.py b/paddlespeech/t2s/models/hifigan/hifigan.py index 7a01840e..2759af9d 100644 --- a/paddlespeech/t2s/models/hifigan/hifigan.py +++ b/paddlespeech/t2s/models/hifigan/hifigan.py @@ -37,8 +37,8 @@ class HiFiGANGenerator(nn.Layer): channels: int=512, global_channels: int=-1, kernel_size: int=7, - upsample_scales: List[int]=(8, 8, 2, 2), - upsample_kernel_sizes: List[int]=(16, 16, 4, 4), + upsample_scales: List[int]=(5, 5, 4, 3), + upsample_kernel_sizes: List[int]=(10, 10, 8, 6), resblock_kernel_sizes: List[int]=(3, 7, 11), resblock_dilations: List[List[int]]=[(1, 3, 5), (1, 3, 5), (1, 3, 5)], @@ -47,8 +47,13 @@ class HiFiGANGenerator(nn.Layer): nonlinear_activation: str="leakyrelu", nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.1}, use_weight_norm: bool=True, - init_type: str="xavier_uniform", ): + init_type: str="xavier_uniform", + use_istft: bool=False, + istft_layer_id: int=2, + n_fft: int=2048, + win_length: int=1200, ): """Initialize HiFiGANGenerator module. + Args: in_channels (int): Number of input channels. @@ -79,6 +84,14 @@ class HiFiGANGenerator(nn.Layer): use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. + use_istft (bool): + If set to true, it will be a iSTFTNet based on hifigan. + istft_layer_id (int): + Use istft after istft_layer_id layers of upsample layer if use_istft=True + n_fft (int): + Number of fft points in feature extraction + win_length (int): + Window length in feature extraction """ super().__init__() @@ -89,9 +102,11 @@ class HiFiGANGenerator(nn.Layer): assert kernel_size % 2 == 1, "Kernel size must be odd number." assert len(upsample_scales) == len(upsample_kernel_sizes) assert len(resblock_dilations) == len(resblock_kernel_sizes) + assert len(upsample_scales) >= istft_layer_id if use_istft else True # define modules - self.num_upsamples = len(upsample_kernel_sizes) + self.num_upsamples = len( + upsample_kernel_sizes) if not use_istft else istft_layer_id self.num_blocks = len(resblock_kernel_sizes) self.input_conv = nn.Conv1D( in_channels, @@ -101,7 +116,7 @@ class HiFiGANGenerator(nn.Layer): padding=(kernel_size - 1) // 2, ) self.upsamples = nn.LayerList() self.blocks = nn.LayerList() - for i in range(len(upsample_kernel_sizes)): + for i in range(self.num_upsamples): assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] self.upsamples.append( nn.Sequential( @@ -126,15 +141,36 @@ class HiFiGANGenerator(nn.Layer): nonlinear_activation=nonlinear_activation, nonlinear_activation_params=nonlinear_activation_params, )) - self.output_conv = nn.Sequential( - nn.LeakyReLU(), - nn.Conv1D( + self.use_istft = use_istft + if self.use_istft: + self.istft_hop_size = 1 + for j in range(istft_layer_id, len(upsample_scales)): + self.istft_hop_size *= upsample_scales[j] + s = 1 + for j in range(istft_layer_id): + s *= upsample_scales[j] + self.istft_n_fft = int(n_fft / s) if ( + n_fft / s) % 2 == 0 else int((n_fft / s + 2) - n_fft / s % 2) + self.istft_win_length = int(win_length / s) if ( + win_length / + s) % 2 == 0 else int((win_length / s + 2) - win_length / s % 2) + self.reflection_pad = nn.Pad1D(padding=[1, 0], mode='reflect') + self.output_conv = nn.Conv1D( channels // (2**(i + 1)), - out_channels, + (self.istft_n_fft // 2 + 1) * 2, kernel_size, 1, - padding=(kernel_size - 1) // 2, ), - nn.Tanh(), ) + padding=(kernel_size - 1) // 2, ) + else: + self.output_conv = nn.Sequential( + nn.LeakyReLU(), + nn.Conv1D( + channels // (2**(i + 1)), + out_channels, + kernel_size, + 1, + padding=(kernel_size - 1) // 2, ), + nn.Tanh(), ) if global_channels > 0: self.global_conv = nn.Conv1D(global_channels, channels, 1) @@ -167,7 +203,29 @@ class HiFiGANGenerator(nn.Layer): for j in range(self.num_blocks): cs += self.blocks[i * self.num_blocks + j](c) c = cs / self.num_blocks - c = self.output_conv(c) + + if self.use_istft: + c = F.leaky_relu(c) + c = self.reflection_pad(c) + c = self.output_conv(c) + """ + Input of Exp operator, an N-D Tensor, with data type float32, float64 or float16. + https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/exp_en.html + Use Euler's formula to implement spec*paddle.exp(1j*phase) + """ + spec = paddle.exp(c[:, :self.istft_n_fft // 2 + 1, :]) + phase = paddle.sin(c[:, self.istft_n_fft // 2 + 1:, :]) + + c = paddle.complex(spec * (paddle.cos(phase)), + spec * (paddle.sin(phase))) + c = paddle.signal.istft( + c, + n_fft=self.istft_n_fft, + hop_length=self.istft_hop_size, + win_length=self.istft_win_length) + c = c.unsqueeze(1) + else: + c = self.output_conv(c) return c