[TTS]Update VITS to support VITS and its voice cloning training on AIShell-3 (#2268)

* code for training vits voice clone on aishell3.
Co-authored-by: TianYuan <white-sky@qq.com>
pull/2353/head
艾梦 2 years ago committed by GitHub
parent e8656fdfba
commit ea9ee93739
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -613,7 +613,7 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
</td>
</tr>
<tr>
<td rowspan="3">Voice Cloning</td>
<td rowspan="4">Voice Cloning</td>
<td>GE2E</td>
<td >Librispeech, etc.</td>
<td>
@ -633,13 +633,20 @@ PaddleSpeech supports a series of most popular models. They are summarized in [r
<td>
<a href = "./examples/aishell3/vc1">ge2e-fastspeech2-aishell3</a>
</td>
</tr>
<tr>
<td>GE2E + VITS</td>
<td>AISHELL-3</td>
<td>
<a href = "./examples/aishell3/vits-vc">ge2e-vits-aishell3</a>
</td>
</tr>
<tr>
<td rowspan="3">End-to-End</td>
<td>VITS</td>
<td >CSMSC</td>
<td>CSMSC / AISHELL-3</td>
<td>
<a href = "./examples/csmsc/vits">VITS-csmsc</a>
<a href = "./examples/csmsc/vits">VITS-csmsc</a> / <a href = "./examples/aishell3/vits">VITS-aishell3</a>
</td>
</tr>
</tbody>

@ -608,7 +608,7 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
</td>
</tr>
<tr>
<td rowspan="3">声音克隆</td>
<td rowspan="4">声音克隆</td>
<td>GE2E</td>
<td >Librispeech, etc.</td>
<td>
@ -629,13 +629,20 @@ PaddleSpeech 的 **语音合成** 主要包含三个模块:文本前端、声
<a href = "./examples/aishell3/vc1">ge2e-fastspeech2-aishell3</a>
</td>
</tr>
<tr>
<td>GE2E + VITS</td>
<td>AISHELL-3</td>
<td>
<a href = "./examples/aishell3/vits-vc">ge2e-vits-aishell3</a>
</td>
</tr>
</tr>
<tr>
<td rowspan="3">端到端</td>
<td>VITS</td>
<td >CSMSC</td>
<td>CSMSC / AISHELL-3</td>
<td>
<a href = "./examples/csmsc/vits">VITS-csmsc</a>
<a href = "./examples/csmsc/vits">VITS-csmsc</a> / <a href = "./examples/aishell3/vits">VITS-aishell3</a>
</td>
</tr>
</tbody>

@ -0,0 +1,154 @@
# VITS with AISHELL-3
This example contains code used to train a [VITS](https://arxiv.org/abs/2106.06103) model with [AISHELL-3](http://www.aishelltech.com/aishell_3). The trained model can be used in Voice Cloning Task, We refer to the model structure of [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf). The general steps are as follows:
1. Speaker Encoder: We use Speaker Verification to train a speaker encoder. Datasets used in this task are different from those used in `VITS` because the transcriptions are not needed, we use more datasets, refer to [ge2e](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/ge2e).
2. Synthesizer and Vocoder: We use the trained speaker encoder to generate speaker embedding for each sentence in AISHELL-3. This embedding is an extra input of `VITS` which will be concated with encoder outputs. The vocoder is part of `VITS` due to its special structure.
## Dataset
### Download and Extract
Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`.
### Get MFA Result and Extract
We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for VITS, the durations of MFA are not needed here.
You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo.
## Pretrained GE2E Model
We use pretrained GE2E model to generate speaker embedding for each sentence.
Download pretrained GE2E model from here [ge2e_ckpt_0.3.zip](https://bj.bcebos.com/paddlespeech/Parakeet/released_models/ge2e/ge2e_ckpt_0.3.zip), and `unzip` it.
## Get Started
Assume the path to the dataset is `~/datasets/data_aishell3`.
Assume the path to the MFA result of AISHELL-3 is `./aishell3_alignment_tone`.
Assume the path to the pretrained ge2e model is `./ge2e_ckpt_0.3`.
Run the command below to
1. **source path**.
2. preprocess the dataset.
3. train the model.
4. synthesize waveform from `metadata.jsonl`.
5. start a voice cloning inference.
```bash
./run.sh
```
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
```bash
./run.sh --stage 0 --stop-stage 0
```
### Data Preprocessing
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${conf_path} ${ge2e_ckpt_path}
```
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
```text
dump
├── dev
│   ├── norm
│   └── raw
├── embed
│ ├── SSB0005
│ ├── SSB0009
│ ├── ...
│ └── ...
├── phone_id_map.txt
├── speaker_id_map.txt
├── test
│   ├── norm
│   └── raw
└── train
├── feats_stats.npy
├── norm
└── raw
```
The `embed` contains the generated speaker embedding for each sentence in AISHELL-3, which has the same file structure with wav files and the format is `.npy`.
The computing time of utterance embedding can be x hours.
The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains wave and linear spectrogram of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/feats_stats.npy`.
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, feats, feats_lengths, the path of linear spectrogram features, the path of raw waves, speaker, and the id of each utterance.
The preprocessing step is very similar to that one of [vits](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/vits), but there is one more `ge2e/inference` step here.
### Model Training
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
```
The training step is very similar to that one of [vits](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/vits), but we should set `--voice-cloning=True` when calling `${BIN_DIR}/train.py`.
### Synthesizing
`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
```
```text
usage: synthesize.py [-h] [--config CONFIG] [--ckpt CKPT]
[--phones_dict PHONES_DICT] [--speaker_dict SPEAKER_DICT]
[--voice-cloning VOICE_CLONING] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
Synthesize with VITS
optional arguments:
-h, --help show this help message and exit
--config CONFIG Config of VITS.
--ckpt CKPT Checkpoint file of VITS.
--phones_dict PHONES_DICT
phone vocabulary file.
--speaker_dict SPEAKER_DICT
speaker id map file.
--voice-cloning VOICE_CLONING
whether training voice cloning model.
--ngpu NGPU if ngpu == 0, use cpu.
--test_metadata TEST_METADATA
test metadata.
--output_dir OUTPUT_DIR
output dir.
```
The synthesizing step is very similar to that one of [vits](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/vits), but we should set `--voice-cloning=True` when calling `${BIN_DIR}/../synthesize.py`.
### Voice Cloning
Assume there are some reference audios in `./ref_audio`
```text
ref_audio
├── 001238.wav
├── LJ015-0254.wav
└── audio_self_test.mp3
```
`./local/voice_cloning.sh` calls `${BIN_DIR}/voice_cloning.py`
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ge2e_params_path} ${add_blank} ${ref_audio_dir}
```
If you want to convert a speaker audio file to refered speaker, run:
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ge2e_params_path} ${add_blank} ${ref_audio_dir} ${src_audio_path}
```
<!-- TODO display these after we trained the model -->
<!--
## Pretrained Model
The pretrained model can be downloaded here:
- [vits_vc_aishell3_ckpt_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/vits/vits_vc_aishell3_ckpt_1.1.0.zip) (add_blank=true)
VITS checkpoint contains files listed below.
(There is no need for `speaker_id_map.txt` here )
```text
vits_vc_aishell3_ckpt_1.1.0
├── default.yaml # default config used to train vitx
├── phone_id_map.txt # phone vocabulary file when training vits
└── snapshot_iter_333000.pdz # model parameters and optimizer states
```
ps: This ckpt is not good enough, a better result is training
-->

@ -0,0 +1,185 @@
# This configuration tested on 4 GPUs (V100) with 32GB GPU
# memory. It takes around 2 weeks to finish the training
# but 100k iters model should generate reasonable results.
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 22050 # sr
n_fft: 1024 # FFT size (samples).
n_shift: 256 # Hop size (samples). 12.5ms
win_length: null # Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
##########################################################
# TTS MODEL SETTING #
##########################################################
model:
# generator related
generator_type: vits_generator
generator_params:
hidden_channels: 192
spk_embed_dim: 256
global_channels: 256
segment_size: 32
text_encoder_attention_heads: 2
text_encoder_ffn_expand: 4
text_encoder_blocks: 6
text_encoder_positionwise_layer_type: "conv1d"
text_encoder_positionwise_conv_kernel_size: 3
text_encoder_positional_encoding_layer_type: "rel_pos"
text_encoder_self_attention_layer_type: "rel_selfattn"
text_encoder_activation_type: "swish"
text_encoder_normalize_before: True
text_encoder_dropout_rate: 0.1
text_encoder_positional_dropout_rate: 0.0
text_encoder_attention_dropout_rate: 0.1
use_macaron_style_in_text_encoder: True
use_conformer_conv_in_text_encoder: False
text_encoder_conformer_kernel_size: -1
decoder_kernel_size: 7
decoder_channels: 512
decoder_upsample_scales: [8, 8, 2, 2]
decoder_upsample_kernel_sizes: [16, 16, 4, 4]
decoder_resblock_kernel_sizes: [3, 7, 11]
decoder_resblock_dilations: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
use_weight_norm_in_decoder: True
posterior_encoder_kernel_size: 5
posterior_encoder_layers: 16
posterior_encoder_stacks: 1
posterior_encoder_base_dilation: 1
posterior_encoder_dropout_rate: 0.0
use_weight_norm_in_posterior_encoder: True
flow_flows: 4
flow_kernel_size: 5
flow_base_dilation: 1
flow_layers: 4
flow_dropout_rate: 0.0
use_weight_norm_in_flow: True
use_only_mean_in_flow: True
stochastic_duration_predictor_kernel_size: 3
stochastic_duration_predictor_dropout_rate: 0.5
stochastic_duration_predictor_flows: 4
stochastic_duration_predictor_dds_conv_layers: 3
# discriminator related
discriminator_type: hifigan_multi_scale_multi_period_discriminator
discriminator_params:
scales: 1
scale_downsample_pooling: "AvgPool1D"
scale_downsample_pooling_params:
kernel_size: 4
stride: 2
padding: 2
scale_discriminator_params:
in_channels: 1
out_channels: 1
kernel_sizes: [15, 41, 5, 3]
channels: 128
max_downsample_channels: 1024
max_groups: 16
bias: True
downsample_scales: [2, 2, 4, 4, 1]
nonlinear_activation: "leakyrelu"
nonlinear_activation_params:
negative_slope: 0.1
use_weight_norm: True
use_spectral_norm: False
follow_official_norm: False
periods: [2, 3, 5, 7, 11]
period_discriminator_params:
in_channels: 1
out_channels: 1
kernel_sizes: [5, 3]
channels: 32
downsample_scales: [3, 3, 3, 3, 1]
max_downsample_channels: 1024
bias: True
nonlinear_activation: "leakyrelu"
nonlinear_activation_params:
negative_slope: 0.1
use_weight_norm: True
use_spectral_norm: False
# others
sampling_rate: 22050 # needed in the inference for saving wav
cache_generator_outputs: True # whether to cache generator outputs in the training
###########################################################
# LOSS SETTING #
###########################################################
# loss function related
generator_adv_loss_params:
average_by_discriminators: False # whether to average loss value by #discriminators
loss_type: mse # loss type, "mse" or "hinge"
discriminator_adv_loss_params:
average_by_discriminators: False # whether to average loss value by #discriminators
loss_type: mse # loss type, "mse" or "hinge"
feat_match_loss_params:
average_by_discriminators: False # whether to average loss value by #discriminators
average_by_layers: False # whether to average loss value by #layers of each discriminator
include_final_outputs: True # whether to include final outputs for loss calculation
mel_loss_params:
fs: 22050 # must be the same as the training data
fft_size: 1024 # fft points
hop_size: 256 # hop size
win_length: null # window length
window: hann # window type
num_mels: 80 # number of Mel basis
fmin: 0 # minimum frequency for Mel basis
fmax: null # maximum frequency for Mel basis
log_base: null # null represent natural log
###########################################################
# ADVERSARIAL LOSS SETTING #
###########################################################
lambda_adv: 1.0 # loss scaling coefficient for adversarial loss
lambda_mel: 45.0 # loss scaling coefficient for Mel loss
lambda_feat_match: 2.0 # loss scaling coefficient for feat match loss
lambda_dur: 1.0 # loss scaling coefficient for duration loss
lambda_kl: 1.0 # loss scaling coefficient for KL divergence loss
# others
sampling_rate: 22050 # needed in the inference for saving wav
cache_generator_outputs: True # whether to cache generator outputs in the training
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 50 # Batch size.
num_workers: 4 # Number of workers in DataLoader.
##########################################################
# OPTIMIZER & SCHEDULER SETTING #
##########################################################
# optimizer setting for generator
generator_optimizer_params:
beta1: 0.8
beta2: 0.99
epsilon: 1.0e-9
weight_decay: 0.0
generator_scheduler: exponential_decay
generator_scheduler_params:
learning_rate: 2.0e-4
gamma: 0.999875
# optimizer setting for discriminator
discriminator_optimizer_params:
beta1: 0.8
beta2: 0.99
epsilon: 1.0e-9
weight_decay: 0.0
discriminator_scheduler: exponential_decay
discriminator_scheduler_params:
learning_rate: 2.0e-4
gamma: 0.999875
generator_first: False # whether to start updating generator first
##########################################################
# OTHER TRAINING SETTING #
##########################################################
num_snapshots: 10 # max number of snapshots to keep while training
train_max_steps: 350000 # Number of training steps. == total_iters / ngpus, total_iters = 1000000
save_interval_steps: 1000 # Interval steps to save checkpoint.
eval_interval_steps: 250 # Interval steps to evaluate the network.
seed: 777 # random seed number

@ -0,0 +1,79 @@
#!/bin/bash
stage=0
stop_stage=100
config_path=$1
add_blank=$2
ge2e_ckpt_path=$3
# gen speaker embedding
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${MAIN_ROOT}/paddlespeech/vector/exps/ge2e/inference.py \
--input=~/datasets/data_aishell3/train/wav/ \
--output=dump/embed \
--checkpoint_path=${ge2e_ckpt_path}
fi
# copy from tts3/preprocess
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# get durations from MFA's result
echo "Generate durations.txt from MFA results ..."
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
--inputdir=./aishell3_alignment_tone \
--output durations.txt \
--config=${config_path}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/preprocess.py \
--dataset=aishell3 \
--rootdir=~/datasets/data_aishell3/ \
--dumpdir=dump \
--dur-file=durations.txt \
--config=${config_path} \
--num-cpu=20 \
--cut-sil=True \
--spk_emb_dir=dump/embed
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="feats"
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# normalize and covert phone/speaker to id, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--add-blank=${add_blank} \
--skip-wav-copy
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--add-blank=${add_blank} \
--skip-wav-copy
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--add-blank=${add_blank} \
--skip-wav-copy
fi

@ -0,0 +1,19 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
stage=0
stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
--config=${config_path} \
--ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--phones_dict=dump/phone_id_map.txt \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test \
--voice-cloning=True
fi

@ -0,0 +1,18 @@
#!/bin/bash
config_path=$1
train_output_path=$2
# install monotonic_align
cd ${MAIN_ROOT}/paddlespeech/t2s/models/vits/monotonic_align
python3 setup.py build_ext --inplace
cd -
python3 ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=4 \
--phones-dict=dump/phone_id_map.txt \
--voice-cloning=True

@ -0,0 +1,22 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
ge2e_params_path=$4
add_blank=$5
ref_audio_dir=$6
src_audio_path=$7
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/voice_cloning.py \
--config=${config_path} \
--ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--ge2e_params_path=${ge2e_params_path} \
--phones_dict=dump/phone_id_map.txt \
--text="凯莫瑞安联合体的经济崩溃迫在眉睫。" \
--audio-path=${src_audio_path} \
--input-dir=${ref_audio_dir} \
--output-dir=${train_output_path}/vc_syn \
--add-blank=${add_blank}

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=vits
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}

@ -0,0 +1,45 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1,2,3
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_153.pdz
add_blank=true
ref_audio_dir=ref_audio
src_audio_path=''
# not include ".pdparams" here
ge2e_ckpt_path=./ge2e_ckpt_0.3/step-3000000
# include ".pdparams" here
ge2e_params_path=${ge2e_ckpt_path}.pdparams
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${conf_path} ${add_blank} ${ge2e_ckpt_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} \
${ge2e_params_path} ${add_blank} ${ref_audio_dir} ${src_audio_path} || exit -1
fi

@ -0,0 +1,202 @@
# VITS with AISHELL-3
This example contains code used to train a [VITS](https://arxiv.org/abs/2106.06103) model with [AISHELL-3](http://www.aishelltech.com/aishell_3).
AISHELL-3 is a large-scale and high-fidelity multi-speaker Mandarin speech corpus that could be used to train multi-speaker Text-to-Speech (TTS) systems.
We use AISHELL-3 to train a multi-speaker VITS model here.
## Dataset
### Download and Extract
Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`.
### Get MFA Result and Extract
We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get phonemes for VITS, the durations of MFA are not needed here.
You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo.
## Get Started
Assume the path to the dataset is `~/datasets/data_aishell3`.
Assume the path to the MFA result of AISHELL-3 is `./aishell3_alignment_tone`.
Run the command below to
1. **source path**.
2. preprocess the dataset.
3. train the model.
4. synthesize wavs.
- synthesize waveform from `metadata.jsonl`.
- synthesize waveform from a text file.
```bash
./run.sh
```
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
```bash
./run.sh --stage 0 --stop-stage 0
```
### Data Preprocessing
```bash
./local/preprocess.sh ${conf_path}
```
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
```text
dump
├── dev
│   ├── norm
│   └── raw
├── phone_id_map.txt
├── speaker_id_map.txt
├── test
│   ├── norm
│   └── raw
└── train
├── feats_stats.npy
├── norm
└── raw
```
The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains wave and linear spectrogram of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/feats_stats.npy`.
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, feats, feats_lengths, the path of linear spectrogram features, the path of raw waves, speaker, and the id of each utterance.
### Model Training
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
```
`./local/train.sh` calls `${BIN_DIR}/train.py`.
Here's the complete help message.
```text
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
[--ngpu NGPU] [--phones-dict PHONES_DICT]
[--speaker-dict SPEAKER_DICT] [--voice-cloning VOICE_CLONING]
Train a VITS model.
optional arguments:
-h, --help show this help message and exit
--config CONFIG config file to overwrite default config.
--train-metadata TRAIN_METADATA
training data.
--dev-metadata DEV_METADATA
dev data.
--output-dir OUTPUT_DIR
output dir.
--ngpu NGPU if ngpu == 0, use cpu.
--phones-dict PHONES_DICT
phone vocabulary file.
--speaker-dict SPEAKER_DICT
speaker id map file for multiple speaker model.
--voice-cloning VOICE_CLONING
whether training voice cloning model.
```
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
5. `--phones-dict` is the path of the phone vocabulary file.
6. `--speaker-dict` is the path of the speaker id map file when training a multi-speaker VITS.
### Synthesizing
`./local/synthesize.sh` calls `${BIN_DIR}/synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
```
```text
usage: synthesize.py [-h] [--config CONFIG] [--ckpt CKPT]
[--phones_dict PHONES_DICT] [--speaker_dict SPEAKER_DICT]
[--voice-cloning VOICE_CLONING] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
Synthesize with VITS
optional arguments:
-h, --help show this help message and exit
--config CONFIG Config of VITS.
--ckpt CKPT Checkpoint file of VITS.
--phones_dict PHONES_DICT
phone vocabulary file.
--speaker_dict SPEAKER_DICT
speaker id map file.
--voice-cloning VOICE_CLONING
whether training voice cloning model.
--ngpu NGPU if ngpu == 0, use cpu.
--test_metadata TEST_METADATA
test metadata.
--output_dir OUTPUT_DIR
output dir.
```
`./local/synthesize_e2e.sh` calls `${BIN_DIR}/synthesize_e2e.py`, which can synthesize waveform from text file.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name}
```
```text
usage: synthesize_e2e.py [-h] [--config CONFIG] [--ckpt CKPT]
[--phones_dict PHONES_DICT]
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
[--lang LANG]
[--inference_dir INFERENCE_DIR] [--ngpu NGPU]
[--text TEXT] [--output_dir OUTPUT_DIR]
Synthesize with VITS
optional arguments:
-h, --help show this help message and exit
--config CONFIG Config of VITS.
--ckpt CKPT Checkpoint file of VITS.
--phones_dict PHONES_DICT
phone vocabulary file.
--speaker_dict SPEAKER_DICT
speaker id map file.
--spk_id SPK_ID spk id for multi speaker acoustic model
--lang LANG Choose model language. zh or en
--inference_dir INFERENCE_DIR
dir to save inference models
--ngpu NGPU if ngpu == 0, use cpu.
--text TEXT text to synthesize, a 'utt_id sentence' pair per line.
--output_dir OUTPUT_DIR
output dir.
```
1. `--config`, `--ckpt`, `--phones_dict` and `--speaker_dict` are arguments for acoustic model, which correspond to the 3 files in the VITS pretrained model.
2. `--lang` is the model language, which can be `zh` or `en`.
3. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
4. `--text` is the text file, which contains sentences to synthesize.
5. `--output_dir` is the directory to save synthesized audio files.
6. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
<!-- TODO display these after we trained the model -->
<!--
## Pretrained Model
The pretrained model can be downloaded here:
- [vits_aishell3_ckpt_1.1.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/vits/vits_aishell3_ckpt_1.1.0.zip) (add_blank=true)
VITS checkpoint contains files listed below.
```text
vits_aishell3_ckpt_1.1.0
├── default.yaml # default config used to train vitx
├── phone_id_map.txt # phone vocabulary file when training vits
├── speaker_id_map.txt # speaker id map file when training a multi-speaker vits
└── snapshot_iter_333000.pdz # model parameters and optimizer states
```
ps: This ckpt is not good enough, a better result is training
You can use the following scripts to synthesize for `${BIN_DIR}/../sentences.txt` using pretrained VITS.
```bash
source path.sh
add_blank=true
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize_e2e.py \
--config=vits_aishell3_ckpt_1.1.0/default.yaml \
--ckpt=vits_aishell3_ckpt_1.1.0/snapshot_iter_333000.pdz \
--phones_dict=vits_aishell3_ckpt_1.1.0/phone_id_map.txt \
--speaker_dict=vits_aishell3_ckpt_1.1.0/speaker_id_map.txt \
--output_dir=exp/default/test_e2e \
--text=${BIN_DIR}/../sentences.txt \
--add-blank=${add_blank}
```
-->

@ -0,0 +1,184 @@
# This configuration tested on 4 GPUs (V100) with 32GB GPU
# memory. It takes around 2 weeks to finish the training
# but 100k iters model should generate reasonable results.
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 22050 # sr
n_fft: 1024 # FFT size (samples).
n_shift: 256 # Hop size (samples). 12.5ms
win_length: null # Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
##########################################################
# TTS MODEL SETTING #
##########################################################
model:
# generator related
generator_type: vits_generator
generator_params:
hidden_channels: 192
global_channels: 256
segment_size: 32
text_encoder_attention_heads: 2
text_encoder_ffn_expand: 4
text_encoder_blocks: 6
text_encoder_positionwise_layer_type: "conv1d"
text_encoder_positionwise_conv_kernel_size: 3
text_encoder_positional_encoding_layer_type: "rel_pos"
text_encoder_self_attention_layer_type: "rel_selfattn"
text_encoder_activation_type: "swish"
text_encoder_normalize_before: True
text_encoder_dropout_rate: 0.1
text_encoder_positional_dropout_rate: 0.0
text_encoder_attention_dropout_rate: 0.1
use_macaron_style_in_text_encoder: True
use_conformer_conv_in_text_encoder: False
text_encoder_conformer_kernel_size: -1
decoder_kernel_size: 7
decoder_channels: 512
decoder_upsample_scales: [8, 8, 2, 2]
decoder_upsample_kernel_sizes: [16, 16, 4, 4]
decoder_resblock_kernel_sizes: [3, 7, 11]
decoder_resblock_dilations: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
use_weight_norm_in_decoder: True
posterior_encoder_kernel_size: 5
posterior_encoder_layers: 16
posterior_encoder_stacks: 1
posterior_encoder_base_dilation: 1
posterior_encoder_dropout_rate: 0.0
use_weight_norm_in_posterior_encoder: True
flow_flows: 4
flow_kernel_size: 5
flow_base_dilation: 1
flow_layers: 4
flow_dropout_rate: 0.0
use_weight_norm_in_flow: True
use_only_mean_in_flow: True
stochastic_duration_predictor_kernel_size: 3
stochastic_duration_predictor_dropout_rate: 0.5
stochastic_duration_predictor_flows: 4
stochastic_duration_predictor_dds_conv_layers: 3
# discriminator related
discriminator_type: hifigan_multi_scale_multi_period_discriminator
discriminator_params:
scales: 1
scale_downsample_pooling: "AvgPool1D"
scale_downsample_pooling_params:
kernel_size: 4
stride: 2
padding: 2
scale_discriminator_params:
in_channels: 1
out_channels: 1
kernel_sizes: [15, 41, 5, 3]
channels: 128
max_downsample_channels: 1024
max_groups: 16
bias: True
downsample_scales: [2, 2, 4, 4, 1]
nonlinear_activation: "leakyrelu"
nonlinear_activation_params:
negative_slope: 0.1
use_weight_norm: True
use_spectral_norm: False
follow_official_norm: False
periods: [2, 3, 5, 7, 11]
period_discriminator_params:
in_channels: 1
out_channels: 1
kernel_sizes: [5, 3]
channels: 32
downsample_scales: [3, 3, 3, 3, 1]
max_downsample_channels: 1024
bias: True
nonlinear_activation: "leakyrelu"
nonlinear_activation_params:
negative_slope: 0.1
use_weight_norm: True
use_spectral_norm: False
# others
sampling_rate: 22050 # needed in the inference for saving wav
cache_generator_outputs: True # whether to cache generator outputs in the training
###########################################################
# LOSS SETTING #
###########################################################
# loss function related
generator_adv_loss_params:
average_by_discriminators: False # whether to average loss value by #discriminators
loss_type: mse # loss type, "mse" or "hinge"
discriminator_adv_loss_params:
average_by_discriminators: False # whether to average loss value by #discriminators
loss_type: mse # loss type, "mse" or "hinge"
feat_match_loss_params:
average_by_discriminators: False # whether to average loss value by #discriminators
average_by_layers: False # whether to average loss value by #layers of each discriminator
include_final_outputs: True # whether to include final outputs for loss calculation
mel_loss_params:
fs: 22050 # must be the same as the training data
fft_size: 1024 # fft points
hop_size: 256 # hop size
win_length: null # window length
window: hann # window type
num_mels: 80 # number of Mel basis
fmin: 0 # minimum frequency for Mel basis
fmax: null # maximum frequency for Mel basis
log_base: null # null represent natural log
###########################################################
# ADVERSARIAL LOSS SETTING #
###########################################################
lambda_adv: 1.0 # loss scaling coefficient for adversarial loss
lambda_mel: 45.0 # loss scaling coefficient for Mel loss
lambda_feat_match: 2.0 # loss scaling coefficient for feat match loss
lambda_dur: 1.0 # loss scaling coefficient for duration loss
lambda_kl: 1.0 # loss scaling coefficient for KL divergence loss
# others
sampling_rate: 22050 # needed in the inference for saving wav
cache_generator_outputs: True # whether to cache generator outputs in the training
###########################################################
# DATA LOADER SETTING #
###########################################################
batch_size: 50 # Batch size.
num_workers: 4 # Number of workers in DataLoader.
##########################################################
# OPTIMIZER & SCHEDULER SETTING #
##########################################################
# optimizer setting for generator
generator_optimizer_params:
beta1: 0.8
beta2: 0.99
epsilon: 1.0e-9
weight_decay: 0.0
generator_scheduler: exponential_decay
generator_scheduler_params:
learning_rate: 2.0e-4
gamma: 0.999875
# optimizer setting for discriminator
discriminator_optimizer_params:
beta1: 0.8
beta2: 0.99
epsilon: 1.0e-9
weight_decay: 0.0
discriminator_scheduler: exponential_decay
discriminator_scheduler_params:
learning_rate: 2.0e-4
gamma: 0.999875
generator_first: False # whether to start updating generator first
##########################################################
# OTHER TRAINING SETTING #
##########################################################
num_snapshots: 10 # max number of snapshots to keep while training
train_max_steps: 350000 # Number of training steps. == total_iters / ngpus, total_iters = 1000000
save_interval_steps: 1000 # Interval steps to save checkpoint.
eval_interval_steps: 250 # Interval steps to evaluate the network.
seed: 777 # random seed number

@ -0,0 +1,69 @@
#!/bin/bash
stage=0
stop_stage=100
config_path=$1
add_blank=$2
# copy from tts3/preprocess
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# get durations from MFA's result
echo "Generate durations.txt from MFA results ..."
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
--inputdir=./aishell3_alignment_tone \
--output durations.txt \
--config=${config_path}
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/preprocess.py \
--dataset=aishell3 \
--rootdir=~/datasets/data_aishell3/ \
--dumpdir=dump \
--dur-file=durations.txt \
--config=${config_path} \
--num-cpu=20 \
--cut-sil=True
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="feats"
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# normalize and covert phone/speaker to id, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--add-blank=${add_blank} \
--skip-wav-copy
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--add-blank=${add_blank} \
--skip-wav-copy
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--feats-stats=dump/train/feats_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt \
--add-blank=${add_blank} \
--skip-wav-copy
fi

@ -0,0 +1,19 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
stage=0
stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
--config=${config_path} \
--ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test
fi

@ -0,0 +1,24 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
add_blank=$4
stage=0
stop_stage=0
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize_e2e.py \
--config=${config_path} \
--ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \
--spk_id=0 \
--output_dir=${train_output_path}/test_e2e \
--text=${BIN_DIR}/../sentences.txt \
--add-blank=${add_blank}
fi

@ -0,0 +1,18 @@
#!/bin/bash
config_path=$1
train_output_path=$2
# install monotonic_align
cd ${MAIN_ROOT}/paddlespeech/t2s/models/vits/monotonic_align
python3 setup.py build_ext --inplace
cd -
python3 ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=4 \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=vits
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}

@ -0,0 +1,36 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1,2,3
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_153.pdz
add_blank=true
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} ${add_blank}|| exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} ${add_blank}|| exit -1
fi

@ -3,7 +3,7 @@
set -e
source path.sh
gpus=0,1
gpus=0,1,2,3
stage=0
stop_stage=100

@ -483,3 +483,58 @@ def vits_single_spk_batch_fn(examples):
"speech": speech
}
return batch
def vits_multi_spk_batch_fn(examples):
"""
Returns:
Dict[str, Any]:
- text (Tensor): Text index tensor (B, T_text).
- text_lengths (Tensor): Text length tensor (B,).
- feats (Tensor): Feature tensor (B, T_feats, aux_channels).
- feats_lengths (Tensor): Feature length tensor (B,).
- speech (Tensor): Speech waveform tensor (B, T_wav).
- spk_id (Optional[Tensor]): Speaker index tensor (B,) or (B, 1).
- spk_emb (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim).
"""
# fields = ["text", "text_lengths", "feats", "feats_lengths", "speech", "spk_id"/"spk_emb"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
feats = [np.array(item["feats"], dtype=np.float32) for item in examples]
speech = [np.array(item["wave"], dtype=np.float32) for item in examples]
text_lengths = [
np.array(item["text_lengths"], dtype=np.int64) for item in examples
]
feats_lengths = [
np.array(item["feats_lengths"], dtype=np.int64) for item in examples
]
text = batch_sequences(text)
feats = batch_sequences(feats)
speech = batch_sequences(speech)
# convert each batch to paddle.Tensor
text = paddle.to_tensor(text)
feats = paddle.to_tensor(feats)
text_lengths = paddle.to_tensor(text_lengths)
feats_lengths = paddle.to_tensor(feats_lengths)
batch = {
"text": text,
"text_lengths": text_lengths,
"feats": feats,
"feats_lengths": feats_lengths,
"speech": speech
}
# spk_emb has a higher priority than spk_id
if "spk_emb" in examples[0]:
spk_emb = [
np.array(item["spk_emb"], dtype=np.float32) for item in examples
]
spk_emb = batch_sequences(spk_emb)
spk_emb = paddle.to_tensor(spk_emb)
batch["spk_emb"] = spk_emb
elif "spk_id" in examples[0]:
spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples]
spk_id = paddle.to_tensor(spk_id)
batch["spk_id"] = spk_id
return batch

@ -15,6 +15,7 @@ import argparse
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import soundfile as sf
import yaml
@ -23,6 +24,7 @@ from yacs.config import CfgNode
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.utils import str2bool
def evaluate(args):
@ -40,8 +42,26 @@ def evaluate(args):
print(config)
fields = ["utt_id", "text"]
converters = {}
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker vits!")
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
fields += ["spk_id"]
elif args.voice_cloning:
print("Evaluating voice cloning!")
fields += ["spk_emb"]
else:
print("single speaker vits!")
print("spk_num:", spk_num)
test_dataset = DataTable(data=test_metadata, fields=fields)
test_dataset = DataTable(
data=test_metadata,
fields=fields,
converters=converters, )
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
@ -49,6 +69,7 @@ def evaluate(args):
print("vocab_size:", vocab_size)
odim = config.n_fft // 2 + 1
config["model"]["generator_params"]["spks"] = spk_num
vits = VITS(idim=vocab_size, odim=odim, **config["model"])
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
@ -65,7 +86,15 @@ def evaluate(args):
phone_ids = paddle.to_tensor(datum["text"])
with timer() as t:
with paddle.no_grad():
out = vits.inference(text=phone_ids)
spk_emb = None
spk_id = None
# multi speaker
if args.voice_cloning and "spk_emb" in datum:
spk_emb = paddle.to_tensor(np.load(datum["spk_emb"]))
elif "spk_id" in datum:
spk_id = paddle.to_tensor(datum["spk_id"])
out = vits.inference(
text=phone_ids, sids=spk_id, spembs=spk_emb)
wav = out["wav"]
wav = wav.numpy()
N += wav.size
@ -90,6 +119,13 @@ def parse_args():
'--ckpt', type=str, default=None, help='Checkpoint file of VITS.')
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--voice-cloning",
type=str2bool,
default=False,
help="whether training voice cloning model.")
# other
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")

@ -42,12 +42,23 @@ def evaluate(args):
# frontend
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker vits!")
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
else:
print("single speaker vits!")
print("spk_num:", spk_num)
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
odim = config.n_fft // 2 + 1
config["model"]["generator_params"]["spks"] = spk_num
vits = VITS(idim=vocab_size, odim=odim, **config["model"])
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
@ -78,7 +89,10 @@ def evaluate(args):
flags = 0
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i]
out = vits.inference(text=part_phone_ids)
spk_id = None
if spk_num is not None:
spk_id = paddle.to_tensor(args.spk_id)
out = vits.inference(text=part_phone_ids, sids=spk_id)
wav = out["wav"]
if flags == 0:
wav_all = wav
@ -109,6 +123,13 @@ def parse_args():
'--ckpt', type=str, default=None, help='Checkpoint file of VITS.')
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker_dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
'--spk_id',
type=int,
default=0,
help='spk id for multi speaker acoustic model')
# other
parser.add_argument(
'--lang',

@ -28,6 +28,7 @@ from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import vits_multi_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.vits import VITS
@ -43,6 +44,7 @@ from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import scheduler_classes
from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer
from paddlespeech.t2s.utils import str2bool
def train_sp(args, config):
@ -72,6 +74,23 @@ def train_sp(args, config):
"wave": np.load,
"feats": np.load,
}
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker vits!")
collate_fn = vits_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
fields += ["spk_id"]
elif args.voice_cloning:
print("Training voice cloning!")
collate_fn = vits_multi_spk_batch_fn
fields += ["spk_emb"]
converters["spk_emb"] = np.load
else:
print("single speaker vits!")
collate_fn = vits_single_spk_batch_fn
print("spk_num:", spk_num)
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
@ -100,18 +119,16 @@ def train_sp(args, config):
drop_last=False)
print("samplers done!")
train_batch_fn = vits_single_spk_batch_fn
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=train_batch_fn,
collate_fn=collate_fn,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
collate_fn=train_batch_fn,
collate_fn=collate_fn,
num_workers=config.num_workers)
print("dataloaders done!")
@ -121,6 +138,7 @@ def train_sp(args, config):
print("vocab_size:", vocab_size)
odim = config.n_fft // 2 + 1
config["model"]["generator_params"]["spks"] = spk_num
model = VITS(idim=vocab_size, odim=odim, **config["model"])
gen_parameters = model.generator.parameters()
dis_parameters = model.discriminator.parameters()
@ -240,6 +258,17 @@ def main():
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
parser.add_argument(
"--voice-cloning",
type=str2bool,
default=False,
help="whether training voice cloning model.")
args = parser.parse_args()

@ -0,0 +1,213 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import librosa
import numpy as np
import paddle
import soundfile as sf
import yaml
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.get_feats import LinearSpectrogram
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.utils import str2bool
from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder
def voice_cloning(args):
# Init body.
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
# speaker encoder
spec_extractor = LinearSpectrogram(
n_fft=config.n_fft,
hop_length=config.n_shift,
win_length=config.win_length,
window=config.window)
p = SpeakerVerificationPreprocessor(
sampling_rate=16000,
audio_norm_target_dBFS=-30,
vad_window_length=30,
vad_moving_average_width=8,
vad_max_silence_length=6,
mel_window_length=25,
mel_window_step=10,
n_mels=40,
partial_n_frames=160,
min_pad_coverage=0.75,
partial_overlap_ratio=0.5)
print("Audio Processor Done!")
speaker_encoder = LSTMSpeakerEncoder(
n_mels=40, num_layers=3, hidden_size=256, output_size=256)
speaker_encoder.set_state_dict(paddle.load(args.ge2e_params_path))
speaker_encoder.eval()
print("GE2E Done!")
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
print("frontend done!")
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
odim = config.n_fft // 2 + 1
vits = VITS(idim=vocab_size, odim=odim, **config["model"])
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
vits.eval()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
input_dir = Path(args.input_dir)
if args.audio_path == "":
args.audio_path = None
if args.audio_path is None:
sentence = args.text
merge_sentences = True
add_blank = args.add_blank
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences, add_blank=add_blank)
elif args.lang == 'en':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"][0]
else:
wav, _ = librosa.load(str(args.audio_path), sr=config.fs)
feats = paddle.to_tensor(spec_extractor.get_linear_spectrogram(wav))
mel_sequences = p.extract_mel_partials(
p.preprocess_wav(args.audio_path))
with paddle.no_grad():
spk_emb_src = speaker_encoder.embed_utterance(
paddle.to_tensor(mel_sequences))
for name in os.listdir(input_dir):
utt_id = name.split(".")[0]
ref_audio_path = input_dir / name
mel_sequences = p.extract_mel_partials(p.preprocess_wav(ref_audio_path))
# print("mel_sequences: ", mel_sequences.shape)
with paddle.no_grad():
spk_emb = speaker_encoder.embed_utterance(
paddle.to_tensor(mel_sequences))
# print("spk_emb shape: ", spk_emb.shape)
with paddle.no_grad():
if args.audio_path is None:
out = vits.inference(text=phone_ids, spembs=spk_emb)
else:
out = vits.voice_conversion(
feats=feats, spembs_src=spk_emb_src, spembs_tgt=spk_emb)
wav = out["wav"]
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
samplerate=config.fs)
print(f"{utt_id} done!")
# Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb
random_spk_emb = np.random.rand(256) * 0.2
random_spk_emb = paddle.to_tensor(random_spk_emb, dtype='float32')
utt_id = "random_spk_emb"
with paddle.no_grad():
if args.audio_path is None:
out = vits.inference(text=phone_ids, spembs=random_spk_emb)
else:
out = vits.voice_conversion(
feats=feats, spembs_src=spk_emb_src, spembs_tgt=random_spk_emb)
wav = out["wav"]
sf.write(
str(output_dir / (utt_id + ".wav")), wav.numpy(), samplerate=config.fs)
print(f"{utt_id} done!")
def parse_args():
# parse args and config
parser = argparse.ArgumentParser(description="")
parser.add_argument(
'--config', type=str, default=None, help='Config of VITS.')
parser.add_argument(
'--ckpt', type=str, default=None, help='Checkpoint file of VITS.')
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--text",
type=str,
default="每当你觉得,想要批评什么人的时候,你切要记着,这个世界上的人,并非都具备你禀有的条件。",
help="text to synthesize, a line")
parser.add_argument(
'--lang',
type=str,
default='zh',
help='Choose model language. zh or en')
parser.add_argument(
"--audio-path",
type=str,
default=None,
help="audio as content to synthesize")
parser.add_argument(
"--ge2e_params_path", type=str, help="ge2e params path.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
parser.add_argument(
"--input-dir",
type=str,
help="input dir of *.wav, the sample rate will be resample to 16k.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--add-blank",
type=str2bool,
default=True,
help="whether to add blank between phones")
args = parser.parse_args()
return args
def main():
args = parse_args()
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
paddle.set_device("gpu")
else:
print("ngpu should >= 0 !")
voice_cloning(args)
if __name__ == "__main__":
main()

@ -522,6 +522,82 @@ class VITSGenerator(nn.Layer):
return wav.squeeze(1), attn.squeeze(1), dur.squeeze(1)
def voice_conversion(
self,
feats: paddle.Tensor=None,
feats_lengths: paddle.Tensor=None,
sids_src: Optional[paddle.Tensor]=None,
sids_tgt: Optional[paddle.Tensor]=None,
spembs_src: Optional[paddle.Tensor]=None,
spembs_tgt: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None, ) -> paddle.Tensor:
"""Run voice conversion.
Args:
feats (Tensor): Feature tensor (B, aux_channels, T_feats,).
feats_lengths (Tensor): Feature length tensor (B,).
sids_src (Optional[Tensor]): Speaker index tensor of source feature (B,) or (B, 1).
sids_tgt (Optional[Tensor]): Speaker index tensor of target feature (B,) or (B, 1).
spembs_src (Optional[Tensor]): Speaker embedding tensor of source feature (B, spk_embed_dim).
spembs_tgt (Optional[Tensor]): Speaker embedding tensor of target feature (B, spk_embed_dim).
lids (Optional[Tensor]): Language index tensor (B,) or (B, 1).
Returns:
Tensor: Generated waveform tensor (B, T_wav).
"""
# encoder
g_src = None
g_tgt = None
if self.spks is not None:
# (B, global_channels, 1)
g_src = self.global_emb(
paddle.reshape(sids_src, [-1])).unsqueeze(-1)
g_tgt = self.global_emb(
paddle.reshape(sids_tgt, [-1])).unsqueeze(-1)
if self.spk_embed_dim is not None:
# (B, global_channels, 1)
g_src_ = self.spemb_proj(
F.normalize(spembs_src.unsqueeze(0))).unsqueeze(-1)
if g_src is None:
g_src = g_src_
else:
g_src = g_src + g_src_
# (B, global_channels, 1)
g_tgt_ = self.spemb_proj(
F.normalize(spembs_tgt.unsqueeze(0))).unsqueeze(-1)
if g_tgt is None:
g_tgt = g_tgt_
else:
g_tgt = g_tgt + g_tgt_
if self.langs is not None:
# (B, global_channels, 1)
g_ = self.lang_emb(paddle.reshape(lids, [-1])).unsqueeze(-1)
if g_src is None:
g_src = g_
else:
g_src = g_src + g_
if g_tgt is None:
g_tgt = g_
else:
g_tgt = g_tgt + g_
# forward posterior encoder
z, m_q, logs_q, y_mask = self.posterior_encoder(
feats, feats_lengths, g=g_src)
# forward flow
# (B, H, T_feats)
z_p = self.flow(z, y_mask, g=g_src)
# decoder
z_hat = self.flow(z_p, y_mask, g=g_tgt, inverse=True)
wav = self.decoder(z_hat * y_mask, g=g_tgt)
return wav.squeeze(1)
def _generate_path(self, dur: paddle.Tensor,
mask: paddle.Tensor) -> paddle.Tensor:
"""Generate path a.k.a. monotonic attention.

@ -381,7 +381,7 @@ class VITS(nn.Layer):
if use_teacher_forcing:
assert feats is not None
feats = feats[None].transpose([0, 2, 1])
feats_lengths = paddle.to_tensor([paddle.shape(feats)[2]])
feats_lengths = paddle.to_tensor(paddle.shape(feats)[2])
wav, att_w, dur = self.generator.inference(
text=text,
text_lengths=text_lengths,
@ -406,3 +406,43 @@ class VITS(nn.Layer):
max_len=max_len, )
return dict(
wav=paddle.reshape(wav, [-1]), att_w=att_w[0], duration=dur[0])
def voice_conversion(
self,
feats: paddle.Tensor,
sids_src: Optional[paddle.Tensor]=None,
sids_tgt: Optional[paddle.Tensor]=None,
spembs_src: Optional[paddle.Tensor]=None,
spembs_tgt: Optional[paddle.Tensor]=None,
lids: Optional[paddle.Tensor]=None, ) -> paddle.Tensor:
"""Run voice conversion.
Args:
feats (Tensor): Feature tensor (T_feats, aux_channels).
sids_src (Optional[Tensor]): Speaker index tensor of source feature (1,).
sids_tgt (Optional[Tensor]): Speaker index tensor of target feature (1,).
spembs_src (Optional[Tensor]): Speaker embedding tensor of source feature (spk_embed_dim,).
spembs_tgt (Optional[Tensor]): Speaker embedding tensor of target feature (spk_embed_dim,).
lids (Optional[Tensor]): Language index tensor (1,).
Returns:
Dict[str, Tensor]:
* wav (Tensor): Generated waveform tensor (T_wav,).
"""
assert feats is not None
feats = feats[None].transpose([0, 2, 1])
feats_lengths = paddle.to_tensor(paddle.shape(feats)[2])
sids_none = sids_src is None and sids_tgt is None
spembs_none = spembs_src is None and spembs_tgt is None
assert not sids_none or not spembs_none
wav = self.generator.voice_conversion(
feats,
feats_lengths,
sids_src,
sids_tgt,
spembs_src,
spembs_tgt,
lids, )
return dict(wav=paddle.reshape(wav, [-1]))

@ -111,6 +111,8 @@ class VITSUpdater(StandardUpdater):
text_lengths=batch["text_lengths"],
feats=batch["feats"],
feats_lengths=batch["feats_lengths"],
sids=batch.get("spk_id", None),
spembs=batch.get("spk_emb", None),
forward_generator=turn == "generator")
# Generator
if turn == "generator":
@ -268,6 +270,8 @@ class VITSEvaluator(StandardEvaluator):
text_lengths=batch["text_lengths"],
feats=batch["feats"],
feats_lengths=batch["feats_lengths"],
sids=batch.get("spk_id", None),
spembs=batch.get("spk_emb", None),
forward_generator=turn == "generator")
# Generator
if turn == "generator":

Loading…
Cancel
Save