[TTS]add Diffsinger with opencpop dataset (#3005)

pull/3031/head
liangym 3 years ago committed by GitHub
parent acf943007e
commit 1afd14acd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,174 @@
([简体中文](./README_cn.md)|English)
# DiffSinger with Opencpop
This example contains code used to train a [DiffSinger](https://arxiv.org/abs/2105.02446) model with [Mandarin singing corpus](https://wenet.org.cn/opencpop/).
## Dataset
### Download and Extract
Download Opencpop from it's [Official Website](https://wenet.org.cn/opencpop/download/) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/Opencpop`.
## Get Started
Assume the path to the dataset is `~/datasets/Opencpop`.
Run the command below to
1. **source path**.
2. preprocess the dataset.
3. train the model.
4. synthesize wavs.
- synthesize waveform from `metadata.jsonl`.
- (Supporting) synthesize waveform from a text file.
5. (Supporting) inference using the static model.
```bash
./run.sh
```
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
```bash
./run.sh --stage 0 --stop-stage 0
```
### Data Preprocessing
```bash
./local/preprocess.sh ${conf_path}
```
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
```text
dump
├── dev
│ ├── norm
│ └── raw
├── phone_id_map.txt
├── speaker_id_map.txt
├── test
│ ├── norm
│ └── raw
└── train
├── energy_stats.npy
├── norm
├── pitch_stats.npy
├── raw
├── speech_stats.npy
└── speech_stretchs.npy
```
The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains speech, pitch and energy features of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/*_stats.npy`. `speech_stretchs.npy` contains the minimum and maximum values of each dimension of the mel spectrum, which is used for linear stretching before training/inference of the diffusion module.
Note: Since the training effect of non-norm features is due to norm, the features saved under `norm` are features that have not been normed.
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains utterance id, speaker id, phones, text_lengths, speech_lengths, phone durations, the path of speech features, the path of pitch features, the path of energy features, note, note durations, slur.
### Model Training
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
```
`./local/train.sh` calls `${BIN_DIR}/train.py`.
Here's the complete help message.
```text
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
[--ngpu NGPU] [--phones-dict PHONES_DICT]
[--speaker-dict SPEAKER_DICT] [--speech-stretchs SPEECH_STRETCHS]
Train a FastSpeech2 model.
optional arguments:
-h, --help show this help message and exit
--config CONFIG fastspeech2 config file.
--train-metadata TRAIN_METADATA
training data.
--dev-metadata DEV_METADATA
dev data.
--output-dir OUTPUT_DIR
output dir.
--ngpu NGPU if ngpu=0, use cpu.
--phones-dict PHONES_DICT
phone vocabulary file.
--speaker-dict SPEAKER_DICT
speaker id map file for multiple speaker model.
--speech-stretchs SPEECH_STRETCHS
min amd max mel for stretching.
```
1. `--config` is a config file in yaml format to overwrite the default config, which can be found at `conf/default.yaml`.
2. `--train-metadata` and `--dev-metadata` should be the metadata file in the normalized subfolder of `train` and `dev` in the `dump` folder.
3. `--output-dir` is the directory to save the results of the experiment. Checkpoints are saved in `checkpoints/` inside this directory.
4. `--ngpu` is the number of gpus to use, if ngpu == 0, use cpu.
5. `--phones-dict` is the path of the phone vocabulary file.
6. `--speech-stretchs` is the path of mel's min-max data file.
### Synthesizing
We use parallel wavegan as the neural vocoder.
Download pretrained parallel wavegan model from [pwgan_opencpop_ckpt_1.4.0.zip](https://paddlespeech.bj.bcebos.com/t2s/svs/opencpop/pwgan_opencpop_ckpt_1.4.0.zip) and unzip it.
```bash
unzip pwgan_opencpop_ckpt_1.4.0.zip
```
Parallel WaveGAN checkpoint contains files listed below.
```text
pwgan_opencpop_ckpt_1.4.0.zip
├── default.yaml # default config used to train parallel wavegan
├── snapshot_iter_100000.pdz # model parameters of parallel wavegan
└── feats_stats.npy # statistics used to normalize spectrogram when training parallel wavegan
```
`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
```
```text
usage: synthesize.py [-h]
[--am {diffsinger_opencpop}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--voc {pwgan_opencpop}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
[--speech_stretchs SPEECH_STRETCHS]
Synthesize with acoustic model & vocoder
optional arguments:
-h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
Choose acoustic model type of tts task.
--am_config AM_CONFIG
Config of acoustic model.
--am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model.
--phones_dict PHONES_DICT
phone vocabulary file.
--tones_dict TONES_DICT
tone vocabulary file.
--speaker_dict SPEAKER_DICT
speaker id map file.
--voice-cloning VOICE_CLONING
whether training voice cloning model.
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
Choose vocoder type of tts task.
--voc_config VOC_CONFIG
Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc.
--ngpu NGPU if ngpu == 0, use cpu.
--test_metadata TEST_METADATA
test metadata.
--output_dir OUTPUT_DIR
output dir.
--speech-stretchs mel min and max values file.
```
## Pretrained Model
Pretrained DiffSinger model:
- [diffsinger_opencpop_ckpt_1.4.0.zip](https://paddlespeech.bj.bcebos.com/t2s/svs/opencpop/diffsinger_opencpop_ckpt_1.4.0.zip)
DiffSinger checkpoint contains files listed below.
```text
diffsinger_opencpop_ckpt_1.4.0.zip
├── default.yaml # default config used to train diffsinger
├── energy_stats.npy # statistics used to normalize energy when training diffsinger if norm is needed
├── phone_id_map.txt # phone vocabulary file when training diffsinger
├── pitch_stats.npy # statistics used to normalize pitch when training diffsinger if norm is needed
├── snapshot_iter_160000.pdz # model parameters of diffsinger
├── speech_stats.npy # statistics used to normalize mel when training diffsinger if norm is needed
└── speech_stretchs.npy # Min and max values to use for mel spectral stretching before training diffusion
```
At present, the text frontend is not perfect, and the method of `synthesize_e2e` is not supported for synthesizing audio. Try using `synthesize` first.

@ -0,0 +1,179 @@
(简体中文|[English](./README.md))
# 用 Opencpop 数据集训练 DiffSinger 模型
本用例包含用于训练 [DiffSinger](https://arxiv.org/abs/2105.02446) 模型的代码,使用 [Mandarin singing corpus](https://wenet.org.cn/opencpop/) 数据集。
## 数据集
### 下载并解压
从 [官方网站](https://wenet.org.cn/opencpop/download/) 下载数据集
## 开始
假设数据集的路径是 `~/datasets/Opencpop`.
运行下面的命令会进行如下操作:
1. **设置原路径**。
2. 对数据集进行预处理。
3. 训练模型
4. 合成波形
- 从 `metadata.jsonl` 合成波形。
- (支持中)从文本文件合成波形。
5. (支持中)使用静态模型进行推理。
```bash
./run.sh
```
您可以选择要运行的一系列阶段,或者将 `stage` 设置为 `stop-stage` 以仅使用一个阶段,例如,运行以下命令只会预处理数据集。
```bash
./run.sh --stage 0 --stop-stage 0
```
### 数据预处理
```bash
./local/preprocess.sh ${conf_path}
```
当它完成时。将在当前目录中创建 `dump` 文件夹。转储文件夹的结构如下所示。
```text
dump
├── dev
│ ├── norm
│ └── raw
├── phone_id_map.txt
├── speaker_id_map.txt
├── test
│ ├── norm
│ └── raw
└── train
├── energy_stats.npy
├── norm
├── pitch_stats.npy
├── raw
├── speech_stats.npy
└── speech_stretchs.npy
```
数据集分为三个部分,即 `train``dev``test` ,每个部分都包含一个 `norm``raw` 子文件夹。原始文件夹包含每个话语的语音、音调和能量特征,而 `norm` 文件夹包含规范化的特征。用于规范化特征的统计数据是从 `dump/train/*_stats.npy` 中的训练集计算出来的。`speech_stretchs.npy` 中包含 mel谱每个维度上的最小值和最大值用于 diffusion 模块训练/推理前的线性拉伸。
注意:由于非 norm 特征训练效果由于 norm因此 `norm` 下保存的特征是未经过 norm 的特征。
此外,还有一个 `metadata.jsonl` 在每个子文件夹中。它是一个类似表格的文件包含话语id音色id音素、文本长度、语音长度、音素持续时间、语音特征路径、音调特征路径、能量特征路径、音调音调持续时间是否为转音。
### 模型训练
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
```
`./local/train.sh` 调用 `${BIN_DIR}/train.py`
以下是完整的帮助信息。
```text
usage: train.py [-h] [--config CONFIG] [--train-metadata TRAIN_METADATA]
[--dev-metadata DEV_METADATA] [--output-dir OUTPUT_DIR]
[--ngpu NGPU] [--phones-dict PHONES_DICT]
[--speaker-dict SPEAKER_DICT] [--speech-stretchs SPEECH_STRETCHS]
Train a DiffSinger model.
optional arguments:
-h, --help show this help message and exit
--config CONFIG fastspeech2 config file.
--train-metadata TRAIN_METADATA
training data.
--dev-metadata DEV_METADATA
dev data.
--output-dir OUTPUT_DIR
output dir.
--ngpu NGPU if ngpu=0, use cpu.
--phones-dict PHONES_DICT
phone vocabulary file.
--speaker-dict SPEAKER_DICT
speaker id map file for multiple speaker model.
--speech-stretchs SPEECH_STRETCHS
min amd max mel for stretching.
```
1. `--config` 是一个 yaml 格式的配置文件,用于覆盖默认配置,位于 `conf/default.yaml`.
2. `--train-metadata``--dev-metadata` 应为 `dump` 文件夹中 `train``dev` 下的规范化元数据文件
3. `--output-dir` 是保存结果的目录。 检查点保存在此目录中的 `checkpoints/` 目录下。
4. `--ngpu` 要使用的 GPU 数,如果 ngpu==0则使用 cpu 。
5. `--phones-dict` 是音素词汇表文件的路径。
6. `--speech-stretchs` mel的最小最大值数据的文件路径。
### 合成
我们使用 parallel opencpop 作为神经声码器vocoder
从 [pwgan_opencpop_ckpt_1.4.0.zip](https://paddlespeech.bj.bcebos.com/t2s/svs/opencpop/pwgan_opencpop_ckpt_1.4.0.zip) 下载预训练的 parallel wavegan 模型并将其解压。
```bash
unzip pwgan_opencpop_ckpt_1.4.0.zip
```
Parallel WaveGAN 检查点包含如下文件。
```text
pwgan_opencpop_ckpt_1.4.0.zip
├── default.yaml # 用于训练 parallel wavegan 的默认配置
├── snapshot_iter_100000.pdz # parallel wavegan 的模型参数
└── feats_stats.npy # 训练平行波形时用于规范化谱图的统计数据
```
`./local/synthesize.sh` 调用 `${BIN_DIR}/../synthesize.py` 即可从 `metadata.jsonl`中合成波形。
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
```
```text
usage: synthesize.py [-h]
[--am {diffsinger_opencpop}]
[--am_config AM_CONFIG] [--am_ckpt AM_CKPT]
[--am_stat AM_STAT] [--phones_dict PHONES_DICT]
[--voc {pwgan_opencpop}]
[--voc_config VOC_CONFIG] [--voc_ckpt VOC_CKPT]
[--voc_stat VOC_STAT] [--ngpu NGPU]
[--test_metadata TEST_METADATA] [--output_dir OUTPUT_DIR]
[--speech_stretchs SPEECH_STRETCHS]
Synthesize with acoustic model & vocoder
optional arguments:
-h, --help show this help message and exit
--am {speedyspeech_csmsc,fastspeech2_csmsc,fastspeech2_ljspeech,fastspeech2_aishell3,fastspeech2_vctk,tacotron2_csmsc,tacotron2_ljspeech,tacotron2_aishell3}
Choose acoustic model type of tts task.
--am_config AM_CONFIG
Config of acoustic model.
--am_ckpt AM_CKPT Checkpoint file of acoustic model.
--am_stat AM_STAT mean and standard deviation used to normalize
spectrogram when training acoustic model.
--phones_dict PHONES_DICT
phone vocabulary file.
--tones_dict TONES_DICT
tone vocabulary file.
--speaker_dict SPEAKER_DICT
speaker id map file.
--voice-cloning VOICE_CLONING
whether training voice cloning model.
--voc {pwgan_csmsc,pwgan_ljspeech,pwgan_aishell3,pwgan_vctk,mb_melgan_csmsc,wavernn_csmsc,hifigan_csmsc,hifigan_ljspeech,hifigan_aishell3,hifigan_vctk,style_melgan_csmsc}
Choose vocoder type of tts task.
--voc_config VOC_CONFIG
Config of voc.
--voc_ckpt VOC_CKPT Checkpoint file of voc.
--voc_stat VOC_STAT mean and standard deviation used to normalize
spectrogram when training voc.
--ngpu NGPU if ngpu == 0, use cpu.
--test_metadata TEST_METADATA
test metadata.
--output_dir OUTPUT_DIR
output dir.
--speech-stretchs mel min and max values file.
```
## 预训练模型
预先训练的 DiffSinger 模型:
- [diffsinger_opencpop_ckpt_1.4.0.zip](https://paddlespeech.bj.bcebos.com/t2s/svs/opencpop/diffsinger_opencpop_ckpt_1.4.0.zip)
DiffSinger 检查点包含下列文件。
```text
diffsinger_opencpop_ckpt_1.4.0.zip
├── default.yaml # 用于训练 diffsinger 的默认配置
├── energy_stats.npy # 训练 diffsinger 时如若需要 norm energy 会使用到的统计数据
├── phone_id_map.txt # 训练 diffsinger 时的音素词汇文件
├── pitch_stats.npy # 训练 diffsinger 时如若需要 norm pitch 会使用到的统计数据
├── snapshot_iter_160000.pdz # 模型参数和优化器状态
├── speech_stats.npy # 训练 diffsinger 时用于规范化频谱图的统计数据
└── speech_stretchs.npy # 训练 diffusion 前用于 mel 谱拉伸的最小及最大值
```
目前文本前端未完善,暂不支持 `synthesize_e2e` 的方式合成音频。尝试效果可先使用 `synthesize`

@ -0,0 +1,159 @@
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # sr
n_fft: 512 # FFT size (samples).
n_shift: 128 # Hop size (samples). 12.5ms
win_length: 512 # Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
# Only used for feats_type != raw
fmin: 30 # Minimum frequency of Mel basis.
fmax: 12000 # Maximum frequency of Mel basis.
n_mels: 80 # The number of mel basis.
# Only used for the model using pitch features (e.g. FastSpeech2)
f0min: 80 # Minimum f0 for pitch extraction.
f0max: 750 # Maximum f0 for pitch extraction.
###########################################################
# DATA SETTING #
###########################################################
batch_size: 48 # batch size
num_workers: 1 # number of gpu
###########################################################
# MODEL SETTING #
###########################################################
model:
# music score related
note_num: 300 # number of note
is_slur_num: 2 # number of slur
# fastspeech2 module options
use_energy_pred: False # whether use energy predictor
use_postnet: False # whether use postnet
# fastspeech2 module
fastspeech2_params:
adim: 256 # attention dimension
aheads: 2 # number of attention heads
elayers: 4 # number of encoder layers
eunits: 1024 # number of encoder ff units
dlayers: 4 # number of decoder layers
dunits: 1024 # number of decoder ff units
positionwise_layer_type: conv1d-linear # type of position-wise layer
positionwise_conv_kernel_size: 9 # kernel size of position wise conv layer
transformer_enc_dropout_rate: 0.1 # dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate: 0.1 # dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate: 0.0 # dropout rate for transformer encoder attention layer
transformer_activation_type: "gelu" # Activation function type in transformer.
encoder_normalize_before: True # whether to perform layer normalization before the input
decoder_normalize_before: True # whether to perform layer normalization before the input
reduction_factor: 1 # reduction factor
init_type: xavier_uniform # initialization type
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
use_scaled_pos_enc: True # whether to use scaled positional encoding
transformer_dec_dropout_rate: 0.1 # dropout rate for transformer decoder layer
transformer_dec_positional_dropout_rate: 0.1 # dropout rate for transformer decoder positional encoding
transformer_dec_attn_dropout_rate: 0.0 # dropout rate for transformer decoder attention layer
duration_predictor_layers: 5 # number of layers of duration predictor
duration_predictor_chans: 256 # number of channels of duration predictor
duration_predictor_kernel_size: 3 # filter size of duration predictor
duration_predictor_dropout_rate: 0.5 # dropout rate in energy predictor
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder
energy_predictor_layers: 2 # number of conv layers in energy predictor
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
postnet_layers: 5 # number of layers of postnet
postnet_filts: 5 # filter size of conv layers in postnet
postnet_chans: 256 # number of channels of conv layers in postnet
postnet_dropout_rate: 0.5 # dropout rate for postnet
# denoiser module
denoiser_params:
in_channels: 80 # Number of channels of the input mel-spectrogram
out_channels: 80 # Number of channels of the output mel-spectrogram
kernel_size: 3 # Kernel size of the residual blocks inside
layers: 20 # Number of residual blocks inside
stacks: 5 # The number of groups to split the residual blocks into
residual_channels: 256 # Residual channel of the residual blocks
gate_channels: 512 # Gate channel of the residual blocks
skip_channels: 256 # Skip channel of the residual blocks
aux_channels: 256 # Auxiliary channel of the residual blocks
dropout: 0.1 # Dropout of the residual blocks
bias: True # Whether to use bias in residual blocks
use_weight_norm: False # Whether to use weight norm in all convolutions
init_type: "kaiming_normal" # Type of initialize weights of a neural network module
diffusion_params:
num_train_timesteps: 100 # The number of timesteps between the noise and the real during training
beta_start: 0.0001 # beta start parameter for the scheduler
beta_end: 0.06 # beta end parameter for the scheduler
beta_schedule: "linear" # beta schedule parameter for the scheduler
num_max_timesteps: 100 # The max timestep transition from real to noise
stretch: True # whether to stretch before diffusion
###########################################################
# UPDATER SETTING #
###########################################################
fs2_updater:
use_masking: True # whether to apply masking for padded part in loss calculation
ds_updater:
use_masking: True # whether to apply masking for padded part in loss calculation
###########################################################
# OPTIMIZER SETTING #
###########################################################
# fastspeech2 optimizer
fs2_optimizer:
optim: adam # optimizer type
learning_rate: 0.001 # learning rate
# diffusion optimizer
ds_optimizer_params:
beta1: 0.9
beta2: 0.98
weight_decay: 0.0
ds_scheduler_params:
learning_rate: 0.001
gamma: 0.5
step_size: 50000
ds_grad_norm: 1
###########################################################
# INTERVAL SETTING #
###########################################################
only_train_diffusion: True # Whether to freeze fastspeech2 parameters when training diffusion
ds_train_start_steps: 160000 # Number of steps to start to train diffusion module.
train_max_steps: 320000 # Number of training steps.
save_interval_steps: 2000 # Interval steps to save checkpoint.
eval_interval_steps: 2000 # Interval steps to evaluate the network.
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086

@ -0,0 +1,74 @@
#!/bin/bash
stage=0
stop_stage=100
config_path=$1
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/preprocess.py \
--dataset=opencpop \
--rootdir=~/datasets/Opencpop/segments \
--dumpdir=dump \
--label-file=~/datasets/Opencpop/segments/transcriptions.txt \
--config=${config_path} \
--num-cpu=20 \
--cut-sil=True
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="speech"
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="pitch"
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="energy"
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# normalize and covert phone/speaker to id, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# Get feature(mel) extremum for diffusion stretch
echo "Get feature(mel) extremum ..."
python3 ${BIN_DIR}/get_minmax.py \
--metadata=dump/train/norm/metadata.jsonl \
--speech-stretchs=dump/train/speech_stretchs.npy
fi

@ -0,0 +1,27 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
stage=0
stop_stage=0
# pwgan
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \
--am=diffsinger_opencpop \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=pwgan_opencpop \
--voc_config=pwgan_opencpop_ckpt_1.4.0/default.yaml \
--voc_ckpt=pwgan_opencpop_ckpt_1.4.0/snapshot_iter_100000.pdz \
--voc_stat=pwgan_opencpop_ckpt_1.4.0/feats_stats.npy \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test \
--phones_dict=dump/phone_id_map.txt \
--speech_stretchs=dump/train/speech_stretchs.npy
fi

@ -0,0 +1,13 @@
#!/bin/bash
config_path=$1
train_output_path=$2
python3 ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=1 \
--phones-dict=dump/phone_id_map.txt \
--speech-stretchs=dump/train/speech_stretchs.npy

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=diffsinger
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}

@ -0,0 +1,32 @@
#!/bin/bash
set -e
source path.sh
gpus=0
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_320000.pdz
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize, vocoder is pwgan by default
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi

@ -414,6 +414,129 @@ def fastspeech2_multi_spk_batch_fn(examples):
return batch
def diffsinger_single_spk_batch_fn(examples):
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
note = [np.array(item["note"], dtype=np.int64) for item in examples]
note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples]
is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
energy = [np.array(item["energy"], dtype=np.float32) for item in examples]
durations = [
np.array(item["durations"], dtype=np.int64) for item in examples
]
text_lengths = [
np.array(item["text_lengths"], dtype=np.int64) for item in examples
]
speech_lengths = [
np.array(item["speech_lengths"], dtype=np.int64) for item in examples
]
text = batch_sequences(text)
note = batch_sequences(note)
note_dur = batch_sequences(note_dur)
is_slur = batch_sequences(is_slur)
pitch = batch_sequences(pitch)
speech = batch_sequences(speech)
durations = batch_sequences(durations)
energy = batch_sequences(energy)
# convert each batch to paddle.Tensor
text = paddle.to_tensor(text)
note = paddle.to_tensor(note)
note_dur = paddle.to_tensor(note_dur)
is_slur = paddle.to_tensor(is_slur)
pitch = paddle.to_tensor(pitch)
speech = paddle.to_tensor(speech)
durations = paddle.to_tensor(durations)
energy = paddle.to_tensor(energy)
text_lengths = paddle.to_tensor(text_lengths)
speech_lengths = paddle.to_tensor(speech_lengths)
batch = {
"text": text,
"note": note,
"note_dur": note_dur,
"is_slur": is_slur,
"text_lengths": text_lengths,
"durations": durations,
"speech": speech,
"speech_lengths": speech_lengths,
"pitch": pitch,
"energy": energy
}
return batch
def diffsinger_multi_spk_batch_fn(examples):
# fields = ["text", "note", "note_dur", "is_slur", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
note = [np.array(item["note"], dtype=np.int64) for item in examples]
note_dur = [np.array(item["note_dur"], dtype=np.float32) for item in examples]
is_slur = [np.array(item["is_slur"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
energy = [np.array(item["energy"], dtype=np.float32) for item in examples]
durations = [
np.array(item["durations"], dtype=np.int64) for item in examples
]
text_lengths = [
np.array(item["text_lengths"], dtype=np.int64) for item in examples
]
speech_lengths = [
np.array(item["speech_lengths"], dtype=np.int64) for item in examples
]
text = batch_sequences(text)
note = batch_sequences(note)
note_dur = batch_sequences(note_dur)
is_slur = batch_sequences(is_slur)
pitch = batch_sequences(pitch)
speech = batch_sequences(speech)
durations = batch_sequences(durations)
energy = batch_sequences(energy)
# convert each batch to paddle.Tensor
text = paddle.to_tensor(text)
note = paddle.to_tensor(note)
note_dur = paddle.to_tensor(note_dur)
is_slur = paddle.to_tensor(is_slur)
pitch = paddle.to_tensor(pitch)
speech = paddle.to_tensor(speech)
durations = paddle.to_tensor(durations)
energy = paddle.to_tensor(energy)
text_lengths = paddle.to_tensor(text_lengths)
speech_lengths = paddle.to_tensor(speech_lengths)
batch = {
"text": text,
"note": note,
"note_dur": note_dur,
"is_slur": is_slur,
"text_lengths": text_lengths,
"durations": durations,
"speech": speech,
"speech_lengths": speech_lengths,
"pitch": pitch,
"energy": energy
}
# spk_emb has a higher priority than spk_id
if "spk_emb" in examples[0]:
spk_emb = [
np.array(item["spk_emb"], dtype=np.float32) for item in examples
]
spk_emb = batch_sequences(spk_emb)
spk_emb = paddle.to_tensor(spk_emb)
batch["spk_emb"] = spk_emb
elif "spk_id" in examples[0]:
spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples]
spk_id = paddle.to_tensor(spk_id)
batch["spk_id"] = spk_id
return batch
def transformer_single_spk_batch_fn(examples):
# fields = ["text", "text_lengths", "speech", "speech_lengths"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]

@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from typing import List
from typing import Optional
from typing import Union
import librosa
import numpy as np
import pyworld
from scipy.interpolate import interp1d
from typing import Optional
from typing import Union
from typing_extensions import Literal
class LogMelFBank():
def __init__(self,
sr: int=24000,
@ -79,7 +79,7 @@ class LogMelFBank():
def _spectrogram(self, wav: np.ndarray):
D = self._stft(wav)
return np.abs(D) ** self.power
return np.abs(D)**self.power
def _mel_spectrogram(self, wav: np.ndarray):
S = self._spectrogram(wav)
@ -117,7 +117,6 @@ class Pitch():
if (f0 == 0).all():
print("All frames seems to be unvoiced, this utt will be removed.")
return f0
# padding start and end of f0 sequence
start_f0 = f0[f0 != 0][0]
end_f0 = f0[f0 != 0][-1]
@ -179,6 +178,8 @@ class Pitch():
f0 = self._calculate_f0(wav, use_continuous_f0, use_log_f0)
if use_token_averaged_f0 and duration is not None:
f0 = self._average_by_duration(f0, duration)
else:
f0 = np.expand_dims(np.array(f0), 0).T
return f0
@ -237,6 +238,8 @@ class Energy():
energy = self._calculate_energy(wav)
if use_token_averaged_energy and duration is not None:
energy = self._average_by_duration(energy, duration)
else:
energy = np.expand_dims(np.array(energy), 0).T
return energy

@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import List
import librosa
import numpy as np
# speaker|utt_id|phn dur phn dur ...
@ -41,6 +45,90 @@ def get_phn_dur(file_name):
return sentence, speaker_set
def note2midi(notes: List[str]) -> List[str]:
"""Covert note string to note id, for example: ["C1"] -> [24]
Args:
notes (List[str]): the list of note string
Returns:
List[str]: the list of note id
"""
midis = []
for note in notes:
if note == 'rest':
midi = 0
else:
midi = librosa.note_to_midi(note.split("/")[0])
midis.append(midi)
return midis
def time2frame(
times: List[float],
sample_rate: int=24000,
n_shift: int=128, ) -> List[int]:
"""Convert the phoneme duration of time(s) into frames
Args:
times (List[float]): phoneme duration of time(s)
sample_rate (int, optional): sample rate. Defaults to 24000.
n_shift (int, optional): frame shift. Defaults to 128.
Returns:
List[int]: phoneme duration of frame
"""
end = 0.0
ends = []
for t in times:
end += t
ends.append(end)
frame_pos = librosa.time_to_frames(ends, sr=sample_rate, hop_length=n_shift)
durations = np.diff(frame_pos, prepend=0)
return durations
def get_sentences_svs(
file_name,
dataset: str='opencpop',
sample_rate: int=24000,
n_shift: int=128, ):
'''
read label file
Args:
file_name (str or Path): path of gen_duration_from_textgrid.py's result
dataset (str): dataset name
Returns:
Dict: the information of sentence, include [phone id (int)], [the frame of phone (int)], [note id (int)], [note duration (float)], [is slur (int)], text(str), speaker name (str)
tuple: speaker name
'''
f = open(file_name, 'r')
sentence = {}
speaker_set = set()
if dataset == 'opencpop':
speaker_set.add("opencpop")
for line in f:
line_list = line.strip().split('|')
utt = line_list[0]
text = line_list[1]
ph = line_list[2].split()
midi = note2midi(line_list[3].split())
midi_dur = line_list[4].split()
ph_dur = time2frame([float(t) for t in line_list[5].split()], sample_rate=sample_rate, n_shift=n_shift)
is_slur = line_list[6].split()
assert len(ph) == len(midi) == len(midi_dur) == len(is_slur)
sentence[utt] = (ph, [int(i) for i in ph_dur],
[int(i) for i in midi],
[float(i) for i in midi_dur],
[int(i) for i in is_slur], text, "opencpop")
else:
print("dataset should in {opencpop} now!")
f.close()
return sentence, speaker_set
def merge_silence(sentence):
'''
merge silences
@ -88,6 +176,9 @@ def get_input_token(sentence, output_path, dataset="baker"):
phn_token = ["<pad>", "<unk>"] + phn_token
if dataset in {"baker", "aishell3"}:
phn_token += ["", "", "", ""]
# svs dataset
elif dataset in {"opencpop"}:
pass
else:
phn_token += [",", ".", "?", "!"]
phn_token += ["<eos>"]

@ -0,0 +1,13 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -0,0 +1,82 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import jsonlines
import numpy as np
from tqdm import tqdm
from paddlespeech.t2s.datasets.data_table import DataTable
def get_minmax(spec, min_spec, max_spec):
# spec: [T, 80]
for i in range(spec.shape[1]):
min_value = np.min(spec[:, i])
max_value = np.max(spec[:, i])
min_spec[i] = min(min_value, min_spec[i])
max_spec[i] = max(max_value, max_spec[i])
return min_spec, max_spec
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
)
parser.add_argument(
"--metadata",
type=str,
required=True,
help="directory including feature files to be normalized. "
"you need to specify either *-scp or rootdir.")
parser.add_argument(
"--speech-stretchs",
type=str,
required=True,
help="min max spec file. only computer on train data")
args = parser.parse_args()
# get dataset
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
dataset = DataTable(
metadata, converters={
"speech": np.load,
})
logging.info(f"The number of files = {len(dataset)}.")
n_mel = 80
min_spec = 100.0 * np.ones(shape=(n_mel), dtype=np.float32)
max_spec = -100.0 * np.ones(shape=(n_mel), dtype=np.float32)
for item in tqdm(dataset):
spec = item['speech']
min_spec, max_spec = get_minmax(spec, min_spec, max_spec)
# Using min_spec=-6.0 training effect is better so far
min_spec = -6.0 * np.ones(shape=(n_mel), dtype=np.float32)
min_max_spec = np.stack([min_spec, max_spec], axis=0)
np.save(
str(args.speech_stretchs),
min_max_spec.astype(np.float32),
allow_pickle=False)
if __name__ == "__main__":
main()

@ -0,0 +1,189 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalize feature files and dump them."""
import argparse
import logging
from operator import itemgetter
from pathlib import Path
import jsonlines
import numpy as np
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.utils import str2bool
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
)
parser.add_argument(
"--metadata",
type=str,
required=True,
help="directory including feature files to be normalized. "
"you need to specify either *-scp or rootdir.")
parser.add_argument(
"--dumpdir",
type=str,
required=True,
help="directory to dump normalized feature files.")
parser.add_argument(
"--speech-stats",
type=str,
required=True,
help="speech statistics file.")
parser.add_argument(
"--pitch-stats", type=str, required=True, help="pitch statistics file.")
parser.add_argument(
"--energy-stats",
type=str,
required=True,
help="energy statistics file.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker-dict", type=str, default=None, help="speaker id map file.")
parser.add_argument(
"--norm-feats",
type=str2bool,
default=False,
help="whether to norm features")
args = parser.parse_args()
dumpdir = Path(args.dumpdir).expanduser()
# use absolute path
dumpdir = dumpdir.resolve()
dumpdir.mkdir(parents=True, exist_ok=True)
# get dataset
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
dataset = DataTable(
metadata,
converters={
"speech": np.load,
"pitch": np.load,
"energy": np.load,
})
logging.info(f"The number of files = {len(dataset)}.")
# restore scaler
speech_scaler = StandardScaler()
if args.norm_feats:
speech_scaler.mean_ = np.load(args.speech_stats)[0]
speech_scaler.scale_ = np.load(args.speech_stats)[1]
else:
speech_scaler.mean_ = np.zeros(
np.load(args.speech_stats)[0].shape, dtype="float32")
speech_scaler.scale_ = np.ones(
np.load(args.speech_stats)[1].shape, dtype="float32")
speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0]
pitch_scaler = StandardScaler()
if args.norm_feats:
pitch_scaler.mean_ = np.load(args.pitch_stats)[0]
pitch_scaler.scale_ = np.load(args.pitch_stats)[1]
else:
pitch_scaler.mean_ = np.zeros(
np.load(args.pitch_stats)[0].shape, dtype="float32")
pitch_scaler.scale_ = np.ones(
np.load(args.pitch_stats)[1].shape, dtype="float32")
pitch_scaler.n_features_in_ = pitch_scaler.mean_.shape[0]
energy_scaler = StandardScaler()
if args.norm_feats:
energy_scaler.mean_ = np.load(args.energy_stats)[0]
energy_scaler.scale_ = np.load(args.energy_stats)[1]
else:
energy_scaler.mean_ = np.zeros(
np.load(args.energy_stats)[0].shape, dtype="float32")
energy_scaler.scale_ = np.ones(
np.load(args.energy_stats)[1].shape, dtype="float32")
energy_scaler.n_features_in_ = energy_scaler.mean_.shape[0]
vocab_phones = {}
with open(args.phones_dict, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
for phn, id in phn_id:
vocab_phones[phn] = int(id)
vocab_speaker = {}
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
for spk, id in spk_id:
vocab_speaker[spk] = int(id)
# process each file
output_metadata = []
for item in tqdm(dataset):
utt_id = item['utt_id']
speech = item['speech']
pitch = item['pitch']
energy = item['energy']
# normalize
speech = speech_scaler.transform(speech)
speech_dir = dumpdir / "data_speech"
speech_dir.mkdir(parents=True, exist_ok=True)
speech_path = speech_dir / f"{utt_id}_speech.npy"
np.save(speech_path, speech.astype(np.float32), allow_pickle=False)
pitch = pitch_scaler.transform(pitch)
pitch_dir = dumpdir / "data_pitch"
pitch_dir.mkdir(parents=True, exist_ok=True)
pitch_path = pitch_dir / f"{utt_id}_pitch.npy"
np.save(pitch_path, pitch.astype(np.float32), allow_pickle=False)
energy = energy_scaler.transform(energy)
energy_dir = dumpdir / "data_energy"
energy_dir.mkdir(parents=True, exist_ok=True)
energy_path = energy_dir / f"{utt_id}_energy.npy"
np.save(energy_path, energy.astype(np.float32), allow_pickle=False)
phone_ids = [vocab_phones[p] for p in item['phones']]
spk_id = vocab_speaker[item["speaker"]]
record = {
"utt_id": item['utt_id'],
"spk_id": spk_id,
"text": phone_ids,
"text_lengths": item['text_lengths'],
"speech_lengths": item['speech_lengths'],
"durations": item['durations'],
"speech": str(speech_path),
"pitch": str(pitch_path),
"energy": str(energy_path),
"note": item['note'],
"note_dur": item['note_dur'],
"is_slur": item['is_slur'],
}
# add spk_emb for voice cloning
if "spk_emb" in item:
record["spk_emb"] = str(item["spk_emb"])
output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata:
writer.write(item)
logging.info(f"metadata dumped into {output_metadata_path}")
if __name__ == "__main__":
main()

@ -0,0 +1,376 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from concurrent.futures import ThreadPoolExecutor
from operator import itemgetter
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
import jsonlines
import librosa
import numpy as np
import tqdm
import yaml
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.get_feats import Energy
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
from paddlespeech.t2s.datasets.get_feats import Pitch
from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
from paddlespeech.t2s.datasets.preprocess_utils import get_sentences_svs
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
from paddlespeech.t2s.utils import str2bool
ALL_INITIALS = [
'zh', 'ch', 'sh', 'b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 'h',
'j', 'q', 'x', 'r', 'z', 'c', 's', 'y', 'w'
]
ALL_FINALS = [
'a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia',
'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'iu', 'ng', 'o', 'ong',
'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've',
'vn'
]
def process_sentence(
config: Dict[str, Any],
fp: Path,
sentences: Dict,
output_dir: Path,
mel_extractor=None,
pitch_extractor=None,
energy_extractor=None,
cut_sil: bool=True,
spk_emb_dir: Path=None, ):
utt_id = fp.stem
record = None
if utt_id in sentences:
# reading, resampling may occur
wav, _ = librosa.load(str(fp), sr=config.fs)
if len(wav.shape) != 1:
return record
max_value = np.abs(wav).max()
if max_value > 1.0:
wav = wav / max_value
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
assert np.abs(wav).max(
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
note = sentences[utt_id][2]
note_dur = sentences[utt_id][3]
is_slur = sentences[utt_id][4]
speaker = sentences[utt_id][-1]
# extract mel feats
logmel = mel_extractor.get_log_mel_fbank(wav)
# change duration according to mel_length
compare_duration_and_mel_length(sentences, utt_id, logmel)
# utt_id may be popped in compare_duration_and_mel_length
if utt_id not in sentences:
return None
phones = sentences[utt_id][0]
durations = sentences[utt_id][1]
num_frames = logmel.shape[0]
assert sum(
durations
) == num_frames, "the sum of durations doesn't equal to the num of mel frames. "
speech_dir = output_dir / "data_speech"
speech_dir.mkdir(parents=True, exist_ok=True)
speech_path = speech_dir / (utt_id + "_speech.npy")
np.save(speech_path, logmel)
# extract pitch and energy
pitch = pitch_extractor.get_pitch(wav)
assert pitch.shape[0] == num_frames
pitch_dir = output_dir / "data_pitch"
pitch_dir.mkdir(parents=True, exist_ok=True)
pitch_path = pitch_dir / (utt_id + "_pitch.npy")
np.save(pitch_path, pitch)
energy = energy_extractor.get_energy(wav)
assert energy.shape[0] == num_frames
energy_dir = output_dir / "data_energy"
energy_dir.mkdir(parents=True, exist_ok=True)
energy_path = energy_dir / (utt_id + "_energy.npy")
np.save(energy_path, energy)
record = {
"utt_id": utt_id,
"phones": phones,
"text_lengths": len(phones),
"speech_lengths": num_frames,
"durations": durations,
"speech": str(speech_path),
"pitch": str(pitch_path),
"energy": str(energy_path),
"speaker": speaker,
"note": note,
"note_dur": note_dur,
"is_slur": is_slur,
}
if spk_emb_dir:
if speaker in os.listdir(spk_emb_dir):
embed_name = utt_id + ".npy"
embed_path = spk_emb_dir / speaker / embed_name
if embed_path.is_file():
record["spk_emb"] = str(embed_path)
else:
return None
return record
def process_sentences(
config,
fps: List[Path],
sentences: Dict,
output_dir: Path,
mel_extractor=None,
pitch_extractor=None,
energy_extractor=None,
nprocs: int=1,
cut_sil: bool=True,
spk_emb_dir: Path=None,
write_metadata_method: str='w', ):
if nprocs == 1:
results = []
for fp in tqdm.tqdm(fps, total=len(fps)):
record = process_sentence(
config=config,
fp=fp,
sentences=sentences,
output_dir=output_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
cut_sil=cut_sil,
spk_emb_dir=spk_emb_dir, )
if record:
results.append(record)
else:
with ThreadPoolExecutor(nprocs) as pool:
futures = []
with tqdm.tqdm(total=len(fps)) as progress:
for fp in fps:
future = pool.submit(
process_sentence,
config,
fp,
sentences,
output_dir,
mel_extractor,
pitch_extractor,
energy_extractor,
cut_sil,
spk_emb_dir, )
future.add_done_callback(lambda p: progress.update())
futures.append(future)
results = []
for ft in futures:
record = ft.result()
if record:
results.append(record)
results.sort(key=itemgetter("utt_id"))
with jsonlines.open(output_dir / "metadata.jsonl",
write_metadata_method) as writer:
for item in results:
writer.write(item)
print("Done")
def main():
# parse config and args
parser = argparse.ArgumentParser(
description="Preprocess audio and then extract features.")
parser.add_argument(
"--dataset",
default="opencpop",
type=str,
help="name of dataset, should in {opencpop} now")
parser.add_argument(
"--rootdir", default=None, type=str, help="directory to dataset.")
parser.add_argument(
"--dumpdir",
type=str,
required=True,
help="directory to dump feature files.")
parser.add_argument(
"--label-file", default=None, type=str, help="path to label file.")
parser.add_argument("--config", type=str, help="diffsinger config file.")
parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.")
parser.add_argument(
"--cut-sil",
type=str2bool,
default=True,
help="whether cut sil in the edge of audio")
parser.add_argument(
"--spk_emb_dir",
default=None,
type=str,
help="directory to speaker embedding files.")
parser.add_argument(
"--write_metadata_method",
default="w",
type=str,
choices=["w", "a"],
help="How the metadata.jsonl file is written.")
args = parser.parse_args()
rootdir = Path(args.rootdir).expanduser()
dumpdir = Path(args.dumpdir).expanduser()
# use absolute path
dumpdir = dumpdir.resolve()
dumpdir.mkdir(parents=True, exist_ok=True)
label_file = Path(args.label_file).expanduser()
if args.spk_emb_dir:
spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
else:
spk_emb_dir = None
assert rootdir.is_dir()
assert label_file.is_file()
with open(args.config, 'rt') as f:
config = CfgNode(yaml.safe_load(f))
sentences, speaker_set = get_sentences_svs(
label_file,
dataset=args.dataset,
sample_rate=config.fs,
n_shift=config.n_shift, )
phone_id_map_path = dumpdir / "phone_id_map.txt"
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
get_input_token(sentences, phone_id_map_path, args.dataset)
get_spk_id_map(speaker_set, speaker_id_map_path)
if args.dataset == "opencpop":
wavdir = rootdir / "wavs"
# split data into 3 sections
train_file = rootdir / "train.txt"
train_wav_files = []
with open(train_file, "r") as f_train:
for line in f_train.readlines():
utt = line.split("|")[0]
wav_name = utt + ".wav"
wav_path = wavdir / wav_name
train_wav_files.append(wav_path)
test_file = rootdir / "test.txt"
dev_wav_files = []
test_wav_files = []
num_dev = 106
count = 0
with open(test_file, "r") as f_test:
for line in f_test.readlines():
count += 1
utt = line.split("|")[0]
wav_name = utt + ".wav"
wav_path = wavdir / wav_name
if count > num_dev:
test_wav_files.append(wav_path)
else:
dev_wav_files.append(wav_path)
else:
print("dataset should in {opencpop} now!")
train_dump_dir = dumpdir / "train" / "raw"
train_dump_dir.mkdir(parents=True, exist_ok=True)
dev_dump_dir = dumpdir / "dev" / "raw"
dev_dump_dir.mkdir(parents=True, exist_ok=True)
test_dump_dir = dumpdir / "test" / "raw"
test_dump_dir.mkdir(parents=True, exist_ok=True)
# Extractor
mel_extractor = LogMelFBank(
sr=config.fs,
n_fft=config.n_fft,
hop_length=config.n_shift,
win_length=config.win_length,
window=config.window,
n_mels=config.n_mels,
fmin=config.fmin,
fmax=config.fmax)
pitch_extractor = Pitch(
sr=config.fs,
hop_length=config.n_shift,
f0min=config.f0min,
f0max=config.f0max)
energy_extractor = Energy(
n_fft=config.n_fft,
hop_length=config.n_shift,
win_length=config.win_length,
window=config.window)
# process for the 3 sections
if train_wav_files:
process_sentences(
config=config,
fps=train_wav_files,
sentences=sentences,
output_dir=train_dump_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
nprocs=args.num_cpu,
cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir,
write_metadata_method=args.write_metadata_method)
if dev_wav_files:
process_sentences(
config=config,
fps=dev_wav_files,
sentences=sentences,
output_dir=dev_dump_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir,
write_metadata_method=args.write_metadata_method)
if test_wav_files:
process_sentences(
config=config,
fps=test_wav_files,
sentences=sentences,
output_dir=test_dump_dir,
mel_extractor=mel_extractor,
pitch_extractor=pitch_extractor,
energy_extractor=energy_extractor,
nprocs=args.num_cpu,
cut_sil=args.cut_sil,
spk_emb_dir=spk_emb_dir,
write_metadata_method=args.write_metadata_method)
if __name__ == "__main__":
main()

@ -0,0 +1,257 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import shutil
from pathlib import Path
import jsonlines
import numpy as np
import paddle
import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle import nn
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import AdamW
from paddle.optimizer.lr import StepDecay
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import diffsinger_multi_spk_batch_fn
from paddlespeech.t2s.datasets.am_batch_fn import diffsinger_single_spk_batch_fn
from paddlespeech.t2s.datasets.data_table import DataTable
from paddlespeech.t2s.models.diffsinger import DiffSinger
from paddlespeech.t2s.models.diffsinger import DiffSingerEvaluator
from paddlespeech.t2s.models.diffsinger import DiffSingerUpdater
from paddlespeech.t2s.models.diffsinger import DiffusionLoss
from paddlespeech.t2s.models.diffsinger.fastspeech2midi import FastSpeech2MIDILoss
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import build_optimizers
from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer
def train_sp(args, config):
# decides device type and whether to run in parallel
# setup running environment correctly
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
paddle.set_device("cpu")
else:
paddle.set_device("gpu")
world_size = paddle.distributed.get_world_size()
if world_size > 1:
paddle.distributed.init_parallel_env()
# set the random seed, it is a must for multiprocess training
seed_everything(config.seed)
print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
fields = [
"text", "text_lengths", "speech", "speech_lengths", "durations",
"pitch", "energy", "note", "note_dur", "is_slur"
]
converters = {"speech": np.load, "pitch": np.load, "energy": np.load}
spk_num = None
if args.speaker_dict is not None:
print("multiple speaker diffsinger!")
collate_fn = diffsinger_multi_spk_batch_fn
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
spk_num = len(spk_id)
fields += ["spk_id"]
else:
collate_fn = diffsinger_single_spk_batch_fn
print("single speaker diffsinger!")
print("spk_num:", spk_num)
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
train_dataset = DataTable(
data=train_metadata,
fields=fields,
converters=converters, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=fields,
converters=converters, )
# collate function and dataloader
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
drop_last=True)
print("samplers done!")
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=collate_fn,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
shuffle=False,
drop_last=False,
batch_size=config.batch_size,
collate_fn=collate_fn,
num_workers=config.num_workers)
print("dataloaders done!")
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
with open(args.speech_stretchs, "r") as f:
spec_min = np.load(args.speech_stretchs)[0]
spec_max = np.load(args.speech_stretchs)[1]
spec_min = paddle.to_tensor(spec_min)
spec_max = paddle.to_tensor(spec_max)
print("min and max spec done!")
odim = config.n_mels
config["model"]["fastspeech2_params"]["spk_num"] = spk_num
model = DiffSinger(
spec_min=spec_min,
spec_max=spec_max,
idim=vocab_size,
odim=odim,
**config["model"], )
model_fs2 = model.fs2
model_ds = model.diffusion
if world_size > 1:
model = DataParallel(model)
model_fs2 = model._layers.fs2
model_ds = model._layers.diffusion
print("models done!")
criterion_fs2 = FastSpeech2MIDILoss(**config["fs2_updater"])
criterion_ds = DiffusionLoss(**config["ds_updater"])
print("criterions done!")
optimizer_fs2 = build_optimizers(model_fs2, **config["fs2_optimizer"])
lr_schedule_ds = StepDecay(**config["ds_scheduler_params"])
gradient_clip_ds = nn.ClipGradByGlobalNorm(config["ds_grad_norm"])
optimizer_ds = AdamW(
learning_rate=lr_schedule_ds,
grad_clip=gradient_clip_ds,
parameters=model_ds.parameters(),
**config["ds_optimizer_params"])
print("optimizer done!")
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
if dist.get_rank() == 0:
config_name = args.config.split("/")[-1]
# copy conf to output_dir
shutil.copyfile(args.config, output_dir / config_name)
updater = DiffSingerUpdater(
model=model,
optimizers={
"fs2": optimizer_fs2,
"ds": optimizer_ds,
},
criterions={
"fs2": criterion_fs2,
"ds": criterion_ds,
},
dataloader=train_dataloader,
ds_train_start_steps=config.ds_train_start_steps,
output_dir=output_dir,
only_train_diffusion=config["only_train_diffusion"])
evaluator = DiffSingerEvaluator(
model=model,
criterions={
"fs2": criterion_fs2,
"ds": criterion_ds,
},
dataloader=dev_dataloader,
output_dir=output_dir, )
trainer = Trainer(
updater,
stop_trigger=(config.train_max_steps, "iteration"),
out=output_dir, )
if dist.get_rank() == 0:
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
print("Trainer Done!")
trainer.run()
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="Train a DiffSinger model.")
parser.add_argument("--config", type=str, help="diffsinger config file.")
parser.add_argument("--train-metadata", type=str, help="training data.")
parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
parser.add_argument(
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
parser.add_argument(
"--speaker-dict",
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
parser.add_argument(
"--speech-stretchs",
type=str,
help="The min and max values of the mel spectrum.")
args = parser.parse_args()
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(config)
print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
)
# dispatch
if args.ngpu > 1:
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
else:
train_sp(args, config)
if __name__ == "__main__":
main()

@ -56,6 +56,11 @@ model_alias = {
"paddlespeech.t2s.models.tacotron2:Tacotron2",
"tacotron2_inference":
"paddlespeech.t2s.models.tacotron2:Tacotron2Inference",
"diffsinger":
"paddlespeech.t2s.models.diffsinger:DiffSinger",
"diffsinger_inference":
"paddlespeech.t2s.models.diffsinger:DiffSingerInference",
# voc
"pwgan":
"paddlespeech.t2s.models.parallel_wavegan:PWGGenerator",
@ -142,6 +147,8 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]],
fields += ["spk_emb"]
else:
print("single speaker fastspeech2!")
elif am_name == 'diffsinger':
fields = ["utt_id", "text", "note", "note_dur", "is_slur"]
elif am_name == 'speedyspeech':
fields = ["utt_id", "phones", "tones"]
elif am_name == 'tacotron2':
@ -326,14 +333,16 @@ def run_frontend(frontend: object,
# dygraph
def get_am_inference(am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
return_am: bool=False):
def get_am_inference(
am: str='fastspeech2_csmsc',
am_config: CfgNode=None,
am_ckpt: Optional[os.PathLike]=None,
am_stat: Optional[os.PathLike]=None,
phones_dict: Optional[os.PathLike]=None,
tones_dict: Optional[os.PathLike]=None,
speaker_dict: Optional[os.PathLike]=None,
return_am: bool=False,
speech_stretchs: Optional[os.PathLike]=None, ):
with open(phones_dict, 'rt', encoding='utf-8') as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
@ -356,6 +365,19 @@ def get_am_inference(am: str='fastspeech2_csmsc',
if am_name == 'fastspeech2':
am = am_class(
idim=vocab_size, odim=odim, spk_num=spk_num, **am_config["model"])
elif am_name == 'diffsinger':
with open(speech_stretchs, "r") as f:
spec_min = np.load(speech_stretchs)[0]
spec_max = np.load(speech_stretchs)[1]
spec_min = paddle.to_tensor(spec_min)
spec_max = paddle.to_tensor(spec_max)
am_config["model"]["fastspeech2_params"]["spk_num"] = spk_num
am = am_class(
spec_min=spec_min,
spec_max=spec_max,
idim=vocab_size,
odim=odim,
**am_config["model"], )
elif am_name == 'speedyspeech':
am = am_class(
vocab_size=vocab_size,
@ -366,8 +388,6 @@ def get_am_inference(am: str='fastspeech2_csmsc',
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
elif am_name == 'erniesat':
am = am_class(idim=vocab_size, odim=odim, **am_config["model"])
else:
print("wrong am, please input right am!!!")
am.set_state_dict(paddle.load(am_ckpt)["main_params"])
am.eval()

@ -60,7 +60,8 @@ def evaluate(args):
am_stat=args.am_stat,
phones_dict=args.phones_dict,
tones_dict=args.tones_dict,
speaker_dict=args.speaker_dict)
speaker_dict=args.speaker_dict,
speech_stretchs=args.speech_stretchs, )
test_dataset = get_test_dataset(
test_metadata=test_metadata,
am=args.am,
@ -107,6 +108,20 @@ def evaluate(args):
if args.voice_cloning and "spk_emb" in datum:
spk_emb = paddle.to_tensor(np.load(datum["spk_emb"]))
mel = am_inference(phone_ids, spk_emb=spk_emb)
elif am_name == 'diffsinger':
phone_ids = paddle.to_tensor(datum["text"])
note = paddle.to_tensor(datum["note"])
note_dur = paddle.to_tensor(datum["note_dur"])
is_slur = paddle.to_tensor(datum["is_slur"])
# get_mel_fs2 = False, means mel from diffusion, get_mel_fs2 = True, means mel from fastspeech2.
get_mel_fs2 = False
# mel: [T, mel_bin]
mel = am_inference(
phone_ids,
note=note,
note_dur=note_dur,
is_slur=is_slur,
get_mel_fs2=get_mel_fs2)
# vocoder
wav = voc_inference(mel)
@ -134,10 +149,17 @@ def parse_args():
type=str,
default='fastspeech2_csmsc',
choices=[
'speedyspeech_csmsc', 'fastspeech2_csmsc', 'fastspeech2_ljspeech',
'fastspeech2_aishell3', 'fastspeech2_vctk', 'tacotron2_csmsc',
'tacotron2_ljspeech', 'tacotron2_aishell3', 'fastspeech2_mix',
'fastspeech2_canton'
'speedyspeech_csmsc',
'fastspeech2_csmsc',
'fastspeech2_ljspeech',
'fastspeech2_aishell3',
'fastspeech2_vctk',
'tacotron2_csmsc',
'tacotron2_ljspeech',
'tacotron2_aishell3',
'fastspeech2_mix',
'fastspeech2_canton',
'diffsinger_opencpop',
],
help='Choose acoustic model type of tts task.')
parser.add_argument(
@ -170,10 +192,19 @@ def parse_args():
type=str,
default='pwgan_csmsc',
choices=[
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
'mb_melgan_csmsc', 'wavernn_csmsc', 'hifigan_csmsc',
'hifigan_ljspeech', 'hifigan_aishell3', 'hifigan_vctk',
'style_melgan_csmsc'
'pwgan_csmsc',
'pwgan_ljspeech',
'pwgan_aishell3',
'pwgan_vctk',
'mb_melgan_csmsc',
'wavernn_csmsc',
'hifigan_csmsc',
'hifigan_ljspeech',
'hifigan_aishell3',
'hifigan_vctk',
'style_melgan_csmsc',
"pwgan_opencpop",
"hifigan_opencpop",
],
help='Choose vocoder type of tts task.')
parser.add_argument(
@ -191,6 +222,11 @@ def parse_args():
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.")
parser.add_argument("--test_metadata", type=str, help="test metadata.")
parser.add_argument("--output_dir", type=str, help="output dir.")
parser.add_argument(
"--speech_stretchs",
type=str,
default=None,
help="The min and max values of the mel spectrum.")
args = parser.parse_args()
return args

@ -0,0 +1,15 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .diffsinger import *
from .diffsinger_updater import *

@ -0,0 +1,399 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
"""DiffSinger related modules for paddle"""
from typing import Any
from typing import Dict
from typing import Tuple
import numpy as np
import paddle
from paddle import nn
from typeguard import check_argument_types
from paddlespeech.t2s.models.diffsinger.fastspeech2midi import FastSpeech2MIDI
from paddlespeech.t2s.modules.diffnet import DiffNet
from paddlespeech.t2s.modules.diffusion import GaussianDiffusion
class DiffSinger(nn.Layer):
"""DiffSinger module.
This is a module of DiffSinger described in `DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism`._
.. _`DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism`:
https://arxiv.org/pdf/2105.02446.pdf
Args:
Returns:
"""
def __init__(
self,
# min and max spec for stretching before diffusion
spec_min: paddle.Tensor,
spec_max: paddle.Tensor,
# fastspeech2midi config
idim: int,
odim: int,
use_energy_pred: bool=False,
use_postnet: bool=False,
# music score related
note_num: int=300,
is_slur_num: int=2,
fastspeech2_params: Dict[str, Any]={
"adim": 256,
"aheads": 2,
"elayers": 4,
"eunits": 1024,
"dlayers": 4,
"dunits": 1024,
"positionwise_layer_type": "conv1d",
"positionwise_conv_kernel_size": 1,
"use_scaled_pos_enc": True,
"use_batch_norm": True,
"encoder_normalize_before": True,
"decoder_normalize_before": True,
"encoder_concat_after": False,
"decoder_concat_after": False,
"reduction_factor": 1,
# for transformer
"transformer_enc_dropout_rate": 0.1,
"transformer_enc_positional_dropout_rate": 0.1,
"transformer_enc_attn_dropout_rate": 0.1,
"transformer_dec_dropout_rate": 0.1,
"transformer_dec_positional_dropout_rate": 0.1,
"transformer_dec_attn_dropout_rate": 0.1,
"transformer_activation_type": "gelu",
# duration predictor
"duration_predictor_layers": 2,
"duration_predictor_chans": 384,
"duration_predictor_kernel_size": 3,
"duration_predictor_dropout_rate": 0.1,
# pitch predictor
"use_pitch_embed": True,
"pitch_predictor_layers": 2,
"pitch_predictor_chans": 384,
"pitch_predictor_kernel_size": 3,
"pitch_predictor_dropout": 0.5,
"pitch_embed_kernel_size": 9,
"pitch_embed_dropout": 0.5,
"stop_gradient_from_pitch_predictor": False,
# energy predictor
"use_energy_embed": False,
"energy_predictor_layers": 2,
"energy_predictor_chans": 384,
"energy_predictor_kernel_size": 3,
"energy_predictor_dropout": 0.5,
"energy_embed_kernel_size": 9,
"energy_embed_dropout": 0.5,
"stop_gradient_from_energy_predictor": False,
# postnet
"postnet_layers": 5,
"postnet_chans": 512,
"postnet_filts": 5,
"postnet_dropout_rate": 0.5,
# spk emb
"spk_num": None,
"spk_embed_dim": None,
"spk_embed_integration_type": "add",
# training related
"init_type": "xavier_uniform",
"init_enc_alpha": 1.0,
"init_dec_alpha": 1.0,
# speaker classifier
"enable_speaker_classifier": False,
"hidden_sc_dim": 256,
},
# denoiser config
denoiser_params: Dict[str, Any]={
"in_channels": 80,
"out_channels": 80,
"kernel_size": 3,
"layers": 20,
"stacks": 5,
"residual_channels": 256,
"gate_channels": 512,
"skip_channels": 256,
"aux_channels": 256,
"dropout": 0.,
"bias": True,
"use_weight_norm": False,
"init_type": "kaiming_normal",
},
# diffusion config
diffusion_params: Dict[str, Any]={
"num_train_timesteps": 100,
"beta_start": 0.0001,
"beta_end": 0.06,
"beta_schedule": "squaredcos_cap_v2",
"num_max_timesteps": 60,
"stretch": True,
}, ):
"""Initialize DiffSinger module.
Args:
spec_min (paddle.Tensor): The minimum value of the feature(mel) to stretch before diffusion.
spec_max (paddle.Tensor): The maximum value of the feature(mel) to stretch before diffusion.
idim (int): Dimension of the inputs (Input vocabrary size.).
odim (int): Dimension of the outputs (Acoustic feature dimension.).
use_energy_pred (bool, optional): whether use energy predictor. Defaults False.
use_postnet (bool, optional): whether use postnet. Defaults False.
note_num (int, optional): The number of note. Defaults to 300.
is_slur_num (int, optional): The number of slur. Defaults to 2.
fastspeech2_params (Dict[str, Any]): Parameter dict for fastspeech2 module.
denoiser_params (Dict[str, Any]): Parameter dict for dinoiser module.
diffusion_params (Dict[str, Any]): Parameter dict for diffusion module.
"""
assert check_argument_types()
super().__init__()
self.fs2 = FastSpeech2MIDI(
idim=idim,
odim=odim,
fastspeech2_params=fastspeech2_params,
note_num=note_num,
is_slur_num=is_slur_num,
use_energy_pred=use_energy_pred,
use_postnet=use_postnet, )
denoiser = DiffNet(**denoiser_params)
self.diffusion = GaussianDiffusion(
denoiser,
**diffusion_params,
min_values=spec_min,
max_values=spec_max, )
def forward(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
text_lengths: paddle.Tensor,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
durations: paddle.Tensor,
pitch: paddle.Tensor,
energy: paddle.Tensor,
spk_emb: paddle.Tensor=None,
spk_id: paddle.Tensor=None,
only_train_fs2: bool=True,
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
"""Calculate forward propagation.
Args:
text(Tensor(int64)):
Batch of padded token (phone) ids (B, Tmax).
note(Tensor(int64)):
Batch of padded note (element in music score) ids (B, Tmax).
note_dur(Tensor(float32)):
Batch of padded note durations in seconds (element in music score) (B, Tmax).
is_slur(Tensor(int64)):
Batch of padded slur (element in music score) ids (B, Tmax).
text_lengths(Tensor(int64)):
Batch of phone lengths of each input (B,).
speech(Tensor[float32]):
Batch of padded target features (e.g. mel) (B, Lmax, odim).
speech_lengths(Tensor(int64)):
Batch of the lengths of each target features (B,).
durations(Tensor(int64)):
Batch of padded token durations in frame (B, Tmax).
pitch(Tensor[float32]):
Batch of padded frame-averaged pitch (B, Lmax, 1).
energy(Tensor[float32]):
Batch of padded frame-averaged energy (B, Lmax, 1).
spk_emb(Tensor[float32], optional):
Batch of speaker embeddings (B, spk_embed_dim).
spk_id(Tnesor[int64], optional(int64)):
Batch of speaker ids (B,)
only_train_fs2(bool):
Whether to train only the fastspeech2 module
Returns:
"""
# only train fastspeech2 module firstly
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.fs2(
text=text,
note=note,
note_dur=note_dur,
is_slur=is_slur,
text_lengths=text_lengths,
speech=speech,
speech_lengths=speech_lengths,
durations=durations,
pitch=pitch,
energy=energy,
spk_id=spk_id,
spk_emb=spk_emb)
if only_train_fs2:
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
# get the encoder output from fastspeech2 as the condition of denoiser module
cond_fs2, mel_masks = self.fs2.encoder_infer_batch(
text=text,
note=note,
note_dur=note_dur,
is_slur=is_slur,
text_lengths=text_lengths,
speech_lengths=speech_lengths,
ds=durations,
ps=pitch,
es=energy)
cond_fs2 = cond_fs2.transpose((0, 2, 1))
# get the output(final mel) from diffusion module
noise_pred, noise_target = self.diffusion(
speech.transpose((0, 2, 1)), cond_fs2)
return noise_pred, noise_target, mel_masks
def inference(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
get_mel_fs2: bool=False, ):
"""Run inference
Args:
text(Tensor(int64)):
Batch of padded token (phone) ids (B, Tmax).
note(Tensor(int64)):
Batch of padded note (element in music score) ids (B, Tmax).
note_dur(Tensor(float32)):
Batch of padded note durations in seconds (element in music score) (B, Tmax).
is_slur(Tensor(int64)):
Batch of padded slur (element in music score) ids (B, Tmax).
get_mel_fs2 (bool, optional): . Defaults to False.
Whether to get mel from fastspeech2 module.
Returns:
"""
mel_fs2, _, _, _ = self.fs2.inference(text, note, note_dur, is_slur)
if get_mel_fs2:
return mel_fs2
mel_fs2 = mel_fs2.unsqueeze(0).transpose((0, 2, 1))
cond_fs2 = self.fs2.encoder_infer(text, note, note_dur, is_slur)
cond_fs2 = cond_fs2.transpose((0, 2, 1))
noise = paddle.randn(mel_fs2.shape)
mel = self.diffusion.inference(
noise=noise,
cond=cond_fs2,
ref_x=mel_fs2,
scheduler_type="ddpm",
num_inference_steps=60)
mel = mel.transpose((0, 2, 1))
return mel[0]
class DiffSingerInference(nn.Layer):
def __init__(self, normalizer, model):
super().__init__()
self.normalizer = normalizer
self.acoustic_model = model
def forward(self, text, note, note_dur, is_slur, get_mel_fs2: bool=False):
"""Calculate forward propagation.
Args:
text(Tensor(int64)):
Batch of padded token (phone) ids (B, Tmax).
note(Tensor(int64)):
Batch of padded note (element in music score) ids (B, Tmax).
note_dur(Tensor(float32)):
Batch of padded note durations in seconds (element in music score) (B, Tmax).
is_slur(Tensor(int64)):
Batch of padded slur (element in music score) ids (B, Tmax).
get_mel_fs2 (bool, optional): . Defaults to False.
Whether to get mel from fastspeech2 module.
Returns:
logmel(Tensor(float32)): denorm logmel, [T, mel_bin]
"""
normalized_mel = self.acoustic_model.inference(
text=text,
note=note,
note_dur=note_dur,
is_slur=is_slur,
get_mel_fs2=get_mel_fs2)
logmel = normalized_mel
return logmel
class DiffusionLoss(nn.Layer):
"""Loss function module for Diffusion module on DiffSinger."""
def __init__(self, use_masking: bool=True,
use_weighted_masking: bool=False):
"""Initialize feed-forward Transformer loss module.
Args:
use_masking (bool):
Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool):
Whether to weighted masking in loss calculation.
"""
assert check_argument_types()
super().__init__()
assert (use_masking != use_weighted_masking) or not use_masking
self.use_masking = use_masking
self.use_weighted_masking = use_weighted_masking
# define criterions
reduction = "none" if self.use_weighted_masking else "mean"
self.l1_criterion = nn.L1Loss(reduction=reduction)
def forward(
self,
noise_pred: paddle.Tensor,
noise_target: paddle.Tensor,
mel_masks: paddle.Tensor, ) -> paddle.Tensor:
"""Calculate forward propagation.
Args:
noise_pred(Tensor):
Batch of outputs predict noise (B, Lmax, odim).
noise_target(Tensor):
Batch of target noise (B, Lmax, odim).
mel_masks(Tensor):
Batch of mask of real mel (B, Lmax, 1).
Returns:
"""
# apply mask to remove padded part
if self.use_masking:
noise_pred = noise_pred.masked_select(
mel_masks.broadcast_to(noise_pred.shape))
noise_target = noise_target.masked_select(
mel_masks.broadcast_to(noise_target.shape))
# calculate loss
l1_loss = self.l1_criterion(noise_pred, noise_target)
# make weighted mask and apply it
if self.use_weighted_masking:
mel_masks = mel_masks.unsqueeze(-1)
out_weights = mel_masks.cast(dtype=paddle.float32) / mel_masks.cast(
dtype=paddle.float32).sum(
axis=1, keepdim=True)
out_weights /= noise_target.shape[0] * noise_target.shape[2]
# apply weight
l1_loss = l1_loss.multiply(out_weights)
l1_loss = l1_loss.masked_select(
mel_masks.broadcast_to(l1_loss.shape)).sum()
return l1_loss

@ -0,0 +1,302 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from typing import Dict
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
from paddlespeech.t2s.training.reporter import report
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class DiffSingerUpdater(StandardUpdater):
def __init__(self,
model: Layer,
optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer],
dataloader: DataLoader,
ds_train_start_steps: int=160000,
output_dir: Path=None,
only_train_diffusion: bool=True):
super().__init__(model, optimizers, dataloader, init_state=None)
self.model = model._layers if isinstance(model,
paddle.DataParallel) else model
self.only_train_diffusion = only_train_diffusion
self.optimizers = optimizers
self.optimizer_fs2: Optimizer = optimizers['fs2']
self.optimizer_ds: Optimizer = optimizers['ds']
self.criterions = criterions
self.criterion_fs2 = criterions['fs2']
self.criterion_ds = criterions['ds']
self.dataloader = dataloader
self.ds_train_start_steps = ds_train_start_steps
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# spk_id!=None in multiple spk diffsinger
spk_id = batch["spk_id"] if "spk_id" in batch else None
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
# No explicit speaker identifier labels are used during voice cloning training.
if spk_emb is not None:
spk_id = None
# only train fastspeech2 module firstly
if self.state.iteration < self.ds_train_start_steps:
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
note=batch["note"],
note_dur=batch["note_dur"],
is_slur=batch["is_slur"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb,
only_train_fs2=True, )
l1_loss_fs2, ssim_loss_fs2, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion_fs2(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
p_outs=p_outs,
e_outs=e_outs,
ys=ys,
ds=batch["durations"],
ps=batch["pitch"],
es=batch["energy"],
ilens=batch["text_lengths"],
olens=olens,
spk_logits=spk_logits,
spk_ids=spk_id, )
loss_fs2 = l1_loss_fs2 + ssim_loss_fs2 + duration_loss + pitch_loss + energy_loss + speaker_loss
self.optimizer_fs2.clear_grad()
loss_fs2.backward()
self.optimizer_fs2.step()
report("train/loss_fs2", float(loss_fs2))
report("train/l1_loss_fs2", float(l1_loss_fs2))
report("train/ssim_loss_fs2", float(ssim_loss_fs2))
report("train/duration_loss", float(duration_loss))
report("train/pitch_loss", float(pitch_loss))
losses_dict["l1_loss_fs2"] = float(l1_loss_fs2)
losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
if speaker_loss != 0.:
report("train/speaker_loss", float(speaker_loss))
losses_dict["speaker_loss"] = float(speaker_loss)
if energy_loss != 0.:
report("train/energy_loss", float(energy_loss))
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss_fs2"] = float(loss_fs2)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
# Then only train diffusion module, freeze fastspeech2 parameters.
if self.state.iteration > self.ds_train_start_steps:
for param in self.model.fs2.parameters():
param.trainable = False if self.only_train_diffusion else True
noise_pred, noise_target, mel_masks = self.model(
text=batch["text"],
note=batch["note"],
note_dur=batch["note_dur"],
is_slur=batch["is_slur"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb,
only_train_fs2=False, )
noise_pred = noise_pred.transpose((0, 2, 1))
noise_target = noise_target.transpose((0, 2, 1))
mel_masks = mel_masks.transpose((0, 2, 1))
l1_loss_ds = self.criterion_ds(
noise_pred=noise_pred,
noise_target=noise_target,
mel_masks=mel_masks, )
loss_ds = l1_loss_ds
self.optimizer_ds.clear_grad()
loss_ds.backward()
self.optimizer_ds.step()
report("train/loss_ds", float(loss_ds))
report("train/l1_loss_ds", float(l1_loss_ds))
losses_dict["l1_loss_ds"] = float(l1_loss_ds)
losses_dict["loss_ds"] = float(loss_ds)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)
class DiffSingerEvaluator(StandardEvaluator):
def __init__(
self,
model: Layer,
criterions: Dict[str, Layer],
dataloader: DataLoader,
output_dir: Path=None, ):
super().__init__(model, dataloader)
self.model = model._layers if isinstance(model,
paddle.DataParallel) else model
self.criterions = criterions
self.criterion_fs2 = criterions['fs2']
self.criterion_ds = criterions['ds']
self.dataloader = dataloader
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
self.msg = "Evaluate: "
losses_dict = {}
# spk_id!=None in multiple spk diffsinger
spk_id = batch["spk_id"] if "spk_id" in batch else None
spk_emb = batch["spk_emb"] if "spk_emb" in batch else None
if spk_emb is not None:
spk_id = None
# Here show fastspeech2 eval
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits = self.model(
text=batch["text"],
note=batch["note"],
note_dur=batch["note_dur"],
is_slur=batch["is_slur"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb,
only_train_fs2=True, )
l1_loss_fs2, ssim_loss_fs2, duration_loss, pitch_loss, energy_loss, speaker_loss = self.criterion_fs2(
after_outs=after_outs,
before_outs=before_outs,
d_outs=d_outs,
p_outs=p_outs,
e_outs=e_outs,
ys=ys,
ds=batch["durations"],
ps=batch["pitch"],
es=batch["energy"],
ilens=batch["text_lengths"],
olens=olens,
spk_logits=spk_logits,
spk_ids=spk_id, )
loss_fs2 = l1_loss_fs2 + ssim_loss_fs2 + duration_loss + pitch_loss + energy_loss + speaker_loss
report("eval/loss_fs2", float(loss_fs2))
report("eval/l1_loss_fs2", float(l1_loss_fs2))
report("eval/ssim_loss_fs2", float(ssim_loss_fs2))
report("eval/duration_loss", float(duration_loss))
report("eval/pitch_loss", float(pitch_loss))
losses_dict["l1_loss_fs2"] = float(l1_loss_fs2)
losses_dict["ssim_loss_fs2"] = float(ssim_loss_fs2)
losses_dict["duration_loss"] = float(duration_loss)
losses_dict["pitch_loss"] = float(pitch_loss)
if speaker_loss != 0.:
report("eval/speaker_loss", float(speaker_loss))
losses_dict["speaker_loss"] = float(speaker_loss)
if energy_loss != 0.:
report("eval/energy_loss", float(energy_loss))
losses_dict["energy_loss"] = float(energy_loss)
losses_dict["loss_fs2"] = float(loss_fs2)
# Here show diffusion eval
noise_pred, noise_target, mel_masks = self.model(
text=batch["text"],
note=batch["note"],
note_dur=batch["note_dur"],
is_slur=batch["is_slur"],
text_lengths=batch["text_lengths"],
speech=batch["speech"],
speech_lengths=batch["speech_lengths"],
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id,
spk_emb=spk_emb,
only_train_fs2=False, )
noise_pred = noise_pred.transpose((0, 2, 1))
noise_target = noise_target.transpose((0, 2, 1))
mel_masks = mel_masks.transpose((0, 2, 1))
l1_loss_ds = self.criterion_ds(
noise_pred=noise_pred,
noise_target=noise_target,
mel_masks=mel_masks, )
loss_ds = l1_loss_ds
report("eval/loss_ds", float(loss_ds))
report("eval/l1_loss_ds", float(l1_loss_ds))
losses_dict["l1_loss_ds"] = float(l1_loss_ds)
losses_dict["loss_ds"] = float(loss_ds)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)

@ -0,0 +1,654 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from espnet(https://github.com/espnet/espnet)
from typing import Any
from typing import Dict
from typing import Sequence
from typing import Tuple
import paddle
from paddle import nn
from typeguard import check_argument_types
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Loss
from paddlespeech.t2s.modules.losses import ssim
from paddlespeech.t2s.modules.masked_fill import masked_fill
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
from paddlespeech.t2s.modules.nets_utils import make_pad_mask
class FastSpeech2MIDI(FastSpeech2):
"""The Fastspeech2 module of DiffSinger.
"""
def __init__(
self,
# fastspeech2 network structure related
idim: int,
odim: int,
fastspeech2_params: Dict[str, Any],
# note emb
note_num: int=300,
# is_slur emb
is_slur_num: int=2,
use_energy_pred: bool=False,
use_postnet: bool=False, ):
"""Initialize FastSpeech2 module for svs.
Args:
fastspeech2_params (Dict):
The config of FastSpeech2 module on DiffSinger model
note_num (Optional[int]):
Number of note. If not None, assume that the
note_ids will be provided as the input and use note_embedding_table.
is_slur_num (Optional[int]):
Number of note. If not None, assume that the
is_slur_ids will be provided as the input
"""
assert check_argument_types()
super().__init__(idim=idim, odim=odim, **fastspeech2_params)
self.use_energy_pred = use_energy_pred
self.use_postnet = use_postnet
if not self.use_postnet:
self.postnet = None
self.note_embed_dim = self.is_slur_embed_dim = fastspeech2_params[
"adim"]
# note_ embed
self.note_embedding_table = nn.Embedding(
num_embeddings=note_num,
embedding_dim=self.note_embed_dim,
padding_idx=self.padding_idx)
self.note_dur_layer = nn.Linear(1, self.note_embed_dim)
# slur embed
self.is_slur_embedding_table = nn.Embedding(
num_embeddings=is_slur_num,
embedding_dim=self.is_slur_embed_dim,
padding_idx=self.padding_idx)
def forward(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
text_lengths: paddle.Tensor,
speech: paddle.Tensor,
speech_lengths: paddle.Tensor,
durations: paddle.Tensor,
pitch: paddle.Tensor,
energy: paddle.Tensor,
spk_emb: paddle.Tensor=None,
spk_id: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]:
"""Calculate forward propagation.
Args:
text(Tensor(int64)):
Batch of padded token (phone) ids (B, Tmax).
note(Tensor(int64)):
Batch of padded note (element in music score) ids (B, Tmax).
note_dur(Tensor(float32)):
Batch of padded note durations in seconds (element in music score) (B, Tmax).
is_slur(Tensor(int64)):
Batch of padded slur (element in music score) ids (B, Tmax).
text_lengths(Tensor(int64)):
Batch of phone lengths of each input (B,).
speech(Tensor[float32]):
Batch of padded target features (e.g. mel) (B, Lmax, odim).
speech_lengths(Tensor(int64)):
Batch of the lengths of each target features (B,).
durations(Tensor(int64)):
Batch of padded token durations in frame (B, Tmax).
pitch(Tensor[float32]):
Batch of padded frame-averaged pitch (B, Lmax, 1).
energy(Tensor[float32]):
Batch of padded frame-averaged energy (B, Lmax, 1).
spk_emb(Tensor[float32], optional):
Batch of speaker embeddings (B, spk_embed_dim).
spk_id(Tnesor[int64], optional(int64)):
Batch of speaker ids (B,)
Returns:
"""
xs = paddle.cast(text, 'int64')
note = paddle.cast(note, 'int64')
note_dur = paddle.cast(note_dur, 'float32')
is_slur = paddle.cast(is_slur, 'int64')
ilens = paddle.cast(text_lengths, 'int64')
olens = paddle.cast(speech_lengths, 'int64')
ds = paddle.cast(durations, 'int64')
ps = pitch
es = energy
ys = speech
olens = speech_lengths
if spk_id is not None:
spk_id = paddle.cast(spk_id, 'int64')
# forward propagation
before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
olens=olens,
ds=ds,
ps=ps,
es=es,
is_inference=False,
spk_emb=spk_emb,
spk_id=spk_id, )
# modify mod part of groundtruth
if self.reduction_factor > 1:
olens = olens - olens % self.reduction_factor
max_olen = max(olens)
ys = ys[:, :max_olen]
return before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens, spk_logits
def _forward(
self,
xs: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor=None,
ds: paddle.Tensor=None,
ps: paddle.Tensor=None,
es: paddle.Tensor=None,
is_inference: bool=False,
is_train_diffusion: bool=False,
return_after_enc=False,
alpha: float=1.0,
spk_emb=None,
spk_id=None, ) -> Sequence[paddle.Tensor]:
before_outs = after_outs = d_outs = p_outs = e_outs = spk_logits = None
# forward encoder
masks = self._source_mask(ilens)
note_emb = self.note_embedding_table(note)
note_dur_emb = self.note_dur_layer(paddle.unsqueeze(note_dur, axis=-1))
is_slur_emb = self.is_slur_embedding_table(is_slur)
# (B, Tmax, adim)
hs, _ = self.encoder(
xs=xs,
masks=masks,
note_emb=note_emb,
note_dur_emb=note_dur_emb,
is_slur_emb=is_slur_emb, )
if self.spk_num and self.enable_speaker_classifier and not is_inference:
hs_for_spk_cls = self.grad_reverse(hs)
spk_logits = self.speaker_classifier(hs_for_spk_cls, ilens)
else:
spk_logits = None
# integrate speaker embedding
if self.spk_embed_dim is not None:
# spk_emb has a higher priority than spk_id
if spk_emb is not None:
hs = self._integrate_with_spk_embed(hs, spk_emb)
elif spk_id is not None:
spk_emb = self.spk_embedding_table(spk_id)
hs = self._integrate_with_spk_embed(hs, spk_emb)
# forward duration predictor (phone-level) and variance predictors (frame-level)
d_masks = make_pad_mask(ilens)
if olens is not None:
pitch_masks = make_pad_mask(olens).unsqueeze(-1)
else:
pitch_masks = None
# inference for decoder input for diffusion
if is_train_diffusion:
hs = self.length_regulator(hs, ds, is_inference=False)
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
e_embs = self.energy_embed(
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
hs += e_embs
elif is_inference:
# (B, Tmax)
if ds is not None:
d_outs = ds
else:
d_outs = self.duration_predictor.inference(hs, d_masks)
# (B, Lmax, adim)
hs = self.length_regulator(hs, d_outs, alpha, is_inference=True)
if ps is not None:
p_outs = ps
else:
if self.stop_gradient_from_pitch_predictor:
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else:
p_outs = self.pitch_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(p_outs.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
if es is not None:
e_outs = es
else:
if self.stop_gradient_from_energy_predictor:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
else:
e_outs = self.energy_predictor(hs, pitch_masks)
e_embs = self.energy_embed(
e_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
hs += e_embs
# training
else:
d_outs = self.duration_predictor(hs, d_masks)
# (B, Lmax, adim)
hs = self.length_regulator(hs, ds, is_inference=False)
if self.stop_gradient_from_pitch_predictor:
p_outs = self.pitch_predictor(hs.detach(), pitch_masks)
else:
p_outs = self.pitch_predictor(hs, pitch_masks)
p_embs = self.pitch_embed(ps.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += p_embs
if self.use_energy_pred:
if self.stop_gradient_from_energy_predictor:
e_outs = self.energy_predictor(hs.detach(), pitch_masks)
else:
e_outs = self.energy_predictor(hs, pitch_masks)
e_embs = self.energy_embed(es.transpose((0, 2, 1))).transpose(
(0, 2, 1))
hs += e_embs
# forward decoder
if olens is not None and not is_inference:
if self.reduction_factor > 1:
olens_in = paddle.to_tensor(
[olen // self.reduction_factor for olen in olens.numpy()])
else:
olens_in = olens
# (B, 1, T)
h_masks = self._source_mask(olens_in)
else:
h_masks = None
if return_after_enc:
return hs, h_masks
if self.decoder_type == 'cnndecoder':
# remove output masks for dygraph to static graph
zs = self.decoder(hs, h_masks)
before_outs = zs
else:
# (B, Lmax, adim)
zs, _ = self.decoder(hs, h_masks)
# (B, Lmax, odim)
before_outs = self.feat_out(zs).reshape(
(paddle.shape(zs)[0], -1, self.odim))
# postnet -> (B, Lmax//r * r, odim)
if self.postnet is None:
after_outs = before_outs
else:
after_outs = before_outs + self.postnet(
before_outs.transpose((0, 2, 1))).transpose((0, 2, 1))
return before_outs, after_outs, d_outs, p_outs, e_outs, spk_logits
def encoder_infer(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
alpha: float=1.0,
spk_emb=None,
spk_id=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
xs = paddle.cast(text, 'int64').unsqueeze(0)
note = paddle.cast(note, 'int64').unsqueeze(0)
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
# setup batch axis
ilens = paddle.shape(xs)[1]
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
# (1, L, odim)
# use *_ to avoid bug in dygraph to static graph
hs, _ = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
is_inference=True,
return_after_enc=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
return hs
# get encoder output for diffusion training
def encoder_infer_batch(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
text_lengths: paddle.Tensor,
speech_lengths: paddle.Tensor,
ds: paddle.Tensor=None,
ps: paddle.Tensor=None,
es: paddle.Tensor=None,
alpha: float=1.0,
spk_emb=None,
spk_id=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
xs = paddle.cast(text, 'int64')
note = paddle.cast(note, 'int64')
note_dur = paddle.cast(note_dur, 'float32')
is_slur = paddle.cast(is_slur, 'int64')
ilens = paddle.cast(text_lengths, 'int64')
olens = paddle.cast(speech_lengths, 'int64')
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
# (1, L, odim)
# use *_ to avoid bug in dygraph to static graph
hs, h_masks = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
olens=olens,
ds=ds,
ps=ps,
es=es,
return_after_enc=True,
is_train_diffusion=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
return hs, h_masks
def inference(
self,
text: paddle.Tensor,
note: paddle.Tensor,
note_dur: paddle.Tensor,
is_slur: paddle.Tensor,
durations: paddle.Tensor=None,
pitch: paddle.Tensor=None,
energy: paddle.Tensor=None,
alpha: float=1.0,
use_teacher_forcing: bool=False,
spk_emb=None,
spk_id=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Generate the sequence of features given the sequences of characters.
Args:
text(Tensor(int64)):
Input sequence of characters (T,).
note(Tensor(int64)):
Input note (element in music score) ids (T,).
note_dur(Tensor(float32)):
Input note durations in seconds (element in music score) (T,).
is_slur(Tensor(int64)):
Input slur (element in music score) ids (T,).
durations(Tensor, optional (int64)):
Groundtruth of duration (T,).
pitch(Tensor, optional):
Groundtruth of token-averaged pitch (T, 1).
energy(Tensor, optional):
Groundtruth of token-averaged energy (T, 1).
alpha(float, optional):
Alpha to control the speed.
use_teacher_forcing(bool, optional):
Whether to use teacher forcing.
If true, groundtruth of duration, pitch and energy will be used.
spk_emb(Tensor, optional, optional):
peaker embedding vector (spk_embed_dim,). (Default value = None)
spk_id(Tensor, optional(int64), optional):
spk ids (1,). (Default value = None)
Returns:
"""
xs = paddle.cast(text, 'int64').unsqueeze(0)
note = paddle.cast(note, 'int64').unsqueeze(0)
note_dur = paddle.cast(note_dur, 'float32').unsqueeze(0)
is_slur = paddle.cast(is_slur, 'int64').unsqueeze(0)
d, p, e = durations, pitch, energy
# setup batch axis
ilens = paddle.shape(xs)[1]
if spk_emb is not None:
spk_emb = spk_emb.unsqueeze(0)
if use_teacher_forcing:
# use groundtruth of duration, pitch, and energy
ds = d.unsqueeze(0) if d is not None else None
ps = p.unsqueeze(0) if p is not None else None
es = e.unsqueeze(0) if e is not None else None
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
ds=ds,
ps=ps,
es=es,
spk_emb=spk_emb,
spk_id=spk_id,
is_inference=True)
else:
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs, _ = self._forward(
xs=xs,
note=note,
note_dur=note_dur,
is_slur=is_slur,
ilens=ilens,
is_inference=True,
alpha=alpha,
spk_emb=spk_emb,
spk_id=spk_id, )
if e_outs is None:
e_outs = [None]
return outs[0], d_outs[0], p_outs[0], e_outs[0]
class FastSpeech2MIDILoss(FastSpeech2Loss):
"""Loss function module for DiffSinger."""
def __init__(self, use_masking: bool=True,
use_weighted_masking: bool=False):
"""Initialize feed-forward Transformer loss module.
Args:
use_masking (bool):
Whether to apply masking for padded part in loss calculation.
use_weighted_masking (bool):
Whether to weighted masking in loss calculation.
"""
assert check_argument_types()
super().__init__(use_masking, use_weighted_masking)
def forward(
self,
after_outs: paddle.Tensor,
before_outs: paddle.Tensor,
d_outs: paddle.Tensor,
p_outs: paddle.Tensor,
e_outs: paddle.Tensor,
ys: paddle.Tensor,
ds: paddle.Tensor,
ps: paddle.Tensor,
es: paddle.Tensor,
ilens: paddle.Tensor,
olens: paddle.Tensor,
spk_logits: paddle.Tensor=None,
spk_ids: paddle.Tensor=None,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor,
paddle.Tensor, ]:
"""Calculate forward propagation.
Args:
after_outs(Tensor):
Batch of outputs after postnets (B, Lmax, odim).
before_outs(Tensor):
Batch of outputs before postnets (B, Lmax, odim).
d_outs(Tensor):
Batch of outputs of duration predictor (B, Tmax).
p_outs(Tensor):
Batch of outputs of pitch predictor (B, Lmax, 1).
e_outs(Tensor):
Batch of outputs of energy predictor (B, Lmax, 1).
ys(Tensor):
Batch of target features (B, Lmax, odim).
ds(Tensor):
Batch of durations (B, Tmax).
ps(Tensor):
Batch of target frame-averaged pitch (B, Lmax, 1).
es(Tensor):
Batch of target frame-averaged energy (B, Lmax, 1).
ilens(Tensor):
Batch of the lengths of each input (B,).
olens(Tensor):
Batch of the lengths of each target (B,).
spk_logits(Option[Tensor]):
Batch of outputs after speaker classifier (B, Lmax, num_spk)
spk_ids(Option[Tensor]):
Batch of target spk_id (B,)
Returns:
"""
l1_loss = duration_loss = pitch_loss = energy_loss = speaker_loss = ssim_loss = 0.0
# apply mask to remove padded part
if self.use_masking:
# make feature for ssim loss
out_pad_masks = make_pad_mask(olens).unsqueeze(-1)
before_outs_ssim = masked_fill(before_outs, out_pad_masks, 0.0)
if not paddle.equal_all(after_outs, before_outs):
after_outs_ssim = masked_fill(after_outs, out_pad_masks, 0.0)
ys_ssim = masked_fill(ys, out_pad_masks, 0.0)
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
before_outs = before_outs.masked_select(
out_masks.broadcast_to(before_outs.shape))
if not paddle.equal_all(after_outs, before_outs):
after_outs = after_outs.masked_select(
out_masks.broadcast_to(after_outs.shape))
ys = ys.masked_select(out_masks.broadcast_to(ys.shape))
duration_masks = make_non_pad_mask(ilens)
d_outs = d_outs.masked_select(
duration_masks.broadcast_to(d_outs.shape))
ds = ds.masked_select(duration_masks.broadcast_to(ds.shape))
pitch_masks = out_masks
p_outs = p_outs.masked_select(
pitch_masks.broadcast_to(p_outs.shape))
ps = ps.masked_select(pitch_masks.broadcast_to(ps.shape))
if e_outs is not None:
e_outs = e_outs.masked_select(
pitch_masks.broadcast_to(e_outs.shape))
es = es.masked_select(pitch_masks.broadcast_to(es.shape))
if spk_logits is not None and spk_ids is not None:
batch_size = spk_ids.shape[0]
spk_ids = paddle.repeat_interleave(spk_ids, spk_logits.shape[1],
None)
spk_logits = paddle.reshape(spk_logits,
[-1, spk_logits.shape[-1]])
mask_index = spk_logits.abs().sum(axis=1) != 0
spk_ids = spk_ids[mask_index]
spk_logits = spk_logits[mask_index]
# calculate loss
l1_loss = self.l1_criterion(before_outs, ys)
ssim_loss = 1.0 - ssim(
before_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1))
if not paddle.equal_all(after_outs, before_outs):
l1_loss += self.l1_criterion(after_outs, ys)
ssim_loss += (
1.0 - ssim(after_outs_ssim.unsqueeze(1), ys_ssim.unsqueeze(1)))
l1_loss = l1_loss * 0.5
ssim_loss = ssim_loss * 0.5
duration_loss = self.duration_criterion(d_outs, ds)
pitch_loss = self.l1_criterion(p_outs, ps)
if e_outs is not None:
energy_loss = self.l1_criterion(e_outs, es)
if spk_logits is not None and spk_ids is not None:
speaker_loss = self.ce_criterion(spk_logits, spk_ids) / batch_size
# make weighted mask and apply it
if self.use_weighted_masking:
out_masks = make_non_pad_mask(olens).unsqueeze(-1)
out_weights = out_masks.cast(dtype=paddle.float32) / out_masks.cast(
dtype=paddle.float32).sum(
axis=1, keepdim=True)
out_weights /= ys.shape[0] * ys.shape[2]
duration_masks = make_non_pad_mask(ilens)
duration_weights = (duration_masks.cast(dtype=paddle.float32) /
duration_masks.cast(dtype=paddle.float32).sum(
axis=1, keepdim=True))
duration_weights /= ds.shape[0]
# apply weight
l1_loss = l1_loss.multiply(out_weights)
l1_loss = l1_loss.masked_select(
out_masks.broadcast_to(l1_loss.shape)).sum()
ssim_loss = ssim_loss.multiply(out_weights)
ssim_loss = ssim_loss.masked_select(
out_masks.broadcast_to(ssim_loss.shape)).sum()
duration_loss = (duration_loss.multiply(duration_weights)
.masked_select(duration_masks).sum())
pitch_masks = out_masks
pitch_weights = out_weights
pitch_loss = pitch_loss.multiply(pitch_weights)
pitch_loss = pitch_loss.masked_select(
pitch_masks.broadcast_to(pitch_loss.shape)).sum()
if e_outs is not None:
energy_loss = energy_loss.multiply(pitch_weights)
energy_loss = energy_loss.masked_select(
pitch_masks.broadcast_to(energy_loss.shape)).sum()
return l1_loss, ssim_loss, duration_loss, pitch_loss, energy_loss, speaker_loss

@ -93,6 +93,7 @@ class FastSpeech2(nn.Layer):
transformer_dec_dropout_rate: float=0.1,
transformer_dec_positional_dropout_rate: float=0.1,
transformer_dec_attn_dropout_rate: float=0.1,
transformer_activation_type: str="relu",
# for conformer
conformer_pos_enc_layer_type: str="rel_pos",
conformer_self_attn_layer_type: str="rel_selfattn",
@ -200,6 +201,8 @@ class FastSpeech2(nn.Layer):
Dropout rate after decoder positional encoding.
transformer_dec_attn_dropout_rate (float):
Dropout rate in decoder self-attention module.
transformer_activation_type (str):
Activation function type in transformer.
conformer_pos_enc_layer_type (str):
Pos encoding layer type in conformer.
conformer_self_attn_layer_type (str):
@ -250,7 +253,7 @@ class FastSpeech2(nn.Layer):
Kernel size of energy embedding.
energy_embed_dropout_rate (float):
Dropout rate for energy embedding.
stop_gradient_from_energy_predictorbool):
stop_gradient_from_energy_predictor (bool):
Whether to stop gradient from energy predictor to encoder.
spk_num (Optional[int]):
Number of speakers. If not None, assume that the spk_embed_dim is not None,
@ -269,7 +272,7 @@ class FastSpeech2(nn.Layer):
How to integrate tone embedding.
init_type (str):
How to initialize transformer parameters.
init_enc_alpha float):
init_enc_alpha (float):
Initial value of alpha in scaled pos encoding of the encoder.
init_dec_alpha (float):
Initial value of alpha in scaled pos encoding of the decoder.
@ -344,7 +347,8 @@ class FastSpeech2(nn.Layer):
normalize_before=encoder_normalize_before,
concat_after=encoder_concat_after,
positionwise_layer_type=positionwise_layer_type,
positionwise_conv_kernel_size=positionwise_conv_kernel_size, )
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
activation_type=transformer_activation_type)
elif encoder_type == "conformer":
self.encoder = ConformerEncoder(
idim=idim,
@ -453,7 +457,8 @@ class FastSpeech2(nn.Layer):
normalize_before=decoder_normalize_before,
concat_after=decoder_concat_after,
positionwise_layer_type=positionwise_layer_type,
positionwise_conv_kernel_size=positionwise_conv_kernel_size, )
positionwise_conv_kernel_size=positionwise_conv_kernel_size,
activation_type=conformer_activation_type, )
elif decoder_type == "conformer":
self.decoder = ConformerEncoder(
idim=0,

@ -37,7 +37,8 @@ def get_activation(act, **kwargs):
"selu": paddle.nn.SELU,
"leakyrelu": paddle.nn.LeakyReLU,
"swish": paddle.nn.Swish,
"glu": GLU
"glu": GLU,
"gelu": paddle.nn.GELU,
}
return activation_funcs[act](**kwargs)

@ -0,0 +1,245 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.utils.initialize import _calculate_fan_in_and_fan_out
from paddlespeech.utils.initialize import kaiming_normal_
from paddlespeech.utils.initialize import kaiming_uniform_
from paddlespeech.utils.initialize import uniform_
from paddlespeech.utils.initialize import zeros_
def Conv1D(*args, **kwargs):
layer = nn.Conv1D(*args, **kwargs)
# Initialize the weight to be consistent with the official
kaiming_normal_(layer.weight)
# Initialization is consistent with torch
if layer.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(layer.weight)
if fan_in != 0:
bound = 1 / math.sqrt(fan_in)
uniform_(layer.bias, -bound, bound)
return layer
# Initialization is consistent with torch
def Linear(*args, **kwargs):
layer = nn.Linear(*args, **kwargs)
kaiming_uniform_(layer.weight, a=math.sqrt(5))
if layer.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(layer.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
uniform_(layer.bias, -bound, bound)
return layer
class ResidualBlock(nn.Layer):
"""ResidualBlock
Args:
encoder_hidden (int, optional):
Input feature size of the 1D convolution, by default 256
residual_channels (int, optional):
Feature size of the residual output(and also the input), by default 256
gate_channels (int, optional):
Output feature size of the 1D convolution, by default 512
kernel_size (int, optional):
Kernel size of the 1D convolution, by default 3
dilation (int, optional):
Dilation of the 1D convolution, by default 4
"""
def __init__(self,
encoder_hidden: int=256,
residual_channels: int=256,
gate_channels: int=512,
kernel_size: int=3,
dilation: int=4):
super().__init__()
self.dilated_conv = Conv1D(
residual_channels,
gate_channels,
kernel_size,
padding=dilation,
dilation=dilation)
self.diffusion_projection = Linear(residual_channels, residual_channels)
self.conditioner_projection = Conv1D(encoder_hidden, gate_channels, 1)
self.output_projection = Conv1D(residual_channels, gate_channels, 1)
def forward(
self,
x: paddle.Tensor,
diffusion_step: paddle.Tensor,
cond: paddle.Tensor, ):
"""Calculate forward propagation.
Args:
spec (Tensor(float32)): input feature. (B, residual_channels, T)
diffusion_step (Tensor(int64)): The timestep input (adding noise step). (B,)
cond (Tensor(float32)): The auxiliary input (e.g. fastspeech2 encoder output). (B, residual_channels, T)
Returns:
x (Tensor(float32)): output (B, residual_channels, T)
"""
diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
cond = self.conditioner_projection(cond)
y = x + diffusion_step
y = self.dilated_conv(y) + cond
gate, filter = paddle.chunk(y, 2, axis=1)
y = F.sigmoid(gate) * paddle.tanh(filter)
y = self.output_projection(y)
residual, skip = paddle.chunk(y, 2, axis=1)
return (x + residual) / math.sqrt(2.0), skip
class SinusoidalPosEmb(nn.Layer):
"""Positional embedding
"""
def __init__(self, dim: int=256):
super().__init__()
self.dim = dim
def forward(self, x: paddle.Tensor):
x = paddle.cast(x, 'float32')
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = paddle.exp(paddle.arange(half_dim) * -emb)
emb = x[:, None] * emb[None, :]
emb = paddle.concat([emb.sin(), emb.cos()], axis=-1)
return emb
class DiffNet(nn.Layer):
"""A Mel-Spectrogram Denoiser
Args:
in_channels (int, optional):
Number of channels of the input mel-spectrogram, by default 80
out_channels (int, optional):
Number of channels of the output mel-spectrogram, by default 80
kernel_size (int, optional):
Kernel size of the residual blocks inside, by default 3
layers (int, optional):
Number of residual blocks inside, by default 20
stacks (int, optional):
The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional):
Residual channel of the residual blocks, by default 256
gate_channels (int, optional):
Gate channel of the residual blocks, by default 512
skip_channels (int, optional):
Skip channel of the residual blocks, by default 256
aux_channels (int, optional):
Auxiliary channel of the residual blocks, by default 256
dropout (float, optional):
Dropout of the residual blocks, by default 0.
bias (bool, optional):
Whether to use bias in residual blocks, by default True
use_weight_norm (bool, optional):
Whether to use weight norm in all convolutions, by default False
"""
def __init__(
self,
in_channels: int=80,
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_normal", ):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.layers = layers
self.aux_channels = aux_channels
self.residual_channels = residual_channels
self.gate_channels = gate_channels
self.kernel_size = kernel_size
self.dilation_cycle_length = layers // stacks
self.skip_channels = skip_channels
self.input_projection = Conv1D(self.in_channels, self.residual_channels,
1)
self.diffusion_embedding = SinusoidalPosEmb(self.residual_channels)
dim = self.residual_channels
self.mlp = nn.Sequential(
Linear(dim, dim * 4), nn.Mish(), Linear(dim * 4, dim))
self.residual_layers = nn.LayerList([
ResidualBlock(
encoder_hidden=self.aux_channels,
residual_channels=self.residual_channels,
gate_channels=self.gate_channels,
kernel_size=self.kernel_size,
dilation=2**(i % self.dilation_cycle_length))
for i in range(self.layers)
])
self.skip_projection = Conv1D(self.residual_channels,
self.skip_channels, 1)
self.output_projection = Conv1D(self.residual_channels,
self.out_channels, 1)
zeros_(self.output_projection.weight)
def forward(
self,
spec: paddle.Tensor,
diffusion_step: paddle.Tensor,
cond: paddle.Tensor, ):
"""Calculate forward propagation.
Args:
spec (Tensor(float32)): The input mel-spectrogram. (B, n_mel, T)
diffusion_step (Tensor(int64)): The timestep input (adding noise step). (B,)
cond (Tensor(float32)): The auxiliary input (e.g. fastspeech2 encoder output). (B, D_enc_out, T)
Returns:
x (Tensor(float32)): pred noise (B, n_mel, T)
"""
x = spec
x = self.input_projection(x) # x [B, residual_channel, T]
x = F.relu(x)
diffusion_step = self.diffusion_embedding(diffusion_step)
diffusion_step = self.mlp(diffusion_step)
skip = []
for layer_id, layer in enumerate(self.residual_layers):
x, skip_connection = layer(
x=x,
diffusion_step=diffusion_step,
cond=cond, )
skip.append(skip_connection)
x = paddle.sum(
paddle.stack(skip), axis=0) / math.sqrt(len(self.residual_layers))
x = self.skip_projection(x)
x = F.relu(x)
x = self.output_projection(x) # [B, 80, T]
return x

@ -17,6 +17,7 @@ from typing import Callable
from typing import Optional
from typing import Tuple
import numpy as np
import paddle
import ppdiffusers
from paddle import nn
@ -27,170 +28,6 @@ from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock
class WaveNetDenoiser(nn.Layer):
"""A Mel-Spectrogram Denoiser modified from WaveNet
Args:
in_channels (int, optional):
Number of channels of the input mel-spectrogram, by default 80
out_channels (int, optional):
Number of channels of the output mel-spectrogram, by default 80
kernel_size (int, optional):
Kernel size of the residual blocks inside, by default 3
layers (int, optional):
Number of residual blocks inside, by default 20
stacks (int, optional):
The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional):
Residual channel of the residual blocks, by default 256
gate_channels (int, optional):
Gate channel of the residual blocks, by default 512
skip_channels (int, optional):
Skip channel of the residual blocks, by default 256
aux_channels (int, optional):
Auxiliary channel of the residual blocks, by default 256
dropout (float, optional):
Dropout of the residual blocks, by default 0.
bias (bool, optional):
Whether to use bias in residual blocks, by default True
use_weight_norm (bool, optional):
Whether to use weight norm in all convolutions, by default False
"""
def __init__(
self,
in_channels: int=80,
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_normal", ):
super().__init__()
# initialize parameters
initialize(self, init_type)
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
assert layers % stacks == 0
layers_per_stack = layers // stacks
self.first_t_emb = nn.Sequential(
Timesteps(
residual_channels,
flip_sin_to_cos=False,
downscale_freq_shift=1),
nn.Linear(residual_channels, residual_channels * 4),
nn.Mish(), nn.Linear(residual_channels * 4, residual_channels))
self.t_emb_layers = nn.LayerList([
nn.Linear(residual_channels, residual_channels)
for _ in range(layers)
])
self.first_conv = nn.Conv1D(
in_channels, residual_channels, 1, bias_attr=True)
self.first_act = nn.ReLU()
self.conv_layers = nn.LayerList()
for layer in range(layers):
dilation = 2**(layer % layers_per_stack)
conv = WaveNetResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias)
self.conv_layers.append(conv)
final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
nn.initializer.Constant(0.0)(final_conv.weight)
self.last_conv_layers = nn.Sequential(nn.ReLU(),
nn.Conv1D(
skip_channels,
skip_channels,
1,
bias_attr=True),
nn.ReLU(), final_conv)
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x, t, c):
"""Denoise mel-spectrogram.
Args:
x(Tensor):
Shape (N, C_in, T), The input mel-spectrogram.
t(Tensor):
Shape (N), The timestep input.
c(Tensor):
Shape (N, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output).
Returns:
Tensor: Shape (N, C_out, T), the denoised mel-spectrogram.
"""
assert c.shape[-1] == x.shape[-1]
if t.shape[0] != x.shape[0]:
t = t.tile([x.shape[0]])
t_emb = self.first_t_emb(t)
t_embs = [
t_emb_layer(t_emb)[..., None] for t_emb_layer in self.t_emb_layers
]
x = self.first_conv(x)
x = self.first_act(x)
skips = 0
for f, t in zip(self.conv_layers, t_embs):
x = x + t
x, s = f(x, c)
skips += s
skips *= math.sqrt(1.0 / len(self.conv_layers))
x = self.last_conv_layers(skips)
return x
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
class GaussianDiffusion(nn.Layer):
"""Common Gaussian Diffusion Denoising Model Module
@ -207,6 +44,13 @@ class GaussianDiffusion(nn.Layer):
beta schedule parameter for the scheduler, by default 'squaredcos_cap_v2' (cosine schedule).
num_max_timesteps (int, optional):
The max timestep transition from real to noise, by default None.
stretch (bool, optional):
Whether to stretch before diffusion, by defalut True.
min_values: (paddle.Tensor):
The minimum value of the feature to stretch.
max_values: (paddle.Tensor):
The maximum value of the feature to stretch.
Examples:
>>> import paddle
@ -294,13 +138,17 @@ class GaussianDiffusion(nn.Layer):
"""
def __init__(self,
denoiser: nn.Layer,
num_train_timesteps: Optional[int]=1000,
beta_start: Optional[float]=0.0001,
beta_end: Optional[float]=0.02,
beta_schedule: Optional[str]="squaredcos_cap_v2",
num_max_timesteps: Optional[int]=None):
def __init__(
self,
denoiser: nn.Layer,
num_train_timesteps: Optional[int]=1000,
beta_start: Optional[float]=0.0001,
beta_end: Optional[float]=0.02,
beta_schedule: Optional[str]="squaredcos_cap_v2",
num_max_timesteps: Optional[int]=None,
stretch: bool=True,
min_values: paddle.Tensor=None,
max_values: paddle.Tensor=None, ):
super().__init__()
self.num_train_timesteps = num_train_timesteps
@ -315,6 +163,22 @@ class GaussianDiffusion(nn.Layer):
beta_end=beta_end,
beta_schedule=beta_schedule)
self.num_max_timesteps = num_max_timesteps
self.stretch = stretch
self.min_values = min_values
self.max_values = max_values
def norm_spec(self, x):
"""
Linearly map x to [-1, 1]
Args:
x: [B, T, N]
"""
return (x - self.min_values) / (self.max_values - self.min_values
) * 2 - 1
def denorm_spec(self, x):
return (x + 1) / 2 * (self.max_values - self.min_values
) + self.min_values
def forward(self, x: paddle.Tensor, cond: Optional[paddle.Tensor]=None
) -> Tuple[paddle.Tensor, paddle.Tensor]:
@ -333,6 +197,11 @@ class GaussianDiffusion(nn.Layer):
The noises which is added to the input.
"""
if self.stretch:
x = x.transpose((0, 2, 1))
x = self.norm_spec(x)
x = x.transpose((0, 2, 1))
noise_scheduler = self.noise_scheduler
# Sample noise that we'll add to the mel-spectrograms
@ -360,7 +229,7 @@ class GaussianDiffusion(nn.Layer):
num_inference_steps: Optional[int]=1000,
strength: Optional[float]=None,
scheduler_type: Optional[str]="ddpm",
clip_noise: Optional[bool]=True,
clip_noise: Optional[bool]=False,
clip_noise_range: Optional[Tuple[float, float]]=(-1, 1),
callback: Optional[Callable[[int, int, int, paddle.Tensor],
None]]=None,
@ -369,9 +238,9 @@ class GaussianDiffusion(nn.Layer):
Args:
noise (Tensor):
The input tensor as a starting point for denoising.
The input tensor as a starting point for denoising.
cond (Tensor, optional):
Conditional input for compute noises.
Conditional input for compute noises. (N, C_aux, T)
ref_x (Tensor, optional):
The real output for the denoising process to refer.
num_inference_steps (int, optional):
@ -382,6 +251,7 @@ class GaussianDiffusion(nn.Layer):
scheduler_type (str, optional):
Noise scheduler for generate noises.
Choose a great scheduler can skip many denoising step, by default 'ddpm'.
only support 'ddpm' now !
clip_noise (bool, optional):
Whether to clip each denoised output, by default True.
clip_noise_range (tuple, optional):
@ -425,48 +295,33 @@ class GaussianDiffusion(nn.Layer):
# set timesteps
scheduler.set_timesteps(num_inference_steps)
# prepare first noise variables
noisy_input = noise
timesteps = scheduler.timesteps
if ref_x is not None:
init_timestep = None
if strength is None or strength < 0. or strength > 1.:
strength = None
if self.num_max_timesteps is not None:
strength = self.num_max_timesteps / self.num_train_timesteps
if strength is not None:
# get the original timestep using init_timestep
init_timestep = min(
int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = scheduler.timesteps[t_start:]
num_inference_steps = num_inference_steps - t_start
noisy_input = scheduler.add_noise(
ref_x, noise, timesteps[:1].tile([noise.shape[0]]))
# denoising loop
if self.stretch and ref_x is not None:
ref_x = ref_x.transpose((0, 2, 1))
ref_x = self.norm_spec(ref_x)
ref_x = ref_x.transpose((0, 2, 1))
# for ddpm
timesteps = paddle.to_tensor(
np.flipud(np.arange(num_inference_steps)))
noisy_input = scheduler.add_noise(ref_x, noise, timesteps[0])
denoised_output = noisy_input
if clip_noise:
n_min, n_max = clip_noise_range
denoised_output = paddle.clip(denoised_output, n_min, n_max)
num_warmup_steps = len(
timesteps) - num_inference_steps * scheduler.order
for i, t in enumerate(timesteps):
denoised_output = scheduler.scale_model_input(denoised_output, t)
# predict the noise residual
noise_pred = self.denoiser(denoised_output, t, cond)
# compute the previous noisy sample x_t -> x_t-1
denoised_output = scheduler.step(noise_pred, t,
denoised_output).prev_sample
if clip_noise:
denoised_output = paddle.clip(denoised_output, n_min, n_max)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
(i + 1) % scheduler.order == 0):
if callback is not None and i % callback_steps == 0:
callback(i, t, len(timesteps), denoised_output)
if self.stretch:
denoised_output = denoised_output.transpose((0, 2, 1))
denoised_output = self.denorm_spec(denoised_output)
denoised_output = denoised_output.transpose((0, 2, 1))
return denoised_output

@ -38,11 +38,9 @@ def masked_fill(xs: paddle.Tensor,
value: Union[float, int]):
# comment following line for converting dygraph to static graph.
# assert is_broadcastable(xs.shape, mask.shape) is True
# bshape = paddle.broadcast_shape(xs.shape, mask.shape)
bshape = broadcast_shape(xs.shape, mask.shape)
mask.stop_gradient = True
mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value
mask = mask.cast(dtype=paddle.bool)
xs = paddle.where(mask, trues, xs)

@ -96,7 +96,7 @@ class VariancePredictor(nn.Layer):
xs = f(xs)
# (B, Tmax, 1)
xs = self.linear(xs.transpose([0, 2, 1]))
if x_masks is not None:
xs = masked_fill(xs, x_masks, 0.0)
return xs

@ -15,6 +15,7 @@
from typing import List
from typing import Union
import paddle
from paddle import nn
from paddlespeech.t2s.modules.activation import get_activation
@ -390,7 +391,13 @@ class TransformerEncoder(BaseEncoder):
padding_idx=padding_idx,
encoder_type="transformer")
def forward(self, xs, masks):
def forward(self,
xs: paddle.Tensor,
masks: paddle.Tensor,
note_emb: paddle.Tensor=None,
note_dur_emb: paddle.Tensor=None,
is_slur_emb: paddle.Tensor=None,
scale: int=16):
"""Encoder input sequence.
Args:
@ -398,6 +405,12 @@ class TransformerEncoder(BaseEncoder):
Input tensor (#batch, time, idim).
masks(Tensor):
Mask tensor (#batch, 1, time).
note_emb(Tensor):
Input tensor (#batch, time, attention_dim).
note_dur_emb(Tensor):
Input tensor (#batch, time, attention_dim).
is_slur_emb(Tensor):
Input tensor (#batch, time, attention_dim).
Returns:
Tensor:
@ -406,6 +419,8 @@ class TransformerEncoder(BaseEncoder):
Mask tensor (#batch, 1, time).
"""
xs = self.embed(xs)
if note_emb is not None:
xs = scale * xs + note_emb + note_dur_emb + is_slur_emb
xs, masks = self.encoders(xs, masks)
if self.normalize_before:
xs = self.after_norm(xs)

@ -0,0 +1,191 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable
from typing import Optional
from typing import Tuple
import numpy as np
import paddle
import ppdiffusers
from paddle import nn
from ppdiffusers.models.embeddings import Timesteps
from ppdiffusers.schedulers import DDPMScheduler
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.residual_block import WaveNetResidualBlock
class WaveNetDenoiser(nn.Layer):
"""A Mel-Spectrogram Denoiser modified from WaveNet
Args:
in_channels (int, optional):
Number of channels of the input mel-spectrogram, by default 80
out_channels (int, optional):
Number of channels of the output mel-spectrogram, by default 80
kernel_size (int, optional):
Kernel size of the residual blocks inside, by default 3
layers (int, optional):
Number of residual blocks inside, by default 20
stacks (int, optional):
The number of groups to split the residual blocks into, by default 5
Within each group, the dilation of the residual block grows exponentially.
residual_channels (int, optional):
Residual channel of the residual blocks, by default 256
gate_channels (int, optional):
Gate channel of the residual blocks, by default 512
skip_channels (int, optional):
Skip channel of the residual blocks, by default 256
aux_channels (int, optional):
Auxiliary channel of the residual blocks, by default 256
dropout (float, optional):
Dropout of the residual blocks, by default 0.
bias (bool, optional):
Whether to use bias in residual blocks, by default True
use_weight_norm (bool, optional):
Whether to use weight norm in all convolutions, by default False
"""
def __init__(
self,
in_channels: int=80,
out_channels: int=80,
kernel_size: int=3,
layers: int=20,
stacks: int=5,
residual_channels: int=256,
gate_channels: int=512,
skip_channels: int=256,
aux_channels: int=256,
dropout: float=0.,
bias: bool=True,
use_weight_norm: bool=False,
init_type: str="kaiming_normal", ):
super().__init__()
# initialize parameters
initialize(self, init_type)
self.in_channels = in_channels
self.out_channels = out_channels
self.aux_channels = aux_channels
self.layers = layers
self.stacks = stacks
self.kernel_size = kernel_size
assert layers % stacks == 0
layers_per_stack = layers // stacks
self.first_t_emb = nn.Sequential(
Timesteps(
residual_channels,
flip_sin_to_cos=False,
downscale_freq_shift=1),
nn.Linear(residual_channels, residual_channels * 4),
nn.Mish(), nn.Linear(residual_channels * 4, residual_channels))
self.t_emb_layers = nn.LayerList([
nn.Linear(residual_channels, residual_channels)
for _ in range(layers)
])
self.first_conv = nn.Conv1D(
in_channels, residual_channels, 1, bias_attr=True)
self.first_act = nn.ReLU()
self.conv_layers = nn.LayerList()
for layer in range(layers):
dilation = 2**(layer % layers_per_stack)
conv = WaveNetResidualBlock(
kernel_size=kernel_size,
residual_channels=residual_channels,
gate_channels=gate_channels,
skip_channels=skip_channels,
aux_channels=aux_channels,
dilation=dilation,
dropout=dropout,
bias=bias)
self.conv_layers.append(conv)
final_conv = nn.Conv1D(skip_channels, out_channels, 1, bias_attr=True)
nn.initializer.Constant(0.0)(final_conv.weight)
self.last_conv_layers = nn.Sequential(nn.ReLU(),
nn.Conv1D(
skip_channels,
skip_channels,
1,
bias_attr=True),
nn.ReLU(), final_conv)
if use_weight_norm:
self.apply_weight_norm()
def forward(self, x: paddle.Tensor, t: paddle.Tensor, c: paddle.Tensor):
"""Denoise mel-spectrogram.
Args:
x(Tensor):
Shape (B, C_in, T), The input mel-spectrogram.
t(Tensor):
Shape (B), The timestep input.
c(Tensor):
Shape (B, C_aux, T'). The auxiliary input (e.g. fastspeech2 encoder output).
Returns:
Tensor: Shape (B, C_out, T), the pred noise.
"""
assert c.shape[-1] == x.shape[-1]
if t.shape[0] != x.shape[0]:
t = t.tile([x.shape[0]])
t_emb = self.first_t_emb(t)
t_embs = [
t_emb_layer(t_emb)[..., None] for t_emb_layer in self.t_emb_layers
]
x = self.first_conv(x)
x = self.first_act(x)
skips = 0
for f, t in zip(self.conv_layers, t_embs):
x = x + t
x, s = f(x, c)
skips += s
skips *= math.sqrt(1.0 / len(self.conv_layers))
x = self.last_conv_layers(skips)
return x
def apply_weight_norm(self):
"""Recursively apply weight normalization to all the Convolution layers
in the sublayers.
"""
def _apply_weight_norm(layer):
if isinstance(layer, (nn.Conv1D, nn.Conv2D)):
nn.utils.weight_norm(layer)
self.apply(_apply_weight_norm)
def remove_weight_norm(self):
"""Recursively remove weight normalization from all the Convolution
layers in the sublayers.
"""
def _remove_weight_norm(layer):
try:
nn.utils.remove_weight_norm(layer)
except ValueError:
pass
self.apply(_remove_weight_norm)
Loading…
Cancel
Save