From 8db06444c5f33c2c658de41bca4a4d165d59e8cf Mon Sep 17 00:00:00 2001 From: TianYuan Date: Fri, 20 May 2022 05:58:21 +0000 Subject: [PATCH] add vits trainer and synthesize --- examples/csmsc/tts3/conf/default.yaml | 4 +- examples/csmsc/vits/conf/default.yaml | 183 +++++++++ examples/csmsc/vits/local/preprocess.sh | 64 ++++ examples/csmsc/vits/local/synthesize.sh | 18 + examples/csmsc/vits/local/synthesize_e2e.sh | 18 + examples/csmsc/vits/local/train.sh | 12 + examples/csmsc/vits/run.sh | 36 ++ examples/csmsc/voc5/README.md | 2 +- paddlespeech/t2s/datasets/am_batch_fn.py | 42 +++ paddlespeech/t2s/datasets/batch.py | 1 - paddlespeech/t2s/datasets/get_feats.py | 2 - .../t2s/exps/fastspeech2/preprocess.py | 58 +-- .../t2s/exps/gan_vocoder/preprocess.py | 46 ++- .../t2s/exps/speedyspeech/preprocess.py | 40 +- paddlespeech/t2s/exps/syn_utils.py | 2 +- paddlespeech/t2s/exps/synthesize.py | 10 +- paddlespeech/t2s/exps/synthesize_e2e.py | 6 +- paddlespeech/t2s/exps/synthesize_streaming.py | 6 +- paddlespeech/t2s/exps/tacotron2/preprocess.py | 42 ++- .../t2s/exps/transformer_tts/preprocess.py | 39 +- paddlespeech/t2s/exps/vits/normalize.py | 154 +++++++- paddlespeech/t2s/exps/vits/preprocess.py | 335 +++++++++++++++++ paddlespeech/t2s/exps/vits/synthesize.py | 104 ++++++ paddlespeech/t2s/exps/vits/synthesize_e2e.py | 146 ++++++++ paddlespeech/t2s/exps/vits/train.py | 248 ++++++++++++ paddlespeech/t2s/exps/voice_cloning.py | 6 +- paddlespeech/t2s/models/__init__.py | 1 + .../parallel_wavegan_updater.py | 2 +- paddlespeech/t2s/models/vits/__init__.py | 2 + paddlespeech/t2s/models/vits/generator.py | 1 - paddlespeech/t2s/models/vits/vits.py | 197 +--------- paddlespeech/t2s/models/vits/vits_updater.py | 353 ++++++++++++++++++ paddlespeech/t2s/training/optimizer.py | 8 + paddlespeech/utils/__init__.py | 13 + paddlespeech/utils/dynamic_import.py | 38 ++ 35 files changed, 1939 insertions(+), 300 deletions(-) create mode 100755 examples/csmsc/vits/local/synthesize_e2e.sh create mode 100644 paddlespeech/t2s/exps/vits/synthesize_e2e.py create mode 100644 paddlespeech/utils/__init__.py create mode 100644 paddlespeech/utils/dynamic_import.py diff --git a/examples/csmsc/tts3/conf/default.yaml b/examples/csmsc/tts3/conf/default.yaml index 2c2a1ea1..08b6f75b 100644 --- a/examples/csmsc/tts3/conf/default.yaml +++ b/examples/csmsc/tts3/conf/default.yaml @@ -86,8 +86,8 @@ updater: # OPTIMIZER SETTING # ########################################################### optimizer: - optim: adam # optimizer type - learning_rate: 0.001 # learning rate + optim: adam # optimizer type + learning_rate: 0.001 # learning rate ########################################################### # TRAINING SETTING # diff --git a/examples/csmsc/vits/conf/default.yaml b/examples/csmsc/vits/conf/default.yaml index e69de29b..47af780d 100644 --- a/examples/csmsc/vits/conf/default.yaml +++ b/examples/csmsc/vits/conf/default.yaml @@ -0,0 +1,183 @@ +# This configuration tested on 4 GPUs (V100) with 32GB GPU +# memory. It takes around 2 weeks to finish the training +# but 100k iters model should generate reasonable results. +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### + +fs: 22050 # sr +n_fft: 1024 # FFT size (samples). +n_shift: 256 # Hop size (samples). 12.5ms +win_length: null # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + + +########################################################## +# TTS MODEL SETTING # +########################################################## +model: + # generator related + generator_type: vits_generator + generator_params: + hidden_channels: 192 + spks: -1 + global_channels: -1 + segment_size: 32 + text_encoder_attention_heads: 2 + text_encoder_ffn_expand: 4 + text_encoder_blocks: 6 + text_encoder_positionwise_layer_type: "conv1d" + text_encoder_positionwise_conv_kernel_size: 3 + text_encoder_positional_encoding_layer_type: "rel_pos" + text_encoder_self_attention_layer_type: "rel_selfattn" + text_encoder_activation_type: "swish" + text_encoder_normalize_before: True + text_encoder_dropout_rate: 0.1 + text_encoder_positional_dropout_rate: 0.0 + text_encoder_attention_dropout_rate: 0.1 + use_macaron_style_in_text_encoder: True + use_conformer_conv_in_text_encoder: False + text_encoder_conformer_kernel_size: -1 + decoder_kernel_size: 7 + decoder_channels: 512 + decoder_upsample_scales: [8, 8, 2, 2] + decoder_upsample_kernel_sizes: [16, 16, 4, 4] + decoder_resblock_kernel_sizes: [3, 7, 11] + decoder_resblock_dilations: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] + use_weight_norm_in_decoder: True + posterior_encoder_kernel_size: 5 + posterior_encoder_layers: 16 + posterior_encoder_stacks: 1 + posterior_encoder_base_dilation: 1 + posterior_encoder_dropout_rate: 0.0 + use_weight_norm_in_posterior_encoder: True + flow_flows: 4 + flow_kernel_size: 5 + flow_base_dilation: 1 + flow_layers: 4 + flow_dropout_rate: 0.0 + use_weight_norm_in_flow: True + use_only_mean_in_flow: True + stochastic_duration_predictor_kernel_size: 3 + stochastic_duration_predictor_dropout_rate: 0.5 + stochastic_duration_predictor_flows: 4 + stochastic_duration_predictor_dds_conv_layers: 3 + # discriminator related + discriminator_type: hifigan_multi_scale_multi_period_discriminator + discriminator_params: + scales: 1 + scale_downsample_pooling: "AvgPool1D" + scale_downsample_pooling_params: + kernel_size: 4 + stride: 2 + padding: 2 + scale_discriminator_params: + in_channels: 1 + out_channels: 1 + kernel_sizes: [15, 41, 5, 3] + channels: 128 + max_downsample_channels: 1024 + max_groups: 16 + bias: True + downsample_scales: [2, 2, 4, 4, 1] + nonlinear_activation: "leakyrelu" + nonlinear_activation_params: + negative_slope: 0.1 + use_weight_norm: True + use_spectral_norm: False + follow_official_norm: False + periods: [2, 3, 5, 7, 11] + period_discriminator_params: + in_channels: 1 + out_channels: 1 + kernel_sizes: [5, 3] + channels: 32 + downsample_scales: [3, 3, 3, 3, 1] + max_downsample_channels: 1024 + bias: True + nonlinear_activation: "leakyrelu" + nonlinear_activation_params: + negative_slope: 0.1 + use_weight_norm: True + use_spectral_norm: False + # others + sampling_rate: 22050 # needed in the inference for saving wav + cache_generator_outputs: True # whether to cache generator outputs in the training + +########################################################### +# LOSS SETTING # +########################################################### +# loss function related +generator_adv_loss_params: + average_by_discriminators: False # whether to average loss value by #discriminators + loss_type: mse # loss type, "mse" or "hinge" +discriminator_adv_loss_params: + average_by_discriminators: False # whether to average loss value by #discriminators + loss_type: mse # loss type, "mse" or "hinge" +feat_match_loss_params: + average_by_discriminators: False # whether to average loss value by #discriminators + average_by_layers: False # whether to average loss value by #layers of each discriminator + include_final_outputs: True # whether to include final outputs for loss calculation +mel_loss_params: + fs: 22050 # must be the same as the training data + fft_size: 1024 # fft points + hop_size: 256 # hop size + win_length: null # window length + window: hann # window type + num_mels: 80 # number of Mel basis + fmin: 0 # minimum frequency for Mel basis + fmax: null # maximum frequency for Mel basis + log_base: null # null represent natural log + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_adv: 1.0 # loss scaling coefficient for adversarial loss +lambda_mel: 45.0 # loss scaling coefficient for Mel loss +lambda_feat_match: 2.0 # loss scaling coefficient for feat match loss +lambda_dur: 1.0 # loss scaling coefficient for duration loss +lambda_kl: 1.0 # loss scaling coefficient for KL divergence loss +# others +sampling_rate: 22050 # needed in the inference for saving wav +cache_generator_outputs: True # whether to cache generator outputs in the training + + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 64 # Batch size. +num_workers: 4 # Number of workers in DataLoader. + +########################################################## +# OPTIMIZER & SCHEDULER SETTING # +########################################################## +# optimizer setting for generator +generator_optimizer_params: + beta1: 0.8 + beta2: 0.99 + epsilon: 1.0e-9 + weight_decay: 0.0 +generator_scheduler: exponential_decay +generator_scheduler_params: + learning_rate: 2.0e-4 + gamma: 0.999875 + +# optimizer setting for discriminator +discriminator_optimizer_params: + beta1: 0.8 + beta2: 0.99 + epsilon: 1.0e-9 + weight_decay: 0.0 +discriminator_scheduler: exponential_decay +discriminator_scheduler_params: + learning_rate: 2.0e-4 + gamma: 0.999875 +generator_first: False # whether to start updating generator first + +########################################################## +# OTHER TRAINING SETTING # +########################################################## +max_epoch: 1000 # number of epochs +num_snapshots: 10 # max number of snapshots to keep while training +seed: 777 # random seed number diff --git a/examples/csmsc/vits/local/preprocess.sh b/examples/csmsc/vits/local/preprocess.sh index e69de29b..1d3ae593 100755 --- a/examples/csmsc/vits/local/preprocess.sh +++ b/examples/csmsc/vits/local/preprocess.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./baker_alignment_tone \ + --output=durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=baker \ + --rootdir=~/datasets/BZNSYP/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="feats" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize and covert phone/speaker to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --feats-stats=dump/train/feats_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt \ + --skip-wav-copy + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --feats-stats=dump/train/feats_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt \ + --skip-wav-copy + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --feats-stats=dump/train/feats_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt \ + --skip-wav-copy +fi diff --git a/examples/csmsc/vits/local/synthesize.sh b/examples/csmsc/vits/local/synthesize.sh index e69de29b..c15d5f99 100755 --- a/examples/csmsc/vits/local/synthesize.sh +++ b/examples/csmsc/vits/local/synthesize.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +stage=0 +stop_stage=0 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --config=${config_path} \ + --ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --phones_dict=dump/phone_id_map.txt \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test +fi \ No newline at end of file diff --git a/examples/csmsc/vits/local/synthesize_e2e.sh b/examples/csmsc/vits/local/synthesize_e2e.sh new file mode 100755 index 00000000..edbb07bf --- /dev/null +++ b/examples/csmsc/vits/local/synthesize_e2e.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +stage=0 +stop_stage=0 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize_e2e.py \ + --config=${config_path} \ + --ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --phones_dict=dump/phone_id_map.txt \ + --output_dir=${train_output_path}/test_e2e \ + --text=${BIN_DIR}/../sentences.txt +fi diff --git a/examples/csmsc/vits/local/train.sh b/examples/csmsc/vits/local/train.sh index e69de29b..42fff26c 100755 --- a/examples/csmsc/vits/local/train.sh +++ b/examples/csmsc/vits/local/train.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=4 \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/csmsc/vits/run.sh b/examples/csmsc/vits/run.sh index e69de29b..80e56e7c 100755 --- a/examples/csmsc/vits/run.sh +++ b/examples/csmsc/vits/run.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_153.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # synthesize_e2e, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize_e2e.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/examples/csmsc/voc5/README.md b/examples/csmsc/voc5/README.md index 33e67616..94f93b48 100644 --- a/examples/csmsc/voc5/README.md +++ b/examples/csmsc/voc5/README.md @@ -130,7 +130,7 @@ HiFiGAN checkpoint contains files listed below. ```text hifigan_csmsc_ckpt_0.1.1 ├── default.yaml # default config used to train hifigan -├── feats_stats.npy # statistics used to normalize spectrogram when training hifigan +├── feats_stats.npy # statistics used to normalize spectrogram when training hifigan └── snapshot_iter_2500000.pdz # generator parameters of hifigan ``` diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 4e3ad3c1..0b278aba 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -293,3 +293,45 @@ def transformer_single_spk_batch_fn(examples): "speech_lengths": speech_lengths, } return batch + + +def vits_single_spk_batch_fn(examples): + """ + Returns: + Dict[str, Any]: + - text (Tensor): Text index tensor (B, T_text). + - text_lengths (Tensor): Text length tensor (B,). + - feats (Tensor): Feature tensor (B, T_feats, aux_channels). + - feats_lengths (Tensor): Feature length tensor (B,). + - speech (Tensor): Speech waveform tensor (B, T_wav). + + """ + # fields = ["text", "text_lengths", "feats", "feats_lengths", "speech"] + text = [np.array(item["text"], dtype=np.int64) for item in examples] + feats = [np.array(item["feats"], dtype=np.float32) for item in examples] + speech = [np.array(item["wave"], dtype=np.float32) for item in examples] + text_lengths = [ + np.array(item["text_lengths"], dtype=np.int64) for item in examples + ] + feats_lengths = [ + np.array(item["feats_lengths"], dtype=np.int64) for item in examples + ] + + text = batch_sequences(text) + feats = batch_sequences(feats) + speech = batch_sequences(speech) + + # convert each batch to paddle.Tensor + text = paddle.to_tensor(text) + feats = paddle.to_tensor(feats) + text_lengths = paddle.to_tensor(text_lengths) + feats_lengths = paddle.to_tensor(feats_lengths) + + batch = { + "text": text, + "text_lengths": text_lengths, + "feats": feats, + "feats_lengths": feats_lengths, + "speech": speech + } + return batch diff --git a/paddlespeech/t2s/datasets/batch.py b/paddlespeech/t2s/datasets/batch.py index 9d83bbe0..4f21d447 100644 --- a/paddlespeech/t2s/datasets/batch.py +++ b/paddlespeech/t2s/datasets/batch.py @@ -167,7 +167,6 @@ def batch_spec(minibatch, pad_value=0., time_major=False, dtype=np.float32): def batch_sequences(sequences, axis=0, pad_value=0): - # import pdb; pdb.set_trace() seq = sequences[0] ndim = seq.ndim if axis < 0: diff --git a/paddlespeech/t2s/datasets/get_feats.py b/paddlespeech/t2s/datasets/get_feats.py index e1ca0eeb..21458f15 100644 --- a/paddlespeech/t2s/datasets/get_feats.py +++ b/paddlespeech/t2s/datasets/get_feats.py @@ -171,7 +171,6 @@ class Pitch(): class Energy(): def __init__(self, - sr: int=24000, n_fft: int=2048, hop_length: int=300, win_length: int=None, @@ -179,7 +178,6 @@ class Energy(): center: bool=True, pad_mode: str="reflect"): - self.sr = sr self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length diff --git a/paddlespeech/t2s/exps/fastspeech2/preprocess.py b/paddlespeech/t2s/exps/fastspeech2/preprocess.py index db1842b2..55dc3808 100644 --- a/paddlespeech/t2s/exps/fastspeech2/preprocess.py +++ b/paddlespeech/t2s/exps/fastspeech2/preprocess.py @@ -144,10 +144,17 @@ def process_sentences(config, spk_emb_dir: Path=None): if nprocs == 1: results = [] - for fp in fps: - record = process_sentence(config, fp, sentences, output_dir, - mel_extractor, pitch_extractor, - energy_extractor, cut_sil, spk_emb_dir) + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir) if record: results.append(record) else: @@ -322,7 +329,6 @@ def main(): f0min=config.f0min, f0max=config.f0max) energy_extractor = Energy( - sr=config.fs, n_fft=config.n_fft, hop_length=config.n_shift, win_length=config.win_length, @@ -331,36 +337,36 @@ def main(): # process for the 3 sections if train_wav_files: process_sentences( - config, - train_wav_files, - sentences, - train_dump_dir, - mel_extractor, - pitch_extractor, - energy_extractor, + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if dev_wav_files: process_sentences( - config, - dev_wav_files, - sentences, - dev_dump_dir, - mel_extractor, - pitch_extractor, - energy_extractor, + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if test_wav_files: process_sentences( - config, - test_wav_files, - sentences, - test_dump_dir, - mel_extractor, - pitch_extractor, - energy_extractor, + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, + pitch_extractor=pitch_extractor, + energy_extractor=energy_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) diff --git a/paddlespeech/t2s/exps/gan_vocoder/preprocess.py b/paddlespeech/t2s/exps/gan_vocoder/preprocess.py index 4871bca7..5a407f5b 100644 --- a/paddlespeech/t2s/exps/gan_vocoder/preprocess.py +++ b/paddlespeech/t2s/exps/gan_vocoder/preprocess.py @@ -85,15 +85,17 @@ def process_sentence(config: Dict[str, Any], y, (0, num_frames * config.n_shift - y.size), mode="reflect") else: y = y[:num_frames * config.n_shift] - num_sample = y.shape[0] + num_samples = y.shape[0] mel_path = output_dir / (utt_id + "_feats.npy") wav_path = output_dir / (utt_id + "_wave.npy") - np.save(wav_path, y) # (num_samples, ) - np.save(mel_path, logmel) # (num_frames, n_mels) + # (num_samples, ) + np.save(wav_path, y) + # (num_frames, n_mels) + np.save(mel_path, logmel) record = { "utt_id": utt_id, - "num_samples": num_sample, + "num_samples": num_samples, "num_frames": num_frames, "feats": str(mel_path), "wave": str(wav_path), @@ -108,11 +110,17 @@ def process_sentences(config, mel_extractor=None, nprocs: int=1, cut_sil: bool=True): + if nprocs == 1: results = [] for fp in tqdm.tqdm(fps, total=len(fps)): - record = process_sentence(config, fp, sentences, output_dir, - mel_extractor, cut_sil) + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + cut_sil=cut_sil) if record: results.append(record) else: @@ -147,7 +155,7 @@ def main(): "--dataset", default="baker", type=str, - help="name of dataset, should in {baker, ljspeech, vctk} now") + help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now") parser.add_argument( "--rootdir", default=None, type=str, help="directory to dataset.") parser.add_argument( @@ -261,28 +269,28 @@ def main(): # process for the 3 sections if train_wav_files: process_sentences( - config, - train_wav_files, - sentences, - train_dump_dir, + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil) if dev_wav_files: process_sentences( - config, - dev_wav_files, - sentences, - dev_dump_dir, + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil) if test_wav_files: process_sentences( - config, - test_wav_files, - sentences, - test_dump_dir, + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil) diff --git a/paddlespeech/t2s/exps/speedyspeech/preprocess.py b/paddlespeech/t2s/exps/speedyspeech/preprocess.py index e833d139..e8d89a4f 100644 --- a/paddlespeech/t2s/exps/speedyspeech/preprocess.py +++ b/paddlespeech/t2s/exps/speedyspeech/preprocess.py @@ -123,11 +123,17 @@ def process_sentences(config, nprocs: int=1, cut_sil: bool=True, use_relative_path: bool=False): + if nprocs == 1: results = [] for fp in tqdm.tqdm(fps, total=len(fps)): - record = process_sentence(config, fp, sentences, output_dir, - mel_extractor, cut_sil) + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + cut_sil=cut_sil) if record: results.append(record) else: @@ -265,30 +271,30 @@ def main(): # process for the 3 sections if train_wav_files: process_sentences( - config, - train_wav_files, - sentences, - train_dump_dir, - mel_extractor, + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, use_relative_path=args.use_relative_path) if dev_wav_files: process_sentences( - config, - dev_wav_files, - sentences, - dev_dump_dir, - mel_extractor, + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, cut_sil=args.cut_sil, use_relative_path=args.use_relative_path) if test_wav_files: process_sentences( - config, - test_wav_files, - sentences, - test_dump_dir, - mel_extractor, + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, use_relative_path=args.use_relative_path) diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index ce0aee05..41fd7317 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -27,11 +27,11 @@ from paddle import jit from paddle.static import InputSpec from yacs.config import CfgNode -from paddlespeech.s2t.utils.dynamic_import import dynamic_import from paddlespeech.t2s.datasets.data_table import DataTable from paddlespeech.t2s.frontend import English from paddlespeech.t2s.frontend.zh_frontend import Frontend from paddlespeech.t2s.modules.normalizer import ZScore +from paddlespeech.utils.dynamic_import import dynamic_import model_alias = { # acoustic model diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index 0855a6a2..2b543ef7 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -107,8 +107,8 @@ def evaluate(args): if args.voice_cloning and "spk_emb" in datum: spk_emb = paddle.to_tensor(np.load(datum["spk_emb"])) mel = am_inference(phone_ids, spk_emb=spk_emb) - # vocoder - wav = voc_inference(mel) + # vocoder + wav = voc_inference(mel) wav = wav.numpy() N += wav.size @@ -125,7 +125,7 @@ def evaluate(args): def parse_args(): - # parse args and config and redirect to train_sp + # parse args and config parser = argparse.ArgumentParser( description="Synthesize with acoustic model & vocoder") # acoustic model @@ -143,7 +143,7 @@ def parse_args(): '--am_config', type=str, default=None, - help='Config of acoustic model. Use deault config when it is None.') + help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -182,7 +182,7 @@ def parse_args(): '--voc_config', type=str, default=None, - help='Config of voc. Use deault config when it is None.') + help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 2f14ef56..6101d159 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -159,7 +159,7 @@ def evaluate(args): def parse_args(): - # parse args and config and redirect to train_sp + # parse args and config parser = argparse.ArgumentParser( description="Synthesize with acoustic model & vocoder") # acoustic model @@ -177,7 +177,7 @@ def parse_args(): '--am_config', type=str, default=None, - help='Config of acoustic model. Use deault config when it is None.') + help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -223,7 +223,7 @@ def parse_args(): '--voc_config', type=str, default=None, - help='Config of voc. Use deault config when it is None.') + help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/synthesize_streaming.py b/paddlespeech/t2s/exps/synthesize_streaming.py index 3659cb49..b11bc799 100644 --- a/paddlespeech/t2s/exps/synthesize_streaming.py +++ b/paddlespeech/t2s/exps/synthesize_streaming.py @@ -201,7 +201,7 @@ def evaluate(args): def parse_args(): - # parse args and config and redirect to train_sp + # parse args and config parser = argparse.ArgumentParser( description="Synthesize with acoustic model & vocoder") # acoustic model @@ -215,7 +215,7 @@ def parse_args(): '--am_config', type=str, default=None, - help='Config of acoustic model. Use deault config when it is None.') + help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -248,7 +248,7 @@ def parse_args(): '--voc_config', type=str, default=None, - help='Config of voc. Use deault config when it is None.') + help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( diff --git a/paddlespeech/t2s/exps/tacotron2/preprocess.py b/paddlespeech/t2s/exps/tacotron2/preprocess.py index 14a0d7ea..fe2a2b26 100644 --- a/paddlespeech/t2s/exps/tacotron2/preprocess.py +++ b/paddlespeech/t2s/exps/tacotron2/preprocess.py @@ -122,9 +122,15 @@ def process_sentences(config, spk_emb_dir: Path=None): if nprocs == 1: results = [] - for fp in fps: - record = process_sentence(config, fp, sentences, output_dir, - mel_extractor, cut_sil, spk_emb_dir) + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir) if record: results.append(record) else: @@ -296,30 +302,30 @@ def main(): # process for the 3 sections if train_wav_files: process_sentences( - config, - train_wav_files, - sentences, - train_dump_dir, - mel_extractor, + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if dev_wav_files: process_sentences( - config, - dev_wav_files, - sentences, - dev_dump_dir, - mel_extractor, + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) if test_wav_files: process_sentences( - config, - test_wav_files, - sentences, - test_dump_dir, - mel_extractor, + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu, cut_sil=args.cut_sil, spk_emb_dir=spk_emb_dir) diff --git a/paddlespeech/t2s/exps/transformer_tts/preprocess.py b/paddlespeech/t2s/exps/transformer_tts/preprocess.py index 9aa87e91..28ca3de6 100644 --- a/paddlespeech/t2s/exps/transformer_tts/preprocess.py +++ b/paddlespeech/t2s/exps/transformer_tts/preprocess.py @@ -125,11 +125,16 @@ def process_sentences(config, output_dir: Path, mel_extractor=None, nprocs: int=1): + if nprocs == 1: results = [] for fp in tqdm.tqdm(fps, total=len(fps)): - record = process_sentence(config, fp, sentences, output_dir, - mel_extractor) + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor) if record: results.append(record) else: @@ -247,27 +252,27 @@ def main(): # process for the 3 sections if train_wav_files: process_sentences( - config, - train_wav_files, - sentences, - train_dump_dir, - mel_extractor, + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu) if dev_wav_files: process_sentences( - config, - dev_wav_files, - sentences, - dev_dump_dir, - mel_extractor, + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu) if test_wav_files: process_sentences( - config, - test_wav_files, - sentences, - test_dump_dir, - mel_extractor, + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, nprocs=args.num_cpu) diff --git a/paddlespeech/t2s/exps/vits/normalize.py b/paddlespeech/t2s/exps/vits/normalize.py index 97043fd7..6fc8adb0 100644 --- a/paddlespeech/t2s/exps/vits/normalize.py +++ b/paddlespeech/t2s/exps/vits/normalize.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,155 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Normalize feature files and dump them.""" +import argparse +import logging +from operator import itemgetter +from pathlib import Path + +import jsonlines +import numpy as np +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from paddlespeech.t2s.datasets.data_table import DataTable + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." + ) + parser.add_argument( + "--metadata", + type=str, + required=True, + help="directory including feature files to be normalized. " + "you need to specify either *-scp or rootdir.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump normalized feature files.") + parser.add_argument( + "--feats-stats", + type=str, + required=True, + help="speech statistics file.") + parser.add_argument( + "--skip-wav-copy", + default=False, + action="store_true", + help="whether to skip the copy of wav files.") + + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker-dict", type=str, default=None, help="speaker id map file.") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + logging.warning('Skip DEBUG/INFO messages') + + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, + converters={ + "feats": np.load, + "wave": None if args.skip_wav_copy else np.load, + }) + logging.info(f"The number of files = {len(dataset)}.") + + # restore scaler + feats_scaler = StandardScaler() + feats_scaler.mean_ = np.load(args.feats_stats)[0] + feats_scaler.scale_ = np.load(args.feats_stats)[1] + feats_scaler.n_features_in_ = feats_scaler.mean_.shape[0] + + vocab_phones = {} + with open(args.phones_dict, 'rt') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + vocab_phones[phn] = int(id) + + vocab_speaker = {} + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + for spk, id in spk_id: + vocab_speaker[spk] = int(id) + + # process each file + output_metadata = [] + + for item in tqdm(dataset): + utt_id = item['utt_id'] + feats = item['feats'] + wave = item['wave'] + + # normalize + feats = feats_scaler.transform(feats) + feats_path = dumpdir / f"{utt_id}_feats.npy" + np.save(feats_path, feats.astype(np.float32), allow_pickle=False) + + if not args.skip_wav_copy: + wav_path = dumpdir / f"{utt_id}_wave.npy" + np.save(wav_path, wave.astype(np.float32), allow_pickle=False) + else: + wav_path = wave + + phone_ids = [vocab_phones[p] for p in item['phones']] + spk_id = vocab_speaker[item["speaker"]] + + record = { + "utt_id": item['utt_id'], + "text": phone_ids, + "text_lengths": item['text_lengths'], + 'feats': str(feats_path), + "feats_lengths": item['feats_lengths'], + "wave": str(wav_path), + "spk_id": spk_id, + } + + # add spk_emb for voice cloning + if "spk_emb" in item: + record["spk_emb"] = str(item["spk_emb"]) + + output_metadata.append(record) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/vits/preprocess.py b/paddlespeech/t2s/exps/vits/preprocess.py index 97043fd7..6aa139fb 100644 --- a/paddlespeech/t2s/exps/vits/preprocess.py +++ b/paddlespeech/t2s/exps/vits/preprocess.py @@ -11,3 +11,338 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse +import os +from concurrent.futures import ThreadPoolExecutor +from operator import itemgetter +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +import jsonlines +import librosa +import numpy as np +import tqdm +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.get_feats import LinearSpectrogram +from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length +from paddlespeech.t2s.datasets.preprocess_utils import get_input_token +from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur +from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map +from paddlespeech.t2s.datasets.preprocess_utils import merge_silence +from paddlespeech.t2s.utils import str2bool + + +def process_sentence(config: Dict[str, Any], + fp: Path, + sentences: Dict, + output_dir: Path, + spec_extractor=None, + cut_sil: bool=True, + spk_emb_dir: Path=None): + utt_id = fp.stem + # for vctk + if utt_id.endswith("_mic2"): + utt_id = utt_id[:-5] + record = None + if utt_id in sentences: + # reading, resampling may occur + wav, _ = librosa.load(str(fp), sr=config.fs) + if len(wav.shape) != 1: + return record + max_value = np.abs(wav).max() + if max_value > 1.0: + wav = wav / max_value + assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs(wav).max( + ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + speaker = sentences[utt_id][2] + d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant') + # little imprecise than use *.TextGrid directly + times = librosa.frames_to_time( + d_cumsum, sr=config.fs, hop_length=config.n_shift) + if cut_sil: + start = 0 + end = d_cumsum[-1] + if phones[0] == "sil" and len(durations) > 1: + start = times[1] + durations = durations[1:] + phones = phones[1:] + if phones[-1] == 'sil' and len(durations) > 1: + end = times[-2] + durations = durations[:-1] + phones = phones[:-1] + sentences[utt_id][0] = phones + sentences[utt_id][1] = durations + start, end = librosa.time_to_samples([start, end], sr=config.fs) + wav = wav[start:end] + # extract mel feats + spec = spec_extractor.get_linear_spectrogram(wav) + # change duration according to mel_length + compare_duration_and_mel_length(sentences, utt_id, spec) + # utt_id may be popped in compare_duration_and_mel_length + if utt_id not in sentences: + return None + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + num_frames = spec.shape[0] + assert sum(durations) == num_frames + + if wav.size < num_frames * config.n_shift: + wav = np.pad( + wav, (0, num_frames * config.n_shift - wav.size), + mode="reflect") + else: + wav = wav[:num_frames * config.n_shift] + num_samples = wav.shape[0] + + spec_path = output_dir / (utt_id + "_feats.npy") + wav_path = output_dir / (utt_id + "_wave.npy") + # (num_samples, ) + np.save(wav_path, wav) + # (num_frames, aux_channels) + np.save(spec_path, spec) + + record = { + "utt_id": utt_id, + "phones": phones, + "text_lengths": len(phones), + "feats": str(spec_path), + "feats_lengths": num_frames, + "wave": str(wav_path), + "speaker": speaker + } + if spk_emb_dir: + if speaker in os.listdir(spk_emb_dir): + embed_name = utt_id + ".npy" + embed_path = spk_emb_dir / speaker / embed_name + if embed_path.is_file(): + record["spk_emb"] = str(embed_path) + else: + return None + return record + + +def process_sentences(config, + fps: List[Path], + sentences: Dict, + output_dir: Path, + spec_extractor=None, + nprocs: int=1, + cut_sil: bool=True, + spk_emb_dir: Path=None): + if nprocs == 1: + results = [] + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + spec_extractor=spec_extractor, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir) + if record: + results.append(record) + else: + with ThreadPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp in fps: + future = pool.submit(process_sentence, config, fp, + sentences, output_dir, spec_extractor, + cut_sil, spk_emb_dir) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + record = ft.result() + if record: + results.append(record) + + results.sort(key=itemgetter("utt_id")) + with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: + for item in results: + writer.write(item) + print("Done") + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features.") + + parser.add_argument( + "--dataset", + default="baker", + type=str, + help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now") + + parser.add_argument( + "--rootdir", default=None, type=str, help="directory to dataset.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.") + parser.add_argument( + "--dur-file", default=None, type=str, help="path to durations.txt.") + + parser.add_argument("--config", type=str, help="fastspeech2 config file.") + + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + parser.add_argument( + "--num-cpu", type=int, default=1, help="number of process.") + + parser.add_argument( + "--cut-sil", + type=str2bool, + default=True, + help="whether cut sil in the edge of audio") + + parser.add_argument( + "--spk_emb_dir", + default=None, + type=str, + help="directory to speaker embedding files.") + args = parser.parse_args() + + rootdir = Path(args.rootdir).expanduser() + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + dur_file = Path(args.dur_file).expanduser() + + if args.spk_emb_dir: + spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve() + else: + spk_emb_dir = None + + assert rootdir.is_dir() + assert dur_file.is_file() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + if args.verbose > 1: + print(vars(args)) + print(config) + + sentences, speaker_set = get_phn_dur(dur_file) + + merge_silence(sentences) + phone_id_map_path = dumpdir / "phone_id_map.txt" + speaker_id_map_path = dumpdir / "speaker_id_map.txt" + get_input_token(sentences, phone_id_map_path, args.dataset) + get_spk_id_map(speaker_set, speaker_id_map_path) + + if args.dataset == "baker": + wav_files = sorted(list((rootdir / "Wave").rglob("*.wav"))) + # split data into 3 sections + num_train = 9800 + num_dev = 100 + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + elif args.dataset == "aishell3": + sub_num_dev = 5 + wav_dir = rootdir / "train" / "wav" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*.wav"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + elif args.dataset == "ljspeech": + wav_files = sorted(list((rootdir / "wavs").rglob("*.wav"))) + # split data into 3 sections + num_train = 12900 + num_dev = 100 + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + elif args.dataset == "vctk": + sub_num_dev = 5 + wav_dir = rootdir / "wav48_silence_trimmed" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + else: + print("dataset should in {baker, aishell3, ljspeech, vctk} now!") + + train_dump_dir = dumpdir / "train" / "raw" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dumpdir / "dev" / "raw" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dumpdir / "test" / "raw" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # Extractor + + spec_extractor = LinearSpectrogram( + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window) + + # process for the 3 sections + if train_wav_files: + process_sentences( + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + spec_extractor=spec_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) + if dev_wav_files: + process_sentences( + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + spec_extractor=spec_extractor, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) + if test_wav_files: + process_sentences( + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + spec_extractor=spec_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/vits/synthesize.py b/paddlespeech/t2s/exps/vits/synthesize.py index 97043fd7..074b890f 100644 --- a/paddlespeech/t2s/exps/vits/synthesize.py +++ b/paddlespeech/t2s/exps/vits/synthesize.py @@ -11,3 +11,107 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse +from pathlib import Path + +import jsonlines +import paddle +import soundfile as sf +import yaml +from timer import timer +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.models.vits import VITS + + +def evaluate(args): + + # construct dataset for evaluation + with jsonlines.open(args.test_metadata, 'r') as reader: + test_metadata = list(reader) + # Init body. + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + + fields = ["utt_id", "text"] + + test_dataset = DataTable(data=test_metadata, fields=fields) + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_fft // 2 + 1 + + vits = VITS(idim=vocab_size, odim=odim, **config["model"]) + vits.set_state_dict(paddle.load(args.ckpt)["main_params"]) + vits.eval() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + N = 0 + T = 0 + + for datum in test_dataset: + utt_id = datum["utt_id"] + phone_ids = paddle.to_tensor(datum["text"]) + with timer() as t: + with paddle.no_grad(): + out = vits.inference(text=phone_ids) + wav = out["wav"] + wav = wav.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = config.fs / speed + print( + f"{utt_id}, wave: {wav.size}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") + + +def parse_args(): + # parse args and config + parser = argparse.ArgumentParser(description="Synthesize with VITS") + # model + parser.add_argument( + '--config', type=str, default=None, help='Config of VITS.') + parser.add_argument( + '--ckpt', type=str, default=None, help='Checkpoint file of VITS.') + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + # other + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument("--test_metadata", type=str, help="test metadata.") + parser.add_argument("--output_dir", type=str, help="output dir.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/vits/synthesize_e2e.py b/paddlespeech/t2s/exps/vits/synthesize_e2e.py new file mode 100644 index 00000000..c82e5c03 --- /dev/null +++ b/paddlespeech/t2s/exps/vits/synthesize_e2e.py @@ -0,0 +1,146 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from pathlib import Path + +import paddle +import soundfile as sf +import yaml +from timer import timer +from yacs.config import CfgNode + +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.exps.syn_utils import get_sentences +from paddlespeech.t2s.models.vits import VITS + + +def evaluate(args): + + # Init body. + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + + sentences = get_sentences(text_file=args.text, lang=args.lang) + + # frontend + frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict) + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_fft // 2 + 1 + + vits = VITS(idim=vocab_size, odim=odim, **config["model"]) + vits.set_state_dict(paddle.load(args.ckpt)["main_params"]) + vits.eval() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + merge_sentences = False + + N = 0 + T = 0 + for utt_id, sentence in sentences: + with timer() as t: + if args.lang == 'zh': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + elif args.lang == 'en': + input_ids = frontend.get_input_ids( + sentence, merge_sentences=merge_sentences) + phone_ids = input_ids["phone_ids"] + else: + print("lang should in {'zh', 'en'}!") + with paddle.no_grad(): + flags = 0 + for i in range(len(phone_ids)): + part_phone_ids = phone_ids[i] + out = vits.inference(text=part_phone_ids) + wav = out["wav"] + if flags == 0: + wav_all = wav + flags = 1 + else: + wav_all = paddle.concat([wav_all, wav]) + wav = wav_all.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + rtf = config.fs / speed + print( + f"{utt_id}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {rtf}." + ) + sf.write(str(output_dir / (utt_id + ".wav")), wav, samplerate=config.fs) + print(f"{utt_id} done!") + print(f"generation speed: {N / T}Hz, RTF: {config.fs / (N / T) }") + + +def parse_args(): + # parse args and config + parser = argparse.ArgumentParser(description="Synthesize with VITS") + + # model + parser.add_argument( + '--config', type=str, default=None, help='Config of VITS.') + parser.add_argument( + '--ckpt', type=str, default=None, help='Checkpoint file of VITS.') + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + # other + parser.add_argument( + '--lang', + type=str, + default='zh', + help='Choose model language. zh or en') + + parser.add_argument( + "--inference_dir", + type=str, + default=None, + help="dir to save inference models") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--text", + type=str, + help="text to synthesize, a 'utt_id sentence' pair per line.") + parser.add_argument("--output_dir", type=str, help="output dir.") + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/vits/train.py b/paddlespeech/t2s/exps/vits/train.py index 97043fd7..b921f92a 100644 --- a/paddlespeech/t2s/exps/vits/train.py +++ b/paddlespeech/t2s/exps/vits/train.py @@ -11,3 +11,251 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.optimizer import Adam +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import vits_single_spk_batch_fn +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.models.vits import VITS +from paddlespeech.t2s.models.vits import VITSEvaluator +from paddlespeech.t2s.models.vits import VITSUpdater +from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss +from paddlespeech.t2s.modules.losses import FeatureMatchLoss +from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss +from paddlespeech.t2s.modules.losses import KLDivergenceLoss +from paddlespeech.t2s.modules.losses import MelSpectrogramLoss +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.optimizer import scheduler_classes +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + world_size = paddle.distributed.get_world_size() + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + fields = ["text", "text_lengths", "feats", "feats_lengths", "wave"] + + converters = { + "wave": np.load, + "feats": np.load, + } + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=fields, + converters=converters, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=fields, + converters=converters, ) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + dev_sampler = DistributedBatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + print("samplers done!") + + train_batch_fn = vits_single_spk_batch_fn + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_fft // 2 + 1 + model = VITS(idim=vocab_size, odim=odim, **config["model"]) + gen_parameters = model.generator.parameters() + dis_parameters = model.discriminator.parameters() + if world_size > 1: + model = DataParallel(model) + gen_parameters = model._layers.generator.parameters() + dis_parameters = model._layers.discriminator.parameters() + + print("model done!") + + # loss + criterion_mel = MelSpectrogramLoss( + **config["mel_loss_params"], ) + criterion_feat_match = FeatureMatchLoss( + **config["feat_match_loss_params"], ) + criterion_gen_adv = GeneratorAdversarialLoss( + **config["generator_adv_loss_params"], ) + criterion_dis_adv = DiscriminatorAdversarialLoss( + **config["discriminator_adv_loss_params"], ) + criterion_kl = KLDivergenceLoss() + + print("criterions done!") + + lr_schedule_g = scheduler_classes[config["generator_scheduler"]]( + **config["generator_scheduler_params"]) + optimizer_g = Adam( + learning_rate=lr_schedule_g, + parameters=gen_parameters, + **config["generator_optimizer_params"]) + + lr_schedule_d = scheduler_classes[config["discriminator_scheduler"]]( + **config["discriminator_scheduler_params"]) + optimizer_d = Adam( + learning_rate=lr_schedule_d, + parameters=dis_parameters, + **config["discriminator_optimizer_params"]) + + print("optimizers done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = VITSUpdater( + model=model, + optimizers={ + "generator": optimizer_g, + "discriminator": optimizer_d, + }, + criterions={ + "mel": criterion_mel, + "feat_match": criterion_feat_match, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "kl": criterion_kl, + }, + schedulers={ + "generator": lr_schedule_g, + "discriminator": lr_schedule_d, + }, + dataloader=train_dataloader, + lambda_adv=config.lambda_adv, + lambda_mel=config.lambda_mel, + lambda_kl=config.lambda_kl, + lambda_feat_match=config.lambda_feat_match, + lambda_dur=config.lambda_dur, + generator_first=config.generator_first, + output_dir=output_dir) + + evaluator = VITSEvaluator( + model=model, + criterions={ + "mel": criterion_mel, + "feat_match": criterion_feat_match, + "gen_adv": criterion_gen_adv, + "dis_adv": criterion_dis_adv, + "kl": criterion_kl, + }, + dataloader=dev_dataloader, + lambda_adv=config.lambda_adv, + lambda_mel=config.lambda_mel, + lambda_kl=config.lambda_kl, + lambda_feat_match=config.lambda_feat_match, + lambda_dur=config.lambda_dur, + generator_first=config.generator_first, + output_dir=output_dir) + + trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) + + if dist.get_rank() == 0: + trainer.extend(evaluator, trigger=(1, "epoch")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + trainer.extend( + Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) + + print("Trainer Done!") + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + + parser = argparse.ArgumentParser(description="Train a HiFiGAN model.") + parser.add_argument( + "--config", type=str, help="config file to overwrite default config.") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + + args = parser.parse_args() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.ngpu > 1: + dist.spawn(train_sp, (args, config), nprocs=args.ngpu) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/voice_cloning.py b/paddlespeech/t2s/exps/voice_cloning.py index 9257b07d..3cf1cabc 100644 --- a/paddlespeech/t2s/exps/voice_cloning.py +++ b/paddlespeech/t2s/exps/voice_cloning.py @@ -122,7 +122,7 @@ def voice_cloning(args): def parse_args(): - # parse args and config and redirect to train_sp + # parse args and config parser = argparse.ArgumentParser(description="") parser.add_argument( '--am', @@ -134,7 +134,7 @@ def parse_args(): '--am_config', type=str, default=None, - help='Config of acoustic model. Use deault config when it is None.') + help='Config of acoustic model.') parser.add_argument( '--am_ckpt', type=str, @@ -163,7 +163,7 @@ def parse_args(): '--voc_config', type=str, default=None, - help='Config of voc. Use deault config when it is None.') + help='Config of voc.') parser.add_argument( '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') parser.add_argument( diff --git a/paddlespeech/t2s/models/__init__.py b/paddlespeech/t2s/models/__init__.py index 41be7c1d..0b6f2911 100644 --- a/paddlespeech/t2s/models/__init__.py +++ b/paddlespeech/t2s/models/__init__.py @@ -18,5 +18,6 @@ from .parallel_wavegan import * from .speedyspeech import * from .tacotron2 import * from .transformer_tts import * +from .vits import * from .waveflow import * from .wavernn import * diff --git a/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py b/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py index 40cfff5a..c1cd7330 100644 --- a/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py +++ b/paddlespeech/t2s/models/parallel_wavegan/parallel_wavegan_updater.py @@ -68,8 +68,8 @@ class PWGUpdater(StandardUpdater): self.discriminator_train_start_steps = discriminator_train_start_steps self.lambda_adv = lambda_adv self.lambda_aux = lambda_aux - self.state = UpdaterState(iteration=0, epoch=0) + self.state = UpdaterState(iteration=0, epoch=0) self.train_iterator = iter(self.dataloader) log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) diff --git a/paddlespeech/t2s/models/vits/__init__.py b/paddlespeech/t2s/models/vits/__init__.py index 97043fd7..2c23aa3e 100644 --- a/paddlespeech/t2s/models/vits/__init__.py +++ b/paddlespeech/t2s/models/vits/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .vits import * +from .vits_updater import * \ No newline at end of file diff --git a/paddlespeech/t2s/models/vits/generator.py b/paddlespeech/t2s/models/vits/generator.py index e35f9956..f87de91a 100644 --- a/paddlespeech/t2s/models/vits/generator.py +++ b/paddlespeech/t2s/models/vits/generator.py @@ -318,7 +318,6 @@ class VITSGenerator(nn.Layer): g = g + g_ # forward posterior encoder - z, m_q, logs_q, y_mask = self.posterior_encoder( feats, feats_lengths, g=g) diff --git a/paddlespeech/t2s/models/vits/vits.py b/paddlespeech/t2s/models/vits/vits.py index f7f5ba96..ab8eda26 100644 --- a/paddlespeech/t2s/models/vits/vits.py +++ b/paddlespeech/t2s/models/vits/vits.py @@ -27,12 +27,7 @@ from paddlespeech.t2s.models.hifigan import HiFiGANMultiScaleMultiPeriodDiscrimi from paddlespeech.t2s.models.hifigan import HiFiGANPeriodDiscriminator from paddlespeech.t2s.models.hifigan import HiFiGANScaleDiscriminator from paddlespeech.t2s.models.vits.generator import VITSGenerator -from paddlespeech.t2s.modules.losses import DiscriminatorAdversarialLoss -from paddlespeech.t2s.modules.losses import FeatureMatchLoss -from paddlespeech.t2s.modules.losses import GeneratorAdversarialLoss -from paddlespeech.t2s.modules.losses import KLDivergenceLoss -from paddlespeech.t2s.modules.losses import MelSpectrogramLoss -from paddlespeech.t2s.modules.nets_utils import get_segments +from paddlespeech.t2s.modules.nets_utils import initialize AVAILABLE_GENERATERS = { "vits_generator": VITSGenerator, @@ -157,37 +152,8 @@ class VITS(nn.Layer): "use_spectral_norm": False, }, }, - # loss related - generator_adv_loss_params: Dict[str, Any]={ - "average_by_discriminators": False, - "loss_type": "mse", - }, - discriminator_adv_loss_params: Dict[str, Any]={ - "average_by_discriminators": False, - "loss_type": "mse", - }, - feat_match_loss_params: Dict[str, Any]={ - "average_by_discriminators": False, - "average_by_layers": False, - "include_final_outputs": True, - }, - mel_loss_params: Dict[str, Any]={ - "fs": 22050, - "fft_size": 1024, - "hop_size": 256, - "win_length": None, - "window": "hann", - "num_mels": 80, - "fmin": 0, - "fmax": None, - "log_base": None, - }, - lambda_adv: float=1.0, - lambda_mel: float=45.0, - lambda_feat_match: float=2.0, - lambda_dur: float=1.0, - lambda_kl: float=1.0, - cache_generator_outputs: bool=True, ): + cache_generator_outputs: bool=True, + init_type: str="xavier_uniform", ): """Initialize VITS module. Args: idim (int): Input vocabrary size. @@ -200,22 +166,14 @@ class VITS(nn.Layer): generator_params (Dict[str, Any]): Parameter dict for generator. discriminator_type (str): Discriminator type. discriminator_params (Dict[str, Any]): Parameter dict for discriminator. - generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator - adversarial loss. - discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for - discriminator adversarial loss. - feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. - mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. - lambda_adv (float): Loss scaling coefficient for adversarial loss. - lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. - lambda_feat_match (float): Loss scaling coefficient for feat match loss. - lambda_dur (float): Loss scaling coefficient for duration loss. - lambda_kl (float): Loss scaling coefficient for KL divergence loss. cache_generator_outputs (bool): Whether to cache generator outputs. """ assert check_argument_types() super().__init__() + # initialize parameters + initialize(self, init_type) + # define modules generator_class = AVAILABLE_GENERATERS[generator_type] if generator_type == "vits_generator": @@ -229,22 +187,8 @@ class VITS(nn.Layer): discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] self.discriminator = discriminator_class( **discriminator_params, ) - self.generator_adv_loss = GeneratorAdversarialLoss( - **generator_adv_loss_params, ) - self.discriminator_adv_loss = DiscriminatorAdversarialLoss( - **discriminator_adv_loss_params, ) - self.feat_match_loss = FeatureMatchLoss( - **feat_match_loss_params, ) - self.mel_loss = MelSpectrogramLoss( - **mel_loss_params, ) - self.kl_loss = KLDivergenceLoss() - # coefficients - self.lambda_adv = lambda_adv - self.lambda_mel = lambda_mel - self.lambda_kl = lambda_kl - self.lambda_feat_match = lambda_feat_match - self.lambda_dur = lambda_dur + nn.initializer.set_global_initializer(None) # cache self.cache_generator_outputs = cache_generator_outputs @@ -259,15 +203,8 @@ class VITS(nn.Layer): self.langs = self.generator.langs self.spk_embed_dim = self.generator.spk_embed_dim - @property - def require_raw_speech(self): - """Return whether or not speech is required.""" - return True - - @property - def require_vocoder(self): - """Return whether or not vocoder is required.""" - return False + self.reuse_cache_gen = True + self.reuse_cache_dis = True def forward( self, @@ -334,21 +271,15 @@ class VITS(nn.Layer): spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: - Dict[str, Any]: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - * weight (Tensor): Weight tensor to summarize losses. - * optim_idx (int): Optimizer index (0 for G and 1 for D). + """ # setup - batch_size = paddle.shape(text)[0] feats = feats.transpose([0, 2, 1]) - # speech = speech.unsqueeze(1) # calculate generator outputs - reuse_cache = True + self.reuse_cache_gen = True if not self.cache_generator_outputs or self._cache is None: - reuse_cache = False + self.reuse_cache_gen = False outs = self.generator( text=text, text_lengths=text_lengths, @@ -361,59 +292,10 @@ class VITS(nn.Layer): outs = self._cache # store cache - if self.training and self.cache_generator_outputs and not reuse_cache: + if self.training and self.cache_generator_outputs and not self.reuse_cache_gen: self._cache = outs return outs - """ - # parse outputs - speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs - _, z_p, m_p, logs_p, _, logs_q = outs_ - speech_ = get_segments( - x=speech, - start_idxs=start_idxs * self.generator.upsample_factor, - segment_size=self.generator.segment_size * - self.generator.upsample_factor, ) - - # calculate discriminator outputs - p_hat = self.discriminator(speech_hat_) - with paddle.no_grad(): - # do not store discriminator gradient in generator turn - p = self.discriminator(speech_) - - # calculate losses - mel_loss = self.mel_loss(speech_hat_, speech_) - kl_loss = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask) - dur_loss = paddle.sum(dur_nll.float()) - adv_loss = self.generator_adv_loss(p_hat) - feat_match_loss = self.feat_match_loss(p_hat, p) - - mel_loss = mel_loss * self.lambda_mel - kl_loss = kl_loss * self.lambda_kl - dur_loss = dur_loss * self.lambda_dur - adv_loss = adv_loss * self.lambda_adv - feat_match_loss = feat_match_loss * self.lambda_feat_match - loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss - - stats = dict( - generator_loss=loss.item(), - generator_mel_loss=mel_loss.item(), - generator_kl_loss=kl_loss.item(), - generator_dur_loss=dur_loss.item(), - generator_adv_loss=adv_loss.item(), - generator_feat_match_loss=feat_match_loss.item(), ) - - # reset cache - if reuse_cache or not self.training: - self._cache = None - - return { - "loss": loss, - "stats": stats, - # "weight": weight, - "optim_idx": 0, # needed for trainer - } - """ def _forward_discrminator( self, @@ -434,21 +316,15 @@ class VITS(nn.Layer): spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: - Dict[str, Any]: - * loss (Tensor): Loss scalar tensor. - * stats (Dict[str, float]): Statistics to be monitored. - * weight (Tensor): Weight tensor to summarize losses. - * optim_idx (int): Optimizer index (0 for G and 1 for D). + """ # setup - batch_size = paddle.shape(text)[0] feats = feats.transpose([0, 2, 1]) - # speech = speech.unsqueeze(1) # calculate generator outputs - reuse_cache = True + self.reuse_cache_dis = True if not self.cache_generator_outputs or self._cache is None: - reuse_cache = False + self.reuse_cache_dis = False outs = self.generator( text=text, text_lengths=text_lengths, @@ -461,44 +337,10 @@ class VITS(nn.Layer): outs = self._cache # store cache - if self.cache_generator_outputs and not reuse_cache: + if self.cache_generator_outputs and not self.reuse_cache_dis: self._cache = outs return outs - """ - - # parse outputs - speech_hat_, _, _, start_idxs, *_ = outs - speech_ = get_segments( - x=speech, - start_idxs=start_idxs * self.generator.upsample_factor, - segment_size=self.generator.segment_size * - self.generator.upsample_factor, ) - - # calculate discriminator outputs - p_hat = self.discriminator(speech_hat_.detach()) - p = self.discriminator(speech_) - - # calculate losses - real_loss, fake_loss = self.discriminator_adv_loss(p_hat, p) - loss = real_loss + fake_loss - - stats = dict( - discriminator_loss=loss.item(), - discriminator_real_loss=real_loss.item(), - discriminator_fake_loss=fake_loss.item(), ) - - # reset cache - if reuse_cache or not self.training: - self._cache = None - - return { - "loss": loss, - "stats": stats, - # "weight": weight, - "optim_idx": 1, # needed for trainer - } - """ def inference( self, @@ -535,10 +377,7 @@ class VITS(nn.Layer): # setup text = text[None] text_lengths = paddle.to_tensor(paddle.shape(text)[1]) - # if sids is not None: - # sids = sids.view(1) - # if lids is not None: - # lids = lids.view(1) + if durations is not None: durations = paddle.reshape(durations, [1, 1, -1]) diff --git a/paddlespeech/t2s/models/vits/vits_updater.py b/paddlespeech/t2s/models/vits/vits_updater.py index e69de29b..a031dc57 100644 --- a/paddlespeech/t2s/models/vits/vits_updater.py +++ b/paddlespeech/t2s/models/vits/vits_updater.py @@ -0,0 +1,353 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Dict + +import paddle +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from paddlespeech.t2s.modules.nets_utils import get_segments +from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator +from paddlespeech.t2s.training.reporter import report +from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater +from paddlespeech.t2s.training.updaters.standard_updater import UpdaterState + +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class VITSUpdater(StandardUpdater): + def __init__(self, + model: Layer, + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + generator_train_start_steps: int=0, + discriminator_train_start_steps: int=100000, + lambda_adv: float=1.0, + lambda_mel: float=45.0, + lambda_feat_match: float=2.0, + lambda_dur: float=1.0, + lambda_kl: float=1.0, + generator_first: bool=False, + output_dir=None): + # it is designed to hold multiple models + # 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分 + models = {"main": model} + self.models: Dict[str, Layer] = models + # self.model = model + + self.model = model._layers if isinstance(model, paddle.DataParallel) else model + + self.optimizers = optimizers + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] + + self.criterions = criterions + self.criterion_mel = criterions['mel'] + self.criterion_feat_match = criterions['feat_match'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + self.criterion_kl = criterions["kl"] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.generator_train_start_steps = generator_train_start_steps + self.discriminator_train_start_steps = discriminator_train_start_steps + + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_feat_match = lambda_feat_match + self.lambda_dur = lambda_dur + self.lambda_kl = lambda_kl + + if generator_first: + self.turns = ["generator", "discriminator"] + else: + self.turns = ["discriminator", "generator"] + + self.state = UpdaterState(iteration=0, epoch=0) + self.train_iterator = iter(self.dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + for turn in self.turns: + speech = batch["speech"] + speech = speech.unsqueeze(1) + outs = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + feats=batch["feats"], + feats_lengths=batch["feats_lengths"], + forward_generator=turn == "generator") + # Generator + if turn == "generator": + # parse outputs + speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs + _, z_p, m_p, logs_p, _, logs_q = outs_ + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_) + with paddle.no_grad(): + # do not store discriminator gradient in generator turn + p = self.model.discriminator(speech_) + + # calculate losses + mel_loss = self.criterion_mel(speech_hat_, speech_) + kl_loss = self.criterion_kl(z_p, logs_q, m_p, logs_p, z_mask) + dur_loss = paddle.sum(dur_nll) + adv_loss = self.criterion_gen_adv(p_hat) + feat_match_loss = self.criterion_feat_match(p_hat, p) + + mel_loss = mel_loss * self.lambda_mel + kl_loss = kl_loss * self.lambda_kl + dur_loss = dur_loss * self.lambda_dur + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + gen_loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + + report("train/generator_loss", float(gen_loss)) + report("train/generator_mel_loss", float(mel_loss)) + report("train/generator_kl_loss", float(kl_loss)) + report("train/generator_dur_loss", float(dur_loss)) + report("train/generator_adv_loss", float(adv_loss)) + report("train/generator_feat_match_loss", + float(feat_match_loss)) + + losses_dict["generator_loss"] = float(gen_loss) + losses_dict["generator_mel_loss"] = float(mel_loss) + losses_dict["generator_kl_loss"] = float(kl_loss) + losses_dict["generator_dur_loss"] = float(dur_loss) + losses_dict["generator_adv_loss"] = float(adv_loss) + losses_dict["generator_feat_match_loss"] = float( + feat_match_loss) + + self.optimizer_g.clear_grad() + gen_loss.backward() + + self.optimizer_g.step() + self.scheduler_g.step() + + # reset cache + if self.model.reuse_cache_gen or not self.model.training: + self.model._cache = None + + # Disctiminator + elif turn == "discriminator": + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_.detach()) + p = self.model.discriminator(speech_) + + # calculate losses + real_loss, fake_loss = self.criterion_dis_adv(p_hat, p) + dis_loss = real_loss + fake_loss + + report("train/real_loss", float(real_loss)) + report("train/fake_loss", float(fake_loss)) + report("train/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + self.optimizer_d.clear_grad() + dis_loss.backward() + + self.optimizer_d.step() + self.scheduler_d.step() + + # reset cache + if self.model.reuse_cache_dis or not self.model.training: + self.model._cache = None + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + + +class VITSEvaluator(StandardEvaluator): + def __init__(self, + model, + criterions: Dict[str, Layer], + dataloader: DataLoader, + lambda_adv: float=1.0, + lambda_mel: float=45.0, + lambda_feat_match: float=2.0, + lambda_dur: float=1.0, + lambda_kl: float=1.0, + generator_first: bool=False, + output_dir=None): + # 因为输入的是单模型,但是没有用到父类的 init(), 所以需要重新写这部分 + models = {"main": model} + self.models: Dict[str, Layer] = models + # self.model = model + self.model = model._layers if isinstance(model, paddle.DataParallel) else model + + self.criterions = criterions + self.criterion_mel = criterions['mel'] + self.criterion_feat_match = criterions['feat_match'] + self.criterion_gen_adv = criterions["gen_adv"] + self.criterion_dis_adv = criterions["dis_adv"] + self.criterion_kl = criterions["kl"] + + self.dataloader = dataloader + + self.lambda_adv = lambda_adv + self.lambda_mel = lambda_mel + self.lambda_feat_match = lambda_feat_match + self.lambda_dur = lambda_dur + self.lambda_kl = lambda_kl + + if generator_first: + self.turns = ["generator", "discriminator"] + else: + self.turns = ["discriminator", "generator"] + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def evaluate_core(self, batch): + # logging.debug("Evaluate: ") + self.msg = "Evaluate: " + losses_dict = {} + + for turn in self.turns: + speech = batch["speech"] + speech = speech.unsqueeze(1) + outs = self.model( + text=batch["text"], + text_lengths=batch["text_lengths"], + feats=batch["feats"], + feats_lengths=batch["feats_lengths"], + forward_generator=turn == "generator") + # Generator + if turn == "generator": + # parse outputs + speech_hat_, dur_nll, _, start_idxs, _, z_mask, outs_ = outs + _, z_p, m_p, logs_p, _, logs_q = outs_ + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_) + with paddle.no_grad(): + # do not store discriminator gradient in generator turn + p = self.model.discriminator(speech_) + + # calculate losses + mel_loss = self.criterion_mel(speech_hat_, speech_) + kl_loss = self.criterion_kl(z_p, logs_q, m_p, logs_p, z_mask) + dur_loss = paddle.sum(dur_nll) + adv_loss = self.criterion_gen_adv(p_hat) + feat_match_loss = self.criterion_feat_match(p_hat, p) + + mel_loss = mel_loss * self.lambda_mel + kl_loss = kl_loss * self.lambda_kl + dur_loss = dur_loss * self.lambda_dur + adv_loss = adv_loss * self.lambda_adv + feat_match_loss = feat_match_loss * self.lambda_feat_match + gen_loss = mel_loss + kl_loss + dur_loss + adv_loss + feat_match_loss + + report("eval/generator_loss", float(gen_loss)) + report("eval/generator_mel_loss", float(mel_loss)) + report("eval/generator_kl_loss", float(kl_loss)) + report("eval/generator_dur_loss", float(dur_loss)) + report("eval/generator_adv_loss", float(adv_loss)) + report("eval/generator_feat_match_loss", float(feat_match_loss)) + + losses_dict["generator_loss"] = float(gen_loss) + losses_dict["generator_mel_loss"] = float(mel_loss) + losses_dict["generator_kl_loss"] = float(kl_loss) + losses_dict["generator_dur_loss"] = float(dur_loss) + losses_dict["generator_adv_loss"] = float(adv_loss) + losses_dict["generator_feat_match_loss"] = float( + feat_match_loss) + + # reset cache + if self.model.reuse_cache_gen or not self.model.training: + self.model._cache = None + + # Disctiminator + elif turn == "discriminator": + # parse outputs + speech_hat_, _, _, start_idxs, *_ = outs + speech_ = get_segments( + x=speech, + start_idxs=start_idxs * + self.model.generator.upsample_factor, + segment_size=self.model.generator.segment_size * + self.model.generator.upsample_factor, ) + + # calculate discriminator outputs + p_hat = self.model.discriminator(speech_hat_.detach()) + p = self.model.discriminator(speech_) + + # calculate losses + real_loss, fake_loss = self.criterion_dis_adv(p_hat, p) + dis_loss = real_loss + fake_loss + + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + report("eval/discriminator_loss", float(dis_loss)) + losses_dict["real_loss"] = float(real_loss) + losses_dict["fake_loss"] = float(fake_loss) + losses_dict["discriminator_loss"] = float(dis_loss) + + # reset cache + if self.model.reuse_cache_dis or not self.model.training: + self.model._cache = None + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/paddlespeech/t2s/training/optimizer.py b/paddlespeech/t2s/training/optimizer.py index 64274d53..3342cae5 100644 --- a/paddlespeech/t2s/training/optimizer.py +++ b/paddlespeech/t2s/training/optimizer.py @@ -14,6 +14,14 @@ import paddle from paddle import nn +scheduler_classes = dict( + ReduceOnPlateau=paddle.optimizer.lr.ReduceOnPlateau, + lambda_decay=paddle.optimizer.lr.LambdaDecay, + step_decay=paddle.optimizer.lr.StepDecay, + multistep_decay=paddle.optimizer.lr.MultiStepDecay, + exponential_decay=paddle.optimizer.lr.ExponentialDecay, + CosineAnnealingDecay=paddle.optimizer.lr.CosineAnnealingDecay, ) + optim_classes = dict( adadelta=paddle.optimizer.Adadelta, adagrad=paddle.optimizer.Adagrad, diff --git a/paddlespeech/utils/__init__.py b/paddlespeech/utils/__init__.py new file mode 100644 index 00000000..185a92b8 --- /dev/null +++ b/paddlespeech/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/utils/dynamic_import.py b/paddlespeech/utils/dynamic_import.py new file mode 100644 index 00000000..99f93356 --- /dev/null +++ b/paddlespeech/utils/dynamic_import.py @@ -0,0 +1,38 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from espnet(https://github.com/espnet/espnet) +import importlib + +__all__ = ["dynamic_import"] + + +def dynamic_import(import_path, alias=dict()): + """dynamic import module and class + + :param str import_path: syntax 'module_name:class_name' + e.g., 'paddlespeech.s2t.models.u2:U2Model' + :param dict alias: shortcut for registered class + :return: imported class + """ + if import_path not in alias and ":" not in import_path: + raise ValueError( + "import_path should be one of {} or " + 'include ":", e.g. "paddlespeech.s2t.models.u2:U2Model" : ' + "{}".format(set(alias), import_path)) + if ":" not in import_path: + import_path = alias[import_path] + + module_name, objname = import_path.split(":") + m = importlib.import_module(module_name) + return getattr(m, objname)