diff --git a/examples/aishell3/voc1/README.md b/examples/aishell3/voc1/README.md index 9189eb72..de7e04a6 100644 --- a/examples/aishell3/voc1/README.md +++ b/examples/aishell3/voc1/README.md @@ -105,7 +105,7 @@ benchmark: 4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ### Synthesizing -`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. ```bash CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} ``` diff --git a/examples/aishell3/voc1/conf/default.yaml b/examples/aishell3/voc1/conf/default.yaml index ba2d9f2e..eb6d350d 100644 --- a/examples/aishell3/voc1/conf/default.yaml +++ b/examples/aishell3/voc1/conf/default.yaml @@ -35,7 +35,7 @@ generator_params: dropout: 0.0 # Dropout rate. 0.0 means no dropout applied. use_weight_norm: true # Whether to use weight norm. # If set to true, it will be applied to all of the conv layers. - upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size. + upsample_scales: [4, 5, 3, 5] # Upsampling scales. prod(upsample_scales) == n_shift ########################################################### # DISCRIMINATOR NETWORK ARCHITECTURE SETTING # @@ -71,7 +71,7 @@ lambda_adv: 4.0 # Loss balancing coefficient. # DATA LOADER SETTING # ########################################################### batch_size: 8 # Batch size. -batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by hop_size. +batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift. pin_memory: true # Whether to pin memory in Pytorch DataLoader. num_workers: 4 # Number of workers in Pytorch DataLoader. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. diff --git a/examples/aishell3/voc1/local/synthesize.sh b/examples/aishell3/voc1/local/synthesize.sh index 9f904ac0..d85d1b1d 100755 --- a/examples/aishell3/voc1/local/synthesize.sh +++ b/examples/aishell3/voc1/local/synthesize.sh @@ -6,8 +6,9 @@ ckpt_name=$3 FLAGS_allocator_strategy=naive_best_fit \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/synthesize.py \ +python3 ${BIN_DIR}/../synthesize.py \ --config=${config_path} \ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ --test-metadata=dump/test/norm/metadata.jsonl \ - --output-dir=${train_output_path}/test + --output-dir=${train_output_path}/test \ + --generator-type=pwgan diff --git a/examples/csmsc/README.md b/examples/csmsc/README.md index 08a51349..a59a06ed 100644 --- a/examples/csmsc/README.md +++ b/examples/csmsc/README.md @@ -9,3 +9,4 @@ * voc1 - Parallel WaveGAN * voc2 - MelGAN * voc3 - MultiBand MelGAN +* voc4 - Style MelGAN diff --git a/examples/csmsc/voc1/README.md b/examples/csmsc/voc1/README.md index e6ee7b4a..b13d5896 100644 --- a/examples/csmsc/voc1/README.md +++ b/examples/csmsc/voc1/README.md @@ -95,7 +95,7 @@ benchmark: 4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ### Synthesizing -`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. ```bash CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} ``` diff --git a/examples/csmsc/voc1/conf/default.yaml b/examples/csmsc/voc1/conf/default.yaml index 1363b454..21bdd040 100644 --- a/examples/csmsc/voc1/conf/default.yaml +++ b/examples/csmsc/voc1/conf/default.yaml @@ -78,7 +78,7 @@ lambda_adv: 4.0 # Loss balancing coefficient. # DATA LOADER SETTING # ########################################################### batch_size: 8 # Batch size. -batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size. +batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by n_shift. pin_memory: true # Whether to pin memory in Pytorch DataLoader. num_workers: 2 # Number of workers in Pytorch DataLoader. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. @@ -88,23 +88,23 @@ allow_cache: true # Whether to allow cache in dataset. If true, it requ # OPTIMIZER & SCHEDULER SETTING # ########################################################### generator_optimizer_params: - epsilon: 1.0e-6 # Generator's epsilon. + epsilon: 1.0e-6 # Generator's epsilon. weight_decay: 0.0 # Generator's weight decay coefficient. generator_scheduler_params: - learning_rate: 0.0001 # Generator's learning rate. + learning_rate: 0.0001 # Generator's learning rate. step_size: 200000 # Generator's scheduler step size. gamma: 0.5 # Generator's scheduler gamma. # At each step size, lr will be multiplied by this parameter. generator_grad_norm: 10 # Generator's gradient norm. discriminator_optimizer_params: epsilon: 1.0e-6 # Discriminator's epsilon. - weight_decay: 0.0 # Discriminator's weight decay coefficient. + weight_decay: 0.0 # Discriminator's weight decay coefficient. discriminator_scheduler_params: - learning_rate: 0.00005 # Discriminator's learning rate. - step_size: 200000 # Discriminator's scheduler step size. - gamma: 0.5 # Discriminator's scheduler gamma. - # At each step size, lr will be multiplied by this parameter. -discriminator_grad_norm: 1 # Discriminator's gradient norm. + learning_rate: 0.00005 # Discriminator's learning rate. + step_size: 200000 # Discriminator's scheduler step size. + gamma: 0.5 # Discriminator's scheduler gamma. + # At each step size, lr will be multiplied by this parameter. +discriminator_grad_norm: 1 # Discriminator's gradient norm. ########################################################### # INTERVAL SETTING # diff --git a/examples/csmsc/voc1/local/synthesize.sh b/examples/csmsc/voc1/local/synthesize.sh index 9f904ac0..d85d1b1d 100755 --- a/examples/csmsc/voc1/local/synthesize.sh +++ b/examples/csmsc/voc1/local/synthesize.sh @@ -6,8 +6,9 @@ ckpt_name=$3 FLAGS_allocator_strategy=naive_best_fit \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/synthesize.py \ +python3 ${BIN_DIR}/../synthesize.py \ --config=${config_path} \ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ --test-metadata=dump/test/norm/metadata.jsonl \ - --output-dir=${train_output_path}/test + --output-dir=${train_output_path}/test \ + --generator-type=pwgan diff --git a/examples/csmsc/voc3/README.md b/examples/csmsc/voc3/README.md index 52ca51e9..99cef233 100644 --- a/examples/csmsc/voc3/README.md +++ b/examples/csmsc/voc3/README.md @@ -80,7 +80,7 @@ optional arguments: 4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ### Synthesizing -`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. ```bash CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} ``` diff --git a/examples/csmsc/voc3/conf/default.yaml b/examples/csmsc/voc3/conf/default.yaml index 5dda835a..e66d98a6 100644 --- a/examples/csmsc/voc3/conf/default.yaml +++ b/examples/csmsc/voc3/conf/default.yaml @@ -6,7 +6,7 @@ # This configuration is based on full-band MelGAN but the hop size and sampling # rate is different from the paper (16kHz vs 24kHz). The number of iteraions # is not shown in the paper so currently we train 1M iterations (not sure enough -# to converge). +# to converge). ########################################################### # FEATURE EXTRACTION SETTING # @@ -29,7 +29,7 @@ generator_params: out_channels: 4 # Number of output channels. kernel_size: 7 # Kernel size of initial and final conv layers. channels: 384 # Initial number of channels for conv layers. - upsample_scales: [5, 5, 3] # List of Upsampling scales. + upsample_scales: [5, 5, 3] # List of Upsampling scales. prod(upsample_scales) == n_shift stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack. stacks: 4 # Number of stacks in a single residual stack module. use_weight_norm: True # Whether to use weight normalization. @@ -66,7 +66,7 @@ discriminator_params: use_stft_loss: true stft_loss_params: fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. - hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss. win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. window: "hann" # Window function for STFT-based loss use_subband_stft_loss: true @@ -86,7 +86,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. # DATA LOADER SETTING # ########################################################### batch_size: 64 # Batch size. -batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size. +batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by n_shift. num_workers: 2 # Number of workers in DataLoader. ########################################################### @@ -108,7 +108,7 @@ generator_scheduler_params: - 500000 - 600000 discriminator_optimizer_params: - epsilon: 1.0e-7 # Discriminator's epsilon. + epsilon: 1.0e-7 # Discriminator's epsilon. weight_decay: 0.0 # Discriminator's weight decay coefficient. discriminator_grad_norm: -1 # Discriminator's gradient norm. @@ -128,7 +128,7 @@ discriminator_scheduler_params: ########################################################### discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator. train_max_steps: 1000000 # Number of training steps. -save_interval_steps: 5000 # Interval steps to save checkpoint. +save_interval_steps: 5000 # Interval steps to save checkpoint. eval_interval_steps: 1000 # Interval steps to evaluate the network. ########################################################### diff --git a/examples/csmsc/voc3/conf/finetune.yaml b/examples/csmsc/voc3/conf/finetune.yaml index 30227401..8610c526 100644 --- a/examples/csmsc/voc3/conf/finetune.yaml +++ b/examples/csmsc/voc3/conf/finetune.yaml @@ -29,7 +29,7 @@ generator_params: out_channels: 4 # Number of output channels. kernel_size: 7 # Kernel size of initial and final conv layers. channels: 384 # Initial number of channels for conv layers. - upsample_scales: [5, 5, 3] # List of Upsampling scales. + upsample_scales: [5, 5, 3] # List of Upsampling scales. prod(upsample_scales) == n_shift stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack. stacks: 4 # Number of stacks in a single residual stack module. use_weight_norm: True # Whether to use weight normalization. @@ -72,7 +72,7 @@ stft_loss_params: use_subband_stft_loss: true subband_stft_loss_params: fft_sizes: [384, 683, 171] # List of FFT size for STFT-based loss. - hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss + hop_sizes: [30, 60, 10] # List of hop size for STFT-based loss. win_lengths: [150, 300, 60] # List of window length for STFT-based loss. window: "hann" # Window function for STFT-based loss @@ -86,7 +86,7 @@ lambda_adv: 2.5 # Loss balancing coefficient for adversarial loss. # DATA LOADER SETTING # ########################################################### batch_size: 64 # Batch size. -batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by hop_size. +batch_max_steps: 16200 # Length of each audio in batch. Make sure dividable by n_shift. num_workers: 2 # Number of workers in DataLoader. ########################################################### @@ -108,7 +108,7 @@ generator_scheduler_params: - 500000 - 600000 discriminator_optimizer_params: - epsilon: 1.0e-7 # Discriminator's epsilon. + epsilon: 1.0e-7 # Discriminator's epsilon. weight_decay: 0.0 # Discriminator's weight decay coefficient. discriminator_grad_norm: -1 # Discriminator's gradient norm. @@ -128,7 +128,7 @@ discriminator_scheduler_params: ########################################################### discriminator_train_start_steps: 200000 # Number of steps to start to train discriminator. train_max_steps: 2000000 # Number of training steps. -save_interval_steps: 1000 # Interval steps to save checkpoint. +save_interval_steps: 1000 # Interval steps to save checkpoint. eval_interval_steps: 1000 # Interval steps to evaluate the network. ########################################################### diff --git a/examples/csmsc/voc3/local/synthesize.sh b/examples/csmsc/voc3/local/synthesize.sh index 9f904ac0..22d879fa 100755 --- a/examples/csmsc/voc3/local/synthesize.sh +++ b/examples/csmsc/voc3/local/synthesize.sh @@ -6,8 +6,9 @@ ckpt_name=$3 FLAGS_allocator_strategy=naive_best_fit \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/synthesize.py \ +python3 ${BIN_DIR}/../synthesize.py \ --config=${config_path} \ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ --test-metadata=dump/test/norm/metadata.jsonl \ - --output-dir=${train_output_path}/test + --output-dir=${train_output_path}/test \ + --generator-type=mb_melgan diff --git a/examples/csmsc/voc4/README.md b/examples/csmsc/voc4/README.md new file mode 100644 index 00000000..86030e39 --- /dev/null +++ b/examples/csmsc/voc4/README.md @@ -0,0 +1,111 @@ +# Style MelGAN with CSMSC +This example contains code used to train a [Style MelGAN](https://arxiv.org/abs/2011.01557) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html). +## Dataset +### Download and Extract +Download CSMSC from the [official website](https://www.data-baker.com/data/index/source) and extract it to `~/datasets`. Then the dataset is in directory `~/datasets/BZNSYP`. + +### Get MFA Result and Extract +We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut silence in the edge of audio. +You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your own MFA model reference to [mfa example](https://github.com/PaddlePaddle/Parakeet/tree/develop/examples/mfa) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/BZNSYP`. +Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`. +Run the command below to +1. **source path**. +2. preprocess the dataset. +3. train the model. +4. synthesize wavs. + - synthesize waveform from `metadata.jsonl`. +```bash +./run.sh +``` +You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, run the following command will only preprocess the dataset. +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Data Preprocessing +```bash +./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. + +```text +dump +├── dev +│ ├── norm +│ └── raw +├── test +│ ├── norm +│ └── raw +└── train + ├── norm + ├── raw + └── feats_stats.npy +``` +The dataset is split into 3 parts, namely `train`, `dev` and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains log magnitude of mel spectrogram of each utterances, while the norm folder contains normalized spectrogram. The statistics used to normalize the spectrogram is computed from the training set, which is located in `dump/train/feats_stats.npy`. + +Also there is a `metadata.jsonl` in each subfolder. It is a table-like file which contains id and paths to spectrogam of each utterance. + +### Model Training +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +`./local/train.sh` calls `${BIN_DIR}/train.py`. +Here's the complete help message. + +```text +usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA] + [--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] [--verbose VERBOSE] + +Train a Multi-Band MelGAN model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG config file to overwrite default config. + --train-metadata TRAIN_METADATA + training data. + --dev-metadata DEV_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. + --verbose VERBOSE verbose. +``` + +1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`. +2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder. +3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are save in `checkpoints/` inside this directory. +4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. + +### Synthesizing +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +```text +usage: synthesize.py [-h] [--config CONFIG] [--checkpoint CHECKPOINT] + [--test-metadata TEST_METADATA] [--output-dir OUTPUT_DIR] + [--ngpu NGPU] [--verbose VERBOSE] + +Synthesize with multi band melgan. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG multi band melgan config file. + --checkpoint CHECKPOINT + snapshot to load. + --test-metadata TEST_METADATA + dev data. + --output-dir OUTPUT_DIR + output dir. + --ngpu NGPU if ngpu == 0, use cpu. + --verbose VERBOSE verbose. +``` + +1. `--config` multi band melgan config file. You should use the same config with which the model is trained. +2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory. +3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory. +4. `--output-dir` is the directory to save the synthesized audio files. +5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. diff --git a/examples/csmsc/voc4/conf/default.yaml b/examples/csmsc/voc4/conf/default.yaml new file mode 100644 index 00000000..cad4cf9b --- /dev/null +++ b/examples/csmsc/voc4/conf/default.yaml @@ -0,0 +1,136 @@ +# This is the configuration file for CSMSC dataset.This configuration is based +# on StyleMelGAN paper but uses MSE loss instead of Hinge loss. And I found that +# batch_size = 8 is also working good. So maybe if you want to accelerate the training, +# you can reduce the batch size (e.g. 8 or 16). Upsampling scales is modified to +# fit the shift size 300 pt. +# NOTE: batch_max_steps(24000) == prod(noise_upsample_scales)(80) * prod(upsample_scales)(300) + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 2048 # FFT size. (in samples) +n_shift: 300 # Hop size. (in samples) +win_length: 1200 # Window length. (in samples) + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 7600 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 128 # Number of input channels. + aux_channels: 80 + channels: 64 # Initial number of channels for conv layers. + out_channels: 1 # Number of output channels. + kernel_size: 9 # Kernel size of initial and final conv layers. + dilation: 2 + bias: True + noise_upsample_scales: [10, 2, 2, 2] + noise_upsample_activation: "leakyrelu" + noise_upsample_activation_params: + negative_slope: 0.2 + upsample_scales: [5, 1, 5, 1, 3, 1, 2, 2, 1] # List of Upsampling scales. prod(upsample_scales) == n_shift + upsample_mode: "nearest" + gated_function: "softmax" + use_weight_norm: True # Whether to use weight normalization. + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + repeats: 4 + window_sizes: [512, 1024, 2048, 4096] + pqmf_params: + - [1, None, None, None] + - [2, 62, 0.26700, 9.0] + - [4, 62, 0.14200, 9.0] + - [8, 62, 0.07949, 9.0] + discriminator_params: + out_channels: 1 # Number of output channels. + kernel_sizes: [5, 3] # List of kernel size. + channels: 16 # Number of channels of the initial conv layer. + max_downsample_channels: 512 # Maximum number of channels of downsampling layers. + bias: True + downsample_scales: [4, 4, 4, 1] # List of downsampling scales. + nonlinear_activation: "leakyrelu" # Nonlinear activation function. + nonlinear_activation_params: # Parameters of nonlinear activation function. + negative_slope: 0.2 + use_weight_norm: True # Whether to use weight norm. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: true +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss +lambda_aux: 1.0 # Loss balancing coefficient for aux loss. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_adv: 1.0 # Loss balancing coefficient for adv loss. +generator_adv_loss_params: + average_by_discriminators: false # Whether to average loss by #discriminators. +discriminator_adv_loss_params: + average_by_discriminators: false # Whether to average loss by #discriminators. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 32 # Batch size. +# batch_max_steps(24000) == prod(noise_upsample_scales)(80) * prod(upsample_scales)(300, n_shift) +batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift. +num_workers: 2 # Number of workers in Pytorch DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 1.0e-4 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 100000 + - 300000 + - 500000 + - 700000 + - 900000 +generator_grad_norm: -1 # Generator's gradient norm. +discriminator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 2.0e-4 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_grad_norm: -1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator. +train_max_steps: 1500000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/examples/csmsc/voc4/local/preprocess.sh b/examples/csmsc/voc4/local/preprocess.sh new file mode 100755 index 00000000..61d6d62b --- /dev/null +++ b/examples/csmsc/voc4/local/preprocess.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./baker_alignment_tone \ + --output=durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/../preprocess.py \ + --rootdir=~/datasets/BZNSYP/ \ + --dataset=baker \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --cut-sil=True \ + --num-cpu=20 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="feats" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize, dev and test should use train's stats + echo "Normalize ..." + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --stats=dump/train/feats_stats.npy + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --stats=dump/train/feats_stats.npy + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --stats=dump/train/feats_stats.npy +fi diff --git a/examples/csmsc/voc4/local/synthesize.sh b/examples/csmsc/voc4/local/synthesize.sh new file mode 100755 index 00000000..527e5f83 --- /dev/null +++ b/examples/csmsc/voc4/local/synthesize.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/../synthesize.py \ + --config=${config_path} \ + --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --test-metadata=dump/test/norm/metadata.jsonl \ + --output-dir=${train_output_path}/test \ + --generator-type=style_melgan diff --git a/examples/csmsc/voc4/local/train.sh b/examples/csmsc/voc4/local/train.sh new file mode 100755 index 00000000..9695631e --- /dev/null +++ b/examples/csmsc/voc4/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +FLAGS_cudnn_exhaustive_search=true \ +FLAGS_conv_workspace_size_limit=4000 \ +python ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=1 diff --git a/examples/csmsc/voc4/path.sh b/examples/csmsc/voc4/path.sh new file mode 100755 index 00000000..f68ea3be --- /dev/null +++ b/examples/csmsc/voc4/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=style_melgan +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL} \ No newline at end of file diff --git a/examples/csmsc/voc4/run.sh b/examples/csmsc/voc4/run.sh new file mode 100755 index 00000000..3e7d7e2a --- /dev/null +++ b/examples/csmsc/voc4/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_50000.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/examples/ljspeech/voc1/README.md b/examples/ljspeech/voc1/README.md index 3830156f..5c556124 100644 --- a/examples/ljspeech/voc1/README.md +++ b/examples/ljspeech/voc1/README.md @@ -95,7 +95,7 @@ benchmark: 4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ### Synthesizing -`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. ```bash CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} ``` diff --git a/examples/ljspeech/voc1/conf/default.yaml b/examples/ljspeech/voc1/conf/default.yaml index 2edec3b9..fb97ea8e 100644 --- a/examples/ljspeech/voc1/conf/default.yaml +++ b/examples/ljspeech/voc1/conf/default.yaml @@ -35,7 +35,7 @@ generator_params: dropout: 0.0 # Dropout rate. 0.0 means no dropout applied. use_weight_norm: true # Whether to use weight norm. # If set to true, it will be applied to all of the conv layers. - upsample_scales: [4, 4, 4, 4] # Upsampling scales. Prodcut of these must be the same as hop size. + upsample_scales: [4, 4, 4, 4] # Upsampling scales. prod(upsample_scales) == n_shift ########################################################### # DISCRIMINATOR NETWORK ARCHITECTURE SETTING # @@ -60,7 +60,7 @@ stft_loss_params: fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. - window: "hann" # Window function for STFT-based loss + window: "hann" # Window function for STFT-based loss ########################################################### # ADVERSARIAL LOSS SETTING # @@ -71,7 +71,7 @@ lambda_adv: 4.0 # Loss balancing coefficient. # DATA LOADER SETTING # ########################################################### batch_size: 8 # Batch size. -batch_max_steps: 25600 # Length of each audio in batch. Make sure dividable by hop_size. +batch_max_steps: 25600 # Length of each audio in batch. Make sure dividable by n_shift. pin_memory: true # Whether to pin memory in Pytorch DataLoader. num_workers: 4 # Number of workers in Pytorch DataLoader. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. @@ -84,20 +84,20 @@ generator_optimizer_params: epsilon: 1.0e-6 # Generator's epsilon. weight_decay: 0.0 # Generator's weight decay coefficient. generator_scheduler_params: - learning_rate: 0.0001 # Generator's learning rate. + learning_rate: 0.0001 # Generator's learning rate. step_size: 200000 # Generator's scheduler step size. gamma: 0.5 # Generator's scheduler gamma. # At each step size, lr will be multiplied by this parameter. generator_grad_norm: 10 # Generator's gradient norm. discriminator_optimizer_params: - epsilon: 1.0e-6 # Discriminator's epsilon. - weight_decay: 0.0 # Discriminator's weight decay coefficient. + epsilon: 1.0e-6 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. discriminator_scheduler_params: - learning_rate: 0.00005 # Discriminator's learning rate. - step_size: 200000 # Discriminator's scheduler step size. - gamma: 0.5 # Discriminator's scheduler gamma. - # At each step size, lr will be multiplied by this parameter. -discriminator_grad_norm: 1 # Discriminator's gradient norm. + learning_rate: 0.00005 # Discriminator's learning rate. + step_size: 200000 # Discriminator's scheduler step size. + gamma: 0.5 # Discriminator's scheduler gamma. + # At each step size, lr will be multiplied by this parameter. +discriminator_grad_norm: 1 # Discriminator's gradient norm. ########################################################### # INTERVAL SETTING # diff --git a/examples/ljspeech/voc1/local/synthesize.sh b/examples/ljspeech/voc1/local/synthesize.sh index 9f904ac0..d85d1b1d 100755 --- a/examples/ljspeech/voc1/local/synthesize.sh +++ b/examples/ljspeech/voc1/local/synthesize.sh @@ -6,8 +6,9 @@ ckpt_name=$3 FLAGS_allocator_strategy=naive_best_fit \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/synthesize.py \ +python3 ${BIN_DIR}/../synthesize.py \ --config=${config_path} \ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ --test-metadata=dump/test/norm/metadata.jsonl \ - --output-dir=${train_output_path}/test + --output-dir=${train_output_path}/test \ + --generator-type=pwgan diff --git a/examples/vctk/voc1/README.md b/examples/vctk/voc1/README.md index 6aa311fb..6d7b3256 100644 --- a/examples/vctk/voc1/README.md +++ b/examples/vctk/voc1/README.md @@ -100,7 +100,7 @@ benchmark: 4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu. ### Synthesizing -`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. ```bash CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} ``` diff --git a/examples/vctk/voc1/conf/default.yaml b/examples/vctk/voc1/conf/default.yaml index ba2d9f2e..eb6d350d 100644 --- a/examples/vctk/voc1/conf/default.yaml +++ b/examples/vctk/voc1/conf/default.yaml @@ -35,7 +35,7 @@ generator_params: dropout: 0.0 # Dropout rate. 0.0 means no dropout applied. use_weight_norm: true # Whether to use weight norm. # If set to true, it will be applied to all of the conv layers. - upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size. + upsample_scales: [4, 5, 3, 5] # Upsampling scales. prod(upsample_scales) == n_shift ########################################################### # DISCRIMINATOR NETWORK ARCHITECTURE SETTING # @@ -71,7 +71,7 @@ lambda_adv: 4.0 # Loss balancing coefficient. # DATA LOADER SETTING # ########################################################### batch_size: 8 # Batch size. -batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by hop_size. +batch_max_steps: 24000 # Length of each audio in batch. Make sure dividable by n_shift. pin_memory: true # Whether to pin memory in Pytorch DataLoader. num_workers: 4 # Number of workers in Pytorch DataLoader. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. diff --git a/examples/vctk/voc1/local/synthesize.sh b/examples/vctk/voc1/local/synthesize.sh index 9f904ac0..d85d1b1d 100755 --- a/examples/vctk/voc1/local/synthesize.sh +++ b/examples/vctk/voc1/local/synthesize.sh @@ -6,8 +6,9 @@ ckpt_name=$3 FLAGS_allocator_strategy=naive_best_fit \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \ -python3 ${BIN_DIR}/synthesize.py \ +python3 ${BIN_DIR}/../synthesize.py \ --config=${config_path} \ --checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ --test-metadata=dump/test/norm/metadata.jsonl \ - --output-dir=${train_output_path}/test + --output-dir=${train_output_path}/test \ + --generator-type=pwgan diff --git a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize.py b/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize.py deleted file mode 100644 index f275ed44..00000000 --- a/paddlespeech/t2s/exps/gan_vocoder/parallelwave_gan/synthesize.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import os -from pathlib import Path - -import jsonlines -import numpy as np -import paddle -import soundfile as sf -import yaml -from paddle import distributed as dist -from timer import timer -from yacs.config import CfgNode - -from paddlespeech.t2s.datasets.data_table import DataTable -from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator - - -def main(): - parser = argparse.ArgumentParser( - description="Synthesize with parallel wavegan.") - parser.add_argument( - "--config", type=str, help="parallel wavegan config file.") - parser.add_argument("--checkpoint", type=str, help="snapshot to load.") - parser.add_argument("--test-metadata", type=str, help="dev data.") - parser.add_argument("--output-dir", type=str, help="output dir.") - parser.add_argument( - "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") - parser.add_argument("--verbose", type=int, default=1, help="verbose.") - - args = parser.parse_args() - - with open(args.config) as f: - config = CfgNode(yaml.safe_load(f)) - - print("========Args========") - print(yaml.safe_dump(vars(args))) - print("========Config========") - print(config) - print( - f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" - ) - - if args.ngpu == 0: - paddle.set_device("cpu") - elif args.ngpu > 0: - paddle.set_device("gpu") - else: - print("ngpu should >= 0 !") - generator = PWGGenerator(**config["generator_params"]) - state_dict = paddle.load(args.checkpoint) - generator.set_state_dict(state_dict["generator_params"]) - - generator.remove_weight_norm() - generator.eval() - with jsonlines.open(args.test_metadata, 'r') as reader: - metadata = list(reader) - - test_dataset = DataTable( - metadata, - fields=['utt_id', 'feats'], - converters={ - 'utt_id': None, - 'feats': np.load, - }) - output_dir = Path(args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - N = 0 - T = 0 - for example in test_dataset: - utt_id = example['utt_id'] - mel = example['feats'] - mel = paddle.to_tensor(mel) # (T, C) - with timer() as t: - with paddle.no_grad(): - wav = generator.inference(c=mel) - wav = wav.numpy() - N += wav.size - T += t.elapse - speed = wav.size / t.elapse - rtf = config.fs / speed - print( - f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." - ) - sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) - print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") - - -if __name__ == "__main__": - main() diff --git a/paddlespeech/t2s/exps/gan_vocoder/style_melgan/__init__.py b/paddlespeech/t2s/exps/gan_vocoder/style_melgan/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/paddlespeech/t2s/exps/gan_vocoder/style_melgan/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py b/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py new file mode 100644 index 00000000..bc746467 --- /dev/null +++ b/paddlespeech/t2s/exps/gan_vocoder/style_melgan/train.py @@ -0,0 +1,258 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle import nn +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.optimizer import Adam +from paddle.optimizer.lr import MultiStepDecay +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.datasets.vocoder_batch_fn import Clip +from paddlespeech.t2s.models.melgan import StyleMelGANDiscriminator +from paddlespeech.t2s.models.melgan import StyleMelGANEvaluator +from paddlespeech.t2s.models.melgan import StyleMelGANGenerator +from paddlespeech.t2s.models.melgan import StyleMelGANUpdater +from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss +from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss +from paddlespeech.t2s.modules.losses import MultiResolutionSTFTLoss +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + world_size = paddle.distributed.get_world_size() + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + dev_sampler = DistributedBatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + print("samplers done!") + + if "aux_context_window" in config.generator_params: + aux_context_window = config.generator_params.aux_context_window + else: + aux_context_window = 0 + train_batch_fn = Clip( + batch_max_steps=config.batch_max_steps, + hop_size=config.n_shift, + aux_context_window=aux_context_window) + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + generator = StyleMelGANGenerator(**config["generator_params"]) + discriminator = StyleMelGANDiscriminator(**config["discriminator_params"]) + if world_size > 1: + generator = DataParallel(generator) + discriminator = DataParallel(discriminator) + print("models done!") + criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"]) + + criterion_gen_adv = GeneratorAdversarialLoss( + **config["generator_adv_loss_params"]) + criterion_dis_adv = DiscriminatorAdversarialLoss( + **config["discriminator_adv_loss_params"]) + print("criterions done!") + + lr_schedule_g = MultiStepDecay(**config["generator_scheduler_params"]) + # Compared to multi_band_melgan.v1 config, Adam optimizer without gradient norm is used + generator_grad_norm = config["generator_grad_norm"] + gradient_clip_g = nn.ClipGradByGlobalNorm( + generator_grad_norm) if generator_grad_norm > 0 else None + print("gradient_clip_g:", gradient_clip_g) + + optimizer_g = Adam( + learning_rate=lr_schedule_g, + grad_clip=gradient_clip_g, + parameters=generator.parameters(), + **config["generator_optimizer_params"]) + lr_schedule_d = MultiStepDecay(**config["discriminator_scheduler_params"]) + discriminator_grad_norm = config["discriminator_grad_norm"] + gradient_clip_d = nn.ClipGradByGlobalNorm( + discriminator_grad_norm) if discriminator_grad_norm > 0 else None + print("gradient_clip_d:", gradient_clip_d) + optimizer_d = Adam( + learning_rate=lr_schedule_d, + grad_clip=gradient_clip_d, + parameters=discriminator.parameters(), + **config["discriminator_optimizer_params"]) + print("optimizers done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = StyleMelGANUpdater( + models={ + "generator": generator, + "discriminator": discriminator, + }, + optimizers={ + "generator": optimizer_g, + "discriminator": optimizer_d, + }, + criterions={ + "stft": criterion_stft, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + }, + schedulers={ + "generator": lr_schedule_g, + "discriminator": lr_schedule_d, + }, + dataloader=train_dataloader, + discriminator_train_start_steps=config.discriminator_train_start_steps, + lambda_adv=config.lambda_adv, + output_dir=output_dir) + + evaluator = StyleMelGANEvaluator( + models={ + "generator": generator, + "discriminator": discriminator, + }, + criterions={ + "stft": criterion_stft, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + }, + dataloader=dev_dataloader, + lambda_adv=config.lambda_adv, + output_dir=output_dir) + + trainer = Trainer( + updater, + stop_trigger=(config.train_max_steps, "iteration"), + out=output_dir) + + if dist.get_rank() == 0: + trainer.extend( + evaluator, trigger=(config.eval_interval_steps, 'iteration')) + trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration')) + trainer.extend( + Snapshot(max_size=config.num_snapshots), + trigger=(config.save_interval_steps, 'iteration')) + + print("Trainer Done!") + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + + parser = argparse.ArgumentParser( + description="Train a Multi-Band MelGAN model.") + parser.add_argument( + "--config", type=str, help="config file to overwrite default config.") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument("--verbose", type=int, default=1, help="verbose.") + + args = parser.parse_args() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.ngpu > 1: + dist.spawn(train_sp, (args, config), nprocs=args.ngpu) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/synthesize.py b/paddlespeech/t2s/exps/gan_vocoder/synthesize.py similarity index 79% rename from paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/synthesize.py rename to paddlespeech/t2s/exps/gan_vocoder/synthesize.py index 988d4590..d7fd2f94 100644 --- a/paddlespeech/t2s/exps/gan_vocoder/multi_band_melgan/synthesize.py +++ b/paddlespeech/t2s/exps/gan_vocoder/synthesize.py @@ -24,15 +24,19 @@ from paddle import distributed as dist from timer import timer from yacs.config import CfgNode +import paddlespeech from paddlespeech.t2s.datasets.data_table import DataTable -from paddlespeech.t2s.models.melgan import MelGANGenerator def main(): - parser = argparse.ArgumentParser( - description="Synthesize with multi band melgan.") + parser = argparse.ArgumentParser(description="Synthesize with GANVocoder.") parser.add_argument( - "--config", type=str, help="multi band melgan config file.") + "--generator-type", + type=str, + default="pwgan", + help="type of GANVocoder, should in {pwgan, mb_melgan, style_melgan, } now" + ) + parser.add_argument("--config", type=str, help="GANVocoder config file.") parser.add_argument("--checkpoint", type=str, help="snapshot to load.") parser.add_argument("--test-metadata", type=str, help="dev data.") parser.add_argument("--output-dir", type=str, help="output dir.") @@ -59,15 +63,29 @@ def main(): paddle.set_device("gpu") else: print("ngpu should >= 0 !") - generator = MelGANGenerator(**config["generator_params"]) + + class_map = { + "hifigan": "HiFiGANGenerator", + "mb_melgan": "MelGANGenerator", + "pwgan": "PWGGenerator", + "style_melgan": "StyleMelGANGenerator", + } + + generator_type = args.generator_type + + assert generator_type in class_map + + print("generator_type:", generator_type) + + generator_class = getattr(paddlespeech.t2s.models, + class_map[generator_type]) + generator = generator_class(**config["generator_params"]) state_dict = paddle.load(args.checkpoint) generator.set_state_dict(state_dict["generator_params"]) - generator.remove_weight_norm() generator.eval() with jsonlines.open(args.test_metadata, 'r') as reader: metadata = list(reader) - test_dataset = DataTable( metadata, fields=['utt_id', 'feats'], diff --git a/paddlespeech/t2s/models/__init__.py b/paddlespeech/t2s/models/__init__.py index 66720649..601bd9d6 100644 --- a/paddlespeech/t2s/models/__init__.py +++ b/paddlespeech/t2s/models/__init__.py @@ -14,6 +14,7 @@ from .fastspeech2 import * from .melgan import * from .parallel_wavegan import * +from .speedyspeech import * from .tacotron2 import * from .transformer_tts import * from .waveflow import * diff --git a/paddlespeech/t2s/models/melgan/__init__.py b/paddlespeech/t2s/models/melgan/__init__.py index d4f557db..df8ccd92 100644 --- a/paddlespeech/t2s/models/melgan/__init__.py +++ b/paddlespeech/t2s/models/melgan/__init__.py @@ -13,3 +13,5 @@ # limitations under the License. from .melgan import * from .multi_band_melgan_updater import * +from .style_melgan import * +from .style_melgan_updater import * diff --git a/paddlespeech/t2s/models/melgan/melgan.py b/paddlespeech/t2s/models/melgan/melgan.py index 809403f6..32fcf658 100644 --- a/paddlespeech/t2s/models/melgan/melgan.py +++ b/paddlespeech/t2s/models/melgan/melgan.py @@ -21,6 +21,7 @@ import numpy as np import paddle from paddle import nn +from paddlespeech.t2s.modules.activation import get_activation from paddlespeech.t2s.modules.causal_conv import CausalConv1D from paddlespeech.t2s.modules.causal_conv import CausalConv1DTranspose from paddlespeech.t2s.modules.nets_utils import initialize @@ -41,7 +42,7 @@ class MelGANGenerator(nn.Layer): upsample_scales: List[int]=[8, 8, 2, 2], stack_kernel_size: int=3, stacks: int=3, - nonlinear_activation: str="LeakyReLU", + nonlinear_activation: str="leakyrelu", nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, pad: str="Pad1D", pad_params: Dict[str, Any]={"mode": "reflect"}, @@ -88,16 +89,18 @@ class MelGANGenerator(nn.Layer): """ super().__init__() + # initialize parameters + initialize(self, init_type) + + # for compatibility + nonlinear_activation = nonlinear_activation.lower() + # check hyper parameters is valid assert channels >= np.prod(upsample_scales) assert channels % (2**len(upsample_scales)) == 0 if not use_causal_conv: assert (kernel_size - 1 ) % 2 == 0, "Not support even number kernel size." - - # initialize parameters - initialize(self, init_type) - layers = [] if not use_causal_conv: layers += [ @@ -118,7 +121,8 @@ class MelGANGenerator(nn.Layer): for i, upsample_scale in enumerate(upsample_scales): # add upsampling layer layers += [ - getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + get_activation(nonlinear_activation, + **nonlinear_activation_params) ] if not use_causal_conv: layers += [ @@ -158,7 +162,7 @@ class MelGANGenerator(nn.Layer): # add final layer layers += [ - getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + get_activation(nonlinear_activation, **nonlinear_activation_params) ] if not use_causal_conv: layers += [ @@ -242,7 +246,6 @@ class MelGANGenerator(nn.Layer): This initialization follows official implementation manner. https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py """ - # 定义参数为float的正态分布。 dist = paddle.distribution.Normal(loc=0.0, scale=0.02) @@ -287,10 +290,11 @@ class MelGANDiscriminator(nn.Layer): max_downsample_channels: int=1024, bias: bool=True, downsample_scales: List[int]=[4, 4, 4, 4], - nonlinear_activation: str="LeakyReLU", + nonlinear_activation: str="leakyrelu", nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, pad: str="Pad1D", - pad_params: Dict[str, Any]={"mode": "reflect"}, ): + pad_params: Dict[str, Any]={"mode": "reflect"}, + init_type: str="xavier_uniform", ): """Initilize MelGAN discriminator module. Parameters ---------- @@ -321,6 +325,13 @@ class MelGANDiscriminator(nn.Layer): Hyperparameters for padding function. """ super().__init__() + + # for compatibility + nonlinear_activation = nonlinear_activation.lower() + + # initialize parameters + initialize(self, init_type) + self.layers = nn.LayerList() # check kernel size is valid @@ -338,8 +349,8 @@ class MelGANDiscriminator(nn.Layer): channels, int(np.prod(kernel_sizes)), bias_attr=bias), - getattr(nn, nonlinear_activation)( - **nonlinear_activation_params), )) + get_activation(nonlinear_activation, ** + nonlinear_activation_params), )) # add downsample layers in_chs = channels @@ -355,8 +366,8 @@ class MelGANDiscriminator(nn.Layer): padding=downsample_scale * 5, groups=in_chs // 4, bias_attr=bias, ), - getattr(nn, nonlinear_activation)( - **nonlinear_activation_params), )) + get_activation(nonlinear_activation, ** + nonlinear_activation_params), )) in_chs = out_chs # add final layers @@ -369,8 +380,8 @@ class MelGANDiscriminator(nn.Layer): kernel_sizes[0], padding=(kernel_sizes[0] - 1) // 2, bias_attr=bias, ), - getattr(nn, nonlinear_activation)( - **nonlinear_activation_params), )) + get_activation(nonlinear_activation, ** + nonlinear_activation_params), )) self.layers.append( nn.Conv1D( out_chs, @@ -419,7 +430,7 @@ class MelGANMultiScaleDiscriminator(nn.Layer): max_downsample_channels: int=1024, bias: bool=True, downsample_scales: List[int]=[4, 4, 4, 4], - nonlinear_activation: str="LeakyReLU", + nonlinear_activation: str="leakyrelu", nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, pad: str="Pad1D", pad_params: Dict[str, Any]={"mode": "reflect"}, @@ -461,9 +472,13 @@ class MelGANMultiScaleDiscriminator(nn.Layer): Whether to use causal convolution. """ super().__init__() + # initialize parameters initialize(self, init_type) + # for compatibility + nonlinear_activation = nonlinear_activation.lower() + self.discriminators = nn.LayerList() # add discriminators diff --git a/paddlespeech/t2s/models/melgan/multi_band_melgan_updater.py b/paddlespeech/t2s/models/melgan/multi_band_melgan_updater.py index a5d4cdeb..75e99627 100644 --- a/paddlespeech/t2s/models/melgan/multi_band_melgan_updater.py +++ b/paddlespeech/t2s/models/melgan/multi_band_melgan_updater.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from pathlib import Path from typing import Dict import paddle @@ -41,7 +42,7 @@ class MBMelGANUpdater(StandardUpdater): dataloader: DataLoader, discriminator_train_start_steps: int, lambda_adv: float, - output_dir=None): + output_dir: Path=None): self.models = models self.generator: Layer = models['generator'] self.discriminator: Layer = models['discriminator'] @@ -159,11 +160,11 @@ class MBMelGANUpdater(StandardUpdater): class MBMelGANEvaluator(StandardEvaluator): def __init__(self, - models, - criterions, - dataloader, - lambda_adv, - output_dir=None): + models: Dict[str, Layer], + criterions: Dict[str, Layer], + dataloader: DataLoader, + lambda_adv: float, + output_dir: Path=None): self.models = models self.generator = models['generator'] self.discriminator = models['discriminator'] diff --git a/paddlespeech/t2s/models/melgan/style_melgan.py b/paddlespeech/t2s/models/melgan/style_melgan.py new file mode 100644 index 00000000..4725a8d0 --- /dev/null +++ b/paddlespeech/t2s/models/melgan/style_melgan.py @@ -0,0 +1,404 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +"""StyleMelGAN Modules.""" +import copy +import math +from typing import Any +from typing import Dict +from typing import List + +import numpy as np +import paddle +import paddle.nn.functional as F +from paddle import nn + +from paddlespeech.t2s.models.melgan import MelGANDiscriminator as BaseDiscriminator +from paddlespeech.t2s.modules.activation import get_activation +from paddlespeech.t2s.modules.nets_utils import initialize +from paddlespeech.t2s.modules.pqmf import PQMF +from paddlespeech.t2s.modules.tade_res_block import TADEResBlock + + +class StyleMelGANGenerator(nn.Layer): + """Style MelGAN generator module.""" + + def __init__( + self, + in_channels: int=128, + aux_channels: int=80, + channels: int=64, + out_channels: int=1, + kernel_size: int=9, + dilation: int=2, + bias: bool=True, + noise_upsample_scales: List[int]=[11, 2, 2, 2], + noise_upsample_activation: str="leakyrelu", + noise_upsample_activation_params: Dict[str, + Any]={"negative_slope": 0.2}, + upsample_scales: List[int]=[2, 2, 2, 2, 2, 2, 2, 2, 1], + upsample_mode: str="linear", + gated_function: str="softmax", + use_weight_norm: bool=True, + init_type: str="xavier_uniform", ): + """Initilize Style MelGAN generator. + Parameters + ---------- + in_channels : int + Number of input noise channels. + aux_channels : int + Number of auxiliary input channels. + channels : int + Number of channels for conv layer. + out_channels : int + Number of output channels. + kernel_size : int + Kernel size of conv layers. + dilation : int + Dilation factor for conv layers. + bias : bool + Whether to add bias parameter in convolution layers. + noise_upsample_scales : list + List of noise upsampling scales. + noise_upsample_activation : str + Activation function module name for noise upsampling. + noise_upsample_activation_params : dict + Hyperparameters for the above activation function. + upsample_scales : list + List of upsampling scales. + upsample_mode : str + Upsampling mode in TADE layer. + gated_function : str + Gated function in TADEResBlock ("softmax" or "sigmoid"). + use_weight_norm : bool + Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + """ + super().__init__() + + # initialize parameters + initialize(self, init_type) + + self.in_channels = in_channels + noise_upsample = [] + in_chs = in_channels + for noise_upsample_scale in noise_upsample_scales: + noise_upsample.append( + nn.Conv1DTranspose( + in_chs, + channels, + noise_upsample_scale * 2, + stride=noise_upsample_scale, + padding=noise_upsample_scale // 2 + noise_upsample_scale % + 2, + output_padding=noise_upsample_scale % 2, + bias_attr=bias, )) + noise_upsample.append( + get_activation(noise_upsample_activation, ** + noise_upsample_activation_params)) + in_chs = channels + self.noise_upsample = nn.Sequential(*noise_upsample) + self.noise_upsample_factor = np.prod(noise_upsample_scales) + + self.blocks = nn.LayerList() + aux_chs = aux_channels + for upsample_scale in upsample_scales: + self.blocks.append( + TADEResBlock( + in_channels=channels, + aux_channels=aux_chs, + kernel_size=kernel_size, + dilation=dilation, + bias=bias, + upsample_factor=upsample_scale, + upsample_mode=upsample_mode, + gated_function=gated_function, ), ) + aux_chs = channels + self.upsample_factor = np.prod(upsample_scales) + + self.output_conv = nn.Sequential( + nn.Conv1D( + channels, + out_channels, + kernel_size, + 1, + bias_attr=bias, + padding=(kernel_size - 1) // 2, ), + nn.Tanh(), ) + + nn.initializer.set_global_initializer(None) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, c, z=None): + """Calculate forward propagation. + Parameters + ---------- + c : Tensor + Auxiliary input tensor (B, channels, T). + z : Tensor + Input noise tensor (B, in_channels, 1). + Returns + ---------- + Tensor + Output tensor (B, out_channels, T ** prod(upsample_scales)). + """ + # batch_max_steps(24000) == noise_upsample_factor(80) * upsample_factor(300) + if z is None: + z = paddle.randn([paddle.shape(c)[0], self.in_channels, 1]) + # (B, in_channels, noise_upsample_factor). + x = self.noise_upsample(z) + for block in self.blocks: + x, c = block(x, c) + x = self.output_conv(x) + return x + + def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv1DTranspose)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + + def _remove_weight_norm(layer): + try: + if layer: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) + + def reset_parameters(self): + """Reset parameters. + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py + """ + # 定义参数为float的正态分布。 + dist = paddle.distribution.Normal(loc=0.0, scale=0.02) + + def _reset_parameters(m): + if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): + w = dist.sample(m.weight.shape) + m.weight.set_value(w) + + self.apply(_reset_parameters) + + def inference(self, c): + """Perform inference. + Parameters + ---------- + c : Tensor + Input tensor (T, in_channels). + Returns + ---------- + Tensor + Output tensor (T ** prod(upsample_scales), out_channels). + """ + # (1, in_channels, T) + c = c.transpose([1, 0]).unsqueeze(0) + c_shape = paddle.shape(c) + # prepare noise input + # there is a bug in Paddle int division, we must convert a int tensor to int here + noise_size = (1, self.in_channels, + math.ceil(int(c_shape[2]) / self.noise_upsample_factor)) + # (1, in_channels, T/noise_upsample_factor) + noise = paddle.randn(noise_size) + # (1, in_channels, T) + x = self.noise_upsample(noise) + x_shape = paddle.shape(x) + total_length = c_shape[2] * self.upsample_factor + c = F.pad( + c, (0, x_shape[2] - c_shape[2]), "replicate", data_format="NCL") + # c.shape[2] == x.shape[2] here + # (1, in_channels, T*prod(upsample_scales)) + for block in self.blocks: + x, c = block(x, c) + x = self.output_conv(x)[..., :total_length] + return x.squeeze(0).transpose([1, 0]) + + +# StyleMelGANDiscriminator 不需要 remove weight norm 嘛? +class StyleMelGANDiscriminator(nn.Layer): + """Style MelGAN disciminator module.""" + + def __init__( + self, + repeats: int=2, + window_sizes: List[int]=[512, 1024, 2048, 4096], + pqmf_params: List[List[int]]=[ + [1, None, None, None], + [2, 62, 0.26700, 9.0], + [4, 62, 0.14200, 9.0], + [8, 62, 0.07949, 9.0], + ], + discriminator_params: Dict[str, Any]={ + "out_channels": 1, + "kernel_sizes": [5, 3], + "channels": 16, + "max_downsample_channels": 512, + "bias": True, + "downsample_scales": [4, 4, 4, 1], + "nonlinear_activation": "leakyrelu", + "nonlinear_activation_params": { + "negative_slope": 0.2 + }, + "pad": "Pad1D", + "pad_params": { + "mode": "reflect" + }, + }, + use_weight_norm: bool=True, + init_type: str="xavier_uniform", ): + """Initilize Style MelGAN discriminator. + Parameters + ---------- + repeats : int + Number of repititons to apply RWD. + window_sizes : list + List of random window sizes. + pqmf_params : list + List of list of Parameters for PQMF modules + discriminator_params : dict + Parameters for base discriminator module. + use_weight_nom : bool + Whether to apply weight normalization. + """ + super().__init__() + + # initialize parameters + initialize(self, init_type) + + # window size check + assert len(window_sizes) == len(pqmf_params) + sizes = [ws // p[0] for ws, p in zip(window_sizes, pqmf_params)] + assert len(window_sizes) == sum([sizes[0] == size for size in sizes]) + + self.repeats = repeats + self.window_sizes = window_sizes + self.pqmfs = nn.LayerList() + self.discriminators = nn.LayerList() + for pqmf_param in pqmf_params: + d_params = copy.deepcopy(discriminator_params) + d_params["in_channels"] = pqmf_param[0] + if pqmf_param[0] == 1: + self.pqmfs.append(nn.Identity()) + else: + self.pqmfs.append(PQMF(*pqmf_param)) + self.discriminators.append(BaseDiscriminator(**d_params)) + + nn.initializer.set_global_initializer(None) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, x): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input tensor (B, 1, T). + Returns + ---------- + List + List of discriminator outputs, #items in the list will be + equal to repeats * #discriminators. + """ + outs = [] + for _ in range(self.repeats): + outs += self._forward(x) + return outs + + def _forward(self, x): + outs = [] + for idx, (ws, pqmf, disc) in enumerate( + zip(self.window_sizes, self.pqmfs, self.discriminators)): + start_idx = int(np.random.randint(paddle.shape(x)[-1] - ws)) + x_ = x[:, :, start_idx:start_idx + ws] + if idx == 0: + # nn.Identity() + x_ = pqmf(x_) + else: + x_ = pqmf.analysis(x_) + outs += [disc(x_)] + return outs + + def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv1DTranspose)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) + + def reset_parameters(self): + """Reset parameters. + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/mel2wav/modules.py + """ + # 定义参数为float的正态分布。 + dist = paddle.distribution.Normal(loc=0.0, scale=0.02) + + def _reset_parameters(m): + if isinstance(m, nn.Conv1D) or isinstance(m, nn.Conv1DTranspose): + w = dist.sample(m.weight.shape) + m.weight.set_value(w) + + self.apply(_reset_parameters) + + +class StyleMelGANInference(nn.Layer): + def __init__(self, normalizer, style_melgan_generator): + super().__init__() + self.normalizer = normalizer + self.style_melgan_generator = style_melgan_generator + + def forward(self, logmel): + normalized_mel = self.normalizer(logmel) + wav = self.style_melgan_generator.inference(normalized_mel) + return wav diff --git a/paddlespeech/t2s/models/melgan/style_melgan_updater.py b/paddlespeech/t2s/models/melgan/style_melgan_updater.py new file mode 100644 index 00000000..b0cb4ed6 --- /dev/null +++ b/paddlespeech/t2s/models/melgan/style_melgan_updater.py @@ -0,0 +1,227 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from pathlib import Path +from typing import Dict + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator +from paddlespeech.t2s.training.reporter import report +from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater +from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class StyleMelGANUpdater(StandardUpdater): + def __init__(self, + models: Dict[str, Layer], + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + generator_train_start_steps: int=0, + discriminator_train_start_steps: int=100000, + lambda_adv: float=1.0, + lambda_aux: float=1.0, + output_dir: Path=None): + self.models = models + self.generator: Layer = models['generator'] + self.discriminator: Layer = models['discriminator'] + + self.optimizers = optimizers + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.generator_train_start_steps = generator_train_start_steps + self.discriminator_train_start_steps = discriminator_train_start_steps + self.lambda_adv = lambda_adv + self.lambda_aux = lambda_aux + + self.state = UpdaterState(iteration=0, epoch=0) + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + # parse batch + wav, mel = batch + + # Generator + if self.state.iteration > self.generator_train_start_steps: + # (B, out_channels, T ** prod(upsample_scales) + wav_ = self.generator(mel) + + # initialize + gen_loss = 0.0 + aux_loss = 0.0 + + # full band multi-resolution stft loss + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + aux_loss += sc_loss + mag_loss + report("train/spectral_convergence_loss", float(sc_loss)) + report("train/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + + gen_loss += aux_loss * self.lambda_aux + + # adversarial loss + if self.state.iteration > self.discriminator_train_start_steps: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_gen_adv(p_) + report("train/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + + gen_loss += self.lambda_adv * adv_loss + + report("train/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + self.optimizer_g.clear_grad() + gen_loss.backward() + + self.optimizer_g.step() + self.scheduler_g.step() + + # Disctiminator + if self.state.iteration > self.discriminator_train_start_steps: + # re-compute wav_ which leads better quality + with paddle.no_grad(): + wav_ = self.generator(mel) + + p = self.discriminator(wav) + p_ = self.discriminator(wav_.detach()) + real_loss, fake_loss = self.criterion_dis_adv(p_, p) + dis_loss = real_loss + fake_loss + report("train/real_loss", float(real_loss)) + report("train/fake_loss", float(fake_loss)) + report("train/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.optimizer_d.clear_grad() + dis_loss.backward() + + self.optimizer_d.step() + self.scheduler_d.step() + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class StyleMelGANEvaluator(StandardEvaluator): + def __init__(self, + models: Dict[str, Layer], + criterions: Dict[str, Layer], + dataloader: DataLoader, + lambda_adv: float=1.0, + lambda_aux: float=1.0, + output_dir: Path=None): + self.models = models + self.generator = models['generator'] + self.discriminator = models['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + + self.dataloader = dataloader + + self.lambda_adv = lambda_adv + self.lambda_aux = lambda_aux + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + self.msg = "Evaluate: " + losses_dict = {} + wav, mel = batch + + # Generator + # (B, out_channels, T ** prod(upsample_scales) + wav_ = self.generator(mel) + + # initialize + gen_loss = 0.0 + aux_loss = 0.0 + + # adversarial loss + p_ = self.discriminator(wav_) + adv_loss = self.criterion_gen_adv(p_) + report("eval/adversarial_loss", float(adv_loss)) + losses_dict["adversarial_loss"] = float(adv_loss) + + gen_loss += self.lambda_adv * adv_loss + + # multi-resolution stft loss + sc_loss, mag_loss = self.criterion_stft(wav_, wav) + aux_loss += sc_loss + mag_loss + report("eval/spectral_convergence_loss", float(sc_loss)) + report("eval/log_stft_magnitude_loss", float(mag_loss)) + losses_dict["spectral_convergence_loss"] = float(sc_loss) + losses_dict["log_stft_magnitude_loss"] = float(mag_loss) + + gen_loss += aux_loss * self.lambda_aux + + report("eval/generator_loss", float(gen_loss)) + losses_dict["generator_loss"] = float(gen_loss) + + # Disctiminator + p = self.discriminator(wav) + real_loss, fake_loss = self.criterion_dis_adv(p_, p) + dis_loss = real_loss + fake_loss + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + report("eval/discriminator_loss", float(dis_loss)) + + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py b/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py index 4e3daaa3..79707aa4 100644 --- a/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py +++ b/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from pathlib import Path from typing import Dict import paddle @@ -26,6 +27,7 @@ from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator from paddlespeech.t2s.training.reporter import report from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState + logging.basicConfig( format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', datefmt='[%Y-%m-%d %H:%M:%S]') @@ -42,7 +44,7 @@ class PWGUpdater(StandardUpdater): dataloader: DataLoader, discriminator_train_start_steps: int, lambda_adv: float, - output_dir=None): + output_dir: Path=None): self.models = models self.generator: Layer = models['generator'] self.discriminator: Layer = models['discriminator'] @@ -155,11 +157,11 @@ class PWGUpdater(StandardUpdater): class PWGEvaluator(StandardEvaluator): def __init__(self, - models, - criterions, - dataloader, - lambda_adv, - output_dir=None): + models: Dict[str, Layer], + criterions: Dict[str, Layer], + dataloader: DataLoader, + lambda_adv: float, + output_dir: Path=None): self.models = models self.generator = models['generator'] self.discriminator = models['discriminator'] diff --git a/paddlespeech/t2s/modules/activation.py b/paddlespeech/t2s/modules/activation.py index f5b0af6e..8d8cd62e 100644 --- a/paddlespeech/t2s/modules/activation.py +++ b/paddlespeech/t2s/modules/activation.py @@ -27,7 +27,7 @@ class GLU(nn.Layer): return F.glu(xs, axis=self.dim) -def get_activation(act): +def get_activation(act, **kwargs): """Return activation function.""" activation_funcs = { @@ -35,8 +35,9 @@ def get_activation(act): "tanh": paddle.nn.Tanh, "relu": paddle.nn.ReLU, "selu": paddle.nn.SELU, + "leakyrelu": paddle.nn.LeakyReLU, "swish": paddle.nn.Swish, "glu": GLU } - return activation_funcs[act]() + return activation_funcs[act](**kwargs) diff --git a/paddlespeech/t2s/modules/residual_stack.py b/paddlespeech/t2s/modules/residual_stack.py index 236f41d3..b4f95229 100644 --- a/paddlespeech/t2s/modules/residual_stack.py +++ b/paddlespeech/t2s/modules/residual_stack.py @@ -18,6 +18,7 @@ from typing import Dict from paddle import nn +from paddlespeech.t2s.modules.activation import get_activation from paddlespeech.t2s.modules.causal_conv import CausalConv1D @@ -30,7 +31,7 @@ class ResidualStack(nn.Layer): channels: int=32, dilation: int=1, bias: bool=True, - nonlinear_activation: str="LeakyReLU", + nonlinear_activation: str="leakyrelu", nonlinear_activation_params: Dict[str, Any]={"negative_slope": 0.2}, pad: str="Pad1D", pad_params: Dict[str, Any]={"mode": "reflect"}, @@ -58,14 +59,16 @@ class ResidualStack(nn.Layer): Whether to use causal convolution. """ super().__init__() + # for compatibility + nonlinear_activation = nonlinear_activation.lower() # defile residual stack part if not use_causal_conv: assert (kernel_size - 1 ) % 2 == 0, "Not support even number kernel size." self.stack = nn.Sequential( - getattr(nn, nonlinear_activation)( - **nonlinear_activation_params), + get_activation(nonlinear_activation, + **nonlinear_activation_params), getattr(nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params), nn.Conv1D( @@ -74,13 +77,13 @@ class ResidualStack(nn.Layer): kernel_size, dilation=dilation, bias_attr=bias), - getattr(nn, nonlinear_activation)( - **nonlinear_activation_params), + get_activation(nonlinear_activation, + **nonlinear_activation_params), nn.Conv1D(channels, channels, 1, bias_attr=bias), ) else: self.stack = nn.Sequential( - getattr(nn, nonlinear_activation)( - **nonlinear_activation_params), + get_activation(nonlinear_activation, + **nonlinear_activation_params), CausalConv1D( channels, channels, @@ -89,8 +92,8 @@ class ResidualStack(nn.Layer): bias=bias, pad=pad, pad_params=pad_params, ), - getattr(nn, nonlinear_activation)( - **nonlinear_activation_params), + get_activation(nonlinear_activation, + **nonlinear_activation_params), nn.Conv1D(channels, channels, 1, bias_attr=bias), ) # defile extra layer for skip connection diff --git a/paddlespeech/t2s/modules/style_encoder.py b/paddlespeech/t2s/modules/style_encoder.py index e76226f3..9d4b83a2 100644 --- a/paddlespeech/t2s/modules/style_encoder.py +++ b/paddlespeech/t2s/modules/style_encoder.py @@ -298,8 +298,8 @@ class MultiHeadedAttention(BaseMultiHeadedAttention): def __init__(self, q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0): """Initialize multi head attention module.""" - # NOTE(kan-bayashi): Do not use super().__init__() here since we want to - # overwrite BaseMultiHeadedAttention.__init__() method. + # Do not use super().__init__() here since we want to + # overwrite BaseMultiHeadedAttention.__init__() method. nn.Layer.__init__(self) assert n_feat % n_head == 0 # We assume d_v always equals d_k diff --git a/paddlespeech/t2s/modules/tade_res_block.py b/paddlespeech/t2s/modules/tade_res_block.py new file mode 100644 index 00000000..19b07639 --- /dev/null +++ b/paddlespeech/t2s/modules/tade_res_block.py @@ -0,0 +1,164 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +"""StyleMelGAN's TADEResBlock Modules.""" +from functools import partial + +import paddle.nn.functional as F +from paddle import nn + + +class TADELayer(nn.Layer): + """TADE Layer module.""" + + def __init__( + self, + in_channels: int=64, + aux_channels: int=80, + kernel_size: int=9, + bias: bool=True, + upsample_factor: int=2, + upsample_mode: str="nearest", ): + """Initilize TADE layer.""" + super().__init__() + self.norm = nn.InstanceNorm1D( + in_channels, momentum=0.1, data_format="NCL") + self.aux_conv = nn.Sequential( + nn.Conv1D( + aux_channels, + in_channels, + kernel_size, + 1, + bias_attr=bias, + padding=(kernel_size - 1) // 2, ), ) + self.gated_conv = nn.Sequential( + nn.Conv1D( + in_channels, + in_channels * 2, + kernel_size, + 1, + bias_attr=bias, + padding=(kernel_size - 1) // 2, ), ) + self.upsample = nn.Upsample( + scale_factor=upsample_factor, mode=upsample_mode) + + def forward(self, x, c): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input tensor (B, in_channels, T). + c : Tensor + Auxiliary input tensor (B, aux_channels, T). + Returns + ---------- + Tensor + Output tensor (B, in_channels, T * upsample_factor). + Tensor + Upsampled aux tensor (B, in_channels, T * upsample_factor). + """ + + x = self.norm(x) + # 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor. + c = self.upsample(c.unsqueeze(-1)) + c = c[:, :, :, 0] + + c = self.aux_conv(c) + cg = self.gated_conv(c) + cg1, cg2 = cg.split(2, axis=1) + # 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor. + y = cg1 * self.upsample(x.unsqueeze(-1))[:, :, :, 0] + cg2 + return y, c + + +class TADEResBlock(nn.Layer): + """TADEResBlock module.""" + + def __init__( + self, + in_channels: int=64, + aux_channels: int=80, + kernel_size: int=9, + dilation: int=2, + bias: bool=True, + upsample_factor: int=2, + # this is a diff in paddle, the mode only can be "linear" when input is 3D + upsample_mode: str="nearest", + gated_function: str="softmax", ): + """Initialize TADEResBlock module.""" + super().__init__() + self.tade1 = TADELayer( + in_channels=in_channels, + aux_channels=aux_channels, + kernel_size=kernel_size, + bias=bias, + upsample_factor=1, + upsample_mode=upsample_mode, ) + self.gated_conv1 = nn.Conv1D( + in_channels, + in_channels * 2, + kernel_size, + 1, + bias_attr=bias, + padding=(kernel_size - 1) // 2, ) + self.tade2 = TADELayer( + in_channels=in_channels, + aux_channels=in_channels, + kernel_size=kernel_size, + bias=bias, + upsample_factor=upsample_factor, + upsample_mode=upsample_mode, ) + self.gated_conv2 = nn.Conv1D( + in_channels, + in_channels * 2, + kernel_size, + 1, + bias_attr=bias, + dilation=dilation, + padding=(kernel_size - 1) // 2 * dilation, ) + self.upsample = nn.Upsample( + scale_factor=upsample_factor, mode=upsample_mode) + if gated_function == "softmax": + self.gated_function = partial(F.softmax, axis=1) + elif gated_function == "sigmoid": + self.gated_function = F.sigmoid + else: + raise ValueError(f"{gated_function} is not supported.") + + def forward(self, x, c): + """Calculate forward propagation. + Parameters + ---------- + x : Tensor + Input tensor (B, in_channels, T). + c : Tensor + Auxiliary input tensor (B, aux_channels, T). + Returns + ---------- + Tensor + Output tensor (B, in_channels, T * upsample_factor). + Tensor + Upsampled auxirialy tensor (B, in_channels, T * upsample_factor). + """ + residual = x + x, c = self.tade1(x, c) + x = self.gated_conv1(x) + xa, xb = x.split(2, axis=1) + x = self.gated_function(xa) * F.tanh(xb) + x, c = self.tade2(x, c) + x = self.gated_conv2(x) + xa, xb = x.split(2, axis=1) + x = self.gated_function(xa) * F.tanh(xb) + # 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor. + return self.upsample(residual.unsqueeze(-1))[:, :, :, 0] + x, c