Merge pull request #1549 from yt605155624/add_ljspeech_hifigan

[TTS]add ljspeech hifigan egs.
pull/1553/head
Hui Zhang 3 years ago committed by GitHub
commit 29f45f6e90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,4 +1,13 @@
# Changelog
Date: 2022-3-08, Author: yt605155624.
Add features to: T2S:
- Add aishell3 hifigan egs.
- PRLink: https://github.com/PaddlePaddle/PaddleSpeech/pull/1545
Date: 2022-3-08, Author: yt605155624.
Add features to: T2S:
- Add vctk hifigan egs.
- PRLink: https://github.com/PaddlePaddle/PaddleSpeech/pull/1544
Date: 2022-1-29, Author: yt605155624.
Add features to: T2S:

@ -0,0 +1,133 @@
# HiFiGAN with the LJSpeech-1.1
This example contains code used to train a [HiFiGAN](https://arxiv.org/abs/2010.05646) model with [LJSpeech-1.1](https://keithito.com/LJ-Speech-Dataset/).
## Dataset
### Download and Extract
Download LJSpeech-1.1 from the [official website](https://keithito.com/LJ-Speech-Dataset/).
### Get MFA Result and Extract
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) results to cut the silence in the edge of audio.
You can download from here [ljspeech_alignment.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/LJSpeech-1.1/ljspeech_alignment.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo.
## Get Started
Assume the path to the dataset is `~/datasets/LJSpeech-1.1`.
Assume the path to the MFA result of LJSpeech-1.1 is `./ljspeech_alignment`.
Run the command below to
1. **source path**.
2. preprocess the dataset.
3. train the model.
4. synthesize wavs.
- synthesize waveform from `metadata.jsonl`.
```bash
./run.sh
```
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
```bash
./run.sh --stage 0 --stop-stage 0
```
### Data Preprocessing
```bash
./local/preprocess.sh ${conf_path}
```
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
```text
dump
├── dev
│ ├── norm
│ └── raw
├── test
│ ├── norm
│ └── raw
└── train
├── norm
├── raw
└── feats_stats.npy
```
The dataset is split into 3 parts, namely `train`, `dev`, and `test`, each of which contains a `norm` and `raw` subfolder. The `raw` folder contains the log magnitude of the mel spectrogram of each utterance, while the norm folder contains the normalized spectrogram. The statistics used to normalize the spectrogram are computed from the training set, which is located in `dump/train/feats_stats.npy`.
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains id and paths to the spectrogram of each utterance.
### Model Training
`./local/train.sh` calls `${BIN_DIR}/train.py`.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
```
Here's the complete help message.
```text
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
[--ngpu NGPU] [--batch-size BATCH_SIZE] [--max-iter MAX_ITER]
[--run-benchmark RUN_BENCHMARK]
[--profiler_options PROFILER_OPTIONS]
Train a ParallelWaveGAN model.
optional arguments:
-h, --help show this help message and exit
--config CONFIG config file to overwrite default config.
--train-metadata TRAIN_METADATA
training data.
--dev-metadata DEV_METADATA
dev data.
--output-dir OUTPUT_DIR
output dir.
--ngpu NGPU if ngpu == 0, use cpu.
benchmark:
arguments related to benchmark.
--batch-size BATCH_SIZE
batch size.
--max-iter MAX_ITER train max steps.
--run-benchmark RUN_BENCHMARK
runing benchmark or not, if True, use the --batch-size
and --max-iter.
--profiler_options PROFILER_OPTIONS
The option of profiler, which should be in format
"key1=value1;key2=value2;key3=value3".
```
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
### Synthesizing
`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
```
```text
usage: synthesize.py [-h] [--generator-type GENERATOR_TYPE] [--config CONFIG]
[--checkpoint CHECKPOINT] [--test-metadata TEST_METADATA]
[--output-dir OUTPUT_DIR] [--ngpu NGPU]
Synthesize with GANVocoder.
optional arguments:
-h, --help show this help message and exit
--generator-type GENERATOR_TYPE
type of GANVocoder, should in {pwgan, mb_melgan,
style_melgan, } now
--config CONFIG GANVocoder config file.
--checkpoint CHECKPOINT
snapshot to load.
--test-metadata TEST_METADATA
dev data.
--output-dir OUTPUT_DIR
output dir.
--ngpu NGPU if ngpu == 0, use cpu.
```
1. `--config` parallel wavegan config file. You should use the same config with which the model is trained.
2. `--checkpoint` is the checkpoint to load. Pick one of the checkpoints from `checkpoints` inside the training output directory.
3. `--test-metadata` is the metadata of the test dataset. Use the `metadata.jsonl` in the `dev/norm` subfolder from the processed directory.
4. `--output-dir` is the directory to save the synthesized audio files.
5. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
## Pretrained Model
## Acknowledgement
We adapted some code from https://github.com/kan-bayashi/ParallelWaveGAN.

@ -0,0 +1,167 @@
# This is the configuration file for LJSpeech dataset.
# This configuration is based on HiFiGAN V1, which is an official configuration.
# But I found that the optimizer setting does not work well with my implementation.
# So I changed optimizer settings as follows:
# - AdamW -> Adam
# - betas: [0.8, 0.99] -> betas: [0.5, 0.9]
# - Scheduler: ExponentialLR -> MultiStepLR
# To match the shift size difference, the upsample scales is also modified from the original 256 shift setting.
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 22050 # Sampling rate.
n_fft: 1024 # FFT size (samples).
n_shift: 256 # Hop size (samples). 11.6ms
win_length: null # Window length (samples).
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
n_mels: 80 # Number of mel basis.
fmin: 80 # Minimum freq in mel basis calculation. (Hz)
fmax: 7600 # Maximum frequency in mel basis calculation. (Hz)
###########################################################
# GENERATOR NETWORK ARCHITECTURE SETTING #
###########################################################
generator_params:
in_channels: 80 # Number of input channels.
out_channels: 1 # Number of output channels.
channels: 512 # Number of initial channels.
kernel_size: 7 # Kernel size of initial and final conv layers.
upsample_scales: [8, 8, 2, 2] # Upsampling scales.
upsample_kernel_sizes: [16, 16, 4, 4] # Kernel size for upsampling layers.
resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks.
resblock_dilations: # Dilations for residual blocks.
- [1, 3, 5]
- [1, 3, 5]
- [1, 3, 5]
use_additional_convs: True # Whether to use additional conv layer in residual blocks.
bias: True # Whether to use bias parameter in conv.
nonlinear_activation: "leakyrelu" # Nonlinear activation type.
nonlinear_activation_params: # Nonlinear activation paramters.
negative_slope: 0.1
use_weight_norm: True # Whether to apply weight normalization.
###########################################################
# DISCRIMINATOR NETWORK ARCHITECTURE SETTING #
###########################################################
discriminator_params:
scales: 3 # Number of multi-scale discriminator.
scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator.
scale_downsample_pooling_params:
kernel_size: 4 # Pooling kernel size.
stride: 2 # Pooling stride.
padding: 2 # Padding size.
scale_discriminator_params:
in_channels: 1 # Number of input channels.
out_channels: 1 # Number of output channels.
kernel_sizes: [15, 41, 5, 3] # List of kernel sizes.
channels: 128 # Initial number of channels.
max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
max_groups: 16 # Maximum number of groups in downsampling conv layers.
bias: True
downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales.
nonlinear_activation: "leakyrelu" # Nonlinear activation.
nonlinear_activation_params:
negative_slope: 0.1
follow_official_norm: True # Whether to follow the official norm setting.
periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator.
period_discriminator_params:
in_channels: 1 # Number of input channels.
out_channels: 1 # Number of output channels.
kernel_sizes: [5, 3] # List of kernel sizes.
channels: 32 # Initial number of channels.
downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales.
max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers.
bias: True # Whether to use bias parameter in conv layer."
nonlinear_activation: "leakyrelu" # Nonlinear activation.
nonlinear_activation_params: # Nonlinear activation paramters.
negative_slope: 0.1
use_weight_norm: True # Whether to apply weight normalization.
use_spectral_norm: False # Whether to apply spectral normalization.
###########################################################
# STFT LOSS SETTING #
###########################################################
use_stft_loss: False # Whether to use multi-resolution STFT loss.
use_mel_loss: True # Whether to use Mel-spectrogram loss.
mel_loss_params:
fs: 22050
fft_size: 1024
hop_size: 256
win_length: null
window: "hann"
num_mels: 80
fmin: 0
fmax: 11025
log_base: null
generator_adv_loss_params:
average_by_discriminators: False # Whether to average loss by #discriminators.
discriminator_adv_loss_params:
average_by_discriminators: False # Whether to average loss by #discriminators.
use_feat_match_loss: True
feat_match_loss_params:
average_by_discriminators: False # Whether to average loss by #discriminators.
average_by_layers: False # Whether to average loss by #layers in each discriminator.
include_final_outputs: False # Whether to include final outputs in feat match loss calculation.
###########################################################
# ADVERSARIAL LOSS SETTING #
###########################################################
lambda_aux: 45.0 # Loss balancing coefficient for STFT loss.
lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss.
lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss..
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 16 # Batch size.
batch_max_steps: 8192 # Length of each audio in batch. Make sure dividable by hop_size.
num_workers: 2 # Number of workers in DataLoader.
###########################################################
# OPTIMIZER & SCHEDULER SETTING #
###########################################################
generator_optimizer_params:
beta1: 0.5
beta2: 0.9
weight_decay: 0.0 # Generator's weight decay coefficient.
generator_scheduler_params:
learning_rate: 2.0e-4 # Generator's learning rate.
gamma: 0.5 # Generator's scheduler gamma.
milestones: # At each milestone, lr will be multiplied by gamma.
- 200000
- 400000
- 600000
- 800000
generator_grad_norm: -1 # Generator's gradient norm.
discriminator_optimizer_params:
beta1: 0.5
beta2: 0.9
weight_decay: 0.0 # Discriminator's weight decay coefficient.
discriminator_scheduler_params:
learning_rate: 2.0e-4 # Discriminator's learning rate.
gamma: 0.5 # Discriminator's scheduler gamma.
milestones: # At each milestone, lr will be multiplied by gamma.
- 200000
- 400000
- 600000
- 800000
discriminator_grad_norm: -1 # Discriminator's gradient norm.
###########################################################
# INTERVAL SETTING #
###########################################################
generator_train_start_steps: 1 # Number of steps to start to train discriminator.
discriminator_train_start_steps: 0 # Number of steps to start to train discriminator.
train_max_steps: 2500000 # Number of training steps.
save_interval_steps: 5000 # Interval steps to save checkpoint.
eval_interval_steps: 1000 # Interval steps to evaluate the network.
###########################################################
# OTHER SETTING #
###########################################################
num_snapshots: 10 # max number of snapshots to keep while training
seed: 42 # random seed for paddle, random, and np.random

@ -0,0 +1,55 @@
#!/bin/bash
stage=0
stop_stage=100
config_path=$1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# get durations from MFA's result
echo "Generate durations.txt from MFA results ..."
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
--inputdir=./ljspeech_alignment \
--output=durations.txt \
--config=${config_path}
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/../preprocess.py \
--rootdir=~/datasets/LJSpeech-1.1/ \
--dataset=ljspeech \
--dumpdir=dump \
--dur-file=durations.txt \
--config=${config_path} \
--cut-sil=True \
--num-cpu=20
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="feats"
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# normalize, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/../normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--stats=dump/train/feats_stats.npy
python3 ${BIN_DIR}/../normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--stats=dump/train/feats_stats.npy
python3 ${BIN_DIR}/../normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--stats=dump/train/feats_stats.npy
fi

@ -0,0 +1,14 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \
--config=${config_path} \
--checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=${train_output_path}/test \
--generator-type=hifigan

@ -0,0 +1,13 @@
#!/bin/bash
config_path=$1
train_output_path=$2
FLAGS_cudnn_exhaustive_search=true \
FLAGS_conv_workspace_size_limit=4000 \
python ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=hifigan
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/gan_vocoder/${MODEL}

@ -0,0 +1,32 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_5000.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
Loading…
Cancel
Save