[TTS]add Diffsinger with opencpop dataset (#3005)
parent
acf943007e
commit
1afd14acd9
@ -0,0 +1,174 @@
|
||||
([简体中文](./README_cn.md)|English)
|
||||
# DiffSinger with Opencpop
|
||||
This example contains code used to train a [DiffSinger](https://arxiv.org/abs/2105.02446) model with [Mandarin singing corpus](https://wenet.org.cn/opencpop/).
|
||||
|
||||
## Dataset
|
||||
### Download and Extract
|
||||
Download Opencpop from it's [Official Website](https://wenet.org.cn/opencpop/download/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/Opencpop`.
|
||||
|
||||
## Get Started
|
||||
Assume the path to the dataset is `~/datasets/Opencpop`.
|
||||
Run the command below to
|
||||
1. **source path**.
|
||||
2. preprocess the dataset.
|
||||
3. train the model.
|
||||
4. synthesize wavs.
|
||||
- synthesize waveform from `metadata.jsonl`.
|
||||
- (Supporting) synthesize waveform from a text file.
|
||||
5. (Supporting) inference using the static model.
|
||||
```bash
|
||||
./run.sh
|
||||
```
|
||||
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
|
||||
```bash
|
||||
./run.sh --stage 0 --stop-stage 0
|
||||
```
|
||||
### Data Preprocessing
|
||||
```bash
|
||||
./local/preprocess.sh ${conf_path}
|
||||
```
|
||||
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
|
||||
|
||||
```text
|
||||
dump
|
||||
├── dev
|
||||
│ ├── norm
|
||||
│ └── raw
|
||||
├── phone_id_map.txt
|
||||
├── speaker_id_map.txt
|
||||
├── test
|
||||
│ ├── norm
|
||||
│ └── raw
|
||||
└── train
|
||||
├── energy_stats.npy
|
||||
├── norm
|
||||
├── pitch_stats.npy
|
||||
├── raw
|
||||
├── speech_stats.npy
|
||||
└── speech_stretchs.npy
|
||||
|
||||
```
|
||||
The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains speech, pitch and energy features of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/*_stats.npy`. `speech_stretchs.npy` contains the minimum and maximum values of each dimension of the mel spectrum, which is used for linear stretching before training/inference of the diffusion module.
|
||||
Note: Since the training effect of non-norm features is due to norm, the features saved under `norm` are features that have not been normed.
|
||||
|
||||
|
||||
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains utterance id, speaker id, phones, text_lengths, speech_lengths, phone durations, the path of speech features, the path of pitch features, the path of energy features, note, note durations, slur.
|
||||
|
||||
### Model Training
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
|
||||
```
|
||||
`./local/train.sh` calls `${BIN_DIR}/train.py`.
|
||||
Here's the complete help message.
|
||||
```text
|
||||
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
|
||||
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
|
||||
[--ngpu NGPU] [--phones-dict PHONES_DICT]
|
||||
[--speaker-dict SPEAKER_DICT] [--speech-stretchs SPEECH_STRETCHS]
|
||||
|
||||
Train a FastSpeech2 model.
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--config CONFIG fastspeech2 config file.
|
||||
--train-metadata TRAIN_METADATA
|
||||
training data.
|
||||
--dev-metadata DEV_METADATA
|
||||
dev data.
|
||||
--output-dir OUTPUT_DIR
|
||||
output dir.
|
||||
--ngpu NGPU if ngpu=0, use cpu.
|
||||
--phones-dict PHONES_DICT
|
||||
phone vocabulary file.
|
||||
--speaker-dict SPEAKER_DICT
|
||||
speaker id map file for multiple speaker model.
|
||||
--speech-stretchs SPEECH_STRETCHS
|
||||
min amd max mel for stretching.
|
||||
```
|
||||
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
|
||||
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
|
||||
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
|
||||
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
|
||||
5. `--phones-dict` is the path of the phone vocabulary file.
|
||||
6. `--speech-stretchs` is the path of mel's min-max data file.
|
||||
|
||||
### Synthesizing
|
||||
We use parallel wavegan as the neural vocoder.
|
||||
Download pretrained parallel wavegan model from [pwgan_opencpop_ckpt_1.4.0.zip](https://paddlespeech.bj.bcebos.com/t2s/svs/opencpop/pwgan_opencpop_ckpt_1.4.0.zip) and unzip it.
|
||||
```bash
|
||||
unzip pwgan_opencpop_ckpt_1.4.0.zip
|
||||
```
|
||||
Parallel WaveGAN checkpoint contains files listed below.
|
||||
```text
|
||||
pwgan_opencpop_ckpt_1.4.0.zip
|
||||
├── default.yaml # default config used to train parallel wavegan
|
||||
├── snapshot_iter_100000.pdz # model parameters of parallel wavegan
|
||||
└── feats_stats.npy # statistics used to normalize spectrogram when training parallel wavegan
|
||||
```
|
||||
`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
|
||||
```
|
||||
```text
|
||||
usage: synthesize.py [-h]
|
||||
[--am {diffsinger_opencpop}]
|
||||
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
|
||||
[--am_stat AM_STAT] [--phones_dict PHONES_DICT]
|
||||
[--voc {pwgan_opencpop}]
|
||||
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
|
||||
[--voc_stat VOC_STAT] [--ngpu NGPU]
|
||||
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
|
||||
[--speech_stretchs SPEECH_STRETCHS]
|
||||
|
||||
Synthesize with acoustic model & vocoder
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
|
||||
Choose acoustic model type of tts task.
|
||||
--am_config AM_CONFIG
|
||||
Config of acoustic model.
|
||||
--am_ckpt AM_CKPT Checkpoint file of acoustic model.
|
||||
--am_stat AM_STAT mean and standard deviation used to normalize
|
||||
spectrogram when training acoustic model.
|
||||
--phones_dict PHONES_DICT
|
||||
phone vocabulary file.
|
||||
--tones_dict TONES_DICT
|
||||
tone vocabulary file.
|
||||
--speaker_dict SPEAKER_DICT
|
||||
speaker id map file.
|
||||
--voice-cloning VOICE_CLONING
|
||||
whether training voice cloning model.
|
||||
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
|
||||
Choose vocoder type of tts task.
|
||||
--voc_config VOC_CONFIG
|
||||
Config of voc.
|
||||
--voc_ckpt VOC_CKPT Checkpoint file of voc.
|
||||
--voc_stat VOC_STAT mean and standard deviation used to normalize
|
||||
spectrogram when training voc.
|
||||
--ngpu NGPU if ngpu == 0, use cpu.
|
||||
--test_metadata TEST_METADATA
|
||||
test metadata.
|
||||
--output_dir OUTPUT_DIR
|
||||
output dir.
|
||||
--speech-stretchs mel min and max values file.
|
||||
```
|
||||
|
||||
|
||||
## Pretrained Model
|
||||
Pretrained DiffSinger model:
|
||||
- [diffsinger_opencpop_ckpt_1.4.0.zip](https://paddlespeech.bj.bcebos.com/t2s/svs/opencpop/diffsinger_opencpop_ckpt_1.4.0.zip)
|
||||
|
||||
DiffSinger checkpoint contains files listed below.
|
||||
```text
|
||||
diffsinger_opencpop_ckpt_1.4.0.zip
|
||||
├── default.yaml # default config used to train diffsinger
|
||||
├── energy_stats.npy # statistics used to normalize energy when training diffsinger if norm is needed
|
||||
├── phone_id_map.txt # phone vocabulary file when training diffsinger
|
||||
├── pitch_stats.npy # statistics used to normalize pitch when training diffsinger if norm is needed
|
||||
├── snapshot_iter_160000.pdz # model parameters of diffsinger
|
||||
├── speech_stats.npy # statistics used to normalize mel when training diffsinger if norm is needed
|
||||
└── speech_stretchs.npy # Min and max values to use for mel spectral stretching before training diffusion
|
||||
|
||||
```
|
||||
At present, the text frontend is not perfect, and the method of `synthesize_e2e` is not supported for synthesizing audio. Try using `synthesize` first.
|
@ -0,0 +1,159 @@
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
|
||||
fs: 24000 # sr
|
||||
n_fft: 512 # FFT size (samples).
|
||||
n_shift: 128 # Hop size (samples). 12.5ms
|
||||
win_length: 512 # Window length (samples). 50ms
|
||||
# If set to null, it will be the same as fft_size.
|
||||
window: "hann" # Window function.
|
||||
|
||||
# Only used for feats_type != raw
|
||||
|
||||
fmin: 30 # Minimum frequency of Mel basis.
|
||||
fmax: 12000 # Maximum frequency of Mel basis.
|
||||
n_mels: 80 # The number of mel basis.
|
||||
|
||||
# Only used for the model using pitch features (e.g. FastSpeech2)
|
||||
f0min: 80 # Minimum f0 for pitch extraction.
|
||||
f0max: 750 # Maximum f0 for pitch extraction.
|
||||
|
||||
|
||||
###########################################################
|
||||
# DATA SETTING #
|
||||
###########################################################
|
||||
batch_size: 48 # batch size
|
||||
num_workers: 1 # number of gpu
|
||||
|
||||
|
||||
###########################################################
|
||||
# MODEL SETTING #
|
||||
###########################################################
|
||||
model:
|
||||
# music score related
|
||||
note_num: 300 # number of note
|
||||
is_slur_num: 2 # number of slur
|
||||
# fastspeech2 module options
|
||||
use_energy_pred: False # whether use energy predictor
|
||||
use_postnet: False # whether use postnet
|
||||
|
||||
# fastspeech2 module
|
||||
fastspeech2_params:
|
||||
adim: 256 # attention dimension
|
||||
aheads: 2 # number of attention heads
|
||||
elayers: 4 # number of encoder layers
|
||||
eunits: 1024 # number of encoder ff units
|
||||
dlayers: 4 # number of decoder layers
|
||||
dunits: 1024 # number of decoder ff units
|
||||
positionwise_layer_type: conv1d-linear # type of position-wise layer
|
||||
positionwise_conv_kernel_size: 9 # kernel size of position wise conv layer
|
||||
transformer_enc_dropout_rate: 0.1 # dropout rate for transformer encoder layer
|
||||
transformer_enc_positional_dropout_rate: 0.1 # dropout rate for transformer encoder positional encoding
|
||||
transformer_enc_attn_dropout_rate: 0.0 # dropout rate for transformer encoder attention layer
|
||||
transformer_activation_type: "gelu" # Activation function type in transformer.
|
||||
encoder_normalize_before: True # whether to perform layer normalization before the input
|
||||
decoder_normalize_before: True # whether to perform layer normalization before the input
|
||||
reduction_factor: 1 # reduction factor
|
||||
init_type: xavier_uniform # initialization type
|
||||
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
|
||||
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
|
||||
use_scaled_pos_enc: True # whether to use scaled positional encoding
|
||||
transformer_dec_dropout_rate: 0.1 # dropout rate for transformer decoder layer
|
||||
transformer_dec_positional_dropout_rate: 0.1 # dropout rate for transformer decoder positional encoding
|
||||
transformer_dec_attn_dropout_rate: 0.0 # dropout rate for transformer decoder attention layer
|
||||
duration_predictor_layers: 5 # number of layers of duration predictor
|
||||
duration_predictor_chans: 256 # number of channels of duration predictor
|
||||
duration_predictor_kernel_size: 3 # filter size of duration predictor
|
||||
duration_predictor_dropout_rate: 0.5 # dropout rate in energy predictor
|
||||
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
|
||||
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
|
||||
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
|
||||
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
|
||||
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
|
||||
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
|
||||
stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder
|
||||
energy_predictor_layers: 2 # number of conv layers in energy predictor
|
||||
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
|
||||
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
|
||||
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
|
||||
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
|
||||
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
|
||||
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
|
||||
postnet_layers: 5 # number of layers of postnet
|
||||
postnet_filts: 5 # filter size of conv layers in postnet
|
||||
postnet_chans: 256 # number of channels of conv layers in postnet
|
||||
postnet_dropout_rate: 0.5 # dropout rate for postnet
|
||||
|
||||
# denoiser module
|
||||
denoiser_params:
|
||||
in_channels: 80 # Number of channels of the input mel-spectrogram
|
||||
out_channels: 80 # Number of channels of the output mel-spectrogram
|
||||
kernel_size: 3 # Kernel size of the residual blocks inside
|
||||
layers: 20 # Number of residual blocks inside
|
||||
stacks: 5 # The number of groups to split the residual blocks into
|
||||
residual_channels: 256 # Residual channel of the residual blocks
|
||||
gate_channels: 512 # Gate channel of the residual blocks
|
||||
skip_channels: 256 # Skip channel of the residual blocks
|
||||
aux_channels: 256 # Auxiliary channel of the residual blocks
|
||||
dropout: 0.1 # Dropout of the residual blocks
|
||||
bias: True # Whether to use bias in residual blocks
|
||||
use_weight_norm: False # Whether to use weight norm in all convolutions
|
||||
init_type: "kaiming_normal" # Type of initialize weights of a neural network module
|
||||
|
||||
|
||||
diffusion_params:
|
||||
num_train_timesteps: 100 # The number of timesteps between the noise and the real during training
|
||||
beta_start: 0.0001 # beta start parameter for the scheduler
|
||||
beta_end: 0.06 # beta end parameter for the scheduler
|
||||
beta_schedule: "linear" # beta schedule parameter for the scheduler
|
||||
num_max_timesteps: 100 # The max timestep transition from real to noise
|
||||
stretch: True # whether to stretch before diffusion
|
||||
|
||||
|
||||
###########################################################
|
||||
# UPDATER SETTING #
|
||||
###########################################################
|
||||
fs2_updater:
|
||||
use_masking: True # whether to apply masking for padded part in loss calculation
|
||||
|
||||
ds_updater:
|
||||
use_masking: True # whether to apply masking for padded part in loss calculation
|
||||
|
||||
|
||||
###########################################################
|
||||
# OPTIMIZER SETTING #
|
||||
###########################################################
|
||||
# fastspeech2 optimizer
|
||||
fs2_optimizer:
|
||||
optim: adam # optimizer type
|
||||
learning_rate: 0.001 # learning rate
|
||||
|
||||
# diffusion optimizer
|
||||
ds_optimizer_params:
|
||||
beta1: 0.9
|
||||
beta2: 0.98
|
||||
weight_decay: 0.0
|
||||
|
||||
ds_scheduler_params:
|
||||
learning_rate: 0.001
|
||||
gamma: 0.5
|
||||
step_size: 50000
|
||||
ds_grad_norm: 1
|
||||
|
||||
|
||||
###########################################################
|
||||
# INTERVAL SETTING #
|
||||
###########################################################
|
||||
only_train_diffusion: True # Whether to freeze fastspeech2 parameters when training diffusion
|
||||
ds_train_start_steps: 160000 # Number of steps to start to train diffusion module.
|
||||
train_max_steps: 320000 # Number of training steps.
|
||||
save_interval_steps: 2000 # Interval steps to save checkpoint.
|
||||
eval_interval_steps: 2000 # Interval steps to evaluate the network.
|
||||
num_snapshots: 5
|
||||
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
###########################################################
|
||||
seed: 10086
|
@ -0,0 +1,74 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
config_path=$1
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# extract features
|
||||
echo "Extract features ..."
|
||||
python3 ${BIN_DIR}/preprocess.py \
|
||||
--dataset=opencpop \
|
||||
--rootdir=~/datasets/Opencpop/segments \
|
||||
--dumpdir=dump \
|
||||
--label-file=~/datasets/Opencpop/segments/transcriptions.txt \
|
||||
--config=${config_path} \
|
||||
--num-cpu=20 \
|
||||
--cut-sil=True
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# get features' stats(mean and std)
|
||||
echo "Get features' stats ..."
|
||||
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--field-name="speech"
|
||||
|
||||
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--field-name="pitch"
|
||||
|
||||
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--field-name="energy"
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# normalize and covert phone/speaker to id, dev and test should use train's stats
|
||||
echo "Normalize ..."
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--dumpdir=dump/train/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--pitch-stats=dump/train/pitch_stats.npy \
|
||||
--energy-stats=dump/train/energy_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/dev/raw/metadata.jsonl \
|
||||
--dumpdir=dump/dev/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--pitch-stats=dump/train/pitch_stats.npy \
|
||||
--energy-stats=dump/train/energy_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/test/raw/metadata.jsonl \
|
||||
--dumpdir=dump/test/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--pitch-stats=dump/train/pitch_stats.npy \
|
||||
--energy-stats=dump/train/energy_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||
# Get feature(mel) extremum for diffusion stretch
|
||||
echo "Get feature(mel) extremum ..."
|
||||
python3 ${BIN_DIR}/get_minmax.py \
|
||||
--metadata=dump/train/norm/metadata.jsonl \
|
||||
--speech-stretchs=dump/train/speech_stretchs.npy
|
||||
fi
|
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
ckpt_name=$3
|
||||
stage=0
|
||||
stop_stage=0
|
||||
|
||||
# pwgan
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
FLAGS_allocator_strategy=naive_best_fit \
|
||||
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||
python3 ${BIN_DIR}/../synthesize.py \
|
||||
--am=diffsinger_opencpop \
|
||||
--am_config=${config_path} \
|
||||
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
|
||||
--am_stat=dump/train/speech_stats.npy \
|
||||
--voc=pwgan_opencpop \
|
||||
--voc_config=pwgan_opencpop_ckpt_1.4.0/default.yaml \
|
||||
--voc_ckpt=pwgan_opencpop_ckpt_1.4.0/snapshot_iter_100000.pdz \
|
||||
--voc_stat=pwgan_opencpop_ckpt_1.4.0/feats_stats.npy \
|
||||
--test_metadata=dump/test/norm/metadata.jsonl \
|
||||
--output_dir=${train_output_path}/test \
|
||||
--phones_dict=dump/phone_id_map.txt \
|
||||
--speech_stretchs=dump/train/speech_stretchs.npy
|
||||
fi
|
||||
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
|
||||
python3 ${BIN_DIR}/train.py \
|
||||
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--config=${config_path} \
|
||||
--output-dir=${train_output_path} \
|
||||
--ngpu=1 \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speech-stretchs=dump/train/speech_stretchs.npy
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
MODEL=diffsinger
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
gpus=0
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
conf_path=conf/default.yaml
|
||||
train_output_path=exp/default
|
||||
ckpt_name=snapshot_iter_320000.pdz
|
||||
|
||||
# with the following command, you can choose the stage range you want to run
|
||||
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||
# this can not be mixed use with `$1`, `$2` ...
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
./local/preprocess.sh ${conf_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# synthesize, vocoder is pwgan by default
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||
fi
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
|
||||
|
||||
def get_minmax(spec, min_spec, max_spec):
|
||||
# spec: [T, 80]
|
||||
for i in range(spec.shape[1]):
|
||||
min_value = np.min(spec[:, i])
|
||||
max_value = np.max(spec[:, i])
|
||||
min_spec[i] = min(min_value, min_spec[i])
|
||||
max_spec[i] = max(max_value, max_spec[i])
|
||||
|
||||
return min_spec, max_spec
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory including feature files to be normalized. "
|
||||
"you need to specify either *-scp or rootdir.")
|
||||
|
||||
parser.add_argument(
|
||||
"--speech-stretchs",
|
||||
type=str,
|
||||
required=True,
|
||||
help="min max spec file. only computer on train data")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# get dataset
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
dataset = DataTable(
|
||||
metadata, converters={
|
||||
"speech": np.load,
|
||||
})
|
||||
logging.info(f"The number of files = {len(dataset)}.")
|
||||
|
||||
n_mel = 80
|
||||
min_spec = 100.0 * np.ones(shape=(n_mel), dtype=np.float32)
|
||||
max_spec = -100.0 * np.ones(shape=(n_mel), dtype=np.float32)
|
||||
|
||||
for item in tqdm(dataset):
|
||||
spec = item['speech']
|
||||
min_spec, max_spec = get_minmax(spec, min_spec, max_spec)
|
||||
|
||||
# Using min_spec=-6.0 training effect is better so far
|
||||
min_spec = -6.0 * np.ones(shape=(n_mel), dtype=np.float32)
|
||||
min_max_spec = np.stack([min_spec, max_spec], axis=0)
|
||||
np.save(
|
||||
str(args.speech_stretchs),
|
||||
min_max_spec.astype(np.float32),
|
||||
allow_pickle=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,189 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Normalize feature files and dump them."""
|
||||
import argparse
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from tqdm import tqdm
|
||||
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
from paddlespeech.t2s.utils import str2bool
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory including feature files to be normalized. "
|
||||
"you need to specify either *-scp or rootdir.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump normalized feature files.")
|
||||
parser.add_argument(
|
||||
"--speech-stats",
|
||||
type=str,
|
||||
required=True,
|
||||
help="speech statistics file.")
|
||||
parser.add_argument(
|
||||
"--pitch-stats", type=str, required=True, help="pitch statistics file.")
|
||||
parser.add_argument(
|
||||
"--energy-stats",
|
||||
type=str,
|
||||
required=True,
|
||||
help="energy statistics file.")
|
||||
parser.add_argument(
|
||||
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
|
||||
parser.add_argument(
|
||||
"--speaker-dict", type=str, default=None, help="speaker id map file.")
|
||||
parser.add_argument(
|
||||
"--norm-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="whether to norm features")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dumpdir = Path(args.dumpdir).expanduser()
|
||||
# use absolute path
|
||||
dumpdir = dumpdir.resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# get dataset
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
dataset = DataTable(
|
||||
metadata,
|
||||
converters={
|
||||
"speech": np.load,
|
||||
"pitch": np.load,
|
||||
"energy": np.load,
|
||||
})
|
||||
logging.info(f"The number of files = {len(dataset)}.")
|
||||
|
||||
# restore scaler
|
||||
speech_scaler = StandardScaler()
|
||||
if args.norm_feats:
|
||||
speech_scaler.mean_ = np.load(args.speech_stats)[0]
|
||||
speech_scaler.scale_ = np.load(args.speech_stats)[1]
|
||||
else:
|
||||
speech_scaler.mean_ = np.zeros(
|
||||
np.load(args.speech_stats)[0].shape, dtype="float32")
|
||||
speech_scaler.scale_ = np.ones(
|
||||
np.load(args.speech_stats)[1].shape, dtype="float32")
|
||||
speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0]
|
||||
|
||||
pitch_scaler = StandardScaler()
|
||||
if args.norm_feats:
|
||||
pitch_scaler.mean_ = np.load(args.pitch_stats)[0]
|
||||
pitch_scaler.scale_ = np.load(args.pitch_stats)[1]
|
||||
else:
|
||||
pitch_scaler.mean_ = np.zeros(
|
||||
np.load(args.pitch_stats)[0].shape, dtype="float32")
|
||||
pitch_scaler.scale_ = np.ones(
|
||||
np.load(args.pitch_stats)[1].shape, dtype="float32")
|
||||
pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0]
|
||||
|
||||
energy_scaler = StandardScaler()
|
||||
if args.norm_feats:
|
||||
energy_scaler.mean_ = np.load(args.energy_stats)[0]
|
||||
energy_scaler.scale_ = np.load(args.energy_stats)[1]
|
||||
else:
|
||||
energy_scaler.mean_ = np.zeros(
|
||||
np.load(args.energy_stats)[0].shape, dtype="float32")
|
||||
energy_scaler.scale_ = np.ones(
|
||||
np.load(args.energy_stats)[1].shape, dtype="float32")
|
||||
energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0]
|
||||
|
||||
vocab_phones = {}
|
||||
with open(args.phones_dict, 'rt') as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
for phn, id in phn_id:
|
||||
vocab_phones[phn] = int(id)
|
||||
|
||||
vocab_speaker = {}
|
||||
with open(args.speaker_dict, 'rt') as f:
|
||||
spk_id = [line.strip().split() for line in f.readlines()]
|
||||
for spk, id in spk_id:
|
||||
vocab_speaker[spk] = int(id)
|
||||
|
||||
# process each file
|
||||
output_metadata = []
|
||||
|
||||
for item in tqdm(dataset):
|
||||
utt_id = item['utt_id']
|
||||
speech = item['speech']
|
||||
pitch = item['pitch']
|
||||
energy = item['energy']
|
||||
# normalize
|
||||
speech = speech_scaler.transform(speech)
|
||||
speech_dir = dumpdir / "data_speech"
|
||||
speech_dir.mkdir(parents=True, exist_ok=True)
|
||||
speech_path = speech_dir / f"{utt_id}_speech.npy"
|
||||
np.save(speech_path, speech.astype(np.float32), allow_pickle=False)
|
||||
|
||||
pitch = pitch_scaler.transform(pitch)
|
||||
pitch_dir = dumpdir / "data_pitch"
|
||||
pitch_dir.mkdir(parents=True, exist_ok=True)
|
||||
pitch_path = pitch_dir / f"{utt_id}_pitch.npy"
|
||||
np.save(pitch_path, pitch.astype(np.float32), allow_pickle=False)
|
||||
|
||||
energy = energy_scaler.transform(energy)
|
||||
energy_dir = dumpdir / "data_energy"
|
||||
energy_dir.mkdir(parents=True, exist_ok=True)
|
||||
energy_path = energy_dir / f"{utt_id}_energy.npy"
|
||||
np.save(energy_path, energy.astype(np.float32), allow_pickle=False)
|
||||
phone_ids = [vocab_phones[p] for p in item['phones']]
|
||||
spk_id = vocab_speaker[item["speaker"]]
|
||||
record = {
|
||||
"utt_id": item['utt_id'],
|
||||
"spk_id": spk_id,
|
||||
"text": phone_ids,
|
||||
"text_lengths": item['text_lengths'],
|
||||
"speech_lengths": item['speech_lengths'],
|
||||
"durations": item['durations'],
|
||||
"speech": str(speech_path),
|
||||
"pitch": str(pitch_path),
|
||||
"energy": str(energy_path),
|
||||
"note": item['note'],
|
||||
"note_dur": item['note_dur'],
|
||||
"is_slur": item['is_slur'],
|
||||
}
|
||||
# add spk_emb for voice cloning
|
||||
if "spk_emb" in item:
|
||||
record["spk_emb"] = str(item["spk_emb"])
|
||||
|
||||
output_metadata.append(record)
|
||||
output_metadata.sort(key=itemgetter('utt_id'))
|
||||
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
||||
with jsonlines.open(output_metadata_path, 'w') as writer:
|
||||
for item in output_metadata:
|
||||
writer.write(item)
|
||||
logging.info(f"metadata dumped into {output_metadata_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,376 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import jsonlines
|
||||
import librosa
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import yaml
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.get_feats import Energy
|
||||
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
|
||||
from paddlespeech.t2s.datasets.get_feats import Pitch
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_sentences_svs
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
|
||||
from paddlespeech.t2s.utils import str2bool
|
||||
|
||||
ALL_INITIALS = [
|
||||
'zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h',
|
||||
'j', 'q', 'x', 'r', 'z', 'c', 's', 'y', 'w'
|
||||
]
|
||||
ALL_FINALS = [
|
||||
'a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia',
|
||||
'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'iu', 'ng', 'o', 'ong',
|
||||
'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've',
|
||||
'vn'
|
||||
]
|
||||
|
||||
|
||||
def process_sentence(
|
||||
config: Dict[str, Any],
|
||||
fp: Path,
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
pitch_extractor=None,
|
||||
energy_extractor=None,
|
||||
cut_sil: bool=True,
|
||||
spk_emb_dir: Path=None, ):
|
||||
utt_id = fp.stem
|
||||
record = None
|
||||
if utt_id in sentences:
|
||||
# reading, resampling may occur
|
||||
wav, _ = librosa.load(str(fp), sr=config.fs)
|
||||
if len(wav.shape) != 1:
|
||||
return record
|
||||
max_value = np.abs(wav).max()
|
||||
if max_value > 1.0:
|
||||
wav = wav / max_value
|
||||
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||
assert np.abs(wav).max(
|
||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||
phones = sentences[utt_id][0]
|
||||
durations = sentences[utt_id][1]
|
||||
note = sentences[utt_id][2]
|
||||
note_dur = sentences[utt_id][3]
|
||||
is_slur = sentences[utt_id][4]
|
||||
speaker = sentences[utt_id][-1]
|
||||
|
||||
# extract mel feats
|
||||
logmel = mel_extractor.get_log_mel_fbank(wav)
|
||||
# change duration according to mel_length
|
||||
compare_duration_and_mel_length(sentences, utt_id, logmel)
|
||||
# utt_id may be popped in compare_duration_and_mel_length
|
||||
if utt_id not in sentences:
|
||||
return None
|
||||
phones = sentences[utt_id][0]
|
||||
durations = sentences[utt_id][1]
|
||||
num_frames = logmel.shape[0]
|
||||
|
||||
assert sum(
|
||||
durations
|
||||
) == num_frames, "the sum of durations doesn't equal to the num of mel frames. "
|
||||
speech_dir = output_dir / "data_speech"
|
||||
speech_dir.mkdir(parents=True, exist_ok=True)
|
||||
speech_path = speech_dir / (utt_id + "_speech.npy")
|
||||
np.save(speech_path, logmel)
|
||||
# extract pitch and energy
|
||||
pitch = pitch_extractor.get_pitch(wav)
|
||||
assert pitch.shape[0] == num_frames
|
||||
pitch_dir = output_dir / "data_pitch"
|
||||
pitch_dir.mkdir(parents=True, exist_ok=True)
|
||||
pitch_path = pitch_dir / (utt_id + "_pitch.npy")
|
||||
np.save(pitch_path, pitch)
|
||||
energy = energy_extractor.get_energy(wav)
|
||||
assert energy.shape[0] == num_frames
|
||||
energy_dir = output_dir / "data_energy"
|
||||
energy_dir.mkdir(parents=True, exist_ok=True)
|
||||
energy_path = energy_dir / (utt_id + "_energy.npy")
|
||||
np.save(energy_path, energy)
|
||||
|
||||
record = {
|
||||
"utt_id": utt_id,
|
||||
"phones": phones,
|
||||
"text_lengths": len(phones),
|
||||
"speech_lengths": num_frames,
|
||||
"durations": durations,
|
||||
"speech": str(speech_path),
|
||||
"pitch": str(pitch_path),
|
||||
"energy": str(energy_path),
|
||||
"speaker": speaker,
|
||||
"note": note,
|
||||
"note_dur": note_dur,
|
||||
"is_slur": is_slur,
|
||||
}
|
||||
if spk_emb_dir:
|
||||
if speaker in os.listdir(spk_emb_dir):
|
||||
embed_name = utt_id + ".npy"
|
||||
embed_path = spk_emb_dir / speaker / embed_name
|
||||
if embed_path.is_file():
|
||||
record["spk_emb"] = str(embed_path)
|
||||
else:
|
||||
return None
|
||||
return record
|
||||
|
||||
|
||||
def process_sentences(
|
||||
config,
|
||||
fps: List[Path],
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
pitch_extractor=None,
|
||||
energy_extractor=None,
|
||||
nprocs: int=1,
|
||||
cut_sil: bool=True,
|
||||
spk_emb_dir: Path=None,
|
||||
write_metadata_method: str='w', ):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp in tqdm.tqdm(fps, total=len(fps)):
|
||||
record = process_sentence(
|
||||
config=config,
|
||||
fp=fp,
|
||||
sentences=sentences,
|
||||
output_dir=output_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
pitch_extractor=pitch_extractor,
|
||||
energy_extractor=energy_extractor,
|
||||
cut_sil=cut_sil,
|
||||
spk_emb_dir=spk_emb_dir, )
|
||||
if record:
|
||||
results.append(record)
|
||||
else:
|
||||
with ThreadPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp in fps:
|
||||
future = pool.submit(
|
||||
process_sentence,
|
||||
config,
|
||||
fp,
|
||||
sentences,
|
||||
output_dir,
|
||||
mel_extractor,
|
||||
pitch_extractor,
|
||||
energy_extractor,
|
||||
cut_sil,
|
||||
spk_emb_dir, )
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
results = []
|
||||
for ft in futures:
|
||||
record = ft.result()
|
||||
if record:
|
||||
results.append(record)
|
||||
|
||||
results.sort(key=itemgetter("utt_id"))
|
||||
with jsonlines.open(output_dir / "metadata.jsonl",
|
||||
write_metadata_method) as writer:
|
||||
for item in results:
|
||||
writer.write(item)
|
||||
print("Done")
|
||||
|
||||
|
||||
def main():
|
||||
# parse config and args
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess audio and then extract features.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default="opencpop",
|
||||
type=str,
|
||||
help="name of dataset, should in {opencpop} now")
|
||||
|
||||
parser.add_argument(
|
||||
"--rootdir", default=None, type=str, help="directory to dataset.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump feature files.")
|
||||
|
||||
parser.add_argument(
|
||||
"--label-file", default=None, type=str, help="path to label file.")
|
||||
|
||||
parser.add_argument("--config", type=str, help="diffsinger config file.")
|
||||
|
||||
parser.add_argument(
|
||||
"--num-cpu", type=int, default=1, help="number of process.")
|
||||
|
||||
parser.add_argument(
|
||||
"--cut-sil",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether cut sil in the edge of audio")
|
||||
|
||||
parser.add_argument(
|
||||
"--spk_emb_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory to speaker embedding files.")
|
||||
|
||||
parser.add_argument(
|
||||
"--write_metadata_method",
|
||||
default="w",
|
||||
type=str,
|
||||
choices=["w", "a"],
|
||||
help="How the metadata.jsonl file is written.")
|
||||
args = parser.parse_args()
|
||||
|
||||
rootdir = Path(args.rootdir).expanduser()
|
||||
dumpdir = Path(args.dumpdir).expanduser()
|
||||
# use absolute path
|
||||
dumpdir = dumpdir.resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
label_file = Path(args.label_file).expanduser()
|
||||
|
||||
if args.spk_emb_dir:
|
||||
spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
|
||||
else:
|
||||
spk_emb_dir = None
|
||||
|
||||
assert rootdir.is_dir()
|
||||
assert label_file.is_file()
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
sentences, speaker_set = get_sentences_svs(
|
||||
label_file,
|
||||
dataset=args.dataset,
|
||||
sample_rate=config.fs,
|
||||
n_shift=config.n_shift, )
|
||||
|
||||
phone_id_map_path = dumpdir / "phone_id_map.txt"
|
||||
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
|
||||
get_input_token(sentences, phone_id_map_path, args.dataset)
|
||||
get_spk_id_map(speaker_set, speaker_id_map_path)
|
||||
|
||||
if args.dataset == "opencpop":
|
||||
wavdir = rootdir / "wavs"
|
||||
# split data into 3 sections
|
||||
train_file = rootdir / "train.txt"
|
||||
train_wav_files = []
|
||||
with open(train_file, "r") as f_train:
|
||||
for line in f_train.readlines():
|
||||
utt = line.split("|")[0]
|
||||
wav_name = utt + ".wav"
|
||||
wav_path = wavdir / wav_name
|
||||
train_wav_files.append(wav_path)
|
||||
|
||||
test_file = rootdir / "test.txt"
|
||||
dev_wav_files = []
|
||||
test_wav_files = []
|
||||
num_dev = 106
|
||||
count = 0
|
||||
with open(test_file, "r") as f_test:
|
||||
for line in f_test.readlines():
|
||||
count += 1
|
||||
utt = line.split("|")[0]
|
||||
wav_name = utt + ".wav"
|
||||
wav_path = wavdir / wav_name
|
||||
if count > num_dev:
|
||||
test_wav_files.append(wav_path)
|
||||
else:
|
||||
dev_wav_files.append(wav_path)
|
||||
|
||||
else:
|
||||
print("dataset should in {opencpop} now!")
|
||||
|
||||
train_dump_dir = dumpdir / "train" / "raw"
|
||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
test_dump_dir = dumpdir / "test" / "raw"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extractor
|
||||
mel_extractor = LogMelFBank(
|
||||
sr=config.fs,
|
||||
n_fft=config.n_fft,
|
||||
hop_length=config.n_shift,
|
||||
win_length=config.win_length,
|
||||
window=config.window,
|
||||
n_mels=config.n_mels,
|
||||
fmin=config.fmin,
|
||||
fmax=config.fmax)
|
||||
pitch_extractor = Pitch(
|
||||
sr=config.fs,
|
||||
hop_length=config.n_shift,
|
||||
f0min=config.f0min,
|
||||
f0max=config.f0max)
|
||||
energy_extractor = Energy(
|
||||
n_fft=config.n_fft,
|
||||
hop_length=config.n_shift,
|
||||
win_length=config.win_length,
|
||||
window=config.window)
|
||||
|
||||
# process for the 3 sections
|
||||
if train_wav_files:
|
||||
process_sentences(
|
||||
config=config,
|
||||
fps=train_wav_files,
|
||||
sentences=sentences,
|
||||
output_dir=train_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
pitch_extractor=pitch_extractor,
|
||||
energy_extractor=energy_extractor,
|
||||
nprocs=args.num_cpu,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir,
|
||||
write_metadata_method=args.write_metadata_method)
|
||||
if dev_wav_files:
|
||||
process_sentences(
|
||||
config=config,
|
||||
fps=dev_wav_files,
|
||||
sentences=sentences,
|
||||
output_dir=dev_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
pitch_extractor=pitch_extractor,
|
||||
energy_extractor=energy_extractor,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir,
|
||||
write_metadata_method=args.write_metadata_method)
|
||||
if test_wav_files:
|
||||
process_sentences(
|
||||
config=config,
|
||||
fps=test_wav_files,
|
||||
sentences=sentences,
|
||||
output_dir=test_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
pitch_extractor=pitch_extractor,
|
||||
energy_extractor=energy_extractor,
|
||||
nprocs=args.num_cpu,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir,
|
||||
write_metadata_method=args.write_metadata_method)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,257 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
from paddle import DataParallel
|
||||
from paddle import distributed as dist
|
||||
from paddle import nn
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from paddle.optimizer import AdamW
|
||||
from paddle.optimizer.lr import StepDecay
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.am_batch_fn import diffsinger_multi_spk_batch_fn
|
||||
from paddlespeech.t2s.datasets.am_batch_fn import diffsinger_single_spk_batch_fn
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
from paddlespeech.t2s.models.diffsinger import DiffSinger
|
||||
from paddlespeech.t2s.models.diffsinger import DiffSingerEvaluator
|
||||
from paddlespeech.t2s.models.diffsinger import DiffSingerUpdater
|
||||
from paddlespeech.t2s.models.diffsinger import DiffusionLoss
|
||||
from paddlespeech.t2s.models.diffsinger.fastspeech2midi import FastSpeech2MIDILoss
|
||||
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
||||
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
||||
from paddlespeech.t2s.training.optimizer import build_optimizers
|
||||
from paddlespeech.t2s.training.seeding import seed_everything
|
||||
from paddlespeech.t2s.training.trainer import Trainer
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
# decides device type and whether to run in parallel
|
||||
# setup running environment correctly
|
||||
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
|
||||
paddle.set_device("cpu")
|
||||
else:
|
||||
paddle.set_device("gpu")
|
||||
world_size = paddle.distributed.get_world_size()
|
||||
if world_size > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
# set the random seed, it is a must for multiprocess training
|
||||
seed_everything(config.seed)
|
||||
|
||||
print(
|
||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
)
|
||||
fields = [
|
||||
"text", "text_lengths", "speech", "speech_lengths", "durations",
|
||||
"pitch", "energy", "note", "note_dur", "is_slur"
|
||||
]
|
||||
converters = {"speech": np.load, "pitch": np.load, "energy": np.load}
|
||||
spk_num = None
|
||||
if args.speaker_dict is not None:
|
||||
print("multiple speaker diffsinger!")
|
||||
collate_fn = diffsinger_multi_spk_batch_fn
|
||||
with open(args.speaker_dict, 'rt') as f:
|
||||
spk_id = [line.strip().split() for line in f.readlines()]
|
||||
spk_num = len(spk_id)
|
||||
fields += ["spk_id"]
|
||||
else:
|
||||
collate_fn = diffsinger_single_spk_batch_fn
|
||||
print("single speaker diffsinger!")
|
||||
|
||||
print("spk_num:", spk_num)
|
||||
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
# construct dataset for training and validation
|
||||
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||
train_metadata = list(reader)
|
||||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=fields,
|
||||
converters=converters, )
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=fields,
|
||||
converters=converters, )
|
||||
|
||||
# collate function and dataloader
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
print("samplers done!")
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
with open(args.phones_dict, "r") as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
vocab_size = len(phn_id)
|
||||
print("vocab_size:", vocab_size)
|
||||
|
||||
with open(args.speech_stretchs, "r") as f:
|
||||
spec_min = np.load(args.speech_stretchs)[0]
|
||||
spec_max = np.load(args.speech_stretchs)[1]
|
||||
spec_min = paddle.to_tensor(spec_min)
|
||||
spec_max = paddle.to_tensor(spec_max)
|
||||
print("min and max spec done!")
|
||||
|
||||
odim = config.n_mels
|
||||
config["model"]["fastspeech2_params"]["spk_num"] = spk_num
|
||||
model = DiffSinger(
|
||||
spec_min=spec_min,
|
||||
spec_max=spec_max,
|
||||
idim=vocab_size,
|
||||
odim=odim,
|
||||
**config["model"], )
|
||||
model_fs2 = model.fs2
|
||||
model_ds = model.diffusion
|
||||
if world_size > 1:
|
||||
model = DataParallel(model)
|
||||
model_fs2 = model._layers.fs2
|
||||
model_ds = model._layers.diffusion
|
||||
print("models done!")
|
||||
|
||||
criterion_fs2 = FastSpeech2MIDILoss(**config["fs2_updater"])
|
||||
criterion_ds = DiffusionLoss(**config["ds_updater"])
|
||||
print("criterions done!")
|
||||
|
||||
optimizer_fs2 = build_optimizers(model_fs2, **config["fs2_optimizer"])
|
||||
lr_schedule_ds = StepDecay(**config["ds_scheduler_params"])
|
||||
gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"])
|
||||
optimizer_ds = AdamW(
|
||||
learning_rate=lr_schedule_ds,
|
||||
grad_clip=gradient_clip_ds,
|
||||
parameters=model_ds.parameters(),
|
||||
**config["ds_optimizer_params"])
|
||||
print("optimizer done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if dist.get_rank() == 0:
|
||||
config_name = args.config.split("/")[-1]
|
||||
# copy conf to output_dir
|
||||
shutil.copyfile(args.config, output_dir / config_name)
|
||||
|
||||
updater = DiffSingerUpdater(
|
||||
model=model,
|
||||
optimizers={
|
||||
"fs2": optimizer_fs2,
|
||||
"ds": optimizer_ds,
|
||||
},
|
||||
criterions={
|
||||
"fs2": criterion_fs2,
|
||||
"ds": criterion_ds,
|
||||
},
|
||||
dataloader=train_dataloader,
|
||||
ds_train_start_steps=config.ds_train_start_steps,
|
||||
output_dir=output_dir,
|
||||
only_train_diffusion=config["only_train_diffusion"])
|
||||
|
||||
evaluator = DiffSingerEvaluator(
|
||||
model=model,
|
||||
criterions={
|
||||
"fs2": criterion_fs2,
|
||||
"ds": criterion_ds,
|
||||
},
|
||||
dataloader=dev_dataloader,
|
||||
output_dir=output_dir, )
|
||||
|
||||
trainer = Trainer(
|
||||
updater,
|
||||
stop_trigger=(config.train_max_steps, "iteration"),
|
||||
out=output_dir, )
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(
|
||||
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
|
||||
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
|
||||
trainer.extend(
|
||||
Snapshot(max_size=config.num_snapshots),
|
||||
trigger=(config.save_interval_steps, 'iteration'))
|
||||
|
||||
print("Trainer Done!")
|
||||
trainer.run()
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(description="Train a DiffSinger model.")
|
||||
parser.add_argument("--config", type=str, help="diffsinger config file.")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
|
||||
parser.add_argument(
|
||||
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
|
||||
parser.add_argument(
|
||||
"--speaker-dict",
|
||||
type=str,
|
||||
default=None,
|
||||
help="speaker id map file for multiple speaker model.")
|
||||
parser.add_argument(
|
||||
"--speech-stretchs",
|
||||
type=str,
|
||||
help="The min and max values of the mel spectrum.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
print(
|
||||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||
)
|
||||
|
||||
# dispatch
|
||||
if args.ngpu > 1:
|
||||
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
|
||||
else:
|
||||
train_sp(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .diffsinger import *
|
||||
from .diffsinger_updater import *
|
@ -0,0 +1,399 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
"""DiffSinger related modules for paddle"""
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from paddlespeech.t2s.models.diffsinger.fastspeech2midi import FastSpeech2MIDI
|
||||
from paddlespeech.t2s.modules.diffnet import DiffNet
|
||||
from paddlespeech.t2s.modules.diffusion import GaussianDiffusion
|
||||
|
||||
|
||||
class DiffSinger(nn.Layer):
|
||||
"""DiffSinger module.
|
||||
|
||||
This is a module of DiffSinger described in `DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism`._
|
||||
.. _`DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism`:
|
||||
https://arxiv.org/pdf/2105.02446.pdf
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# min and max spec for stretching before diffusion
|
||||
spec_min: paddle.Tensor,
|
||||
spec_max: paddle.Tensor,
|
||||
# fastspeech2midi config
|
||||
idim: int,
|
||||
odim: int,
|
||||
use_energy_pred: bool=False,
|
||||
use_postnet: bool=False,
|
||||
# music score related
|
||||
note_num: int=300,
|
||||
is_slur_num: int=2,
|
||||
fastspeech2_params: Dict[str, Any]={
|
||||
"adim": 256,
|
||||
"aheads": 2,
|
||||
"elayers": 4,
|
||||
"eunits": 1024,
|
||||
"dlayers": 4,
|
||||
"dunits": 1024,
|
||||
"positionwise_layer_type": "conv1d",
|
||||
"positionwise_conv_kernel_size": 1,
|
||||
"use_scaled_pos_enc": True,
|
||||
"use_batch_norm": True,
|
||||
"encoder_normalize_before": True,
|
||||
"decoder_normalize_before": True,
|
||||
"encoder_concat_after": False,
|
||||
"decoder_concat_after": False,
|
||||
"reduction_factor": 1,
|
||||
# for transformer
|
||||
"transformer_enc_dropout_rate": 0.1,
|
||||
"transformer_enc_positional_dropout_rate": 0.1,
|
||||
"transformer_enc_attn_dropout_rate": 0.1,
|
||||
"transformer_dec_dropout_rate": 0.1,
|
||||
"transformer_dec_positional_dropout_rate": 0.1,
|
||||
"transformer_dec_attn_dropout_rate": 0.1,
|
||||
"transformer_activation_type": "gelu",
|
||||
# duration predictor
|
||||
"duration_predictor_layers": 2,
|
||||
"duration_predictor_chans": 384,
|
||||
"duration_predictor_kernel_size": 3,
|
||||
"duration_predictor_dropout_rate": 0.1,
|
||||
# pitch predictor
|
||||
"use_pitch_embed": True,
|
||||
"pitch_predictor_layers": 2,
|
||||
"pitch_predictor_chans": 384,
|
||||
"pitch_predictor_kernel_size": 3,
|
||||
"pitch_predictor_dropout": 0.5,
|
||||
"pitch_embed_kernel_size": 9,
|
||||
"pitch_embed_dropout": 0.5,
|
||||
"stop_gradient_from_pitch_predictor": False,
|
||||
# energy predictor
|
||||
"use_energy_embed": False,
|
||||
"energy_predictor_layers": 2,
|
||||
"energy_predictor_chans": 384,
|
||||
"energy_predictor_kernel_size": 3,
|
||||
"energy_predictor_dropout": 0.5,
|
||||
"energy_embed_kernel_size": 9,
|
||||
"energy_embed_dropout": 0.5,
|
||||
"stop_gradient_from_energy_predictor": False,
|
||||
# postnet
|
||||
"postnet_layers": 5,
|
||||
"postnet_chans": 512,
|
||||
"postnet_filts": 5,
|
||||
"postnet_dropout_rate": 0.5,
|
||||
# spk emb
|
||||
"spk_num": None,
|
||||
"spk_embed_dim": None,
|
||||
"spk_embed_integration_type": "add",
|
||||
# training related
|
||||
"init_type": "xavier_uniform",
|
||||
"init_enc_alpha": 1.0,
|
||||
"init_dec_alpha": 1.0,
|
||||
# speaker classifier
|
||||
"enable_speaker_classifier": False,
|
||||
"hidden_sc_dim": 256,
|
||||
},
|
||||
# denoiser config
|
||||
denoiser_params: Dict[str, Any]={
|
||||
"in_channels": 80,
|
||||
"out_channels": 80,
|
||||
"kernel_size": 3,
|
||||
"layers": 20,
|
||||
"stacks": 5,
|
||||
"residual_channels": 256,
|
||||
"gate_channels": 512,
|
||||
"skip_channels": 256,
|
||||
"aux_channels": 256,
|
||||
"dropout": 0.,
|
||||
"bias": True,
|
||||
"use_weight_norm": False,
|
||||
"init_type": "kaiming_normal",
|
||||
},
|
||||
# diffusion config
|
||||
diffusion_params: Dict[str, Any]={
|
||||
"num_train_timesteps": 100,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.06,
|
||||
"beta_schedule": "squaredcos_cap_v2",
|
||||
"num_max_timesteps": 60,
|
||||
"stretch": True,
|
||||
}, ):
|
||||
"""Initialize DiffSinger module.
|
||||
|
||||
Args:
|
||||
spec_min (paddle.Tensor): The minimum value of the feature(mel) to stretch before diffusion.
|
||||
spec_max (paddle.Tensor): The maximum value of the feature(mel) to stretch before diffusion.
|
||||
idim (int): Dimension of the inputs (Input vocabrary size.).
|
||||
odim (int): Dimension of the outputs (Acoustic feature dimension.).
|
||||
use_energy_pred (bool, optional): whether use energy predictor. Defaults False.
|
||||
use_postnet (bool, optional): whether use postnet. Defaults False.
|
||||
note_num (int, optional): The number of note. Defaults to 300.
|
||||
is_slur_num (int, optional): The number of slur. Defaults to 2.
|
||||
fastspeech2_params (Dict[str, Any]): Parameter dict for fastspeech2 module.
|
||||
denoiser_params (Dict[str, Any]): Parameter dict for dinoiser module.
|
||||
diffusion_params (Dict[str, Any]): Parameter dict for diffusion module.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs2 = FastSpeech2MIDI(
|
||||
idim=idim,
|
||||
odim=odim,
|
||||
fastspeech2_params=fastspeech2_params,
|
||||
note_num=note_num,
|
||||
is_slur_num=is_slur_num,
|
||||
use_energy_pred=use_energy_pred,
|
||||
use_postnet=use_postnet, )
|
||||
denoiser = DiffNet(**denoiser_params)
|
||||
self.diffusion = GaussianDiffusion(
|
||||
denoiser,
|
||||
**diffusion_params,
|
||||
min_values=spec_min,
|
||||
max_values=spec_max, )
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
note: paddle.Tensor,
|
||||
note_dur: paddle.Tensor,
|
||||
is_slur: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
speech: paddle.Tensor,
|
||||
speech_lengths: paddle.Tensor,
|
||||
durations: paddle.Tensor,
|
||||
pitch: paddle.Tensor,
|
||||
energy: paddle.Tensor,
|
||||
spk_emb: paddle.Tensor=None,
|
||||
spk_id: paddle.Tensor=None,
|
||||
only_train_fs2: bool=True,
|
||||
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
text(Tensor(int64)):
|
||||
Batch of padded token (phone) ids (B, Tmax).
|
||||
note(Tensor(int64)):
|
||||
Batch of padded note (element in music score) ids (B, Tmax).
|
||||
note_dur(Tensor(float32)):
|
||||
Batch of padded note durations in seconds (element in music score) (B, Tmax).
|
||||
is_slur(Tensor(int64)):
|
||||
Batch of padded slur (element in music score) ids (B, Tmax).
|
||||
text_lengths(Tensor(int64)):
|
||||
Batch of phone lengths of each input (B,).
|
||||
speech(Tensor[float32]):
|
||||
Batch of padded target features (e.g. mel) (B, Lmax, odim).
|
||||
speech_lengths(Tensor(int64)):
|
||||
Batch of the lengths of each target features (B,).
|
||||
durations(Tensor(int64)):
|
||||
Batch of padded token durations in frame (B, Tmax).
|
||||
pitch(Tensor[float32]):
|
||||
Batch of padded frame-averaged pitch (B, Lmax, 1).
|
||||
energy(Tensor[float32]):
|
||||
Batch of padded frame-averaged energy (B, Lmax, 1).
|
||||
spk_emb(Tensor[float32], optional):
|
||||
Batch of speaker embeddings (B, spk_embed_dim).
|
||||
spk_id(Tnesor[int64], optional(int64)):
|
||||
Batch of speaker ids (B,)
|
||||
only_train_fs2(bool):
|
||||
Whether to train only the fastspeech2 module
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# only train fastspeech2 module firstly
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.fs2(
|
||||
text=text,
|
||||
note=note,
|
||||
note_dur=note_dur,
|
||||
is_slur=is_slur,
|
||||
text_lengths=text_lengths,
|
||||
speech=speech,
|
||||
speech_lengths=speech_lengths,
|
||||
durations=durations,
|
||||
pitch=pitch,
|
||||
energy=energy,
|
||||
spk_id=spk_id,
|
||||
spk_emb=spk_emb)
|
||||
if only_train_fs2:
|
||||
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
|
||||
|
||||
# get the encoder output from fastspeech2 as the condition of denoiser module
|
||||
cond_fs2, mel_masks = self.fs2.encoder_infer_batch(
|
||||
text=text,
|
||||
note=note,
|
||||
note_dur=note_dur,
|
||||
is_slur=is_slur,
|
||||
text_lengths=text_lengths,
|
||||
speech_lengths=speech_lengths,
|
||||
ds=durations,
|
||||
ps=pitch,
|
||||
es=energy)
|
||||
cond_fs2 = cond_fs2.transpose((0, 2, 1))
|
||||
|
||||
# get the output(final mel) from diffusion module
|
||||
noise_pred, noise_target = self.diffusion(
|
||||
speech.transpose((0, 2, 1)), cond_fs2)
|
||||
return noise_pred, noise_target, mel_masks
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
note: paddle.Tensor,
|
||||
note_dur: paddle.Tensor,
|
||||
is_slur: paddle.Tensor,
|
||||
get_mel_fs2: bool=False, ):
|
||||
"""Run inference
|
||||
|
||||
Args:
|
||||
text(Tensor(int64)):
|
||||
Batch of padded token (phone) ids (B, Tmax).
|
||||
note(Tensor(int64)):
|
||||
Batch of padded note (element in music score) ids (B, Tmax).
|
||||
note_dur(Tensor(float32)):
|
||||
Batch of padded note durations in seconds (element in music score) (B, Tmax).
|
||||
is_slur(Tensor(int64)):
|
||||
Batch of padded slur (element in music score) ids (B, Tmax).
|
||||
get_mel_fs2 (bool, optional): . Defaults to False.
|
||||
Whether to get mel from fastspeech2 module.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
mel_fs2, _, _, _ = self.fs2.inference(text, note, note_dur, is_slur)
|
||||
if get_mel_fs2:
|
||||
return mel_fs2
|
||||
mel_fs2 = mel_fs2.unsqueeze(0).transpose((0, 2, 1))
|
||||
cond_fs2 = self.fs2.encoder_infer(text, note, note_dur, is_slur)
|
||||
cond_fs2 = cond_fs2.transpose((0, 2, 1))
|
||||
noise = paddle.randn(mel_fs2.shape)
|
||||
mel = self.diffusion.inference(
|
||||
noise=noise,
|
||||
cond=cond_fs2,
|
||||
ref_x=mel_fs2,
|
||||
scheduler_type="ddpm",
|
||||
num_inference_steps=60)
|
||||
mel = mel.transpose((0, 2, 1))
|
||||
return mel[0]
|
||||
|
||||
|
||||
class DiffSingerInference(nn.Layer):
|
||||
def __init__(self, normalizer, model):
|
||||
super().__init__()
|
||||
self.normalizer = normalizer
|
||||
self.acoustic_model = model
|
||||
|
||||
def forward(self, text, note, note_dur, is_slur, get_mel_fs2: bool=False):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
text(Tensor(int64)):
|
||||
Batch of padded token (phone) ids (B, Tmax).
|
||||
note(Tensor(int64)):
|
||||
Batch of padded note (element in music score) ids (B, Tmax).
|
||||
note_dur(Tensor(float32)):
|
||||
Batch of padded note durations in seconds (element in music score) (B, Tmax).
|
||||
is_slur(Tensor(int64)):
|
||||
Batch of padded slur (element in music score) ids (B, Tmax).
|
||||
get_mel_fs2 (bool, optional): . Defaults to False.
|
||||
Whether to get mel from fastspeech2 module.
|
||||
|
||||
Returns:
|
||||
logmel(Tensor(float32)): denorm logmel, [T, mel_bin]
|
||||
"""
|
||||
normalized_mel = self.acoustic_model.inference(
|
||||
text=text,
|
||||
note=note,
|
||||
note_dur=note_dur,
|
||||
is_slur=is_slur,
|
||||
get_mel_fs2=get_mel_fs2)
|
||||
logmel = normalized_mel
|
||||
return logmel
|
||||
|
||||
|
||||
class DiffusionLoss(nn.Layer):
|
||||
"""Loss function module for Diffusion module on DiffSinger."""
|
||||
|
||||
def __init__(self, use_masking: bool=True,
|
||||
use_weighted_masking: bool=False):
|
||||
"""Initialize feed-forward Transformer loss module.
|
||||
Args:
|
||||
use_masking (bool):
|
||||
Whether to apply masking for padded part in loss calculation.
|
||||
use_weighted_masking (bool):
|
||||
Whether to weighted masking in loss calculation.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
assert (use_masking != use_weighted_masking) or not use_masking
|
||||
self.use_masking = use_masking
|
||||
self.use_weighted_masking = use_weighted_masking
|
||||
|
||||
# define criterions
|
||||
reduction = "none" if self.use_weighted_masking else "mean"
|
||||
self.l1_criterion = nn.L1Loss(reduction=reduction)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
noise_pred: paddle.Tensor,
|
||||
noise_target: paddle.Tensor,
|
||||
mel_masks: paddle.Tensor, ) -> paddle.Tensor:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
noise_pred(Tensor):
|
||||
Batch of outputs predict noise (B, Lmax, odim).
|
||||
noise_target(Tensor):
|
||||
Batch of target noise (B, Lmax, odim).
|
||||
mel_masks(Tensor):
|
||||
Batch of mask of real mel (B, Lmax, 1).
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# apply mask to remove padded part
|
||||
if self.use_masking:
|
||||
noise_pred = noise_pred.masked_select(
|
||||
mel_masks.broadcast_to(noise_pred.shape))
|
||||
noise_target = noise_target.masked_select(
|
||||
mel_masks.broadcast_to(noise_target.shape))
|
||||
|
||||
# calculate loss
|
||||
l1_loss = self.l1_criterion(noise_pred, noise_target)
|
||||
|
||||
# make weighted mask and apply it
|
||||
if self.use_weighted_masking:
|
||||
mel_masks = mel_masks.unsqueeze(-1)
|
||||
out_weights = mel_masks.cast(dtype=paddle.float32) / mel_masks.cast(
|
||||
dtype=paddle.float32).sum(
|
||||
axis=1, keepdim=True)
|
||||
out_weights /= noise_target.shape[0] * noise_target.shape[2]
|
||||
|
||||
# apply weight
|
||||
l1_loss = l1_loss.multiply(out_weights)
|
||||
l1_loss = l1_loss.masked_select(
|
||||
mel_masks.broadcast_to(l1_loss.shape)).sum()
|
||||
|
||||
return l1_loss
|
@ -0,0 +1,302 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
|
||||
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
|
||||
from paddlespeech.t2s.training.reporter import report
|
||||
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
|
||||
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
|
||||
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class DiffSingerUpdater(StandardUpdater):
|
||||
def __init__(self,
|
||||
model: Layer,
|
||||
optimizers: Dict[str, Optimizer],
|
||||
criterions: Dict[str, Layer],
|
||||
dataloader: DataLoader,
|
||||
ds_train_start_steps: int=160000,
|
||||
output_dir: Path=None,
|
||||
only_train_diffusion: bool=True):
|
||||
super().__init__(model, optimizers, dataloader, init_state=None)
|
||||
self.model = model._layers if isinstance(model,
|
||||
paddle.DataParallel) else model
|
||||
self.only_train_diffusion = only_train_diffusion
|
||||
|
||||
self.optimizers = optimizers
|
||||
self.optimizer_fs2: Optimizer = optimizers['fs2']
|
||||
self.optimizer_ds: Optimizer = optimizers['ds']
|
||||
|
||||
self.criterions = criterions
|
||||
self.criterion_fs2 = criterions['fs2']
|
||||
self.criterion_ds = criterions['ds']
|
||||
|
||||
self.dataloader = dataloader
|
||||
|
||||
self.ds_train_start_steps = ds_train_start_steps
|
||||
|
||||
self.state = UpdaterState(iteration=0, epoch=0)
|
||||
self.train_iterator = iter(self.dataloader)
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def update_core(self, batch):
|
||||
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||
losses_dict = {}
|
||||
# spk_id!=None in multiple spk diffsinger
|
||||
spk_id = batch["spk_id"] if "spk_id" in batch else None
|
||||
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
|
||||
# No explicit speaker identifier labels are used during voice cloning training.
|
||||
if spk_emb is not None:
|
||||
spk_id = None
|
||||
|
||||
# only train fastspeech2 module firstly
|
||||
if self.state.iteration < self.ds_train_start_steps:
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
|
||||
text=batch["text"],
|
||||
note=batch["note"],
|
||||
note_dur=batch["note_dur"],
|
||||
is_slur=batch["is_slur"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
durations=batch["durations"],
|
||||
pitch=batch["pitch"],
|
||||
energy=batch["energy"],
|
||||
spk_id=spk_id,
|
||||
spk_emb=spk_emb,
|
||||
only_train_fs2=True, )
|
||||
|
||||
l1_loss_fs2, ssim_loss_fs2, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion_fs2(
|
||||
after_outs=after_outs,
|
||||
before_outs=before_outs,
|
||||
d_outs=d_outs,
|
||||
p_outs=p_outs,
|
||||
e_outs=e_outs,
|
||||
ys=ys,
|
||||
ds=batch["durations"],
|
||||
ps=batch["pitch"],
|
||||
es=batch["energy"],
|
||||
ilens=batch["text_lengths"],
|
||||
olens=olens,
|
||||
spk_logits=spk_logits,
|
||||
spk_ids=spk_id, )
|
||||
|
||||
loss_fs2 = l1_loss_fs2 + ssim_loss_fs2 + duration_loss + pitch_loss + energy_loss + speaker_loss
|
||||
|
||||
self.optimizer_fs2.clear_grad()
|
||||
loss_fs2.backward()
|
||||
self.optimizer_fs2.step()
|
||||
|
||||
report("train/loss_fs2", float(loss_fs2))
|
||||
report("train/l1_loss_fs2", float(l1_loss_fs2))
|
||||
report("train/ssim_loss_fs2", float(ssim_loss_fs2))
|
||||
report("train/duration_loss", float(duration_loss))
|
||||
report("train/pitch_loss", float(pitch_loss))
|
||||
|
||||
losses_dict["l1_loss_fs2"] = float(l1_loss_fs2)
|
||||
losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2)
|
||||
losses_dict["duration_loss"] = float(duration_loss)
|
||||
losses_dict["pitch_loss"] = float(pitch_loss)
|
||||
|
||||
if speaker_loss != 0.:
|
||||
report("train/speaker_loss", float(speaker_loss))
|
||||
losses_dict["speaker_loss"] = float(speaker_loss)
|
||||
if energy_loss != 0.:
|
||||
report("train/energy_loss", float(energy_loss))
|
||||
losses_dict["energy_loss"] = float(energy_loss)
|
||||
|
||||
losses_dict["loss_fs2"] = float(loss_fs2)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
# Then only train diffusion module, freeze fastspeech2 parameters.
|
||||
if self.state.iteration > self.ds_train_start_steps:
|
||||
for param in self.model.fs2.parameters():
|
||||
param.trainable = False if self.only_train_diffusion else True
|
||||
|
||||
noise_pred, noise_target, mel_masks = self.model(
|
||||
text=batch["text"],
|
||||
note=batch["note"],
|
||||
note_dur=batch["note_dur"],
|
||||
is_slur=batch["is_slur"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
durations=batch["durations"],
|
||||
pitch=batch["pitch"],
|
||||
energy=batch["energy"],
|
||||
spk_id=spk_id,
|
||||
spk_emb=spk_emb,
|
||||
only_train_fs2=False, )
|
||||
|
||||
noise_pred = noise_pred.transpose((0, 2, 1))
|
||||
noise_target = noise_target.transpose((0, 2, 1))
|
||||
mel_masks = mel_masks.transpose((0, 2, 1))
|
||||
l1_loss_ds = self.criterion_ds(
|
||||
noise_pred=noise_pred,
|
||||
noise_target=noise_target,
|
||||
mel_masks=mel_masks, )
|
||||
|
||||
loss_ds = l1_loss_ds
|
||||
|
||||
self.optimizer_ds.clear_grad()
|
||||
loss_ds.backward()
|
||||
self.optimizer_ds.step()
|
||||
|
||||
report("train/loss_ds", float(loss_ds))
|
||||
report("train/l1_loss_ds", float(l1_loss_ds))
|
||||
losses_dict["l1_loss_ds"] = float(l1_loss_ds)
|
||||
losses_dict["loss_ds"] = float(loss_ds)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
self.logger.info(self.msg)
|
||||
|
||||
|
||||
class DiffSingerEvaluator(StandardEvaluator):
|
||||
def __init__(
|
||||
self,
|
||||
model: Layer,
|
||||
criterions: Dict[str, Layer],
|
||||
dataloader: DataLoader,
|
||||
output_dir: Path=None, ):
|
||||
super().__init__(model, dataloader)
|
||||
self.model = model._layers if isinstance(model,
|
||||
paddle.DataParallel) else model
|
||||
|
||||
self.criterions = criterions
|
||||
self.criterion_fs2 = criterions['fs2']
|
||||
self.criterion_ds = criterions['ds']
|
||||
self.dataloader = dataloader
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
self.msg = "Evaluate: "
|
||||
losses_dict = {}
|
||||
# spk_id!=None in multiple spk diffsinger
|
||||
spk_id = batch["spk_id"] if "spk_id" in batch else None
|
||||
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
|
||||
if spk_emb is not None:
|
||||
spk_id = None
|
||||
|
||||
# Here show fastspeech2 eval
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
|
||||
text=batch["text"],
|
||||
note=batch["note"],
|
||||
note_dur=batch["note_dur"],
|
||||
is_slur=batch["is_slur"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
durations=batch["durations"],
|
||||
pitch=batch["pitch"],
|
||||
energy=batch["energy"],
|
||||
spk_id=spk_id,
|
||||
spk_emb=spk_emb,
|
||||
only_train_fs2=True, )
|
||||
|
||||
l1_loss_fs2, ssim_loss_fs2, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion_fs2(
|
||||
after_outs=after_outs,
|
||||
before_outs=before_outs,
|
||||
d_outs=d_outs,
|
||||
p_outs=p_outs,
|
||||
e_outs=e_outs,
|
||||
ys=ys,
|
||||
ds=batch["durations"],
|
||||
ps=batch["pitch"],
|
||||
es=batch["energy"],
|
||||
ilens=batch["text_lengths"],
|
||||
olens=olens,
|
||||
spk_logits=spk_logits,
|
||||
spk_ids=spk_id, )
|
||||
|
||||
loss_fs2 = l1_loss_fs2 + ssim_loss_fs2 + duration_loss + pitch_loss + energy_loss + speaker_loss
|
||||
|
||||
report("eval/loss_fs2", float(loss_fs2))
|
||||
report("eval/l1_loss_fs2", float(l1_loss_fs2))
|
||||
report("eval/ssim_loss_fs2", float(ssim_loss_fs2))
|
||||
report("eval/duration_loss", float(duration_loss))
|
||||
report("eval/pitch_loss", float(pitch_loss))
|
||||
|
||||
losses_dict["l1_loss_fs2"] = float(l1_loss_fs2)
|
||||
losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2)
|
||||
losses_dict["duration_loss"] = float(duration_loss)
|
||||
losses_dict["pitch_loss"] = float(pitch_loss)
|
||||
|
||||
if speaker_loss != 0.:
|
||||
report("eval/speaker_loss", float(speaker_loss))
|
||||
losses_dict["speaker_loss"] = float(speaker_loss)
|
||||
if energy_loss != 0.:
|
||||
report("eval/energy_loss", float(energy_loss))
|
||||
losses_dict["energy_loss"] = float(energy_loss)
|
||||
|
||||
losses_dict["loss_fs2"] = float(loss_fs2)
|
||||
|
||||
# Here show diffusion eval
|
||||
noise_pred, noise_target, mel_masks = self.model(
|
||||
text=batch["text"],
|
||||
note=batch["note"],
|
||||
note_dur=batch["note_dur"],
|
||||
is_slur=batch["is_slur"],
|
||||
text_lengths=batch["text_lengths"],
|
||||
speech=batch["speech"],
|
||||
speech_lengths=batch["speech_lengths"],
|
||||
durations=batch["durations"],
|
||||
pitch=batch["pitch"],
|
||||
energy=batch["energy"],
|
||||
spk_id=spk_id,
|
||||
spk_emb=spk_emb,
|
||||
only_train_fs2=False, )
|
||||
|
||||
noise_pred = noise_pred.transpose((0, 2, 1))
|
||||
noise_target = noise_target.transpose((0, 2, 1))
|
||||
mel_masks = mel_masks.transpose((0, 2, 1))
|
||||
l1_loss_ds = self.criterion_ds(
|
||||
noise_pred=noise_pred,
|
||||
noise_target=noise_target,
|
||||
mel_masks=mel_masks, )
|
||||
|
||||
loss_ds = l1_loss_ds
|
||||
|
||||
report("eval/loss_ds", float(loss_ds))
|
||||
report("eval/l1_loss_ds", float(l1_loss_ds))
|
||||
losses_dict["l1_loss_ds"] = float(l1_loss_ds)
|
||||
losses_dict["loss_ds"] = float(loss_ds)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
self.logger.info(self.msg)
|
@ -0,0 +1,654 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from espnet(https://github.com/espnet/espnet)
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from typeguard import check_argument_types
|
||||
|
||||
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
|
||||
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
|
||||
from paddlespeech.t2s.modules.losses import ssim
|
||||
from paddlespeech.t2s.modules.masked_fill import masked_fill
|
||||
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
||||
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class FastSpeech2MIDI(FastSpeech2):
|
||||
"""The Fastspeech2 module of DiffSinger.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# fastspeech2 network structure related
|
||||
idim: int,
|
||||
odim: int,
|
||||
fastspeech2_params: Dict[str, Any],
|
||||
# note emb
|
||||
note_num: int=300,
|
||||
# is_slur emb
|
||||
is_slur_num: int=2,
|
||||
use_energy_pred: bool=False,
|
||||
use_postnet: bool=False, ):
|
||||
"""Initialize FastSpeech2 module for svs.
|
||||
Args:
|
||||
fastspeech2_params (Dict):
|
||||
The config of FastSpeech2 module on DiffSinger model
|
||||
note_num (Optional[int]):
|
||||
Number of note. If not None, assume that the
|
||||
note_ids will be provided as the input and use note_embedding_table.
|
||||
is_slur_num (Optional[int]):
|
||||
Number of note. If not None, assume that the
|
||||
is_slur_ids will be provided as the input
|
||||
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__(idim=idim, odim=odim, **fastspeech2_params)
|
||||
self.use_energy_pred = use_energy_pred
|
||||
self.use_postnet = use_postnet
|
||||
if not self.use_postnet:
|
||||
self.postnet = None
|
||||
|
||||
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[
|
||||
"adim"]
|
||||
|
||||
# note_ embed
|
||||
self.note_embedding_table = nn.Embedding(
|
||||
num_embeddings=note_num,
|
||||
embedding_dim=self.note_embed_dim,
|
||||
padding_idx=self.padding_idx)
|
||||
self.note_dur_layer = nn.Linear(1, self.note_embed_dim)
|
||||
|
||||
# slur embed
|
||||
self.is_slur_embedding_table = nn.Embedding(
|
||||
num_embeddings=is_slur_num,
|
||||
embedding_dim=self.is_slur_embed_dim,
|
||||
padding_idx=self.padding_idx)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
note: paddle.Tensor,
|
||||
note_dur: paddle.Tensor,
|
||||
is_slur: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
speech: paddle.Tensor,
|
||||
speech_lengths: paddle.Tensor,
|
||||
durations: paddle.Tensor,
|
||||
pitch: paddle.Tensor,
|
||||
energy: paddle.Tensor,
|
||||
spk_emb: paddle.Tensor=None,
|
||||
spk_id: paddle.Tensor=None,
|
||||
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
text(Tensor(int64)):
|
||||
Batch of padded token (phone) ids (B, Tmax).
|
||||
note(Tensor(int64)):
|
||||
Batch of padded note (element in music score) ids (B, Tmax).
|
||||
note_dur(Tensor(float32)):
|
||||
Batch of padded note durations in seconds (element in music score) (B, Tmax).
|
||||
is_slur(Tensor(int64)):
|
||||
Batch of padded slur (element in music score) ids (B, Tmax).
|
||||
text_lengths(Tensor(int64)):
|
||||
Batch of phone lengths of each input (B,).
|
||||
speech(Tensor[float32]):
|
||||
Batch of padded target features (e.g. mel) (B, Lmax, odim).
|
||||
speech_lengths(Tensor(int64)):
|
||||
Batch of the lengths of each target features (B,).
|
||||
durations(Tensor(int64)):
|
||||
Batch of padded token durations in frame (B, Tmax).
|
||||
pitch(Tensor[float32]):
|
||||
Batch of padded frame-averaged pitch (B, Lmax, 1).
|
||||
energy(Tensor[float32]):
|
||||
Batch of padded frame-averaged energy (B, Lmax, 1).
|
||||
spk_emb(Tensor[float32], optional):
|
||||
Batch of speaker embeddings (B, spk_embed_dim).
|
||||
spk_id(Tnesor[int64], optional(int64)):
|
||||
Batch of speaker ids (B,)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
xs = paddle.cast(text, 'int64')
|
||||
note = paddle.cast(note, 'int64')
|
||||
note_dur = paddle.cast(note_dur, 'float32')
|
||||
is_slur = paddle.cast(is_slur, 'int64')
|
||||
ilens = paddle.cast(text_lengths, 'int64')
|
||||
olens = paddle.cast(speech_lengths, 'int64')
|
||||
ds = paddle.cast(durations, 'int64')
|
||||
ps = pitch
|
||||
es = energy
|
||||
ys = speech
|
||||
olens = speech_lengths
|
||||
if spk_id is not None:
|
||||
spk_id = paddle.cast(spk_id, 'int64')
|
||||
# forward propagation
|
||||
before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward(
|
||||
xs=xs,
|
||||
note=note,
|
||||
note_dur=note_dur,
|
||||
is_slur=is_slur,
|
||||
ilens=ilens,
|
||||
olens=olens,
|
||||
ds=ds,
|
||||
ps=ps,
|
||||
es=es,
|
||||
is_inference=False,
|
||||
spk_emb=spk_emb,
|
||||
spk_id=spk_id, )
|
||||
# modify mod part of groundtruth
|
||||
if self.reduction_factor > 1:
|
||||
olens = olens - olens % self.reduction_factor
|
||||
max_olen = max(olens)
|
||||
ys = ys[:, :max_olen]
|
||||
|
||||
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
xs: paddle.Tensor,
|
||||
note: paddle.Tensor,
|
||||
note_dur: paddle.Tensor,
|
||||
is_slur: paddle.Tensor,
|
||||
ilens: paddle.Tensor,
|
||||
olens: paddle.Tensor=None,
|
||||
ds: paddle.Tensor=None,
|
||||
ps: paddle.Tensor=None,
|
||||
es: paddle.Tensor=None,
|
||||
is_inference: bool=False,
|
||||
is_train_diffusion: bool=False,
|
||||
return_after_enc=False,
|
||||
alpha: float=1.0,
|
||||
spk_emb=None,
|
||||
spk_id=None, ) -> Sequence[paddle.Tensor]:
|
||||
|
||||
before_outs = after_outs = d_outs = p_outs = e_outs = spk_logits = None
|
||||
# forward encoder
|
||||
masks = self._source_mask(ilens)
|
||||
note_emb = self.note_embedding_table(note)
|
||||
note_dur_emb = self.note_dur_layer(paddle.unsqueeze(note_dur, axis=-1))
|
||||
is_slur_emb = self.is_slur_embedding_table(is_slur)
|
||||
|
||||
# (B, Tmax, adim)
|
||||
hs, _ = self.encoder(
|
||||
xs=xs,
|
||||
masks=masks,
|
||||
note_emb=note_emb,
|
||||
note_dur_emb=note_dur_emb,
|
||||
is_slur_emb=is_slur_emb, )
|
||||
|
||||
if self.spk_num and self.enable_speaker_classifier and not is_inference:
|
||||
hs_for_spk_cls = self.grad_reverse(hs)
|
||||
spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens)
|
||||
else:
|
||||
spk_logits = None
|
||||
|
||||
# integrate speaker embedding
|
||||
if self.spk_embed_dim is not None:
|
||||
# spk_emb has a higher priority than spk_id
|
||||
if spk_emb is not None:
|
||||
hs = self._integrate_with_spk_embed(hs, spk_emb)
|
||||
elif spk_id is not None:
|
||||
spk_emb = self.spk_embedding_table(spk_id)
|
||||
hs = self._integrate_with_spk_embed(hs, spk_emb)
|
||||
|
||||
# forward duration predictor (phone-level) and variance predictors (frame-level)
|
||||
d_masks = make_pad_mask(ilens)
|
||||
if olens is not None:
|
||||
pitch_masks = make_pad_mask(olens).unsqueeze(-1)
|
||||
else:
|
||||
pitch_masks = None
|
||||
|
||||
# inference for decoder input for diffusion
|
||||
if is_train_diffusion:
|
||||
hs = self.length_regulator(hs, ds, is_inference=False)
|
||||
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
|
||||
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
|
||||
(0, 2, 1))
|
||||
hs += p_embs
|
||||
if self.use_energy_pred:
|
||||
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
||||
e_embs = self.energy_embed(
|
||||
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
||||
hs += e_embs
|
||||
|
||||
elif is_inference:
|
||||
# (B, Tmax)
|
||||
if ds is not None:
|
||||
d_outs = ds
|
||||
else:
|
||||
d_outs = self.duration_predictor.inference(hs, d_masks)
|
||||
|
||||
# (B, Lmax, adim)
|
||||
hs = self.length_regulator(hs, d_outs, alpha, is_inference=True)
|
||||
|
||||
if ps is not None:
|
||||
p_outs = ps
|
||||
else:
|
||||
if self.stop_gradient_from_pitch_predictor:
|
||||
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
|
||||
else:
|
||||
p_outs = self.pitch_predictor(hs, pitch_masks)
|
||||
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
|
||||
(0, 2, 1))
|
||||
hs += p_embs
|
||||
|
||||
if self.use_energy_pred:
|
||||
if es is not None:
|
||||
e_outs = es
|
||||
else:
|
||||
if self.stop_gradient_from_energy_predictor:
|
||||
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
||||
else:
|
||||
e_outs = self.energy_predictor(hs, pitch_masks)
|
||||
e_embs = self.energy_embed(
|
||||
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
||||
hs += e_embs
|
||||
|
||||
# training
|
||||
else:
|
||||
d_outs = self.duration_predictor(hs, d_masks)
|
||||
# (B, Lmax, adim)
|
||||
hs = self.length_regulator(hs, ds, is_inference=False)
|
||||
if self.stop_gradient_from_pitch_predictor:
|
||||
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
|
||||
else:
|
||||
p_outs = self.pitch_predictor(hs, pitch_masks)
|
||||
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
|
||||
(0, 2, 1))
|
||||
hs += p_embs
|
||||
|
||||
if self.use_energy_pred:
|
||||
if self.stop_gradient_from_energy_predictor:
|
||||
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
|
||||
else:
|
||||
e_outs = self.energy_predictor(hs, pitch_masks)
|
||||
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
|
||||
(0, 2, 1))
|
||||
hs += e_embs
|
||||
|
||||
# forward decoder
|
||||
if olens is not None and not is_inference:
|
||||
if self.reduction_factor > 1:
|
||||
olens_in = paddle.to_tensor(
|
||||
[olen // self.reduction_factor for olen in olens.numpy()])
|
||||
else:
|
||||
olens_in = olens
|
||||
# (B, 1, T)
|
||||
h_masks = self._source_mask(olens_in)
|
||||
else:
|
||||
h_masks = None
|
||||
|
||||
if return_after_enc:
|
||||
return hs, h_masks
|
||||
|
||||
if self.decoder_type == 'cnndecoder':
|
||||
# remove output masks for dygraph to static graph
|
||||
zs = self.decoder(hs, h_masks)
|
||||
before_outs = zs
|
||||
else:
|
||||
# (B, Lmax, adim)
|
||||
zs, _ = self.decoder(hs, h_masks)
|
||||
# (B, Lmax, odim)
|
||||
before_outs = self.feat_out(zs).reshape(
|
||||
(paddle.shape(zs)[0], -1, self.odim))
|
||||
|
||||
# postnet -> (B, Lmax//r * r, odim)
|
||||
if self.postnet is None:
|
||||
after_outs = before_outs
|
||||
else:
|
||||
after_outs = before_outs + self.postnet(
|
||||
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
|
||||
|
||||
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
|
||||
|
||||
def encoder_infer(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
note: paddle.Tensor,
|
||||
note_dur: paddle.Tensor,
|
||||
is_slur: paddle.Tensor,
|
||||
alpha: float=1.0,
|
||||
spk_emb=None,
|
||||
spk_id=None,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
xs = paddle.cast(text, 'int64').unsqueeze(0)
|
||||
note = paddle.cast(note, 'int64').unsqueeze(0)
|
||||
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
|
||||
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
|
||||
# setup batch axis
|
||||
ilens = paddle.shape(xs)[1]
|
||||
|
||||
if spk_emb is not None:
|
||||
spk_emb = spk_emb.unsqueeze(0)
|
||||
|
||||
# (1, L, odim)
|
||||
# use *_ to avoid bug in dygraph to static graph
|
||||
hs, _ = self._forward(
|
||||
xs=xs,
|
||||
note=note,
|
||||
note_dur=note_dur,
|
||||
is_slur=is_slur,
|
||||
ilens=ilens,
|
||||
is_inference=True,
|
||||
return_after_enc=True,
|
||||
alpha=alpha,
|
||||
spk_emb=spk_emb,
|
||||
spk_id=spk_id, )
|
||||
return hs
|
||||
|
||||
# get encoder output for diffusion training
|
||||
def encoder_infer_batch(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
note: paddle.Tensor,
|
||||
note_dur: paddle.Tensor,
|
||||
is_slur: paddle.Tensor,
|
||||
text_lengths: paddle.Tensor,
|
||||
speech_lengths: paddle.Tensor,
|
||||
ds: paddle.Tensor=None,
|
||||
ps: paddle.Tensor=None,
|
||||
es: paddle.Tensor=None,
|
||||
alpha: float=1.0,
|
||||
spk_emb=None,
|
||||
spk_id=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||
|
||||
xs = paddle.cast(text, 'int64')
|
||||
note = paddle.cast(note, 'int64')
|
||||
note_dur = paddle.cast(note_dur, 'float32')
|
||||
is_slur = paddle.cast(is_slur, 'int64')
|
||||
ilens = paddle.cast(text_lengths, 'int64')
|
||||
olens = paddle.cast(speech_lengths, 'int64')
|
||||
|
||||
if spk_emb is not None:
|
||||
spk_emb = spk_emb.unsqueeze(0)
|
||||
|
||||
# (1, L, odim)
|
||||
# use *_ to avoid bug in dygraph to static graph
|
||||
hs, h_masks = self._forward(
|
||||
xs=xs,
|
||||
note=note,
|
||||
note_dur=note_dur,
|
||||
is_slur=is_slur,
|
||||
ilens=ilens,
|
||||
olens=olens,
|
||||
ds=ds,
|
||||
ps=ps,
|
||||
es=es,
|
||||
return_after_enc=True,
|
||||
is_train_diffusion=True,
|
||||
alpha=alpha,
|
||||
spk_emb=spk_emb,
|
||||
spk_id=spk_id, )
|
||||
return hs, h_masks
|
||||
|
||||
def inference(
|
||||
self,
|
||||
text: paddle.Tensor,
|
||||
note: paddle.Tensor,
|
||||
note_dur: paddle.Tensor,
|
||||
is_slur: paddle.Tensor,
|
||||
durations: paddle.Tensor=None,
|
||||
pitch: paddle.Tensor=None,
|
||||
energy: paddle.Tensor=None,
|
||||
alpha: float=1.0,
|
||||
use_teacher_forcing: bool=False,
|
||||
spk_emb=None,
|
||||
spk_id=None,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
|
||||
"""Generate the sequence of features given the sequences of characters.
|
||||
|
||||
Args:
|
||||
text(Tensor(int64)):
|
||||
Input sequence of characters (T,).
|
||||
note(Tensor(int64)):
|
||||
Input note (element in music score) ids (T,).
|
||||
note_dur(Tensor(float32)):
|
||||
Input note durations in seconds (element in music score) (T,).
|
||||
is_slur(Tensor(int64)):
|
||||
Input slur (element in music score) ids (T,).
|
||||
durations(Tensor, optional (int64)):
|
||||
Groundtruth of duration (T,).
|
||||
pitch(Tensor, optional):
|
||||
Groundtruth of token-averaged pitch (T, 1).
|
||||
energy(Tensor, optional):
|
||||
Groundtruth of token-averaged energy (T, 1).
|
||||
alpha(float, optional):
|
||||
Alpha to control the speed.
|
||||
use_teacher_forcing(bool, optional):
|
||||
Whether to use teacher forcing.
|
||||
If true, groundtruth of duration, pitch and energy will be used.
|
||||
spk_emb(Tensor, optional, optional):
|
||||
peaker embedding vector (spk_embed_dim,). (Default value = None)
|
||||
spk_id(Tensor, optional(int64), optional):
|
||||
spk ids (1,). (Default value = None)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
xs = paddle.cast(text, 'int64').unsqueeze(0)
|
||||
note = paddle.cast(note, 'int64').unsqueeze(0)
|
||||
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
|
||||
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
|
||||
d, p, e = durations, pitch, energy
|
||||
# setup batch axis
|
||||
ilens = paddle.shape(xs)[1]
|
||||
|
||||
if spk_emb is not None:
|
||||
spk_emb = spk_emb.unsqueeze(0)
|
||||
|
||||
if use_teacher_forcing:
|
||||
# use groundtruth of duration, pitch, and energy
|
||||
ds = d.unsqueeze(0) if d is not None else None
|
||||
ps = p.unsqueeze(0) if p is not None else None
|
||||
es = e.unsqueeze(0) if e is not None else None
|
||||
|
||||
# (1, L, odim)
|
||||
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
|
||||
xs=xs,
|
||||
note=note,
|
||||
note_dur=note_dur,
|
||||
is_slur=is_slur,
|
||||
ilens=ilens,
|
||||
ds=ds,
|
||||
ps=ps,
|
||||
es=es,
|
||||
spk_emb=spk_emb,
|
||||
spk_id=spk_id,
|
||||
is_inference=True)
|
||||
else:
|
||||
# (1, L, odim)
|
||||
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
|
||||
xs=xs,
|
||||
note=note,
|
||||
note_dur=note_dur,
|
||||
is_slur=is_slur,
|
||||
ilens=ilens,
|
||||
is_inference=True,
|
||||
alpha=alpha,
|
||||
spk_emb=spk_emb,
|
||||
spk_id=spk_id, )
|
||||
|
||||
if e_outs is None:
|
||||
e_outs = [None]
|
||||
|
||||
return outs[0], d_outs[0], p_outs[0], e_outs[0]
|
||||
|
||||
|
||||
class FastSpeech2MIDILoss(FastSpeech2Loss):
|
||||
"""Loss function module for DiffSinger."""
|
||||
|
||||
def __init__(self, use_masking: bool=True,
|
||||
use_weighted_masking: bool=False):
|
||||
"""Initialize feed-forward Transformer loss module.
|
||||
Args:
|
||||
use_masking (bool):
|
||||
Whether to apply masking for padded part in loss calculation.
|
||||
use_weighted_masking (bool):
|
||||
Whether to weighted masking in loss calculation.
|
||||
"""
|
||||
assert check_argument_types()
|
||||
super().__init__(use_masking, use_weighted_masking)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
after_outs: paddle.Tensor,
|
||||
before_outs: paddle.Tensor,
|
||||
d_outs: paddle.Tensor,
|
||||
p_outs: paddle.Tensor,
|
||||
e_outs: paddle.Tensor,
|
||||
ys: paddle.Tensor,
|
||||
ds: paddle.Tensor,
|
||||
ps: paddle.Tensor,
|
||||
es: paddle.Tensor,
|
||||
ilens: paddle.Tensor,
|
||||
olens: paddle.Tensor,
|
||||
spk_logits: paddle.Tensor=None,
|
||||
spk_ids: paddle.Tensor=None,
|
||||
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
|
||||
paddle.Tensor, ]:
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
after_outs(Tensor):
|
||||
Batch of outputs after postnets (B, Lmax, odim).
|
||||
before_outs(Tensor):
|
||||
Batch of outputs before postnets (B, Lmax, odim).
|
||||
d_outs(Tensor):
|
||||
Batch of outputs of duration predictor (B, Tmax).
|
||||
p_outs(Tensor):
|
||||
Batch of outputs of pitch predictor (B, Lmax, 1).
|
||||
e_outs(Tensor):
|
||||
Batch of outputs of energy predictor (B, Lmax, 1).
|
||||
ys(Tensor):
|
||||
Batch of target features (B, Lmax, odim).
|
||||
ds(Tensor):
|
||||
Batch of durations (B, Tmax).
|
||||
ps(Tensor):
|
||||
Batch of target frame-averaged pitch (B, Lmax, 1).
|
||||
es(Tensor):
|
||||
Batch of target frame-averaged energy (B, Lmax, 1).
|
||||
ilens(Tensor):
|
||||
Batch of the lengths of each input (B,).
|
||||
olens(Tensor):
|
||||
Batch of the lengths of each target (B,).
|
||||
spk_logits(Option[Tensor]):
|
||||
Batch of outputs after speaker classifier (B, Lmax, num_spk)
|
||||
spk_ids(Option[Tensor]):
|
||||
Batch of target spk_id (B,)
|
||||
|
||||
|
||||
Returns:
|
||||
|
||||
|
||||
"""
|
||||
l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0
|
||||
|
||||
# apply mask to remove padded part
|
||||
if self.use_masking:
|
||||
# make feature for ssim loss
|
||||
out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
|
||||
before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0)
|
||||
if not paddle.equal_all(after_outs, before_outs):
|
||||
after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0)
|
||||
ys_ssim = masked_fill(ys, out_pad_masks, 0.0)
|
||||
|
||||
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
|
||||
before_outs = before_outs.masked_select(
|
||||
out_masks.broadcast_to(before_outs.shape))
|
||||
if not paddle.equal_all(after_outs, before_outs):
|
||||
after_outs = after_outs.masked_select(
|
||||
out_masks.broadcast_to(after_outs.shape))
|
||||
ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
|
||||
duration_masks = make_non_pad_mask(ilens)
|
||||
d_outs = d_outs.masked_select(
|
||||
duration_masks.broadcast_to(d_outs.shape))
|
||||
ds = ds.masked_select(duration_masks.broadcast_to(ds.shape))
|
||||
pitch_masks = out_masks
|
||||
p_outs = p_outs.masked_select(
|
||||
pitch_masks.broadcast_to(p_outs.shape))
|
||||
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
|
||||
if e_outs is not None:
|
||||
e_outs = e_outs.masked_select(
|
||||
pitch_masks.broadcast_to(e_outs.shape))
|
||||
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
|
||||
|
||||
if spk_logits is not None and spk_ids is not None:
|
||||
batch_size = spk_ids.shape[0]
|
||||
spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1],
|
||||
None)
|
||||
spk_logits = paddle.reshape(spk_logits,
|
||||
[-1, spk_logits.shape[-1]])
|
||||
mask_index = spk_logits.abs().sum(axis=1) != 0
|
||||
spk_ids = spk_ids[mask_index]
|
||||
spk_logits = spk_logits[mask_index]
|
||||
|
||||
# calculate loss
|
||||
l1_loss = self.l1_criterion(before_outs, ys)
|
||||
ssim_loss = 1.0 - ssim(
|
||||
before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))
|
||||
if not paddle.equal_all(after_outs, before_outs):
|
||||
l1_loss += self.l1_criterion(after_outs, ys)
|
||||
ssim_loss += (
|
||||
1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)))
|
||||
l1_loss = l1_loss * 0.5
|
||||
ssim_loss = ssim_loss * 0.5
|
||||
|
||||
duration_loss = self.duration_criterion(d_outs, ds)
|
||||
pitch_loss = self.l1_criterion(p_outs, ps)
|
||||
if e_outs is not None:
|
||||
energy_loss = self.l1_criterion(e_outs, es)
|
||||
|
||||
if spk_logits is not None and spk_ids is not None:
|
||||
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
|
||||
|
||||
# make weighted mask and apply it
|
||||
if self.use_weighted_masking:
|
||||
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
|
||||
out_weights = out_masks.cast(dtype=paddle.float32) / out_masks.cast(
|
||||
dtype=paddle.float32).sum(
|
||||
axis=1, keepdim=True)
|
||||
out_weights /= ys.shape[0] * ys.shape[2]
|
||||
duration_masks = make_non_pad_mask(ilens)
|
||||
duration_weights = (duration_masks.cast(dtype=paddle.float32) /
|
||||
duration_masks.cast(dtype=paddle.float32).sum(
|
||||
axis=1, keepdim=True))
|
||||
duration_weights /= ds.shape[0]
|
||||
|
||||
# apply weight
|
||||
l1_loss = l1_loss.multiply(out_weights)
|
||||
l1_loss = l1_loss.masked_select(
|
||||
out_masks.broadcast_to(l1_loss.shape)).sum()
|
||||
ssim_loss = ssim_loss.multiply(out_weights)
|
||||
ssim_loss = ssim_loss.masked_select(
|
||||
out_masks.broadcast_to(ssim_loss.shape)).sum()
|
||||
duration_loss = (duration_loss.multiply(duration_weights)
|
||||
.masked_select(duration_masks).sum())
|
||||
pitch_masks = out_masks
|
||||
pitch_weights = out_weights
|
||||
pitch_loss = pitch_loss.multiply(pitch_weights)
|
||||
pitch_loss = pitch_loss.masked_select(
|
||||
pitch_masks.broadcast_to(pitch_loss.shape)).sum()
|
||||
if e_outs is not None:
|
||||
energy_loss = energy_loss.multiply(pitch_weights)
|
||||
energy_loss = energy_loss.masked_select(
|
||||
pitch_masks.broadcast_to(energy_loss.shape)).sum()
|
||||
|
||||
return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss
|
@ -0,0 +1,245 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.modules.nets_utils import initialize
|
||||
from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out
|
||||
from paddlespeech.utils.initialize import kaiming_normal_
|
||||
from paddlespeech.utils.initialize import kaiming_uniform_
|
||||
from paddlespeech.utils.initialize import uniform_
|
||||
from paddlespeech.utils.initialize import zeros_
|
||||
|
||||
|
||||
def Conv1D(*args, **kwargs):
|
||||
layer = nn.Conv1D(*args, **kwargs)
|
||||
# Initialize the weight to be consistent with the official
|
||||
kaiming_normal_(layer.weight)
|
||||
|
||||
# Initialization is consistent with torch
|
||||
if layer.bias is not None:
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(layer.weight)
|
||||
if fan_in != 0:
|
||||
bound = 1 / math.sqrt(fan_in)
|
||||
uniform_(layer.bias, -bound, bound)
|
||||
return layer
|
||||
|
||||
|
||||
# Initialization is consistent with torch
|
||||
def Linear(*args, **kwargs):
|
||||
layer = nn.Linear(*args, **kwargs)
|
||||
kaiming_uniform_(layer.weight, a=math.sqrt(5))
|
||||
if layer.bias is not None:
|
||||
fan_in, _ = _calculate_fan_in_and_fan_out(layer.weight)
|
||||
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
||||
uniform_(layer.bias, -bound, bound)
|
||||
return layer
|
||||
|
||||
|
||||
class ResidualBlock(nn.Layer):
|
||||
"""ResidualBlock
|
||||
|
||||
Args:
|
||||
encoder_hidden (int, optional):
|
||||
Input feature size of the 1D convolution, by default 256
|
||||
residual_channels (int, optional):
|
||||
Feature size of the residual output(and also the input), by default 256
|
||||
gate_channels (int, optional):
|
||||
Output feature size of the 1D convolution, by default 512
|
||||
kernel_size (int, optional):
|
||||
Kernel size of the 1D convolution, by default 3
|
||||
dilation (int, optional):
|
||||
Dilation of the 1D convolution, by default 4
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
encoder_hidden: int=256,
|
||||
residual_channels: int=256,
|
||||
gate_channels: int=512,
|
||||
kernel_size: int=3,
|
||||
dilation: int=4):
|
||||
super().__init__()
|
||||
self.dilated_conv = Conv1D(
|
||||
residual_channels,
|
||||
gate_channels,
|
||||
kernel_size,
|
||||
padding=dilation,
|
||||
dilation=dilation)
|
||||
self.diffusion_projection = Linear(residual_channels, residual_channels)
|
||||
self.conditioner_projection = Conv1D(encoder_hidden, gate_channels, 1)
|
||||
self.output_projection = Conv1D(residual_channels, gate_channels, 1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: paddle.Tensor,
|
||||
diffusion_step: paddle.Tensor,
|
||||
cond: paddle.Tensor, ):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
spec (Tensor(float32)): input feature. (B, residual_channels, T)
|
||||
diffusion_step (Tensor(int64)): The timestep input (adding noise step). (B,)
|
||||
cond (Tensor(float32)): The auxiliary input (e.g. fastspeech2 encoder output). (B, residual_channels, T)
|
||||
|
||||
Returns:
|
||||
x (Tensor(float32)): output (B, residual_channels, T)
|
||||
|
||||
"""
|
||||
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
|
||||
cond = self.conditioner_projection(cond)
|
||||
y = x + diffusion_step
|
||||
|
||||
y = self.dilated_conv(y) + cond
|
||||
|
||||
gate, filter = paddle.chunk(y, 2, axis=1)
|
||||
y = F.sigmoid(gate) * paddle.tanh(filter)
|
||||
|
||||
y = self.output_projection(y)
|
||||
residual, skip = paddle.chunk(y, 2, axis=1)
|
||||
return (x + residual) / math.sqrt(2.0), skip
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Layer):
|
||||
"""Positional embedding
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int=256):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x: paddle.Tensor):
|
||||
x = paddle.cast(x, 'float32')
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = paddle.exp(paddle.arange(half_dim) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = paddle.concat([emb.sin(), emb.cos()], axis=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class DiffNet(nn.Layer):
|
||||
"""A Mel-Spectrogram Denoiser
|
||||
|
||||
Args:
|
||||
in_channels (int, optional):
|
||||
Number of channels of the input mel-spectrogram, by default 80
|
||||
out_channels (int, optional):
|
||||
Number of channels of the output mel-spectrogram, by default 80
|
||||
kernel_size (int, optional):
|
||||
Kernel size of the residual blocks inside, by default 3
|
||||
layers (int, optional):
|
||||
Number of residual blocks inside, by default 20
|
||||
stacks (int, optional):
|
||||
The number of groups to split the residual blocks into, by default 5
|
||||
Within each group, the dilation of the residual block grows exponentially.
|
||||
residual_channels (int, optional):
|
||||
Residual channel of the residual blocks, by default 256
|
||||
gate_channels (int, optional):
|
||||
Gate channel of the residual blocks, by default 512
|
||||
skip_channels (int, optional):
|
||||
Skip channel of the residual blocks, by default 256
|
||||
aux_channels (int, optional):
|
||||
Auxiliary channel of the residual blocks, by default 256
|
||||
dropout (float, optional):
|
||||
Dropout of the residual blocks, by default 0.
|
||||
bias (bool, optional):
|
||||
Whether to use bias in residual blocks, by default True
|
||||
use_weight_norm (bool, optional):
|
||||
Whether to use weight norm in all convolutions, by default False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int=80,
|
||||
out_channels: int=80,
|
||||
kernel_size: int=3,
|
||||
layers: int=20,
|
||||
stacks: int=5,
|
||||
residual_channels: int=256,
|
||||
gate_channels: int=512,
|
||||
skip_channels: int=256,
|
||||
aux_channels: int=256,
|
||||
dropout: float=0.,
|
||||
bias: bool=True,
|
||||
use_weight_norm: bool=False,
|
||||
init_type: str="kaiming_normal", ):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.layers = layers
|
||||
self.aux_channels = aux_channels
|
||||
self.residual_channels = residual_channels
|
||||
self.gate_channels = gate_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_cycle_length = layers // stacks
|
||||
self.skip_channels = skip_channels
|
||||
|
||||
self.input_projection = Conv1D(self.in_channels, self.residual_channels,
|
||||
1)
|
||||
self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels)
|
||||
dim = self.residual_channels
|
||||
self.mlp = nn.Sequential(
|
||||
Linear(dim, dim * 4), nn.Mish(), Linear(dim * 4, dim))
|
||||
self.residual_layers = nn.LayerList([
|
||||
ResidualBlock(
|
||||
encoder_hidden=self.aux_channels,
|
||||
residual_channels=self.residual_channels,
|
||||
gate_channels=self.gate_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
dilation=2**(i % self.dilation_cycle_length))
|
||||
for i in range(self.layers)
|
||||
])
|
||||
self.skip_projection = Conv1D(self.residual_channels,
|
||||
self.skip_channels, 1)
|
||||
self.output_projection = Conv1D(self.residual_channels,
|
||||
self.out_channels, 1)
|
||||
zeros_(self.output_projection.weight)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
spec: paddle.Tensor,
|
||||
diffusion_step: paddle.Tensor,
|
||||
cond: paddle.Tensor, ):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
spec (Tensor(float32)): The input mel-spectrogram. (B, n_mel, T)
|
||||
diffusion_step (Tensor(int64)): The timestep input (adding noise step). (B,)
|
||||
cond (Tensor(float32)): The auxiliary input (e.g. fastspeech2 encoder output). (B, D_enc_out, T)
|
||||
|
||||
Returns:
|
||||
x (Tensor(float32)): pred noise (B, n_mel, T)
|
||||
|
||||
"""
|
||||
x = spec
|
||||
x = self.input_projection(x) # x [B, residual_channel, T]
|
||||
|
||||
x = F.relu(x)
|
||||
diffusion_step = self.diffusion_embedding(diffusion_step)
|
||||
diffusion_step = self.mlp(diffusion_step)
|
||||
skip = []
|
||||
for layer_id, layer in enumerate(self.residual_layers):
|
||||
x, skip_connection = layer(
|
||||
x=x,
|
||||
diffusion_step=diffusion_step,
|
||||
cond=cond, )
|
||||
skip.append(skip_connection)
|
||||
x = paddle.sum(
|
||||
paddle.stack(skip), axis=0) / math.sqrt(len(self.residual_layers))
|
||||
x = self.skip_projection(x)
|
||||
x = F.relu(x)
|
||||
x = self.output_projection(x) # [B, 80, T]
|
||||
return x
|
@ -0,0 +1,191 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import ppdiffusers
|
||||
from paddle import nn
|
||||
from ppdiffusers.models.embeddings import Timesteps
|
||||
from ppdiffusers.schedulers import DDPMScheduler
|
||||
|
||||
from paddlespeech.t2s.modules.nets_utils import initialize
|
||||
from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock
|
||||
|
||||
|
||||
class WaveNetDenoiser(nn.Layer):
|
||||
"""A Mel-Spectrogram Denoiser modified from WaveNet
|
||||
|
||||
Args:
|
||||
in_channels (int, optional):
|
||||
Number of channels of the input mel-spectrogram, by default 80
|
||||
out_channels (int, optional):
|
||||
Number of channels of the output mel-spectrogram, by default 80
|
||||
kernel_size (int, optional):
|
||||
Kernel size of the residual blocks inside, by default 3
|
||||
layers (int, optional):
|
||||
Number of residual blocks inside, by default 20
|
||||
stacks (int, optional):
|
||||
The number of groups to split the residual blocks into, by default 5
|
||||
Within each group, the dilation of the residual block grows exponentially.
|
||||
residual_channels (int, optional):
|
||||
Residual channel of the residual blocks, by default 256
|
||||
gate_channels (int, optional):
|
||||
Gate channel of the residual blocks, by default 512
|
||||
skip_channels (int, optional):
|
||||
Skip channel of the residual blocks, by default 256
|
||||
aux_channels (int, optional):
|
||||
Auxiliary channel of the residual blocks, by default 256
|
||||
dropout (float, optional):
|
||||
Dropout of the residual blocks, by default 0.
|
||||
bias (bool, optional):
|
||||
Whether to use bias in residual blocks, by default True
|
||||
use_weight_norm (bool, optional):
|
||||
Whether to use weight norm in all convolutions, by default False
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int=80,
|
||||
out_channels: int=80,
|
||||
kernel_size: int=3,
|
||||
layers: int=20,
|
||||
stacks: int=5,
|
||||
residual_channels: int=256,
|
||||
gate_channels: int=512,
|
||||
skip_channels: int=256,
|
||||
aux_channels: int=256,
|
||||
dropout: float=0.,
|
||||
bias: bool=True,
|
||||
use_weight_norm: bool=False,
|
||||
init_type: str="kaiming_normal", ):
|
||||
super().__init__()
|
||||
|
||||
# initialize parameters
|
||||
initialize(self, init_type)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.aux_channels = aux_channels
|
||||
self.layers = layers
|
||||
self.stacks = stacks
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
assert layers % stacks == 0
|
||||
layers_per_stack = layers // stacks
|
||||
|
||||
self.first_t_emb = nn.Sequential(
|
||||
Timesteps(
|
||||
residual_channels,
|
||||
flip_sin_to_cos=False,
|
||||
downscale_freq_shift=1),
|
||||
nn.Linear(residual_channels, residual_channels * 4),
|
||||
nn.Mish(), nn.Linear(residual_channels * 4, residual_channels))
|
||||
self.t_emb_layers = nn.LayerList([
|
||||
nn.Linear(residual_channels, residual_channels)
|
||||
for _ in range(layers)
|
||||
])
|
||||
|
||||
self.first_conv = nn.Conv1D(
|
||||
in_channels, residual_channels, 1, bias_attr=True)
|
||||
self.first_act = nn.ReLU()
|
||||
|
||||
self.conv_layers = nn.LayerList()
|
||||
for layer in range(layers):
|
||||
dilation = 2**(layer % layers_per_stack)
|
||||
conv = WaveNetResidualBlock(
|
||||
kernel_size=kernel_size,
|
||||
residual_channels=residual_channels,
|
||||
gate_channels=gate_channels,
|
||||
skip_channels=skip_channels,
|
||||
aux_channels=aux_channels,
|
||||
dilation=dilation,
|
||||
dropout=dropout,
|
||||
bias=bias)
|
||||
self.conv_layers.append(conv)
|
||||
|
||||
final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
|
||||
nn.initializer.Constant(0.0)(final_conv.weight)
|
||||
self.last_conv_layers = nn.Sequential(nn.ReLU(),
|
||||
nn.Conv1D(
|
||||
skip_channels,
|
||||
skip_channels,
|
||||
1,
|
||||
bias_attr=True),
|
||||
nn.ReLU(), final_conv)
|
||||
|
||||
if use_weight_norm:
|
||||
self.apply_weight_norm()
|
||||
|
||||
def forward(self, x: paddle.Tensor, t: paddle.Tensor, c: paddle.Tensor):
|
||||
"""Denoise mel-spectrogram.
|
||||
|
||||
Args:
|
||||
x(Tensor):
|
||||
Shape (B, C_in, T), The input mel-spectrogram.
|
||||
t(Tensor):
|
||||
Shape (B), The timestep input.
|
||||
c(Tensor):
|
||||
Shape (B, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output).
|
||||
|
||||
Returns:
|
||||
Tensor: Shape (B, C_out, T), the pred noise.
|
||||
"""
|
||||
assert c.shape[-1] == x.shape[-1]
|
||||
|
||||
if t.shape[0] != x.shape[0]:
|
||||
t = t.tile([x.shape[0]])
|
||||
t_emb = self.first_t_emb(t)
|
||||
t_embs = [
|
||||
t_emb_layer(t_emb)[..., None] for t_emb_layer in self.t_emb_layers
|
||||
]
|
||||
|
||||
x = self.first_conv(x)
|
||||
x = self.first_act(x)
|
||||
skips = 0
|
||||
for f, t in zip(self.conv_layers, t_embs):
|
||||
x = x + t
|
||||
x, s = f(x, c)
|
||||
skips += s
|
||||
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
||||
|
||||
x = self.last_conv_layers(skips)
|
||||
return x
|
||||
|
||||
def apply_weight_norm(self):
|
||||
"""Recursively apply weight normalization to all the Convolution layers
|
||||
in the sublayers.
|
||||
"""
|
||||
|
||||
def _apply_weight_norm(layer):
|
||||
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
|
||||
nn.utils.weight_norm(layer)
|
||||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
"""Recursively remove weight normalization from all the Convolution
|
||||
layers in the sublayers.
|
||||
"""
|
||||
|
||||
def _remove_weight_norm(layer):
|
||||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self.apply(_remove_weight_norm)
|
Loading…
Reference in new issue