parent
fb238d83f4
commit
c088b9a304
@ -0,0 +1,264 @@
|
||||
# FastSpeech2 with CSMSC
|
||||
This example contains code used to train a [Fastspeech2](https://arxiv.org/abs/2006.04558) model with [Chinese Standard Mandarin Speech Copus](https://www.data-baker.com/open_source.html).
|
||||
|
||||
## Dataset
|
||||
### Download and Extract
|
||||
Download CSMSC from it's [Official Website](https://test.data-baker.com/data/index/source).
|
||||
|
||||
### Get MFA Result and Extract
|
||||
We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for fastspeech2.
|
||||
You can download from here [baker_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/BZNSYP/with_tone/baker_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) of our repo.
|
||||
|
||||
## Get Started
|
||||
Assume the path to the dataset is `~/datasets/BZNSYP`.
|
||||
Assume the path to the MFA result of CSMSC is `./baker_alignment_tone`.
|
||||
Run the command below to
|
||||
1. **source path**.
|
||||
2. preprocess the dataset.
|
||||
3. train the model.
|
||||
4. synthesize wavs.
|
||||
- synthesize waveform from `metadata.jsonl`.
|
||||
- synthesize waveform from a text file.
|
||||
5. inference using the static model.
|
||||
```bash
|
||||
./run.sh
|
||||
```
|
||||
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
|
||||
```bash
|
||||
./run.sh --stage 0 --stop-stage 0
|
||||
```
|
||||
### Data Preprocessing
|
||||
```bash
|
||||
./local/preprocess.sh ${conf_path}
|
||||
```
|
||||
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
|
||||
|
||||
```text
|
||||
dump
|
||||
├── dev
|
||||
│ ├── norm
|
||||
│ └── raw
|
||||
├── phone_id_map.txt
|
||||
├── speaker_id_map.txt
|
||||
├── test
|
||||
│ ├── norm
|
||||
│ └── raw
|
||||
└── train
|
||||
├── energy_stats.npy
|
||||
├── norm
|
||||
├── pitch_stats.npy
|
||||
├── raw
|
||||
└── speech_stats.npy
|
||||
```
|
||||
The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains speech、pitch and energy features of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/*_stats.npy`.
|
||||
|
||||
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, speech_lengths, durations, the path of speech features, the path of pitch features, the path of energy features, speaker, and the id of each utterance.
|
||||
|
||||
### Model Training
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
|
||||
```
|
||||
`./local/train.sh` calls `${BIN_DIR}/train.py`.
|
||||
Here's the complete help message.
|
||||
```text
|
||||
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
|
||||
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
|
||||
[--ngpu NGPU] [--phones-dict PHONES_DICT]
|
||||
[--speaker-dict SPEAKER_DICT] [--voice-cloning VOICE_CLONING]
|
||||
|
||||
Train a FastSpeech2 model.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG fastspeech2 config file.
|
||||
--train-metadata TRAIN_METADATA
|
||||
training data.
|
||||
--dev-metadata DEV_METADATA
|
||||
dev data.
|
||||
--output-dir OUTPUT_DIR
|
||||
output dir.
|
||||
--ngpu NGPU if ngpu=0, use cpu.
|
||||
--phones-dict PHONES_DICT
|
||||
phone vocabulary file.
|
||||
--speaker-dict SPEAKER_DICT
|
||||
speaker id map file for multiple speaker model.
|
||||
--voice-cloning VOICE_CLONING
|
||||
whether training voice cloning model.
|
||||
```
|
||||
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
|
||||
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
|
||||
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
|
||||
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||
5. `--phones-dict` is the path of the phone vocabulary file.
|
||||
|
||||
### Synthesizing
|
||||
We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/voc1) as the neural vocoder.
|
||||
Download pretrained parallel wavegan model from [pwg_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_baker_ckpt_0.4.zip) and unzip it.
|
||||
```bash
|
||||
unzip pwg_baker_ckpt_0.4.zip
|
||||
```
|
||||
Parallel WaveGAN checkpoint contains files listed below.
|
||||
```text
|
||||
pwg_baker_ckpt_0.4
|
||||
├── pwg_default.yaml # default config used to train parallel wavegan
|
||||
├── pwg_snapshot_iter_400000.pdz # model parameters of parallel wavegan
|
||||
└── pwg_stats.npy # statistics used to normalize spectrogram when training parallel wavegan
|
||||
```
|
||||
`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
|
||||
```
|
||||
```text
|
||||
usage: synthesize.py [-h]
|
||||
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}]
|
||||
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
|
||||
[--am_stat AM_STAT] [--phones_dict PHONES_DICT]
|
||||
[--tones_dict TONES_DICT] [--speaker_dict SPEAKER_DICT]
|
||||
[--voice-cloning VOICE_CLONING]
|
||||
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}]
|
||||
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
|
||||
[--voc_stat VOC_STAT] [--ngpu NGPU]
|
||||
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
|
||||
|
||||
Synthesize with acoustic model & vocoder
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}
|
||||
Choose acoustic model type of tts task.
|
||||
--am_config AM_CONFIG
|
||||
Config of acoustic model. Use deault config when it is
|
||||
None.
|
||||
--am_ckpt AM_CKPT Checkpoint file of acoustic model.
|
||||
--am_stat AM_STAT mean and standard deviation used to normalize
|
||||
spectrogram when training acoustic model.
|
||||
--phones_dict PHONES_DICT
|
||||
phone vocabulary file.
|
||||
--tones_dict TONES_DICT
|
||||
tone vocabulary file.
|
||||
--speaker_dict SPEAKER_DICT
|
||||
speaker id map file.
|
||||
--voice-cloning VOICE_CLONING
|
||||
whether training voice cloning model.
|
||||
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}
|
||||
Choose vocoder type of tts task.
|
||||
--voc_config VOC_CONFIG
|
||||
Config of voc. Use deault config when it is None.
|
||||
--voc_ckpt VOC_CKPT Checkpoint file of voc.
|
||||
--voc_stat VOC_STAT mean and standard deviation used to normalize
|
||||
spectrogram when training voc.
|
||||
--ngpu NGPU if ngpu == 0, use cpu.
|
||||
--test_metadata TEST_METADATA
|
||||
test metadata.
|
||||
--output_dir OUTPUT_DIR
|
||||
output dir.
|
||||
```
|
||||
`./local/synthesize_e2e.sh` calls `${BIN_DIR}/../synthesize_e2e.py`, which can synthesize waveform from text file.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name}
|
||||
```
|
||||
```text
|
||||
usage: synthesize_e2e.py [-h]
|
||||
[--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}]
|
||||
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
|
||||
[--am_stat AM_STAT] [--phones_dict PHONES_DICT]
|
||||
[--tones_dict TONES_DICT]
|
||||
[--speaker_dict SPEAKER_DICT] [--spk_id SPK_ID]
|
||||
[--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}]
|
||||
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
|
||||
[--voc_stat VOC_STAT] [--lang LANG]
|
||||
[--inference_dir INFERENCE_DIR] [--ngpu NGPU]
|
||||
[--text TEXT] [--output_dir OUTPUT_DIR]
|
||||
|
||||
Synthesize with acoustic model & vocoder
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk}
|
||||
Choose acoustic model type of tts task.
|
||||
--am_config AM_CONFIG
|
||||
Config of acoustic model. Use deault config when it is
|
||||
None.
|
||||
--am_ckpt AM_CKPT Checkpoint file of acoustic model.
|
||||
--am_stat AM_STAT mean and standard deviation used to normalize
|
||||
spectrogram when training acoustic model.
|
||||
--phones_dict PHONES_DICT
|
||||
phone vocabulary file.
|
||||
--tones_dict TONES_DICT
|
||||
tone vocabulary file.
|
||||
--speaker_dict SPEAKER_DICT
|
||||
speaker id map file.
|
||||
--spk_id SPK_ID spk id for multi speaker acoustic model
|
||||
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc}
|
||||
Choose vocoder type of tts task.
|
||||
--voc_config VOC_CONFIG
|
||||
Config of voc. Use deault config when it is None.
|
||||
--voc_ckpt VOC_CKPT Checkpoint file of voc.
|
||||
--voc_stat VOC_STAT mean and standard deviation used to normalize
|
||||
spectrogram when training voc.
|
||||
--lang LANG Choose model language. zh or en
|
||||
--inference_dir INFERENCE_DIR
|
||||
dir to save inference models
|
||||
--ngpu NGPU if ngpu == 0, use cpu.
|
||||
--text TEXT text to synthesize, a 'utt_id sentence' pair per line.
|
||||
--output_dir OUTPUT_DIR
|
||||
output dir.
|
||||
```
|
||||
1. `--am` is acoustic model type with the format {model_name}_{dataset}
|
||||
2. `--am_config`, `--am_checkpoint`, `--am_stat` and `--phones_dict` are arguments for acoustic model, which correspond to the 4 files in the fastspeech2 pretrained model.
|
||||
3. `--voc` is vocoder type with the format {model_name}_{dataset}
|
||||
4. `--voc_config`, `--voc_checkpoint`, `--voc_stat` are arguments for vocoder, which correspond to the 3 files in the parallel wavegan pretrained model.
|
||||
5. `--lang` is the model language, which can be `zh` or `en`.
|
||||
6. `--test_metadata` should be the metadata file in the normalized subfolder of `test` in the `dump` folder.
|
||||
7. `--text` is the text file, which contains sentences to synthesize.
|
||||
8. `--output_dir` is the directory to save synthesized audio files.
|
||||
9. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||
|
||||
### Inferencing
|
||||
After synthesizing, we will get static models of fastspeech2 and pwgan in `${train_output_path}/inference`.
|
||||
`./local/inference.sh` calls `${BIN_DIR}/inference.py`, which provides a paddle static model inference example for fastspeech2 + pwgan synthesize.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path}
|
||||
```
|
||||
|
||||
## Pretrained Model
|
||||
Pretrained FastSpeech2 model with no silence in the edge of audios:
|
||||
- [fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)
|
||||
- [fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip)
|
||||
|
||||
The static model can be downloaded here [fastspeech2_nosil_baker_static_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_static_0.4.zip).
|
||||
|
||||
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss
|
||||
:-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------:
|
||||
default| 2(gpu) x 76000|1.0991|0.59132|0.035815|0.31915|0.15287|
|
||||
conformer| 2(gpu) x 76000|1.0675|0.56103|0.035869|0.31553|0.15509|
|
||||
|
||||
FastSpeech2 checkpoint contains files listed below.
|
||||
```text
|
||||
fastspeech2_nosil_baker_ckpt_0.4
|
||||
├── default.yaml # default config used to train fastspeech2
|
||||
├── phone_id_map.txt # phone vocabulary file when training fastspeech2
|
||||
├── snapshot_iter_76000.pdz # model parameters and optimizer states
|
||||
└── speech_stats.npy # statistics used to normalize spectrogram when training fastspeech2
|
||||
```
|
||||
You can use the following scripts to synthesize for `${BIN_DIR}/../sentences.txt` using pretrained fastspeech2 and parallel wavegan models.
|
||||
```bash
|
||||
source path.sh
|
||||
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize_e2e.py \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_config=fastspeech2_nosil_baker_ckpt_0.4/default.yaml \
|
||||
--am_ckpt=fastspeech2_nosil_baker_ckpt_0.4/snapshot_iter_76000.pdz \
|
||||
--am_stat=fastspeech2_nosil_baker_ckpt_0.4/speech_stats.npy \
|
||||
--voc=pwgan_csmsc \
|
||||
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
|
||||
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
|
||||
--voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
|
||||
--lang=zh \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=exp/default/test_e2e \
|
||||
--inference_dir=exp/default/inference \
|
||||
--phones_dict=fastspeech2_nosil_baker_ckpt_0.4/phone_id_map.txt
|
||||
```
|
@ -0,0 +1,95 @@
|
||||
# This configuration is for Paddle to train Tacotron 2. Compared to the
|
||||
# original paper, this configuration additionally use the guided attention
|
||||
# loss to accelerate the learning of the diagonal attention. It requires
|
||||
# only a single GPU with 12 GB memory and it takes ~1 days to finish the
|
||||
# training on Titan V.
|
||||
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
|
||||
fs: 24000 # sr
|
||||
n_fft: 2048 # FFT size (samples).
|
||||
n_shift: 300 # Hop size (samples). 12.5ms
|
||||
win_length: 1200 # Window length (samples). 50ms
|
||||
# If set to null, it will be the same as fft_size.
|
||||
window: "hann" # Window function.
|
||||
|
||||
# Only used for feats_type != raw
|
||||
|
||||
fmin: 80 # Minimum frequency of Mel basis.
|
||||
fmax: 7600 # Maximum frequency of Mel basis.
|
||||
n_mels: 80 # The number of mel basis.
|
||||
|
||||
# Only used for the model using pitch features (e.g. FastSpeech2)
|
||||
f0min: 80 # Maximum f0 for pitch extraction.
|
||||
f0max: 400 # Minimum f0 for pitch extraction.
|
||||
|
||||
###########################################################
|
||||
# DATA SETTING #
|
||||
###########################################################
|
||||
batch_size: 64
|
||||
num_workers: 2
|
||||
|
||||
###########################################################
|
||||
# MODEL SETTING #
|
||||
###########################################################
|
||||
model: # keyword arguments for the selected model
|
||||
embed_dim: 512 # char or phn embedding dimension
|
||||
elayers: 1 # number of blstm layers in encoder
|
||||
eunits: 512 # number of blstm units
|
||||
econv_layers: 3 # number of convolutional layers in encoder
|
||||
econv_chans: 512 # number of channels in convolutional layer
|
||||
econv_filts: 5 # filter size of convolutional layer
|
||||
atype: location # attention function type
|
||||
adim: 512 # attention dimension
|
||||
aconv_chans: 32 # number of channels in convolutional layer of attention
|
||||
aconv_filts: 15 # filter size of convolutional layer of attention
|
||||
cumulate_att_w: True # whether to cumulate attention weight
|
||||
dlayers: 2 # number of lstm layers in decoder
|
||||
dunits: 1024 # number of lstm units in decoder
|
||||
prenet_layers: 2 # number of layers in prenet
|
||||
prenet_units: 256 # number of units in prenet
|
||||
postnet_layers: 5 # number of layers in postnet
|
||||
postnet_chans: 512 # number of channels in postnet
|
||||
postnet_filts: 5 # filter size of postnet layer
|
||||
output_activation: null # activation function for the final output
|
||||
use_batch_norm: True # whether to use batch normalization in encoder
|
||||
use_concate: True # whether to concatenate encoder embedding with decoder outputs
|
||||
use_residual: False # whether to use residual connection in encoder
|
||||
dropout_rate: 0.5 # dropout rate
|
||||
zoneout_rate: 0.1 # zoneout rate
|
||||
reduction_factor: 1 # reduction factor
|
||||
spk_embed_dim: null # speaker embedding dimension
|
||||
|
||||
|
||||
###########################################################
|
||||
# UPDATER SETTING #
|
||||
###########################################################
|
||||
updater:
|
||||
use_masking: True # whether to apply masking for padded part in loss calculation
|
||||
bce_pos_weight: 5.0 # weight of positive sample in binary cross entropy calculation
|
||||
use_guided_attn_loss: True # whether to use guided attention loss
|
||||
guided_attn_loss_sigma: 0.4 # sigma of guided attention loss
|
||||
guided_attn_loss_lambda: 1.0 # strength of guided attention loss
|
||||
|
||||
|
||||
##########################################################
|
||||
# OPTIMIZER SETTING #
|
||||
##########################################################
|
||||
optimizer:
|
||||
optim: adam # optimizer type
|
||||
learning_rate: 1.0e-03 # learning rate
|
||||
epsilon: 1.0e-06 # epsilon
|
||||
weight_decay: 0.0 # weight decay coefficient
|
||||
|
||||
###########################################################
|
||||
# TRAINING SETTING #
|
||||
###########################################################
|
||||
max_epoch: 200
|
||||
num_snapshots: 5
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
###########################################################
|
||||
seed: 42
|
@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
config_path=$1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# get durations from MFA's result
|
||||
echo "Generate durations.txt from MFA results ..."
|
||||
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||
--inputdir=./baker_alignment_tone \
|
||||
--output=durations.txt \
|
||||
--config=${config_path}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# extract features
|
||||
echo "Extract features ..."
|
||||
python3 ${BIN_DIR}/preprocess.py \
|
||||
--dataset=baker \
|
||||
--rootdir=~/datasets/BZNSYP/ \
|
||||
--dumpdir=dump \
|
||||
--dur-file=durations.txt \
|
||||
--config=${config_path} \
|
||||
--num-cpu=20 \
|
||||
--cut-sil=True
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# get features' stats(mean and std)
|
||||
echo "Get features' stats ..."
|
||||
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--field-name="speech"
|
||||
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# normalize and covert phone to id, dev and test should use train's stats
|
||||
echo "Normalize ..."
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--dumpdir=dump/train/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/dev/raw/metadata.jsonl \
|
||||
--dumpdir=dump/dev/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/test/raw/metadata.jsonl \
|
||||
--dumpdir=dump/test/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
fi
|
@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
ckpt_name=$3
|
||||
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize.py \
|
||||
--am=tacotron2_csmsc \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=pwgan_csmsc \
|
||||
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
|
||||
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
|
||||
--voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
|
||||
--test_metadata=dump/test/norm/metadata.jsonl \
|
||||
--output_dir=${train_output_path}/test \
|
||||
--phones_dict=dump/phone_id_map.txt
|
@ -0,0 +1,91 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
ckpt_name=$3
|
||||
|
||||
stage=0
|
||||
stop_stage=0
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize_e2e.py \
|
||||
--am=tacotron2_csmsc \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=pwgan_csmsc \
|
||||
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
|
||||
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
|
||||
--voc_stat=pwg_baker_ckpt_0.4/pwg_stats.npy \
|
||||
--lang=zh \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/test_e2e \
|
||||
--inference_dir=${train_output_path}/inference \
|
||||
--phones_dict=dump/phone_id_map.txt
|
||||
fi
|
||||
|
||||
# for more GAN Vocoders
|
||||
# multi band melgan
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize_e2e.py \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=mb_melgan_csmsc \
|
||||
--voc_config=mb_melgan_baker_finetune_ckpt_0.5/finetune.yaml \
|
||||
--voc_ckpt=mb_melgan_baker_finetune_ckpt_0.5/snapshot_iter_2000000.pdz\
|
||||
--voc_stat=mb_melgan_baker_finetune_ckpt_0.5/feats_stats.npy \
|
||||
--lang=zh \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/test_e2e \
|
||||
--inference_dir=${train_output_path}/inference \
|
||||
--phones_dict=dump/phone_id_map.txt
|
||||
fi
|
||||
|
||||
# the pretrained models haven't release now
|
||||
# style melgan
|
||||
# style melgan's Dygraph to Static Graph is not ready now
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize_e2e.py \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=style_melgan_csmsc \
|
||||
--voc_config=style_melgan_csmsc_ckpt_0.1.1/default.yaml \
|
||||
--voc_ckpt=style_melgan_csmsc_ckpt_0.1.1/snapshot_iter_1500000.pdz \
|
||||
--voc_stat=style_melgan_csmsc_ckpt_0.1.1/feats_stats.npy \
|
||||
--lang=zh \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/test_e2e \
|
||||
--phones_dict=dump/phone_id_map.txt
|
||||
# --inference_dir=${train_output_path}/inference
|
||||
fi
|
||||
|
||||
# hifigan
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
echo "in hifigan syn_e2e"
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize_e2e.py \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=hifigan_csmsc \
|
||||
--voc_config=hifigan_csmsc_ckpt_0.1.1/default.yaml \
|
||||
--voc_ckpt=hifigan_csmsc_ckpt_0.1.1/snapshot_iter_2500000.pdz \
|
||||
--voc_stat=hifigan_csmsc_ckpt_0.1.1/feats_stats.npy \
|
||||
--lang=zh \
|
||||
--text=${BIN_DIR}/../sentences.txt \
|
||||
--output_dir=${train_output_path}/test_e2e \
|
||||
--inference_dir=${train_output_path}/inference \
|
||||
--phones_dict=dump/phone_id_map.txt
|
||||
fi
|
@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
|
||||
python3 ${BIN_DIR}/train.py \
|
||||
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--config=${config_path} \
|
||||
--output-dir=${train_output_path} \
|
||||
--ngpu=1 \
|
||||
--phones-dict=dump/phone_id_map.txt
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
MODEL=new_tacotron2
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
gpus=0,1
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
conf_path=conf/default.yaml
|
||||
train_output_path=exp/default
|
||||
ckpt_name=snapshot_iter_153.pdz
|
||||
|
||||
# with the following command, you can choose the stage range you want to run
|
||||
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||
# this can not be mixed use with `$1`, `$2` ...
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
./local/preprocess.sh ${conf_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# synthesize, vocoder is pwgan
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# synthesize_e2e, vocoder is pwgan
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||
fi
|
@ -0,0 +1 @@
|
||||
../transformer_tts/normalize.py
|
@ -0,0 +1,353 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import jsonlines
|
||||
import librosa
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import yaml
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.data.get_feats import Energy
|
||||
from paddlespeech.t2s.data.get_feats import LogMelFBank
|
||||
from paddlespeech.t2s.data.get_feats import Pitch
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
|
||||
|
||||
|
||||
def process_sentence(config: Dict[str, Any],
|
||||
fp: Path,
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
pitch_extractor=None,
|
||||
energy_extractor=None,
|
||||
cut_sil: bool=True,
|
||||
spk_emb_dir: Path=None):
|
||||
utt_id = fp.stem
|
||||
# for vctk
|
||||
if utt_id.endswith("_mic2"):
|
||||
utt_id = utt_id[:-5]
|
||||
record = None
|
||||
if utt_id in sentences:
|
||||
# reading, resampling may occur
|
||||
wav, _ = librosa.load(str(fp), sr=config.fs)
|
||||
if len(wav.shape) != 1 or np.abs(wav).max() > 1.0:
|
||||
return record
|
||||
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||
assert np.abs(wav).max(
|
||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||
phones = sentences[utt_id][0]
|
||||
durations = sentences[utt_id][1]
|
||||
speaker = sentences[utt_id][2]
|
||||
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
|
||||
# little imprecise than use *.TextGrid directly
|
||||
times = librosa.frames_to_time(
|
||||
d_cumsum, sr=config.fs, hop_length=config.n_shift)
|
||||
if cut_sil:
|
||||
start = 0
|
||||
end = d_cumsum[-1]
|
||||
if phones[0] == "sil" and len(durations) > 1:
|
||||
start = times[1]
|
||||
durations = durations[1:]
|
||||
phones = phones[1:]
|
||||
if phones[-1] == 'sil' and len(durations) > 1:
|
||||
end = times[-2]
|
||||
durations = durations[:-1]
|
||||
phones = phones[:-1]
|
||||
sentences[utt_id][0] = phones
|
||||
sentences[utt_id][1] = durations
|
||||
start, end = librosa.time_to_samples([start, end], sr=config.fs)
|
||||
wav = wav[start:end]
|
||||
# extract mel feats
|
||||
logmel = mel_extractor.get_log_mel_fbank(wav)
|
||||
# change duration according to mel_length
|
||||
compare_duration_and_mel_length(sentences, utt_id, logmel)
|
||||
phones = sentences[utt_id][0]
|
||||
durations = sentences[utt_id][1]
|
||||
num_frames = logmel.shape[0]
|
||||
assert sum(durations) == num_frames
|
||||
mel_dir = output_dir / "data_speech"
|
||||
mel_dir.mkdir(parents=True, exist_ok=True)
|
||||
mel_path = mel_dir / (utt_id + "_speech.npy")
|
||||
np.save(mel_path, logmel)
|
||||
record = {
|
||||
"utt_id": utt_id,
|
||||
"phones": phones,
|
||||
"text_lengths": len(phones),
|
||||
"speech_lengths": num_frames,
|
||||
"speech": str(mel_path),
|
||||
"speaker": speaker
|
||||
}
|
||||
if spk_emb_dir:
|
||||
if speaker in os.listdir(spk_emb_dir):
|
||||
embed_name = utt_id + ".npy"
|
||||
embed_path = spk_emb_dir / speaker / embed_name
|
||||
if embed_path.is_file():
|
||||
record["spk_emb"] = str(embed_path)
|
||||
else:
|
||||
return None
|
||||
return record
|
||||
|
||||
|
||||
def process_sentences(config,
|
||||
fps: List[Path],
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
pitch_extractor=None,
|
||||
energy_extractor=None,
|
||||
nprocs: int=1,
|
||||
cut_sil: bool=True,
|
||||
spk_emb_dir: Path=None):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp in fps:
|
||||
record = process_sentence(config, fp, sentences, output_dir,
|
||||
mel_extractor, pitch_extractor,
|
||||
energy_extractor, cut_sil, spk_emb_dir)
|
||||
if record:
|
||||
results.append(record)
|
||||
else:
|
||||
with ThreadPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp in fps:
|
||||
future = pool.submit(process_sentence, config, fp,
|
||||
sentences, output_dir, mel_extractor,
|
||||
pitch_extractor, energy_extractor,
|
||||
cut_sil, spk_emb_dir)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
results = []
|
||||
for ft in futures:
|
||||
record = ft.result()
|
||||
if record:
|
||||
results.append(record)
|
||||
|
||||
results.sort(key=itemgetter("utt_id"))
|
||||
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
||||
for item in results:
|
||||
writer.write(item)
|
||||
print("Done")
|
||||
|
||||
|
||||
def main():
|
||||
# parse config and args
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess audio and then extract features.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default="baker",
|
||||
type=str,
|
||||
help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now")
|
||||
|
||||
parser.add_argument(
|
||||
"--rootdir", default=None, type=str, help="directory to dataset.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump feature files.")
|
||||
parser.add_argument(
|
||||
"--dur-file", default=None, type=str, help="path to durations.txt.")
|
||||
|
||||
parser.add_argument("--config", type=str, help="fastspeech2 config file.")
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="logging level. higher is more logging. (default=1)")
|
||||
parser.add_argument(
|
||||
"--num-cpu", type=int, default=1, help="number of process.")
|
||||
|
||||
def str2bool(str):
|
||||
return True if str.lower() == 'true' else False
|
||||
|
||||
parser.add_argument(
|
||||
"--cut-sil",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether cut sil in the edge of audio")
|
||||
|
||||
parser.add_argument(
|
||||
"--spk_emb_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory to speaker embedding files.")
|
||||
args = parser.parse_args()
|
||||
|
||||
rootdir = Path(args.rootdir).expanduser()
|
||||
dumpdir = Path(args.dumpdir).expanduser()
|
||||
# use absolute path
|
||||
dumpdir = dumpdir.resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
dur_file = Path(args.dur_file).expanduser()
|
||||
|
||||
if args.spk_emb_dir:
|
||||
spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
|
||||
else:
|
||||
spk_emb_dir = None
|
||||
|
||||
assert rootdir.is_dir()
|
||||
assert dur_file.is_file()
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
if args.verbose > 1:
|
||||
print(vars(args))
|
||||
print(config)
|
||||
|
||||
sentences, speaker_set = get_phn_dur(dur_file)
|
||||
|
||||
merge_silence(sentences)
|
||||
phone_id_map_path = dumpdir / "phone_id_map.txt"
|
||||
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
|
||||
get_input_token(sentences, phone_id_map_path, args.dataset)
|
||||
get_spk_id_map(speaker_set, speaker_id_map_path)
|
||||
|
||||
if args.dataset == "baker":
|
||||
wav_files = sorted(list((rootdir / "Wave").rglob("*.wav")))
|
||||
# split data into 3 sections
|
||||
num_train = 9800
|
||||
num_dev = 100
|
||||
train_wav_files = wav_files[:num_train]
|
||||
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||
test_wav_files = wav_files[num_train + num_dev:]
|
||||
elif args.dataset == "aishell3":
|
||||
sub_num_dev = 5
|
||||
wav_dir = rootdir / "train" / "wav"
|
||||
train_wav_files = []
|
||||
dev_wav_files = []
|
||||
test_wav_files = []
|
||||
for speaker in os.listdir(wav_dir):
|
||||
wav_files = sorted(list((wav_dir / speaker).rglob("*.wav")))
|
||||
if len(wav_files) > 100:
|
||||
train_wav_files += wav_files[:-sub_num_dev * 2]
|
||||
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
|
||||
test_wav_files += wav_files[-sub_num_dev:]
|
||||
else:
|
||||
train_wav_files += wav_files
|
||||
|
||||
elif args.dataset == "ljspeech":
|
||||
wav_files = sorted(list((rootdir / "wavs").rglob("*.wav")))
|
||||
# split data into 3 sections
|
||||
num_train = 12900
|
||||
num_dev = 100
|
||||
train_wav_files = wav_files[:num_train]
|
||||
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||
test_wav_files = wav_files[num_train + num_dev:]
|
||||
elif args.dataset == "vctk":
|
||||
sub_num_dev = 5
|
||||
wav_dir = rootdir / "wav48_silence_trimmed"
|
||||
train_wav_files = []
|
||||
dev_wav_files = []
|
||||
test_wav_files = []
|
||||
for speaker in os.listdir(wav_dir):
|
||||
wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac")))
|
||||
if len(wav_files) > 100:
|
||||
train_wav_files += wav_files[:-sub_num_dev * 2]
|
||||
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
|
||||
test_wav_files += wav_files[-sub_num_dev:]
|
||||
else:
|
||||
train_wav_files += wav_files
|
||||
|
||||
else:
|
||||
print("dataset should in {baker, aishell3, ljspeech, vctk} now!")
|
||||
|
||||
train_dump_dir = dumpdir / "train" / "raw"
|
||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
test_dump_dir = dumpdir / "test" / "raw"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extractor
|
||||
mel_extractor = LogMelFBank(
|
||||
sr=config.fs,
|
||||
n_fft=config.n_fft,
|
||||
hop_length=config.n_shift,
|
||||
win_length=config.win_length,
|
||||
window=config.window,
|
||||
n_mels=config.n_mels,
|
||||
fmin=config.fmin,
|
||||
fmax=config.fmax)
|
||||
pitch_extractor = Pitch(
|
||||
sr=config.fs,
|
||||
hop_length=config.n_shift,
|
||||
f0min=config.f0min,
|
||||
f0max=config.f0max)
|
||||
energy_extractor = Energy(
|
||||
sr=config.fs,
|
||||
n_fft=config.n_fft,
|
||||
hop_length=config.n_shift,
|
||||
win_length=config.win_length,
|
||||
window=config.window)
|
||||
|
||||
# process for the 3 sections
|
||||
if train_wav_files:
|
||||
process_sentences(
|
||||
config,
|
||||
train_wav_files,
|
||||
sentences,
|
||||
train_dump_dir,
|
||||
mel_extractor,
|
||||
pitch_extractor,
|
||||
energy_extractor,
|
||||
nprocs=args.num_cpu,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir)
|
||||
if dev_wav_files:
|
||||
process_sentences(
|
||||
config,
|
||||
dev_wav_files,
|
||||
sentences,
|
||||
dev_dump_dir,
|
||||
mel_extractor,
|
||||
pitch_extractor,
|
||||
energy_extractor,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir)
|
||||
if test_wav_files:
|
||||
process_sentences(
|
||||
config,
|
||||
test_wav_files,
|
||||
sentences,
|
||||
test_dump_dir,
|
||||
mel_extractor,
|
||||
pitch_extractor,
|
||||
energy_extractor,
|
||||
nprocs=args.num_cpu,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,190 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
from paddle import DataParallel
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.am_batch_fn import tacotron2_single_spk_batch_fn
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
from paddlespeech.t2s.models.new_tacotron2 import Tacotron2
|
||||
from paddlespeech.t2s.models.new_tacotron2 import Tacotron2Evaluator
|
||||
from paddlespeech.t2s.models.new_tacotron2 import Tacotron2Updater
|
||||
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
||||
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
||||
from paddlespeech.t2s.training.optimizer import build_optimizers
|
||||
from paddlespeech.t2s.training.seeding import seed_everything
|
||||
from paddlespeech.t2s.training.trainer import Trainer
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
# decides device type and whether to run in parallel
|
||||
# setup running environment correctly
|
||||
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
|
||||
paddle.set_device("cpu")
|
||||
else:
|
||||
paddle.set_device("gpu")
|
||||
world_size = paddle.distributed.get_world_size()
|
||||
if world_size > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
# set the random seed, it is a must for multiprocess training
|
||||
seed_everything(config.seed)
|
||||
|
||||
print(
|
||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
)
|
||||
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
# construct dataset for training and validation
|
||||
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||
train_metadata = list(reader)
|
||||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=[
|
||||
"text",
|
||||
"text_lengths",
|
||||
"speech",
|
||||
"speech_lengths",
|
||||
],
|
||||
converters={
|
||||
"speech": np.load,
|
||||
}, )
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=[
|
||||
"text",
|
||||
"text_lengths",
|
||||
"speech",
|
||||
"speech_lengths",
|
||||
],
|
||||
converters={
|
||||
"speech": np.load,
|
||||
}, )
|
||||
|
||||
# collate function and dataloader
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
print("samplers done!")
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=tacotron2_single_spk_batch_fn,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=tacotron2_single_spk_batch_fn,
|
||||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
with open(args.phones_dict, "r") as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
vocab_size = len(phn_id)
|
||||
print("vocab_size:", vocab_size)
|
||||
|
||||
odim = config.n_mels
|
||||
model = Tacotron2(idim=vocab_size, odim=odim, **config["model"])
|
||||
if world_size > 1:
|
||||
model = DataParallel(model)
|
||||
print("model done!")
|
||||
|
||||
optimizer = build_optimizers(model, **config["optimizer"])
|
||||
print("optimizer done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if dist.get_rank() == 0:
|
||||
config_name = args.config.split("/")[-1]
|
||||
# copy conf to output_dir
|
||||
shutil.copyfile(args.config, output_dir / config_name)
|
||||
|
||||
updater = Tacotron2Updater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
output_dir=output_dir,
|
||||
**config["updater"])
|
||||
|
||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||
|
||||
evaluator = Tacotron2Evaluator(
|
||||
model, dev_dataloader, output_dir=output_dir, **config["updater"])
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
|
||||
trainer.extend(
|
||||
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||
# print(trainer.extensions)
|
||||
trainer.run()
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(description="Train a Tacotron2 model.")
|
||||
parser.add_argument("--config", type=str, help="tacotron2 config file.")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
|
||||
parser.add_argument(
|
||||
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
print(
|
||||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||
)
|
||||
|
||||
# dispatch
|
||||
if args.ngpu > 1:
|
||||
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
|
||||
else:
|
||||
train_sp(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .tacotron2 import *
|
||||
from .tacotron2_updater import *
|
@ -0,0 +1,496 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tacotron 2 related modules for paddle"""
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from paddlespeech.t2s.modules.nets_utils import initialize
|
||||
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
|
||||
from paddlespeech.t2s.modules.tacotron2.attentions import AttForward
|
||||
from paddlespeech.t2s.modules.tacotron2.attentions import AttForwardTA
|
||||
from paddlespeech.t2s.modules.tacotron2.attentions import AttLoc
|
||||
from paddlespeech.t2s.modules.tacotron2.decoder import Decoder
|
||||
from paddlespeech.t2s.modules.tacotron2.encoder import Encoder
|
||||
|
||||
|
||||
class Tacotron2(nn.Layer):
|
||||
"""Tacotron2 module for end-to-end text-to-speech.
|
||||
|
||||
This is a module of Spectrogram prediction network in Tacotron2 described
|
||||
in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_,
|
||||
which converts the sequence of characters into the sequence of Mel-filterbanks.
|
||||
|
||||
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
|
||||
https://arxiv.org/abs/1712.05884
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# network structure related
|
||||
idim: int,
|
||||
odim: int,
|
||||
embed_dim: int=512,
|
||||
elayers: int=1,
|
||||
eunits: int=512,
|
||||
econv_layers: int=3,
|
||||
econv_chans: int=512,
|
||||
econv_filts: int=5,
|
||||
atype: str="location",
|
||||
adim: int=512,
|
||||
aconv_chans: int=32,
|
||||
aconv_filts: int=15,
|
||||
cumulate_att_w: bool=True,
|
||||
dlayers: int=2,
|
||||
dunits: int=1024,
|
||||
prenet_layers: int=2,
|
||||
prenet_units: int=256,
|
||||
postnet_layers: int=5,
|
||||
postnet_chans: int=512,
|
||||
postnet_filts: int=5,
|
||||
output_activation: str=None,
|
||||
use_batch_norm: bool=True,
|
||||
use_concate: bool=True,
|
||||
use_residual: bool=False,
|
||||
reduction_factor: int=1,
|
||||
# extra embedding related
|
||||
spk_num: Optional[int]=None,
|
||||
lang_num: Optional[int]=None,
|
||||
spk_embed_dim: Optional[int]=None,
|
||||
spk_embed_integration_type: str="concat",
|
||||
dropout_rate: float=0.5,
|
||||
zoneout_rate: float=0.1,
|
||||
# training related
|
||||
init_type: str="xavier_uniform",):
|
||||
"""Initialize Tacotron2 module.
|
||||
Parameters
|
||||
----------
|
||||
idim : int
|
||||
Dimension of the inputs.
|
||||
odim : int
|
||||
Dimension of the outputs.
|
||||
embed_dim : int
|
||||
Dimension of the token embedding.
|
||||
elayers : int
|
||||
Number of encoder blstm layers.
|
||||
eunits : int
|
||||
Number of encoder blstm units.
|
||||
econv_layers : int
|
||||
Number of encoder conv layers.
|
||||
econv_filts : int
|
||||
Number of encoder conv filter size.
|
||||
econv_chans : int
|
||||
Number of encoder conv filter channels.
|
||||
dlayers : int
|
||||
Number of decoder lstm layers.
|
||||
dunits : int
|
||||
Number of decoder lstm units.
|
||||
prenet_layers : int
|
||||
Number of prenet layers.
|
||||
prenet_units : int
|
||||
Number of prenet units.
|
||||
postnet_layers : int
|
||||
Number of postnet layers.
|
||||
postnet_filts : int
|
||||
Number of postnet filter size.
|
||||
postnet_chans : int
|
||||
Number of postnet filter channels.
|
||||
output_activation : str
|
||||
Name of activation function for outputs.
|
||||
adim : int
|
||||
Number of dimension of mlp in attention.
|
||||
aconv_chans : int
|
||||
Number of attention conv filter channels.
|
||||
aconv_filts : int
|
||||
Number of attention conv filter size.
|
||||
cumulate_att_w : bool
|
||||
Whether to cumulate previous attention weight.
|
||||
use_batch_norm : bool
|
||||
Whether to use batch normalization.
|
||||
use_concate : bool
|
||||
Whether to concat enc outputs w/ dec lstm outputs.
|
||||
reduction_factor : int
|
||||
Reduction factor.
|
||||
spk_num : Optional[int]
|
||||
Number of speakers. If set to > 1, assume that the
|
||||
sids will be provided as the input and use sid embedding layer.
|
||||
lang_num : Optional[int]
|
||||
Number of languages. If set to > 1, assume that the
|
||||
lids will be provided as the input and use sid embedding layer.
|
||||
spk_embed_dim : Optional[int]
|
||||
Speaker embedding dimension. If set to > 0,
|
||||
assume that spk_emb will be provided as the input.
|
||||
spk_embed_integration_type : str
|
||||
How to integrate speaker embedding.
|
||||
dropout_rate : float
|
||||
Dropout rate.
|
||||
zoneout_rate : float
|
||||
Zoneout rate.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
# store hyperparameters
|
||||
self.idim = idim
|
||||
self.odim = odim
|
||||
self.eos = idim - 1
|
||||
self.cumulate_att_w = cumulate_att_w
|
||||
self.reduction_factor = reduction_factor
|
||||
|
||||
# define activation function for the final output
|
||||
if output_activation is None:
|
||||
self.output_activation_fn = None
|
||||
elif hasattr(F, output_activation):
|
||||
self.output_activation_fn = getattr(F, output_activation)
|
||||
else:
|
||||
raise ValueError(f"there is no such an activation function. "
|
||||
f"({output_activation})")
|
||||
|
||||
# set padding idx
|
||||
padding_idx = 0
|
||||
self.padding_idx = padding_idx
|
||||
|
||||
# initialize parameters
|
||||
initialize(self, init_type)
|
||||
|
||||
# define network modules
|
||||
self.enc = Encoder(
|
||||
idim=idim,
|
||||
embed_dim=embed_dim,
|
||||
elayers=elayers,
|
||||
eunits=eunits,
|
||||
econv_layers=econv_layers,
|
||||
econv_chans=econv_chans,
|
||||
econv_filts=econv_filts,
|
||||
use_batch_norm=use_batch_norm,
|
||||
use_residual=use_residual,
|
||||
dropout_rate=dropout_rate,
|
||||
padding_idx=padding_idx, )
|
||||
|
||||
self.spk_num = None
|
||||
if spk_num is not None and spk_num > 1:
|
||||
self.spk_num = spk_num
|
||||
self.sid_emb = nn.Embedding(spk_num, eunits)
|
||||
self.lang_num = None
|
||||
if lang_num is not None and lang_num > 1:
|
||||
self.lang_num = lang_num
|
||||
self.lid_emb = nn.Embedding(lang_num, eunits)
|
||||
|
||||
self.spk_embed_dim = None
|
||||
if spk_embed_dim is not None and spk_embed_dim > 0:
|
||||
self.spk_embed_dim = spk_embed_dim
|
||||
self.spk_embed_integration_type = spk_embed_integration_type
|
||||
if self.spk_embed_dim is None:
|
||||
dec_idim = eunits
|
||||
elif self.spk_embed_integration_type == "concat":
|
||||
dec_idim = eunits + spk_embed_dim
|
||||
elif self.spk_embed_integration_type == "add":
|
||||
dec_idim = eunits
|
||||
self.projection = nn.Linear(self.spk_embed_dim, eunits)
|
||||
else:
|
||||
raise ValueError(f"{spk_embed_integration_type} is not supported.")
|
||||
|
||||
if atype == "location":
|
||||
att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts)
|
||||
elif atype == "forward":
|
||||
att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts)
|
||||
if self.cumulate_att_w:
|
||||
logging.warning("cumulation of attention weights is disabled "
|
||||
"in forward attention.")
|
||||
self.cumulate_att_w = False
|
||||
elif atype == "forward_ta":
|
||||
att = AttForwardTA(dec_idim, dunits, adim, aconv_chans, aconv_filts,
|
||||
odim)
|
||||
if self.cumulate_att_w:
|
||||
logging.warning("cumulation of attention weights is disabled "
|
||||
"in forward attention.")
|
||||
self.cumulate_att_w = False
|
||||
else:
|
||||
raise NotImplementedError("Support only location or forward")
|
||||
self.dec = Decoder(
|
||||
idim=dec_idim,
|
||||
odim=odim,
|
||||
att=att,
|
||||
dlayers=dlayers,
|
||||
dunits=dunits,
|
||||
prenet_layers=prenet_layers,
|
||||
prenet_units=prenet_units,
|
||||
postnet_layers=postnet_layers,
|
||||
postnet_chans=postnet_chans,
|
||||
postnet_filts=postnet_filts,
|
||||
output_activation_fn=self.output_activation_fn,
|
||||
cumulate_att_w=self.cumulate_att_w,
|
||||
use_batch_norm=use_batch_norm,
|
||||
use_concate=use_concate,
|
||||
dropout_rate=dropout_rate,
|
||||
zoneout_rate=zoneout_rate,
|
||||
reduction_factor=reduction_factor, )
|
||||
|
||||
nn.initializer.set_global_initializer(None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
speech: paddle.Tensor,
|
||||
speech_lengths: paddle.Tensor,
|
||||
spk_emb: Optional[paddle.Tensor]=None,
|
||||
spk_id: Optional[paddle.Tensor]=None,
|
||||
lang_id: Optional[paddle.Tensor]=None
|
||||
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text : Tensor(int64)
|
||||
Batch of padded character ids (B, T_text).
|
||||
text_lengths : Tensor(int64)
|
||||
Batch of lengths of each input batch (B,).
|
||||
speech : Tensor
|
||||
Batch of padded target features (B, T_feats, odim).
|
||||
speech_lengths : Tensor(int64)
|
||||
Batch of the lengths of each target (B,).
|
||||
spk_emb : Optional[Tensor]
|
||||
Batch of speaker embeddings (B, spk_embed_dim).
|
||||
spk_id : Optional[Tensor]
|
||||
Batch of speaker IDs (B, 1).
|
||||
lang_id : Optional[Tensor]
|
||||
Batch of language IDs (B, 1).
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Loss scalar value.
|
||||
Dict
|
||||
Statistics to be monitored.
|
||||
Tensor
|
||||
Weight value if not joint training else model outputs.
|
||||
|
||||
"""
|
||||
text = text[:, :text_lengths.max()]
|
||||
speech = speech[:, :speech_lengths.max()]
|
||||
|
||||
batch_size = paddle.shape(text)[0]
|
||||
|
||||
# Add eos at the last of sequence
|
||||
xs = F.pad(text, [0, 0, 0, 1], "constant", self.padding_idx)
|
||||
for i, l in enumerate(text_lengths):
|
||||
xs[i, l] = self.eos
|
||||
ilens = text_lengths + 1
|
||||
|
||||
ys = speech
|
||||
olens = speech_lengths
|
||||
|
||||
# make labels for stop prediction
|
||||
labels = make_pad_mask(olens - 1)
|
||||
# bool 类型无法切片
|
||||
labels = paddle.cast(labels, dtype='float32')
|
||||
labels = F.pad(labels, [0, 0, 0, 1], "constant", 1.0)
|
||||
|
||||
# calculate tacotron2 outputs
|
||||
after_outs, before_outs, logits, att_ws = self._forward(
|
||||
xs=xs,
|
||||
ilens=ilens,
|
||||
ys=ys,
|
||||
olens=olens,
|
||||
spk_emb=spk_emb,
|
||||
spk_id=spk_id,
|
||||
lang_id=lang_id, )
|
||||
|
||||
# modify mod part of groundtruth
|
||||
if self.reduction_factor > 1:
|
||||
assert olens.ge(self.reduction_factor).all(
|
||||
), "Output length must be greater than or equal to reduction factor."
|
||||
olens = olens - olens % self.reduction_factor
|
||||
max_out = max(olens)
|
||||
ys = ys[:, :max_out]
|
||||
labels = labels[:, :max_out]
|
||||
labels = paddle.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0)
|
||||
return after_outs, before_outs, logits, ys, labels, olens, att_ws, ilens
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
xs: paddle.Tensor,
|
||||
ilens: paddle.Tensor,
|
||||
ys: paddle.Tensor,
|
||||
olens: paddle.Tensor,
|
||||
spk_emb: paddle.Tensor,
|
||||
spk_id: paddle.Tensor,
|
||||
lang_id: paddle.Tensor,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
|
||||
hs, hlens = self.enc(xs, ilens)
|
||||
if self.spk_num is not None:
|
||||
sid_embs = self.sid_emb(spk_id.reshape([-1]))
|
||||
hs = hs + sid_embs.unsqueeze(1)
|
||||
if self.lang_num is not None:
|
||||
lid_embs = self.lid_emb(lang_id.reshape([-1]))
|
||||
hs = hs + lid_embs.unsqueeze(1)
|
||||
if self.spk_embed_dim is not None:
|
||||
hs = self._integrate_with_spk_embed(hs, spk_emb)
|
||||
|
||||
return self.dec(hs, hlens, ys)
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
speech: Optional[paddle.Tensor]=None,
|
||||
spk_emb: Optional[paddle.Tensor]=None,
|
||||
spk_id: Optional[paddle.Tensor]=None,
|
||||
lang_id: Optional[paddle.Tensor]=None,
|
||||
threshold: float=0.5,
|
||||
minlenratio: float=0.0,
|
||||
maxlenratio: float=10.0,
|
||||
use_att_constraint: bool=False,
|
||||
backward_window: int=1,
|
||||
forward_window: int=3,
|
||||
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
|
||||
"""Generate the sequence of features given the sequences of characters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text Tensor(int64)
|
||||
Input sequence of characters (T_text,).
|
||||
speech : Optional[Tensor]
|
||||
Feature sequence to extract style (N, idim).
|
||||
spk_emb : ptional[Tensor]
|
||||
Speaker embedding (spk_embed_dim,).
|
||||
spk_id : Optional[Tensor]
|
||||
Speaker ID (1,).
|
||||
lang_id : Optional[Tensor]
|
||||
Language ID (1,).
|
||||
threshold : float
|
||||
Threshold in inference.
|
||||
minlenratio : float
|
||||
Minimum length ratio in inference.
|
||||
maxlenratio : float
|
||||
Maximum length ratio in inference.
|
||||
use_att_constraint : bool
|
||||
Whether to apply attention constraint.
|
||||
backward_window : int
|
||||
Backward window in attention constraint.
|
||||
forward_window : int
|
||||
Forward window in attention constraint.
|
||||
use_teacher_forcing : bool
|
||||
Whether to use teacher forcing.
|
||||
|
||||
Return
|
||||
----------
|
||||
Dict[str, Tensor]
|
||||
Output dict including the following items:
|
||||
* feat_gen (Tensor): Output sequence of features (T_feats, odim).
|
||||
* prob (Tensor): Output sequence of stop probabilities (T_feats,).
|
||||
* att_w (Tensor): Attention weights (T_feats, T).
|
||||
|
||||
"""
|
||||
x = text
|
||||
y = speech
|
||||
|
||||
# add eos at the last of sequence
|
||||
x = F.pad(x, [0, 1], "constant", self.eos)
|
||||
|
||||
# inference with teacher forcing
|
||||
if use_teacher_forcing:
|
||||
assert speech is not None, "speech must be provided with teacher forcing."
|
||||
|
||||
xs, ys = x.unsqueeze(0), y.unsqueeze(0)
|
||||
spk_emb = None if spk_emb is None else spk_emb.unsqueeze(0)
|
||||
ilens = paddle.shape(xs)[1]
|
||||
olens = paddle.shape(ys)[1]
|
||||
outs, _, _, att_ws = self._forward(
|
||||
xs=xs,
|
||||
ilens=ilens,
|
||||
ys=ys,
|
||||
olens=olens,
|
||||
spk_emb=spk_emb,
|
||||
spk_id=spk_id,
|
||||
lang_id=lang_id, )
|
||||
|
||||
return dict(feat_gen=outs[0], att_w=att_ws[0])
|
||||
|
||||
# inference
|
||||
h = self.enc.inference(x)
|
||||
if self.spk_num is not None:
|
||||
sid_emb = self.sid_emb(spk_id.reshape([-1]))
|
||||
h = h + sid_emb
|
||||
if self.lang_num is not None:
|
||||
lid_emb = self.lid_emb(lang_id.reshape([-1]))
|
||||
h = h + lid_emb
|
||||
if self.spk_embed_dim is not None:
|
||||
hs, spk_emb = h.unsqueeze(0), spk_emb.unsqueeze(0)
|
||||
h = self._integrate_with_spk_embed(hs, spk_emb)[0]
|
||||
out, prob, att_w = self.dec.inference(
|
||||
h,
|
||||
threshold=threshold,
|
||||
minlenratio=minlenratio,
|
||||
maxlenratio=maxlenratio,
|
||||
use_att_constraint=use_att_constraint,
|
||||
backward_window=backward_window,
|
||||
forward_window=forward_window, )
|
||||
|
||||
return dict(feat_gen=out, prob=prob, att_w=att_w)
|
||||
|
||||
def _integrate_with_spk_embed(self,
|
||||
hs: paddle.Tensor,
|
||||
spk_emb: paddle.Tensor) -> paddle.Tensor:
|
||||
"""Integrate speaker embedding with hidden states.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hs : Tensor
|
||||
Batch of hidden state sequences (B, Tmax, eunits).
|
||||
spk_emb : Tensor
|
||||
Batch of speaker embeddings (B, spk_embed_dim).
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Batch of integrated hidden state sequences (B, Tmax, eunits) if
|
||||
integration_type is "add" else (B, Tmax, eunits + spk_embed_dim).
|
||||
|
||||
"""
|
||||
if self.spk_embed_integration_type == "add":
|
||||
# apply projection and then add to hidden states
|
||||
spk_emb = self.projection(F.normalize(spk_emb))
|
||||
hs = hs + spk_emb.unsqueeze(1)
|
||||
elif self.spk_embed_integration_type == "concat":
|
||||
# concat hidden states with spk embeds
|
||||
spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(
|
||||
-1, paddle.shape(hs)[1], -1)
|
||||
hs = paddle.concat([hs, spk_emb], axis=-1)
|
||||
else:
|
||||
raise NotImplementedError("support only add or concat.")
|
||||
|
||||
return hs
|
||||
|
||||
|
||||
class Tacotron2Inference(nn.Layer):
|
||||
def __init__(self, normalizer, model):
|
||||
super().__init__()
|
||||
self.normalizer = normalizer
|
||||
self.acoustic_model = model
|
||||
|
||||
def forward(self, text, spk_id=None, spk_emb=None):
|
||||
out = self.acoustic_model.inference(
|
||||
text, spk_id=spk_id, spk_emb=spk_emb)
|
||||
normalized_mel = out["feat_gen"]
|
||||
logmel = self.normalizer.inverse(normalized_mel)
|
||||
return logmel
|
@ -0,0 +1,217 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
|
||||
from paddlespeech.t2s.modules.losses import GuidedAttentionLoss
|
||||
from paddlespeech.t2s.modules.losses import Tacotron2Loss
|
||||
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
|
||||
from paddlespeech.t2s.training.reporter import report
|
||||
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class Tacotron2Updater(StandardUpdater):
|
||||
def __init__(self,
|
||||
model: Dict[str, Layer],
|
||||
optimizer: Dict[str, Optimizer],
|
||||
dataloader: DataLoader,
|
||||
init_state=None,
|
||||
use_masking: bool=True,
|
||||
use_weighted_masking: bool=False,
|
||||
bce_pos_weight: float=5.0,
|
||||
loss_type: str="L1+L2",
|
||||
use_guided_attn_loss: bool=True,
|
||||
guided_attn_loss_sigma: float=0.4,
|
||||
guided_attn_loss_lambda: float=1.0,
|
||||
output_dir: Path=None):
|
||||
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.use_guided_attn_loss = use_guided_attn_loss
|
||||
|
||||
self.taco2_loss = Tacotron2Loss(
|
||||
use_masking=use_masking,
|
||||
use_weighted_masking=use_weighted_masking,
|
||||
bce_pos_weight=bce_pos_weight, )
|
||||
if self.use_guided_attn_loss:
|
||||
self.attn_loss = GuidedAttentionLoss(
|
||||
sigma=guided_attn_loss_sigma,
|
||||
alpha=guided_attn_loss_lambda, )
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def update_core(self, batch):
|
||||
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||
losses_dict = {}
|
||||
# spk_id!=None in multiple spk fastspeech2
|
||||
spk_id = batch["spk_id"] if "spk_id" in batch else None
|
||||
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
|
||||
# No explicit speaker identifier labels are used during voice cloning training.
|
||||
if spk_emb is not None:
|
||||
spk_id = None
|
||||
|
||||
after_outs, before_outs, logits, ys, labels, olens, att_ws, ilens = self.model(
|
||||
text=batch["text"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
spk_id=spk_id,
|
||||
spk_emb=spk_emb)
|
||||
|
||||
# calculate taco2 loss
|
||||
l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs,
|
||||
logits, ys, labels, olens)
|
||||
|
||||
if self.loss_type == "L1+L2":
|
||||
loss = l1_loss + mse_loss + bce_loss
|
||||
elif self.loss_type == "L1":
|
||||
loss = l1_loss + bce_loss
|
||||
elif self.loss_type == "L2":
|
||||
loss = mse_loss + bce_loss
|
||||
else:
|
||||
raise ValueError(f"unknown --loss-type {self.loss_type}")
|
||||
|
||||
# calculate attention loss
|
||||
if self.use_guided_attn_loss:
|
||||
# NOTE: length of output for auto-regressive
|
||||
# input will be changed when r > 1
|
||||
if self.model.reduction_factor > 1:
|
||||
olens_in = olens // self.model.reduction_factor
|
||||
else:
|
||||
olens_in = olens
|
||||
attn_loss = self.attn_loss(att_ws, ilens, olens_in)
|
||||
loss = loss + attn_loss
|
||||
|
||||
optimizer = self.optimizer
|
||||
optimizer.clear_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
report("train/l1_loss", float(l1_loss))
|
||||
report("train/mse_loss", float(mse_loss))
|
||||
report("train/bce_loss", float(bce_loss))
|
||||
report("train/attn_loss", float(attn_loss))
|
||||
report("train/loss", float(loss))
|
||||
|
||||
losses_dict["l1_loss"] = float(l1_loss)
|
||||
losses_dict["mse_loss"] = float(mse_loss)
|
||||
losses_dict["bce_loss"] = float(bce_loss)
|
||||
losses_dict["attn_loss"] = float(attn_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
|
||||
class Tacotron2Evaluator(StandardEvaluator):
|
||||
def __init__(self,
|
||||
model,
|
||||
dataloader,
|
||||
use_masking: bool=True,
|
||||
use_weighted_masking: bool=False,
|
||||
bce_pos_weight: float=5.0,
|
||||
loss_type: str="L1+L2",
|
||||
use_guided_attn_loss: bool=True,
|
||||
guided_attn_loss_sigma: float=0.4,
|
||||
guided_attn_loss_lambda: float=1.0,
|
||||
output_dir=None):
|
||||
super().__init__(model, dataloader)
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.use_guided_attn_loss = use_guided_attn_loss
|
||||
|
||||
self.taco2_loss = Tacotron2Loss(
|
||||
use_masking=use_masking,
|
||||
use_weighted_masking=use_weighted_masking,
|
||||
bce_pos_weight=bce_pos_weight, )
|
||||
if self.use_guided_attn_loss:
|
||||
self.attn_loss = GuidedAttentionLoss(
|
||||
sigma=guided_attn_loss_sigma,
|
||||
alpha=guided_attn_loss_lambda, )
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
self.msg = "Evaluate: "
|
||||
losses_dict = {}
|
||||
# spk_id!=None in multiple spk fastspeech2
|
||||
spk_id = batch["spk_id"] if "spk_id" in batch else None
|
||||
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
|
||||
if spk_emb is not None:
|
||||
spk_id = None
|
||||
|
||||
after_outs, before_outs, logits, ys, labels, olens, att_ws, ilens = self.model(
|
||||
text=batch["text"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
spk_id=spk_id,
|
||||
spk_emb=spk_emb)
|
||||
|
||||
# calculate taco2 loss
|
||||
l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs,
|
||||
logits, ys, labels, olens)
|
||||
|
||||
if self.loss_type == "L1+L2":
|
||||
loss = l1_loss + mse_loss + bce_loss
|
||||
elif self.loss_type == "L1":
|
||||
loss = l1_loss + bce_loss
|
||||
elif self.loss_type == "L2":
|
||||
loss = mse_loss + bce_loss
|
||||
else:
|
||||
raise ValueError(f"unknown --loss-type {self.loss_type}")
|
||||
|
||||
# calculate attention loss
|
||||
if self.use_guided_attn_loss:
|
||||
# NOTE: length of output for auto-regressive
|
||||
# input will be changed when r > 1
|
||||
if self.model.reduction_factor > 1:
|
||||
olens_in = olens // self.model.reduction_factor
|
||||
else:
|
||||
olens_in = olens
|
||||
attn_loss = self.attn_loss(att_ws, ilens, olens_in)
|
||||
loss = loss + attn_loss
|
||||
|
||||
report("eval/l1_loss", float(l1_loss))
|
||||
report("eval/mse_loss", float(mse_loss))
|
||||
report("eval/bce_loss", float(bce_loss))
|
||||
report("eval/attn_loss", float(attn_loss))
|
||||
report("eval/loss", float(loss))
|
||||
|
||||
losses_dict["l1_loss"] = float(l1_loss)
|
||||
losses_dict["mse_loss"] = float(mse_loss)
|
||||
losses_dict["bce_loss"] = float(bce_loss)
|
||||
losses_dict["attn_loss"] = float(attn_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
self.logger.info(self.msg)
|
@ -0,0 +1,519 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Attention modules for RNN."""
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.modules.masked_fill import masked_fill
|
||||
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
def _apply_attention_constraint(e,
|
||||
last_attended_idx,
|
||||
backward_window=1,
|
||||
forward_window=3):
|
||||
"""Apply monotonic attention constraint.
|
||||
|
||||
This function apply the monotonic attention constraint
|
||||
introduced in `Deep Voice 3: Scaling
|
||||
Text-to-Speech with Convolutional Sequence Learning`_.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
e : Tensor
|
||||
Attention energy before applying softmax (1, T).
|
||||
last_attended_idx : int
|
||||
The index of the inputs of the last attended [0, T].
|
||||
backward_window : int, optional
|
||||
Backward window size in attention constraint.
|
||||
forward_window : int, optional
|
||||
Forward window size in attetion constraint.
|
||||
|
||||
Returns
|
||||
----------
|
||||
Tensor
|
||||
Monotonic constrained attention energy (1, T).
|
||||
|
||||
.. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`:
|
||||
https://arxiv.org/abs/1710.07654
|
||||
|
||||
"""
|
||||
if paddle.shape(e)[0] != 1:
|
||||
raise NotImplementedError(
|
||||
"Batch attention constraining is not yet supported.")
|
||||
backward_idx = last_attended_idx - backward_window
|
||||
forward_idx = last_attended_idx + forward_window
|
||||
if backward_idx > 0:
|
||||
e[:, :backward_idx] = -float("inf")
|
||||
if forward_idx < paddle.shape(e)[1]:
|
||||
e[:, forward_idx:] = -float("inf")
|
||||
return e
|
||||
|
||||
|
||||
class AttLoc(nn.Layer):
|
||||
"""location-aware attention module.
|
||||
|
||||
Reference: Attention-Based Models for Speech Recognition
|
||||
(https://arxiv.org/pdf/1506.07503.pdf)
|
||||
Parameters
|
||||
----------
|
||||
eprojs : int
|
||||
projection-units of encoder
|
||||
dunits : int
|
||||
units of decoder
|
||||
att_dim : int
|
||||
att_dim: attention dimension
|
||||
aconv_chans : int
|
||||
channels of attention convolution
|
||||
aconv_filts : int
|
||||
filter size of attention convolution
|
||||
han_mode : bool
|
||||
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
eprojs,
|
||||
dunits,
|
||||
att_dim,
|
||||
aconv_chans,
|
||||
aconv_filts,
|
||||
han_mode=False):
|
||||
super().__init__()
|
||||
self.mlp_enc = nn.Linear(eprojs, att_dim)
|
||||
self.mlp_dec = nn.Linear(dunits, att_dim, bias_attr=False)
|
||||
self.mlp_att = nn.Linear(aconv_chans, att_dim, bias_attr=False)
|
||||
self.loc_conv = nn.Conv2D(
|
||||
1,
|
||||
aconv_chans,
|
||||
(1, 2 * aconv_filts + 1),
|
||||
padding=(0, aconv_filts),
|
||||
bias_attr=False, )
|
||||
self.gvec = nn.Linear(att_dim, 1)
|
||||
|
||||
self.dunits = dunits
|
||||
self.eprojs = eprojs
|
||||
self.att_dim = att_dim
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
self.han_mode = han_mode
|
||||
|
||||
def reset(self):
|
||||
"""reset states"""
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
enc_hs_pad,
|
||||
enc_hs_len,
|
||||
dec_z,
|
||||
att_prev,
|
||||
scaling=2.0,
|
||||
last_attended_idx=None,
|
||||
backward_window=1,
|
||||
forward_window=3, ):
|
||||
"""Calculate AttLoc forward propagation.
|
||||
Parameters
|
||||
----------
|
||||
enc_hs_pad : paddle.Tensor
|
||||
padded encoder hidden state (B, T_max, D_enc)
|
||||
enc_hs_len : paddle.Tensor
|
||||
padded encoder hidden state length (B)
|
||||
dec_z : paddle.Tensor dec_z
|
||||
decoder hidden state (B, D_dec)
|
||||
att_prev : paddle.Tensor
|
||||
previous attention weight (B, T_max)
|
||||
scaling : float
|
||||
scaling parameter before applying softmax
|
||||
forward_window : paddle.Tensor
|
||||
forward window size when constraining attention
|
||||
last_attended_idx : int
|
||||
index of the inputs of the last attended
|
||||
backward_window : int
|
||||
backward window size in attention constraint
|
||||
forward_window : int
|
||||
forward window size in attetion constraint
|
||||
|
||||
Returns
|
||||
----------
|
||||
paddle.Tensor
|
||||
attention weighted encoder state (B, D_enc)
|
||||
paddle.Tensor
|
||||
previous attention weights (B, T_max)
|
||||
"""
|
||||
batch = len(enc_hs_pad)
|
||||
# pre-compute all h outside the decoder loop
|
||||
if self.pre_compute_enc_h is None or self.han_mode:
|
||||
# (utt, frame, hdim)
|
||||
self.enc_h = enc_hs_pad
|
||||
self.h_length = paddle.shape(self.enc_h)[1]
|
||||
# (utt, frame, att_dim)
|
||||
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
|
||||
|
||||
if dec_z is None:
|
||||
dec_z = paddle.zeros([batch, self.dunits])
|
||||
else:
|
||||
dec_z = dec_z.reshape([batch, self.dunits])
|
||||
|
||||
# initialize attention weight with uniform dist.
|
||||
if att_prev is None:
|
||||
# if no bias, 0 0-pad goes 0
|
||||
|
||||
att_prev = 1.0 - make_pad_mask(enc_hs_len)
|
||||
att_prev = att_prev / enc_hs_len.unsqueeze(-1)
|
||||
|
||||
# att_prev: (utt, frame) -> (utt, 1, 1, frame)
|
||||
# -> (utt, att_conv_chans, 1, frame)
|
||||
|
||||
att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length]))
|
||||
# att_conv: (utt, att_conv_chans, 1, frame) -> (utt, frame, att_conv_chans)
|
||||
att_conv = att_conv.squeeze(2).transpose([0, 2, 1])
|
||||
# att_conv: (utt, frame, att_conv_chans) -> (utt, frame, att_dim)
|
||||
att_conv = self.mlp_att(att_conv)
|
||||
|
||||
# dec_z_tiled: (utt, frame, att_dim)
|
||||
dec_z_tiled = self.mlp_dec(dec_z).reshape([batch, 1, self.att_dim])
|
||||
|
||||
# dot with gvec
|
||||
# (utt, frame, att_dim) -> (utt, frame)
|
||||
e = self.gvec(
|
||||
paddle.tanh(att_conv + self.pre_compute_enc_h +
|
||||
dec_z_tiled)).squeeze(2)
|
||||
|
||||
# NOTE: consider zero padding when compute w.
|
||||
if self.mask is None:
|
||||
self.mask = make_pad_mask(enc_hs_len)
|
||||
e = masked_fill(e, self.mask, -float("inf"))
|
||||
# apply monotonic attention constraint (mainly for TTS)
|
||||
if last_attended_idx is not None:
|
||||
e = _apply_attention_constraint(e, last_attended_idx,
|
||||
backward_window, forward_window)
|
||||
|
||||
w = F.softmax(scaling * e, axis=1)
|
||||
|
||||
# weighted sum over flames
|
||||
# utt x hdim
|
||||
c = paddle.sum(
|
||||
self.enc_h * w.reshape([batch, self.h_length, 1]), axis=1)
|
||||
|
||||
return c, w
|
||||
|
||||
|
||||
class AttForward(nn.Layer):
|
||||
"""Forward attention module.
|
||||
Reference
|
||||
----------
|
||||
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
|
||||
(https://arxiv.org/pdf/1807.06736.pdf)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
eprojs : int
|
||||
projection-units of encoder
|
||||
dunits : int
|
||||
units of decoder
|
||||
att_dim : int
|
||||
attention dimension
|
||||
aconv_chans : int
|
||||
channels of attention convolution
|
||||
aconv_filts : int
|
||||
filter size of attention convolution
|
||||
"""
|
||||
|
||||
def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts):
|
||||
super().__init__()
|
||||
self.mlp_enc = nn.Linear(eprojs, att_dim)
|
||||
self.mlp_dec = nn.Linear(dunits, att_dim, bias_attr=False)
|
||||
self.mlp_att = nn.Linear(aconv_chans, att_dim, bias_attr=False)
|
||||
self.loc_conv = nn.Conv2D(
|
||||
1,
|
||||
aconv_chans,
|
||||
(1, 2 * aconv_filts + 1),
|
||||
padding=(0, aconv_filts),
|
||||
bias_attr=False, )
|
||||
self.gvec = nn.Linear(att_dim, 1)
|
||||
self.dunits = dunits
|
||||
self.eprojs = eprojs
|
||||
self.att_dim = att_dim
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
|
||||
def reset(self):
|
||||
"""reset states"""
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
enc_hs_pad,
|
||||
enc_hs_len,
|
||||
dec_z,
|
||||
att_prev,
|
||||
scaling=1.0,
|
||||
last_attended_idx=None,
|
||||
backward_window=1,
|
||||
forward_window=3, ):
|
||||
"""Calculate AttForward forward propagation.
|
||||
Parameters
|
||||
----------
|
||||
enc_hs_pad : paddle.Tensor
|
||||
padded encoder hidden state (B, T_max, D_enc)
|
||||
enc_hs_len : list
|
||||
padded encoder hidden state length (B,)
|
||||
dec_z : paddle.Tensor
|
||||
decoder hidden state (B, D_dec)
|
||||
att_prev : paddle.Tensor
|
||||
attention weights of previous step (B, T_max)
|
||||
scaling : float
|
||||
scaling parameter before applying softmax
|
||||
last_attended_idx : int
|
||||
index of the inputs of the last attended
|
||||
backward_window : int
|
||||
backward window size in attention constraint
|
||||
forward_window : int
|
||||
forward window size in attetion constraint
|
||||
Returns
|
||||
----------
|
||||
paddle.Tensor
|
||||
attention weighted encoder state (B, D_enc)
|
||||
paddle.Tensor
|
||||
previous attention weights (B, T_max)
|
||||
"""
|
||||
batch = len(enc_hs_pad)
|
||||
# pre-compute all h outside the decoder loop
|
||||
if self.pre_compute_enc_h is None:
|
||||
self.enc_h = enc_hs_pad # utt x frame x hdim
|
||||
self.h_length = paddle.shape(self.enc_h)[1]
|
||||
# utt x frame x att_dim
|
||||
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
|
||||
|
||||
if dec_z is None:
|
||||
dec_z = paddle.zeros([batch, self.dunits])
|
||||
else:
|
||||
dec_z = dec_z.reshape([batch, self.dunits])
|
||||
|
||||
if att_prev is None:
|
||||
# initial attention will be [1, 0, 0, ...]
|
||||
att_prev = paddle.zeros([*paddle.shape(enc_hs_pad)[:2]])
|
||||
att_prev[:, 0] = 1.0
|
||||
|
||||
# att_prev: utt x frame -> utt x 1 x 1 x frame
|
||||
# -> utt x att_conv_chans x 1 x frame
|
||||
att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length]))
|
||||
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
|
||||
att_conv = att_conv.squeeze(2).transpose([0, 2, 1])
|
||||
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
|
||||
att_conv = self.mlp_att(att_conv)
|
||||
|
||||
# dec_z_tiled: utt x frame x att_dim
|
||||
dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1)
|
||||
|
||||
# dot with gvec
|
||||
# utt x frame x att_dim -> utt x frame
|
||||
e = self.gvec(
|
||||
paddle.tanh(self.pre_compute_enc_h + dec_z_tiled +
|
||||
att_conv)).squeeze(2)
|
||||
|
||||
# NOTE: consider zero padding when compute w.
|
||||
if self.mask is None:
|
||||
self.mask = make_pad_mask(enc_hs_len)
|
||||
e = masked_fill(e, self.mask, -float("inf"))
|
||||
|
||||
# apply monotonic attention constraint (mainly for TTS)
|
||||
if last_attended_idx is not None:
|
||||
e = _apply_attention_constraint(e, last_attended_idx,
|
||||
backward_window, forward_window)
|
||||
|
||||
w = F.softmax(scaling * e, axis=1)
|
||||
|
||||
# forward attention
|
||||
att_prev_shift = F.pad(att_prev, (0, 0, 1, 0))[:, :-1]
|
||||
|
||||
w = (att_prev + att_prev_shift) * w
|
||||
# NOTE: clip is needed to avoid nan gradient
|
||||
w = F.normalize(paddle.clip(w, 1e-6), p=1, axis=1)
|
||||
|
||||
# weighted sum over flames
|
||||
# utt x hdim
|
||||
# NOTE use bmm instead of sum(*)
|
||||
c = paddle.sum(self.enc_h * w.unsqueeze(-1), axis=1)
|
||||
|
||||
return c, w
|
||||
|
||||
|
||||
class AttForwardTA(nn.Layer):
|
||||
"""Forward attention with transition agent module.
|
||||
Reference
|
||||
----------
|
||||
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
|
||||
(https://arxiv.org/pdf/1807.06736.pdf)
|
||||
Parameters
|
||||
----------
|
||||
eunits : int
|
||||
units of encoder
|
||||
dunits : int
|
||||
units of decoder
|
||||
att_dim : int
|
||||
attention dimension
|
||||
aconv_chans : int
|
||||
channels of attention convolution
|
||||
aconv_filts : int
|
||||
filter size of attention convolution
|
||||
odim : int
|
||||
output dimension
|
||||
"""
|
||||
|
||||
def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim):
|
||||
super().__init__()
|
||||
self.mlp_enc = nn.Linear(eunits, att_dim)
|
||||
self.mlp_dec = nn.Linear(dunits, att_dim, bias_attr=False)
|
||||
self.mlp_ta = nn.Linear(eunits + dunits + odim, 1)
|
||||
self.mlp_att = nn.Linear(aconv_chans, att_dim, bias_attr=False)
|
||||
self.loc_conv = nn.Conv2D(
|
||||
1,
|
||||
aconv_chans,
|
||||
(1, 2 * aconv_filts + 1),
|
||||
padding=(0, aconv_filts),
|
||||
bias_attr=False, )
|
||||
self.gvec = nn.Linear(att_dim, 1)
|
||||
self.dunits = dunits
|
||||
self.eunits = eunits
|
||||
self.att_dim = att_dim
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
self.trans_agent_prob = 0.5
|
||||
|
||||
def reset(self):
|
||||
self.h_length = None
|
||||
self.enc_h = None
|
||||
self.pre_compute_enc_h = None
|
||||
self.mask = None
|
||||
self.trans_agent_prob = 0.5
|
||||
|
||||
def forward(
|
||||
self,
|
||||
enc_hs_pad,
|
||||
enc_hs_len,
|
||||
dec_z,
|
||||
att_prev,
|
||||
out_prev,
|
||||
scaling=1.0,
|
||||
last_attended_idx=None,
|
||||
backward_window=1,
|
||||
forward_window=3, ):
|
||||
"""Calculate AttForwardTA forward propagation.
|
||||
Parameters
|
||||
----------
|
||||
enc_hs_pad : paddle.Tensor
|
||||
padded encoder hidden state (B, Tmax, eunits)
|
||||
enc_hs_len : list paddle.Tensor
|
||||
padded encoder hidden state length (B,)
|
||||
dec_z : paddle.Tensor
|
||||
decoder hidden state (B, dunits)
|
||||
att_prev : paddle.Tensor
|
||||
attention weights of previous step (B, T_max)
|
||||
out_prev : paddle.Tensor
|
||||
decoder outputs of previous step (B, odim)
|
||||
scaling : float
|
||||
scaling parameter before applying softmax
|
||||
last_attended_idx : int
|
||||
index of the inputs of the last attended
|
||||
backward_window : int
|
||||
backward window size in attention constraint
|
||||
forward_window : int
|
||||
forward window size in attetion constraint
|
||||
Returns
|
||||
----------
|
||||
paddle.Tensor
|
||||
attention weighted encoder state (B, dunits)
|
||||
paddle.Tensor
|
||||
previous attention weights (B, Tmax)
|
||||
"""
|
||||
batch = len(enc_hs_pad)
|
||||
# pre-compute all h outside the decoder loop
|
||||
if self.pre_compute_enc_h is None:
|
||||
self.enc_h = enc_hs_pad # utt x frame x hdim
|
||||
self.h_length = paddle.shape(self.enc_h)[1]
|
||||
# utt x frame x att_dim
|
||||
self.pre_compute_enc_h = self.mlp_enc(self.enc_h)
|
||||
|
||||
if dec_z is None:
|
||||
dec_z = paddle.zeros([batch, self.dunits])
|
||||
else:
|
||||
dec_z = dec_z.reshape([batch, self.dunits])
|
||||
|
||||
if att_prev is None:
|
||||
# initial attention will be [1, 0, 0, ...]
|
||||
att_prev = paddle.zeros([*paddle.shape(enc_hs_pad)[:2]])
|
||||
att_prev[:, 0] = 1.0
|
||||
|
||||
# att_prev: utt x frame -> utt x 1 x 1 x frame
|
||||
# -> utt x att_conv_chans x 1 x frame
|
||||
att_conv = self.loc_conv(att_prev.reshape([batch, 1, 1, self.h_length]))
|
||||
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
|
||||
att_conv = att_conv.squeeze(2).transpose([0, 2, 1])
|
||||
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
|
||||
att_conv = self.mlp_att(att_conv)
|
||||
|
||||
# dec_z_tiled: utt x frame x att_dim
|
||||
dec_z_tiled = self.mlp_dec(dec_z).reshape([batch, 1, self.att_dim])
|
||||
|
||||
# dot with gvec
|
||||
# utt x frame x att_dim -> utt x frame
|
||||
e = self.gvec(
|
||||
paddle.tanh(att_conv + self.pre_compute_enc_h +
|
||||
dec_z_tiled)).squeeze(2)
|
||||
|
||||
# NOTE consider zero padding when compute w.
|
||||
if self.mask is None:
|
||||
self.mask = make_pad_mask(enc_hs_len)
|
||||
e = masked_fill(e, self.mask, -float("inf"))
|
||||
|
||||
# apply monotonic attention constraint (mainly for TTS)
|
||||
if last_attended_idx is not None:
|
||||
e = _apply_attention_constraint(e, last_attended_idx,
|
||||
backward_window, forward_window)
|
||||
|
||||
w = F.softmax(scaling * e, axis=1)
|
||||
|
||||
# forward attention
|
||||
# att_prev_shift = F.pad(att_prev.unsqueeze(0), (1, 0), data_format='NCL').squeeze(0)[:, :-1]
|
||||
att_prev_shift = F.pad(att_prev, (0, 0, 1, 0))[:, :-1]
|
||||
w = (self.trans_agent_prob * att_prev +
|
||||
(1 - self.trans_agent_prob) * att_prev_shift) * w
|
||||
# NOTE: clip is needed to avoid nan gradient
|
||||
w = F.normalize(paddle.clip(w, 1e-6), p=1, axis=1)
|
||||
|
||||
# weighted sum over flames
|
||||
# utt x hdim
|
||||
# NOTE use bmm instead of sum(*)
|
||||
c = paddle.sum(
|
||||
self.enc_h * w.reshape([batch, self.h_length, 1]), axis=1)
|
||||
|
||||
# update transition agent prob
|
||||
self.trans_agent_prob = F.sigmoid(
|
||||
self.mlp_ta(paddle.concat([c, out_prev, dec_z], axis=1)))
|
||||
|
||||
return c, w
|
Loading…
Reference in new issue