diff --git a/examples/opencpop/README.md b/examples/opencpop/README.md new file mode 100644 index 000000000..5a574dc80 --- /dev/null +++ b/examples/opencpop/README.md @@ -0,0 +1,6 @@ + +# Opencpop + +* svs1 - DiffSinger +* voc1 - Parallel WaveGAN +* voc5 - HiFiGAN diff --git a/examples/opencpop/voc5/conf/default.yaml b/examples/opencpop/voc5/conf/default.yaml new file mode 100644 index 000000000..10449f860 --- /dev/null +++ b/examples/opencpop/voc5/conf/default.yaml @@ -0,0 +1,167 @@ +# This is the configuration file for CSMSC dataset. +# This configuration is based on HiFiGAN V1, which is an official configuration. +# But I found that the optimizer setting does not work well with my implementation. +# So I changed optimizer settings as follows: +# - AdamW -> Adam +# - betas: [0.8, 0.99] -> betas: [0.5, 0.9] +# - Scheduler: ExponentialLR -> MultiStepLR +# To match the shift size difference, the upsample scales is also modified from the original 256 shift setting. + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 512 # FFT size (samples). +n_shift: 128 # Hop size (samples). 12.5ms +win_length: 512 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 12000 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 1 # Number of output channels. + channels: 512 # Number of initial channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + upsample_scales: [8, 4, 2, 2] # Upsampling scales. + upsample_kernel_sizes: [16, 8, 4, 4] # Kernel size for upsampling layers. + resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks. + resblock_dilations: # Dilations for residual blocks. + - [1, 3, 5] + - [1, 3, 5] + - [1, 3, 5] + use_additional_convs: True # Whether to use additional conv layer in residual blocks. + bias: True # Whether to use bias parameter in conv. + nonlinear_activation: "leakyrelu" # Nonlinear activation type. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + scales: 3 # Number of multi-scale discriminator. + scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator. + scale_downsample_pooling_params: + kernel_size: 4 # Pooling kernel size. + stride: 2 # Pooling stride. + padding: 2 # Padding size. + scale_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. + channels: 128 # Initial number of channels. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + max_groups: 16 # Maximum number of groups in downsampling conv layers. + bias: True + downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: + negative_slope: 0.1 + follow_official_norm: True # Whether to follow the official norm setting. + periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. + period_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [5, 3] # List of kernel sizes. + channels: 32 # Initial number of channels. + downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + bias: True # Whether to use bias parameter in conv layer." + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + use_spectral_norm: False # Whether to apply spectral normalization. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: False # Whether to use multi-resolution STFT loss. +use_mel_loss: True # Whether to use Mel-spectrogram loss. +mel_loss_params: + fs: 24000 + fft_size: 512 + hop_size: 128 + win_length: 512 + window: "hann" + num_mels: 80 + fmin: 30 + fmax: 12000 + log_base: null +generator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +discriminator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +use_feat_match_loss: True +feat_match_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. + average_by_layers: False # Whether to average loss by #layers in each discriminator. + include_final_outputs: False # Whether to include final outputs in feat match loss calculation. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_aux: 45.0 # Loss balancing coefficient for STFT loss. +lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss. +lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 16 # Batch size. +batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 1 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 2.0e-4 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +generator_grad_norm: -1 # Generator's gradient norm. +discriminator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 2.0e-4 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_grad_norm: -1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +generator_train_start_steps: 1 # Number of steps to start to train discriminator. +discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. +train_max_steps: 2500000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 4 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/examples/opencpop/voc5/conf/finetune.yaml b/examples/opencpop/voc5/conf/finetune.yaml new file mode 100644 index 000000000..0022a67aa --- /dev/null +++ b/examples/opencpop/voc5/conf/finetune.yaml @@ -0,0 +1,168 @@ +# This is the configuration file for CSMSC dataset. +# This configuration is based on HiFiGAN V1, which is an official configuration. +# But I found that the optimizer setting does not work well with my implementation. +# So I changed optimizer settings as follows: +# - AdamW -> Adam +# - betas: [0.8, 0.99] -> betas: [0.5, 0.9] +# - Scheduler: ExponentialLR -> MultiStepLR +# To match the shift size difference, the upsample scales is also modified from the original 256 shift setting. + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fs: 24000 # Sampling rate. +n_fft: 512 # FFT size (samples). +n_shift: 128 # Hop size (samples). 12.5ms +win_length: 512 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. (Hz) +fmax: 12000 # Maximum frequency in mel basis calculation. (Hz) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 80 # Number of input channels. + out_channels: 1 # Number of output channels. + channels: 512 # Number of initial channels. + kernel_size: 7 # Kernel size of initial and final conv layers. + upsample_scales: [8, 4, 2, 2] # Upsampling scales. + upsample_kernel_sizes: [16, 8, 4, 4] # Kernel size for upsampling layers. + resblock_kernel_sizes: [3, 7, 11] # Kernel size for residual blocks. + resblock_dilations: # Dilations for residual blocks. + - [1, 3, 5] + - [1, 3, 5] + - [1, 3, 5] + use_additional_convs: True # Whether to use additional conv layer in residual blocks. + bias: True # Whether to use bias parameter in conv. + nonlinear_activation: "leakyrelu" # Nonlinear activation type. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + scales: 3 # Number of multi-scale discriminator. + scale_downsample_pooling: "AvgPool1D" # Pooling operation for scale discriminator. + scale_downsample_pooling_params: + kernel_size: 4 # Pooling kernel size. + stride: 2 # Pooling stride. + padding: 2 # Padding size. + scale_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [15, 41, 5, 3] # List of kernel sizes. + channels: 128 # Initial number of channels. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + max_groups: 16 # Maximum number of groups in downsampling conv layers. + bias: True + downsample_scales: [4, 4, 4, 4, 1] # Downsampling scales. + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: + negative_slope: 0.1 + follow_official_norm: True # Whether to follow the official norm setting. + periods: [2, 3, 5, 7, 11] # List of period for multi-period discriminator. + period_discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_sizes: [5, 3] # List of kernel sizes. + channels: 32 # Initial number of channels. + downsample_scales: [3, 3, 3, 3, 1] # Downsampling scales. + max_downsample_channels: 1024 # Maximum number of channels in downsampling conv layers. + bias: True # Whether to use bias parameter in conv layer." + nonlinear_activation: "leakyrelu" # Nonlinear activation. + nonlinear_activation_params: # Nonlinear activation paramters. + negative_slope: 0.1 + use_weight_norm: True # Whether to apply weight normalization. + use_spectral_norm: False # Whether to apply spectral normalization. + + +########################################################### +# STFT LOSS SETTING # +########################################################### +use_stft_loss: False # Whether to use multi-resolution STFT loss. +use_mel_loss: True # Whether to use Mel-spectrogram loss. +mel_loss_params: + fs: 24000 + fft_size: 512 + hop_size: 128 + win_length: 512 + window: "hann" + num_mels: 80 + fmin: 30 + fmax: 12000 + log_base: null +generator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +discriminator_adv_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. +use_feat_match_loss: True +feat_match_loss_params: + average_by_discriminators: False # Whether to average loss by #discriminators. + average_by_layers: False # Whether to average loss by #layers in each discriminator. + include_final_outputs: False # Whether to include final outputs in feat match loss calculation. + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_aux: 45.0 # Loss balancing coefficient for STFT loss. +lambda_adv: 1.0 # Loss balancing coefficient for adversarial loss. +lambda_feat_match: 2.0 # Loss balancing coefficient for feat match loss.. + +########################################################### +# DATA LOADER SETTING # +########################################################### +#batch_size: 16 # Batch size. +batch_size: 1 # Batch size. +batch_max_steps: 8400 # Length of each audio in batch. Make sure dividable by hop_size. +num_workers: 1 # Number of workers in DataLoader. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 2.0e-4 # Generator's learning rate. + gamma: 0.5 # Generator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +generator_grad_norm: -1 # Generator's gradient norm. +discriminator_optimizer_params: + beta1: 0.5 + beta2: 0.9 + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 2.0e-4 # Discriminator's learning rate. + gamma: 0.5 # Discriminator's scheduler gamma. + milestones: # At each milestone, lr will be multiplied by gamma. + - 200000 + - 400000 + - 600000 + - 800000 +discriminator_grad_norm: -1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +generator_train_start_steps: 1 # Number of steps to start to train discriminator. +discriminator_train_start_steps: 0 # Number of steps to start to train discriminator. +train_max_steps: 2600000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + +########################################################### +# OTHER SETTING # +########################################################### +num_snapshots: 4 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random diff --git a/examples/opencpop/voc5/finetune.sh b/examples/opencpop/voc5/finetune.sh new file mode 100755 index 000000000..76f363295 --- /dev/null +++ b/examples/opencpop/voc5/finetune.sh @@ -0,0 +1,74 @@ +#!/bin/bash + +source path.sh + +gpus=0 +stage=0 +stop_stage=100 + +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${MAIN_ROOT}/paddlespeech/t2s/exps/diffsinger/gen_gta_mel.py \ + --diffsinger-config=diffsinger_opencpop_ckpt_1.4.0/default.yaml \ + --diffsinger-checkpoint=diffsinger_opencpop_ckpt_1.4.0/snapshot_iter_160000.pdz \ + --diffsinger-stat=diffsinger_opencpop_ckpt_1.4.0/speech_stats.npy \ + --diffsinger-stretch=diffsinger_opencpop_ckpt_1.4.0/speech_stretchs.npy \ + --dur-file=~/datasets/Opencpop/segments/transcriptions.txt \ + --output-dir=dump_finetune \ + --phones-dict=diffsinger_opencpop_ckpt_1.4.0/phone_id_map.txt \ + --dataset=opencpop \ + --rootdir=~/datasets/Opencpop/segments/ +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + python3 ${MAIN_ROOT}/utils/link_wav.py \ + --old-dump-dir=dump \ + --dump-dir=dump_finetune +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + cp dump/train/feats_stats.npy dump_finetune/train/ +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize, dev and test should use train's stats + echo "Normalize ..." + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump_finetune/train/raw/metadata.jsonl \ + --dumpdir=dump_finetune/train/norm \ + --stats=dump_finetune/train/feats_stats.npy + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump_finetune/dev/raw/metadata.jsonl \ + --dumpdir=dump_finetune/dev/norm \ + --stats=dump_finetune/train/feats_stats.npy + + python3 ${BIN_DIR}/../normalize.py \ + --metadata=dump_finetune/test/raw/metadata.jsonl \ + --dumpdir=dump_finetune/test/norm \ + --stats=dump_finetune/train/feats_stats.npy +fi + +# create finetune env +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "create finetune env" + python3 local/prepare_env.py \ + --pretrained_model_dir=exp/default/checkpoints/ \ + --output_dir=exp/finetune/ +fi + +# finetune +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + CUDA_VISIBLE_DEVICES=${gpus} \ + FLAGS_cudnn_exhaustive_search=true \ + FLAGS_conv_workspace_size_limit=4000 \ + python ${BIN_DIR}/train.py \ + --train-metadata=dump_finetune/train/norm/metadata.jsonl \ + --dev-metadata=dump_finetune/dev/norm/metadata.jsonl \ + --config=conf/finetune.yaml \ + --output-dir=exp/finetune \ + --ngpu=1 +fi diff --git a/examples/opencpop/voc5/local/PTQ_static.sh b/examples/opencpop/voc5/local/PTQ_static.sh new file mode 120000 index 000000000..247ce5c74 --- /dev/null +++ b/examples/opencpop/voc5/local/PTQ_static.sh @@ -0,0 +1 @@ +../../../csmsc/voc1/local/PTQ_static.sh \ No newline at end of file diff --git a/examples/opencpop/voc5/local/dygraph_to_static.sh b/examples/opencpop/voc5/local/dygraph_to_static.sh new file mode 100755 index 000000000..65077661a --- /dev/null +++ b/examples/opencpop/voc5/local/dygraph_to_static.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/../../dygraph_to_static.py \ + --type=voc \ + --voc=hifigan_opencpop \ + --voc_config=${config_path} \ + --voc_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --voc_stat=dump/train/feats_stats.npy \ + --inference_dir=exp/default/inference/ diff --git a/examples/opencpop/voc5/local/prepare_env.py b/examples/opencpop/voc5/local/prepare_env.py new file mode 120000 index 000000000..be03c86b3 --- /dev/null +++ b/examples/opencpop/voc5/local/prepare_env.py @@ -0,0 +1 @@ +../../../other/tts_finetune/tts3/local/prepare_env.py \ No newline at end of file diff --git a/examples/opencpop/voc5/local/preprocess.sh b/examples/opencpop/voc5/local/preprocess.sh new file mode 120000 index 000000000..f0cb24de9 --- /dev/null +++ b/examples/opencpop/voc5/local/preprocess.sh @@ -0,0 +1 @@ +../../voc1/local/preprocess.sh \ No newline at end of file diff --git a/examples/opencpop/voc5/local/synthesize.sh b/examples/opencpop/voc5/local/synthesize.sh new file mode 120000 index 000000000..c887112c0 --- /dev/null +++ b/examples/opencpop/voc5/local/synthesize.sh @@ -0,0 +1 @@ +../../../csmsc/voc5/local/synthesize.sh \ No newline at end of file diff --git a/examples/opencpop/voc5/local/train.sh b/examples/opencpop/voc5/local/train.sh new file mode 120000 index 000000000..2942893d2 --- /dev/null +++ b/examples/opencpop/voc5/local/train.sh @@ -0,0 +1 @@ +../../../csmsc/voc1/local/train.sh \ No newline at end of file diff --git a/examples/opencpop/voc5/path.sh b/examples/opencpop/voc5/path.sh new file mode 120000 index 000000000..b67fe2b39 --- /dev/null +++ b/examples/opencpop/voc5/path.sh @@ -0,0 +1 @@ +../../csmsc/voc5/path.sh \ No newline at end of file diff --git a/examples/opencpop/voc5/run.sh b/examples/opencpop/voc5/run.sh new file mode 100755 index 000000000..290c90d25 --- /dev/null +++ b/examples/opencpop/voc5/run.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_2500000.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +# dygraph to static +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/dygraph_to_static.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +# PTQ_static +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/PTQ_static.sh ${train_output_path} hifigan_opencpop || exit -1 +fi diff --git a/paddlespeech/t2s/exps/PTQ_static.py b/paddlespeech/t2s/exps/PTQ_static.py index 8849abe58..a95786450 100644 --- a/paddlespeech/t2s/exps/PTQ_static.py +++ b/paddlespeech/t2s/exps/PTQ_static.py @@ -43,6 +43,7 @@ def parse_args(): 'hifigan_ljspeech', 'hifigan_vctk', 'pwgan_opencpop', + 'hifigan_opencpop', ], help='Choose model type of tts task.') diff --git a/paddlespeech/t2s/exps/diffsinger/gen_gta_mel.py b/paddlespeech/t2s/exps/diffsinger/gen_gta_mel.py new file mode 100644 index 000000000..519808f2a --- /dev/null +++ b/paddlespeech/t2s/exps/diffsinger/gen_gta_mel.py @@ -0,0 +1,240 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# generate mels using durations.txt +# for mb melgan finetune +import argparse +import os +from pathlib import Path + +import numpy as np +import paddle +import yaml +from tqdm import tqdm +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.preprocess_utils import get_sentences_svs +from paddlespeech.t2s.models.diffsinger import DiffSinger +from paddlespeech.t2s.models.diffsinger import DiffSingerInference +from paddlespeech.t2s.modules.normalizer import ZScore +from paddlespeech.t2s.utils import str2bool + + +def evaluate(args, diffsinger_config): + rootdir = Path(args.rootdir).expanduser() + assert rootdir.is_dir() + + # construct dataset for evaluation + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + phone_dict = {} + for phn, id in phn_id: + phone_dict[phn] = int(id) + + if args.speaker_dict: + with open(args.speaker_dict, 'rt') as f: + spk_id_list = [line.strip().split() for line in f.readlines()] + spk_num = len(spk_id_list) + else: + spk_num = None + + with open(args.diffsinger_stretch, "r") as f: + spec_min = np.load(args.diffsinger_stretch)[0] + spec_max = np.load(args.diffsinger_stretch)[1] + spec_min = paddle.to_tensor(spec_min) + spec_max = paddle.to_tensor(spec_max) + print("min and max spec done!") + + odim = diffsinger_config.n_mels + diffsinger_config["model"]["fastspeech2_params"]["spk_num"] = spk_num + model = DiffSinger( + spec_min=spec_min, + spec_max=spec_max, + idim=vocab_size, + odim=odim, + **diffsinger_config["model"], ) + + model.set_state_dict(paddle.load(args.diffsinger_checkpoint)["main_params"]) + model.eval() + + stat = np.load(args.diffsinger_stat) + mu, std = stat + mu = paddle.to_tensor(mu) + std = paddle.to_tensor(std) + diffsinger_normalizer = ZScore(mu, std) + + diffsinger_inference = DiffSingerInference(diffsinger_normalizer, model) + diffsinger_inference.eval() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + sentences, speaker_set = get_sentences_svs( + args.dur_file, + dataset=args.dataset, + sample_rate=diffsinger_config.fs, + n_shift=diffsinger_config.n_shift, ) + + if args.dataset == "opencpop": + wavdir = rootdir / "wavs" + # split data into 3 sections + train_file = rootdir / "train.txt" + train_wav_files = [] + with open(train_file, "r") as f_train: + for line in f_train.readlines(): + utt = line.split("|")[0] + wav_name = utt + ".wav" + wav_path = wavdir / wav_name + train_wav_files.append(wav_path) + + test_file = rootdir / "test.txt" + dev_wav_files = [] + test_wav_files = [] + num_dev = 106 + count = 0 + with open(test_file, "r") as f_test: + for line in f_test.readlines(): + count += 1 + utt = line.split("|")[0] + wav_name = utt + ".wav" + wav_path = wavdir / wav_name + if count > num_dev: + test_wav_files.append(wav_path) + else: + dev_wav_files.append(wav_path) + else: + print("dataset should in {opencpop} now!") + + train_wav_files = [ + os.path.basename(str(str_path)) for str_path in train_wav_files + ] + dev_wav_files = [ + os.path.basename(str(str_path)) for str_path in dev_wav_files + ] + test_wav_files = [ + os.path.basename(str(str_path)) for str_path in test_wav_files + ] + + for i, utt_id in enumerate(tqdm(sentences)): + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + note = sentences[utt_id][2] + note_dur = sentences[utt_id][3] + is_slur = sentences[utt_id][4] + speaker = sentences[utt_id][-1] + + phone_ids = [phone_dict[phn] for phn in phones] + phone_ids = paddle.to_tensor(np.array(phone_ids)) + + if args.speaker_dict: + speaker_id = int( + [item[1] for item in spk_id_list if speaker == item[0]][0]) + speaker_id = paddle.to_tensor(speaker_id) + else: + speaker_id = None + + durations = paddle.to_tensor(np.array(durations)) + note = paddle.to_tensor(np.array(note)) + note_dur = paddle.to_tensor(np.array(note_dur)) + is_slur = paddle.to_tensor(np.array(is_slur)) + # 生成的和真实的可能有 1, 2 帧的差距,但是 batch_fn 会修复 + # split data into 3 sections + + wav_path = utt_id + ".wav" + + if wav_path in train_wav_files: + sub_output_dir = output_dir / ("train/raw") + elif wav_path in dev_wav_files: + sub_output_dir = output_dir / ("dev/raw") + elif wav_path in test_wav_files: + sub_output_dir = output_dir / ("test/raw") + + sub_output_dir.mkdir(parents=True, exist_ok=True) + + with paddle.no_grad(): + mel = diffsinger_inference( + text=phone_ids, + note=note, + note_dur=note_dur, + is_slur=is_slur, + get_mel_fs2=False) + np.save(sub_output_dir / (utt_id + "_feats.npy"), mel) + + +def main(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser( + description="Generate mel with diffsinger.") + parser.add_argument( + "--dataset", + default="opencpop", + type=str, + help="name of dataset, should in {opencpop} now") + parser.add_argument( + "--rootdir", default=None, type=str, help="directory to dataset.") + parser.add_argument( + "--diffsinger-config", type=str, help="diffsinger config file.") + parser.add_argument( + "--diffsinger-checkpoint", + type=str, + help="diffsinger checkpoint to load.") + parser.add_argument( + "--diffsinger-stat", + type=str, + help="mean and standard deviation used to normalize spectrogram when training diffsinger." + ) + parser.add_argument( + "--diffsinger-stretch", + type=str, + help="min and max mel used to stretch before training diffusion.") + + parser.add_argument( + "--phones-dict", + type=str, + default="phone_id_map.txt", + help="phone vocabulary file.") + + parser.add_argument( + "--speaker-dict", type=str, default=None, help="speaker id map file.") + + parser.add_argument( + "--dur-file", default=None, type=str, help="path to durations.txt.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + + args = parser.parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + with open(args.diffsinger_config) as f: + diffsinger_config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(diffsinger_config) + + evaluate(args, diffsinger_config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/dygraph_to_static.py b/paddlespeech/t2s/exps/dygraph_to_static.py index 3e6e94857..5e15ca4ca 100644 --- a/paddlespeech/t2s/exps/dygraph_to_static.py +++ b/paddlespeech/t2s/exps/dygraph_to_static.py @@ -132,6 +132,7 @@ def parse_args(): 'pwgan_male', 'hifigan_male', 'pwgan_opencpop', + 'hifigan_opencpop', ], help='Choose vocoder type of tts task.') parser.add_argument(