diff --git a/examples/aishell3/tts3/conf/default.yaml b/examples/aishell3/tts3/conf/default.yaml index 90816e7d7..1dd782dbe 100644 --- a/examples/aishell3/tts3/conf/default.yaml +++ b/examples/aishell3/tts3/conf/default.yaml @@ -24,7 +24,7 @@ f0max: 400 # Minimum f0 for pitch extraction. # DATA SETTING # ########################################################### batch_size: 64 -num_workers: 4 +num_workers: 2 ########################################################### diff --git a/examples/aishell3/tts3/run.sh b/examples/aishell3/tts3/run.sh index 656710763..95e4d38fe 100755 --- a/examples/aishell3/tts3/run.sh +++ b/examples/aishell3/tts3/run.sh @@ -7,7 +7,6 @@ gpus=0,1 stage=0 stop_stage=100 - conf_path=conf/default.yaml train_output_path=exp/default ckpt_name=snapshot_iter_482.pdz diff --git a/examples/aishell3/vc0/local/preprocess.sh b/examples/aishell3/vc0/local/preprocess.sh index eeb1923f1..5bf880667 100755 --- a/examples/aishell3/vc0/local/preprocess.sh +++ b/examples/aishell3/vc0/local/preprocess.sh @@ -9,7 +9,7 @@ alignment=$3 ge2e_ckpt_path=$4 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then - python3 ${BIN_DIR}/../../ge2e/inference.py \ + python3 ${MAIN_ROOT}/paddlespeech/vector/exps/ge2e/inference.py \ --input=${input}/wav \ --output=${preprocess_path}/embed \ --checkpoint_path=${ge2e_ckpt_path} diff --git a/examples/aishell3/vc1/README.md b/examples/aishell3/vc1/README.md new file mode 100644 index 000000000..8c0aec3af --- /dev/null +++ b/examples/aishell3/vc1/README.md @@ -0,0 +1,89 @@ +# FastSpeech2 + AISHELL-3 Voice Cloning +This example contains code used to train a [Tacotron2 ](https://arxiv.org/abs/1712.05884) model with [AISHELL-3](http://www.aishelltech.com/aishell_3). The trained model can be used in Voice Cloning Task, We refer to the model structure of [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf) . The general steps are as follows: +1. Speaker Encoder: We use a Speaker Verification to train a speaker encoder. Datasets used in this task are different from those used in Tacotron2, because the transcriptions are not needed, we use more datasets, refer to [ge2e](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/ge2e). +2. Synthesizer: Then, we use the trained speaker encoder to generate utterance embedding for each sentence in AISHELL-3. This embedding is a extra input of Tacotron2 which will be concated with encoder outputs. +3. Vocoder: We use WaveFlow as the neural Vocoder, refer to [waveflow](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0). + +## Get Started +Assume the path to the dataset is `~/datasets/data_aishell3`. +Assume the path to the MFA result of AISHELL-3 is `./alignment`. +Assume the path to the pretrained ge2e model is `ge2e_ckpt_path=./ge2e_ckpt_0.3/step-3000000` +Run the command below to +1. **source path**. +2. preprocess the dataset, +3. train the model. +4. start a voice cloning inference. +```bash +./run.sh +``` +### Preprocess the dataset +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${input} ${preprocess_path} ${alignment} ${ge2e_ckpt_path} +``` +#### generate utterance embedding + Use pretrained GE2E (speaker encoder) to generate utterance embedding for each sentence in AISHELL-3, which has the same file structure with wav files and the format is `.npy`. + +```bash +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../ge2e/inference.py \ + --input=${input} \ + --output=${preprocess_path}/embed \ + --ngpu=1 \ + --checkpoint_path=${ge2e_ckpt_path} +fi +``` + +The computing time of utterance embedding can be x hours. +#### process wav +There are silence in the edge of AISHELL-3's wavs, and the audio amplitude is very small, so, we need to remove the silence and normalize the audio. You can the silence remove method based on volume or energy, but the effect is not very good, We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get the alignment of text and speech, then utilize the alignment results to remove the silence. + +We use Montreal Force Aligner 1.0. The label in aishell3 include pinyin,so the lexicon we provided to MFA is pinyin rather than Chinese characters. And the prosody marks(`$` and `%`) need to be removed. You shoud preprocess the dataset into the format which MFA needs, the texts have the same name with wavs and have the suffix `.lab`. + +We use [lexicon.txt](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/t2s/exps/voice_cloning/tacotron2_ge2e/lexicon.txt) as the lexicon. + +You can download the alignment results from here [alignment_aishell3.tar.gz](https://paddlespeech.bj.bcebos.com/Parakeet/alignment_aishell3.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) (use MFA1.x now) of our repo. + +```bash +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Process wav ..." + python3 ${BIN_DIR}/process_wav.py \ + --input=${input}/wav \ + --output=${preprocess_path}/normalized_wav \ + --alignment=${alignment} +fi +``` + +#### preprocess transcription +We revert the transcription into `phones` and `tones`. It is worth noting that our processing here is different from that used for MFA, we separated the tones. This is a processing method, of course, you can only segment initials and vowels. + +```bash +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + python3 ${BIN_DIR}/preprocess_transcription.py \ + --input=${input} \ + --output=${preprocess_path} +fi +``` +The default input is `~/datasets/data_aishell3/train`,which contains `label_train-set.txt`, the processed results are `metadata.yaml` and `metadata.pickle`. the former is a text format for easy viewing, and the latter is a binary format for direct reading. +#### extract mel +```python +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + python3 ${BIN_DIR}/extract_mel.py \ + --input=${preprocess_path}/normalized_wav \ + --output=${preprocess_path}/mel +fi +``` + +### Train the model +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${preprocess_path} ${train_output_path} +``` + +Our model remve stop token prediction in Tacotron2, because of the problem of extremely unbalanced proportion of positive and negative samples of stop token prediction, and it's very sensitive to the clip of audio silence. We use the last symbol from the highest point of attention to the encoder side as the termination condition. + +In addition, in order to accelerate the convergence of the model, we add `guided attention loss` to induce the alignment between encoder and decoder to show diagonal lines faster. +### Infernece +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${ge2e_params_path} ${tacotron2_params_path} ${waveflow_params_path} ${vc_input} ${vc_output} +``` +## Pretrained Model +[tacotron2_aishell3_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_aishell3_ckpt_0.3.zip). diff --git a/examples/aishell3/vc1/conf/default.yaml b/examples/aishell3/vc1/conf/default.yaml new file mode 100644 index 000000000..bdd2a765e --- /dev/null +++ b/examples/aishell3/vc1/conf/default.yaml @@ -0,0 +1,105 @@ +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### + +fs: 24000 # sr +n_fft: 2048 # FFT size. +n_shift: 300 # Hop size. +win_length: 1200 # Window length. + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + +# Only used for feats_type != raw + +fmin: 80 # Minimum frequency of Mel basis. +fmax: 7600 # Maximum frequency of Mel basis. +n_mels: 80 # The number of mel basis. + +# Only used for the model using pitch features (e.g. FastSpeech2) +f0min: 80 # Maximum f0 for pitch extraction. +f0max: 400 # Minimum f0 for pitch extraction. + + +########################################################### +# DATA SETTING # +########################################################### +batch_size: 64 +num_workers: 2 + + +########################################################### +# MODEL SETTING # +########################################################### +model: + adim: 384 # attention dimension + aheads: 2 # number of attention heads + elayers: 4 # number of encoder layers + eunits: 1536 # number of encoder ff units + dlayers: 4 # number of decoder layers + dunits: 1536 # number of decoder ff units + positionwise_layer_type: conv1d # type of position-wise layer + positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer + duration_predictor_layers: 2 # number of layers of duration predictor + duration_predictor_chans: 256 # number of channels of duration predictor + duration_predictor_kernel_size: 3 # filter size of duration predictor + postnet_layers: 5 # number of layers of postnset + postnet_filts: 5 # filter size of conv layers in postnet + postnet_chans: 256 # number of channels of conv layers in postnet + use_masking: True # whether to apply masking for padded part in loss calculation + use_scaled_pos_enc: True # whether to use scaled positional encoding + encoder_normalize_before: True # whether to perform layer normalization before the input + decoder_normalize_before: True # whether to perform layer normalization before the input + reduction_factor: 1 # reduction factor + init_type: xavier_uniform # initialization type + init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding + init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding + transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer + transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding + transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer + transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer + transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding + transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer + pitch_predictor_layers: 5 # number of conv layers in pitch predictor + pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor + pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor + pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor + pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch + pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch + stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder + energy_predictor_layers: 2 # number of conv layers in energy predictor + energy_predictor_chans: 256 # number of channels of conv layers in energy predictor + energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor + energy_predictor_dropout: 0.5 # dropout rate in energy predictor + energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy + energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy + stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder + spk_embed_dim: 256 # speaker embedding dimension + spk_embed_integration_type: concat # speaker embedding integration type + + + +########################################################### +# UPDATER SETTING # +########################################################### +updater: + use_masking: True # whether to apply masking for padded part in loss calculation + + +########################################################### +# OPTIMIZER SETTING # +########################################################### +optimizer: + optim: adam # optimizer type + learning_rate: 0.001 # learning rate + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 200 +num_snapshots: 5 + + +########################################################### +# OTHER SETTING # +########################################################### +seed: 10086 diff --git a/examples/aishell3/vc1/local/preprocess.sh b/examples/aishell3/vc1/local/preprocess.sh new file mode 100755 index 000000000..5f939a1a8 --- /dev/null +++ b/examples/aishell3/vc1/local/preprocess.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 +ge2e_ckpt_path=$2 + +# gen speaker embedding +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${MAIN_ROOT}/paddlespeech/vector/exps/ge2e/inference.py \ + --input=~/datasets/data_aishell3/train/wav/ \ + --output=dump/embed \ + --checkpoint_path=${ge2e_ckpt_path} +fi + +# copy from tts3/preprocess +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./aishell3_alignment_tone \ + --output durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=aishell3 \ + --rootdir=~/datasets/data_aishell3/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True \ + --spk_emb_dir=dump/embed +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="speech" + + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="pitch" + + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="energy" +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # normalize and covert phone/speaker to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt +fi diff --git a/examples/aishell3/vc1/local/synthesize.sh b/examples/aishell3/vc1/local/synthesize.sh new file mode 100755 index 000000000..35478c784 --- /dev/null +++ b/examples/aishell3/vc1/local/synthesize.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/synthesize.py \ + --fastspeech2-config=${config_path} \ + --fastspeech2-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --fastspeech2-stat=dump/train/speech_stats.npy \ + --pwg-config=pwg_aishell3_ckpt_0.5/default.yaml \ + --pwg-checkpoint=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --pwg-stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --test-metadata=dump/test/norm/metadata.jsonl \ + --output-dir=${train_output_path}/test \ + --phones-dict=dump/phone_id_map.txt \ + --voice-cloning=True diff --git a/examples/aishell3/vc1/local/train.sh b/examples/aishell3/vc1/local/train.sh new file mode 100755 index 000000000..c775fcadc --- /dev/null +++ b/examples/aishell3/vc1/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=2 \ + --phones-dict=dump/phone_id_map.txt \ + --voice-cloning=True \ No newline at end of file diff --git a/examples/aishell3/vc1/local/voice_cloning.sh b/examples/aishell3/vc1/local/voice_cloning.sh new file mode 100755 index 000000000..55bdd761e --- /dev/null +++ b/examples/aishell3/vc1/local/voice_cloning.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +ge2e_params_path=$4 +ref_audio_dir=$5 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/voice_cloning.py \ + --fastspeech2-config=${config_path} \ + --fastspeech2-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \ + --fastspeech2-stat=dump/train/speech_stats.npy \ + --pwg-config=pwg_aishell3_ckpt_0.5/default.yaml \ + --pwg-checkpoint=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --pwg-stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --ge2e_params_path=${ge2e_params_path} \ + --text="凯莫瑞安联合体的经济崩溃迫在眉睫。" \ + --input-dir=${ref_audio_dir} \ + --output-dir=${train_output_path}/vc_syn \ + --phones-dict=dump/phone_id_map.txt diff --git a/examples/aishell3/vc1/path.sh b/examples/aishell3/vc1/path.sh new file mode 100755 index 000000000..fb7e8411c --- /dev/null +++ b/examples/aishell3/vc1/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=fastspeech2 +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} diff --git a/examples/aishell3/vc1/run.sh b/examples/aishell3/vc1/run.sh new file mode 100755 index 000000000..4eae1bdd8 --- /dev/null +++ b/examples/aishell3/vc1/run.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_482.pdz +ref_audio_dir=ref_audio + +# not include ".pdparams" here +ge2e_ckpt_path=./ge2e_ckpt_0.3/step-3000000 + +# include ".pdparams" here +ge2e_params_path=${ge2e_ckpt_path}.pdparams + +# with the following command, you can choice the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${conf_path} ${ge2e_ckpt_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ge2e_params_path} ${ref_audio_dir} || exit -1 +fi diff --git a/examples/csmsc/voc1/conf/default.yaml b/examples/csmsc/voc1/conf/default.yaml index 5628b7f7c..1363b454f 100644 --- a/examples/csmsc/voc1/conf/default.yaml +++ b/examples/csmsc/voc1/conf/default.yaml @@ -80,7 +80,7 @@ lambda_adv: 4.0 # Loss balancing coefficient. batch_size: 8 # Batch size. batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size. pin_memory: true # Whether to pin memory in Pytorch DataLoader. -num_workers: 4 # Number of workers in Pytorch DataLoader. +num_workers: 2 # Number of workers in Pytorch DataLoader. remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. diff --git a/examples/other/ge2e/path.sh b/examples/other/ge2e/path.sh index b4f779859..24305ef78 100755 --- a/examples/other/ge2e/path.sh +++ b/examples/other/ge2e/path.sh @@ -10,4 +10,4 @@ export PYTHONIOENCODING=UTF-8 export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} MODEL=ge2e -export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} +export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL} diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 5ed9aa7af..9470f9234 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -100,7 +100,7 @@ def fastspeech2_single_spk_batch_fn(examples): def fastspeech2_multi_spk_batch_fn(examples): - # fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"] + # fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spk_emb"] text = [np.array(item["text"], dtype=np.int64) for item in examples] speech = [np.array(item["speech"], dtype=np.float32) for item in examples] pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples] @@ -114,7 +114,6 @@ def fastspeech2_multi_spk_batch_fn(examples): speech_lengths = [ np.array(item["speech_lengths"], dtype=np.int64) for item in examples ] - spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples] text = batch_sequences(text) pitch = batch_sequences(pitch) @@ -130,7 +129,6 @@ def fastspeech2_multi_spk_batch_fn(examples): energy = paddle.to_tensor(energy) text_lengths = paddle.to_tensor(text_lengths) speech_lengths = paddle.to_tensor(speech_lengths) - spk_id = paddle.to_tensor(spk_id) batch = { "text": text, @@ -139,9 +137,20 @@ def fastspeech2_multi_spk_batch_fn(examples): "speech": speech, "speech_lengths": speech_lengths, "pitch": pitch, - "energy": energy, - "spk_id": spk_id + "energy": energy } + # spk_emb has a higher priority than spk_id + if "spk_emb" in examples[0]: + spk_emb = [ + np.array(item["spk_emb"], dtype=np.float32) for item in examples + ] + spk_emb = batch_sequences(spk_emb) + spk_emb = paddle.to_tensor(spk_emb) + batch["spk_emb"] = spk_emb + elif "spk_id" in examples[0]: + spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples] + spk_id = paddle.to_tensor(spk_id) + batch["spk_id"] = spk_id return batch diff --git a/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e.py b/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e.py index ee9fe0579..1839415e9 100644 --- a/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e.py +++ b/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e.py @@ -46,14 +46,14 @@ def evaluate(args, fastspeech2_config, pwg_config): print("vocab_size:", vocab_size) with open(args.speaker_dict, 'rt') as f: spk_id = [line.strip().split() for line in f.readlines()] - num_speakers = len(spk_id) - print("num_speakers:", num_speakers) + spk_num = len(spk_id) + print("spk_num:", spk_num) odim = fastspeech2_config.n_mels model = FastSpeech2( idim=vocab_size, odim=odim, - num_speakers=num_speakers, + spk_num=spk_num, **fastspeech2_config["model"]) model.set_state_dict( diff --git a/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e_en.py b/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e_en.py index b5d0ce171..095d20821 100644 --- a/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e_en.py +++ b/paddlespeech/t2s/exps/fastspeech2/multi_spk_synthesize_e2e_en.py @@ -51,14 +51,14 @@ def evaluate(args, fastspeech2_config, pwg_config): print("vocab_size:", vocab_size) with open(args.speaker_dict, 'rt') as f: spk_id = [line.strip().split() for line in f.readlines()] - num_speakers = len(spk_id) - print("num_speakers:", num_speakers) + spk_num = len(spk_id) + print("spk_num:", spk_num) odim = fastspeech2_config.n_mels model = FastSpeech2( idim=vocab_size, odim=odim, - num_speakers=num_speakers, + spk_num=spk_num, **fastspeech2_config["model"]) model.set_state_dict( diff --git a/paddlespeech/t2s/exps/fastspeech2/normalize.py b/paddlespeech/t2s/exps/fastspeech2/normalize.py index 7283f6b43..8ec20ebf0 100644 --- a/paddlespeech/t2s/exps/fastspeech2/normalize.py +++ b/paddlespeech/t2s/exps/fastspeech2/normalize.py @@ -167,6 +167,10 @@ def main(): "pitch": str(pitch_path), "energy": str(energy_path) } + # add spk_emb for voice cloning + if "spk_emb" in item: + record["spk_emb"] = str(item["spk_emb"]) + output_metadata.append(record) output_metadata.sort(key=itemgetter('utt_id')) output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" diff --git a/paddlespeech/t2s/exps/fastspeech2/preprocess.py b/paddlespeech/t2s/exps/fastspeech2/preprocess.py index 3702ecd31..b874b3a70 100644 --- a/paddlespeech/t2s/exps/fastspeech2/preprocess.py +++ b/paddlespeech/t2s/exps/fastspeech2/preprocess.py @@ -44,7 +44,8 @@ def process_sentence(config: Dict[str, Any], mel_extractor=None, pitch_extractor=None, energy_extractor=None, - cut_sil: bool=True): + cut_sil: bool=True, + spk_emb_dir: Path=None): utt_id = fp.stem # for vctk if utt_id.endswith("_mic2"): @@ -116,6 +117,14 @@ def process_sentence(config: Dict[str, Any], "energy": str(energy_path), "speaker": speaker } + if spk_emb_dir: + if speaker in os.listdir(spk_emb_dir): + embed_name = utt_id + ".npy" + embed_path = spk_emb_dir / speaker / embed_name + if embed_path.is_file(): + record["spk_emb"] = str(embed_path) + else: + return None return record @@ -127,13 +136,14 @@ def process_sentences(config, pitch_extractor=None, energy_extractor=None, nprocs: int=1, - cut_sil: bool=True): + cut_sil: bool=True, + spk_emb_dir: Path=None): if nprocs == 1: results = [] for fp in fps: record = process_sentence(config, fp, sentences, output_dir, mel_extractor, pitch_extractor, - energy_extractor, cut_sil) + energy_extractor, cut_sil, spk_emb_dir) if record: results.append(record) else: @@ -144,7 +154,7 @@ def process_sentences(config, future = pool.submit(process_sentence, config, fp, sentences, output_dir, mel_extractor, pitch_extractor, energy_extractor, - cut_sil) + cut_sil, spk_emb_dir) future.add_done_callback(lambda p: progress.update()) futures.append(future) @@ -202,6 +212,11 @@ def main(): default=True, help="whether cut sil in the edge of audio") + parser.add_argument( + "--spk_emb_dir", + default=None, + type=str, + help="directory to speaker embedding files.") args = parser.parse_args() rootdir = Path(args.rootdir).expanduser() @@ -211,6 +226,11 @@ def main(): dumpdir.mkdir(parents=True, exist_ok=True) dur_file = Path(args.dur_file).expanduser() + if args.spk_emb_dir: + spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve() + else: + spk_emb_dir = None + assert rootdir.is_dir() assert dur_file.is_file() @@ -251,6 +271,7 @@ def main(): test_wav_files += wav_files[-sub_num_dev:] else: train_wav_files += wav_files + elif args.dataset == "ljspeech": wav_files = sorted(list((rootdir / "wavs").rglob("*.wav"))) # split data into 3 sections @@ -317,7 +338,8 @@ def main(): pitch_extractor, energy_extractor, nprocs=args.num_cpu, - cut_sil=args.cut_sil) + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) if dev_wav_files: process_sentences( config, @@ -327,7 +349,8 @@ def main(): mel_extractor, pitch_extractor, energy_extractor, - cut_sil=args.cut_sil) + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) if test_wav_files: process_sentences( config, @@ -338,7 +361,8 @@ def main(): pitch_extractor, energy_extractor, nprocs=args.num_cpu, - cut_sil=args.cut_sil) + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) if __name__ == "__main__": diff --git a/paddlespeech/t2s/exps/fastspeech2/synthesize.py b/paddlespeech/t2s/exps/fastspeech2/synthesize.py index 207275f90..249845e4d 100644 --- a/paddlespeech/t2s/exps/fastspeech2/synthesize.py +++ b/paddlespeech/t2s/exps/fastspeech2/synthesize.py @@ -40,16 +40,19 @@ def evaluate(args, fastspeech2_config, pwg_config): fields = ["utt_id", "text"] + spk_num = None if args.speaker_dict is not None: print("multiple speaker fastspeech2!") with open(args.speaker_dict, 'rt') as f: spk_id = [line.strip().split() for line in f.readlines()] - num_speakers = len(spk_id) + spk_num = len(spk_id) fields += ["spk_id"] + elif args.voice_cloning: + print("voice cloning!") + fields += ["spk_emb"] else: print("single speaker fastspeech2!") - num_speakers = None - print("num_speakers:", num_speakers) + print("spk_num:", spk_num) test_dataset = DataTable(data=test_metadata, fields=fields) @@ -62,7 +65,7 @@ def evaluate(args, fastspeech2_config, pwg_config): model = FastSpeech2( idim=vocab_size, odim=odim, - num_speakers=num_speakers, + spk_num=spk_num, **fastspeech2_config["model"]) model.set_state_dict( @@ -96,12 +99,15 @@ def evaluate(args, fastspeech2_config, pwg_config): for datum in test_dataset: utt_id = datum["utt_id"] text = paddle.to_tensor(datum["text"]) - if "spk_id" in datum: + spk_emb = None + spk_id = None + if args.voice_cloning and "spk_emb" in datum: + spk_emb = paddle.to_tensor(np.load(datum["spk_emb"])) + elif "spk_id" in datum: spk_id = paddle.to_tensor(datum["spk_id"]) - else: - spk_id = None with paddle.no_grad(): - wav = pwg_inference(fastspeech2_inference(text, spk_id=spk_id)) + wav = pwg_inference( + fastspeech2_inference(text, spk_id=spk_id, spk_emb=spk_emb)) sf.write( str(output_dir / (utt_id + ".wav")), wav.numpy(), @@ -142,6 +148,15 @@ def main(): type=str, default=None, help="speaker id map file for multiple speaker model.") + + def str2bool(str): + return True if str.lower() == 'true' else False + + parser.add_argument( + "--voice-cloning", + type=str2bool, + default=False, + help="whether training voice cloning model.") parser.add_argument("--test-metadata", type=str, help="test metadata.") parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument( diff --git a/paddlespeech/t2s/exps/fastspeech2/train.py b/paddlespeech/t2s/exps/fastspeech2/train.py index 38ac2fe3f..fafded6fc 100644 --- a/paddlespeech/t2s/exps/fastspeech2/train.py +++ b/paddlespeech/t2s/exps/fastspeech2/train.py @@ -61,18 +61,24 @@ def train_sp(args, config): "text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy" ] + converters = {"speech": np.load, "pitch": np.load, "energy": np.load} + spk_num = None if args.speaker_dict is not None: print("multiple speaker fastspeech2!") collate_fn = fastspeech2_multi_spk_batch_fn with open(args.speaker_dict, 'rt') as f: spk_id = [line.strip().split() for line in f.readlines()] - num_speakers = len(spk_id) + spk_num = len(spk_id) fields += ["spk_id"] + elif args.voice_cloning: + print("Training voice cloning!") + collate_fn = fastspeech2_multi_spk_batch_fn + fields += ["spk_emb"] + converters["spk_emb"] = np.load else: print("single speaker fastspeech2!") collate_fn = fastspeech2_single_spk_batch_fn - num_speakers = None - print("num_speakers:", num_speakers) + print("spk_num:", spk_num) # dataloader has been too verbose logging.getLogger("DataLoader").disabled = True @@ -83,17 +89,13 @@ def train_sp(args, config): train_dataset = DataTable( data=train_metadata, fields=fields, - converters={"speech": np.load, - "pitch": np.load, - "energy": np.load}, ) + converters=converters, ) with jsonlines.open(args.dev_metadata, 'r') as reader: dev_metadata = list(reader) dev_dataset = DataTable( data=dev_metadata, fields=fields, - converters={"speech": np.load, - "pitch": np.load, - "energy": np.load}, ) + converters=converters, ) # collate function and dataloader @@ -127,10 +129,7 @@ def train_sp(args, config): odim = config.n_mels model = FastSpeech2( - idim=vocab_size, - odim=odim, - num_speakers=num_speakers, - **config["model"]) + idim=vocab_size, odim=odim, spk_num=spk_num, **config["model"]) if world_size > 1: model = DataParallel(model) print("model done!") @@ -184,6 +183,15 @@ def main(): default=None, help="speaker id map file for multiple speaker model.") + def str2bool(str): + return True if str.lower() == 'true' else False + + parser.add_argument( + "--voice-cloning", + type=str2bool, + default=False, + help="whether training voice cloning model.") + args = parser.parse_args() with open(args.config) as f: diff --git a/paddlespeech/t2s/exps/fastspeech2/voice_cloning.py b/paddlespeech/t2s/exps/fastspeech2/voice_cloning.py new file mode 100644 index 000000000..9fbd49641 --- /dev/null +++ b/paddlespeech/t2s/exps/fastspeech2/voice_cloning.py @@ -0,0 +1,208 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from pathlib import Path + +import numpy as np +import paddle +import soundfile as sf +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2 +from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Inference +from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator +from paddlespeech.t2s.models.parallel_wavegan import PWGInference +from paddlespeech.t2s.modules.normalizer import ZScore +from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor +from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder + + +def voice_cloning(args, fastspeech2_config, pwg_config): + # speaker encoder + p = SpeakerVerificationPreprocessor( + sampling_rate=16000, + audio_norm_target_dBFS=-30, + vad_window_length=30, + vad_moving_average_width=8, + vad_max_silence_length=6, + mel_window_length=25, + mel_window_step=10, + n_mels=40, + partial_n_frames=160, + min_pad_coverage=0.75, + partial_overlap_ratio=0.5) + print("Audio Processor Done!") + + speaker_encoder = LSTMSpeakerEncoder( + n_mels=40, num_layers=3, hidden_size=256, output_size=256) + speaker_encoder.set_state_dict(paddle.load(args.ge2e_params_path)) + speaker_encoder.eval() + print("GE2E Done!") + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + odim = fastspeech2_config.n_mels + model = FastSpeech2( + idim=vocab_size, odim=odim, **fastspeech2_config["model"]) + + model.set_state_dict( + paddle.load(args.fastspeech2_checkpoint)["main_params"]) + model.eval() + + vocoder = PWGGenerator(**pwg_config["generator_params"]) + vocoder.set_state_dict(paddle.load(args.pwg_checkpoint)["generator_params"]) + vocoder.remove_weight_norm() + vocoder.eval() + print("model done!") + + frontend = Frontend(phone_vocab_path=args.phones_dict) + print("frontend done!") + + stat = np.load(args.fastspeech2_stat) + mu, std = stat + mu = paddle.to_tensor(mu) + std = paddle.to_tensor(std) + fastspeech2_normalizer = ZScore(mu, std) + + stat = np.load(args.pwg_stat) + mu, std = stat + mu = paddle.to_tensor(mu) + std = paddle.to_tensor(std) + pwg_normalizer = ZScore(mu, std) + + fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model) + fastspeech2_inference.eval() + pwg_inference = PWGInference(pwg_normalizer, vocoder) + pwg_inference.eval() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + input_dir = Path(args.input_dir) + + sentence = args.text + + input_ids = frontend.get_input_ids(sentence, merge_sentences=True) + phone_ids = input_ids["phone_ids"][0] + + for name in os.listdir(input_dir): + utt_id = name.split(".")[0] + ref_audio_path = input_dir / name + mel_sequences = p.extract_mel_partials(p.preprocess_wav(ref_audio_path)) + # print("mel_sequences: ", mel_sequences.shape) + with paddle.no_grad(): + spk_emb = speaker_encoder.embed_utterance( + paddle.to_tensor(mel_sequences)) + # print("spk_emb shape: ", spk_emb.shape) + + with paddle.no_grad(): + wav = pwg_inference( + fastspeech2_inference(phone_ids, spk_emb=spk_emb)) + + sf.write( + str(output_dir / (utt_id + ".wav")), + wav.numpy(), + samplerate=fastspeech2_config.fs) + print(f"{utt_id} done!") + # Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb + random_spk_emb = np.random.rand(256) * 0.2 + random_spk_emb = paddle.to_tensor(random_spk_emb) + utt_id = "random_spk_emb" + with paddle.no_grad(): + wav = pwg_inference(fastspeech2_inference(phone_ids, spk_emb=spk_emb)) + sf.write( + str(output_dir / (utt_id + ".wav")), + wav.numpy(), + samplerate=fastspeech2_config.fs) + print(f"{utt_id} done!") + + +def main(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "--fastspeech2-config", type=str, help="fastspeech2 config file.") + parser.add_argument( + "--fastspeech2-checkpoint", + type=str, + help="fastspeech2 checkpoint to load.") + parser.add_argument( + "--fastspeech2-stat", + type=str, + help="mean and standard deviation used to normalize spectrogram when training fastspeech2." + ) + parser.add_argument( + "--pwg-config", type=str, help="parallel wavegan config file.") + parser.add_argument( + "--pwg-checkpoint", + type=str, + help="parallel wavegan generator parameters to load.") + parser.add_argument( + "--pwg-stat", + type=str, + help="mean and standard deviation used to normalize spectrogram when training parallel wavegan." + ) + parser.add_argument( + "--phones-dict", + type=str, + default="phone_id_map.txt", + help="phone vocabulary file.") + parser.add_argument( + "--text", + type=str, + default="每当你觉得,想要批评什么人的时候,你切要记着,这个世界上的人,并非都具备你禀有的条件。", + help="text to synthesize, a line") + + parser.add_argument( + "--ge2e_params_path", type=str, help="ge2e params path.") + + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.") + + parser.add_argument( + "--input-dir", + type=str, + help="input dir of *.wav, the sample rate will be resample to 16k.") + parser.add_argument("--output-dir", type=str, help="output dir.") + + args = parser.parse_args() + + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + with open(args.fastspeech2_config) as f: + fastspeech2_config = CfgNode(yaml.safe_load(f)) + with open(args.pwg_config) as f: + pwg_config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(fastspeech2_config) + print(pwg_config) + + voice_cloning(args, fastspeech2_config, pwg_config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/voice_cloning/tacotron2_ge2e/voice_cloning.py b/paddlespeech/t2s/exps/voice_cloning/tacotron2_ge2e/voice_cloning.py index 2f005e723..4e6b8d362 100644 --- a/paddlespeech/t2s/exps/voice_cloning/tacotron2_ge2e/voice_cloning.py +++ b/paddlespeech/t2s/exps/voice_cloning/tacotron2_ge2e/voice_cloning.py @@ -20,14 +20,14 @@ import paddle import soundfile as sf from matplotlib import pyplot as plt -from paddlespeech.t2s.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor from paddlespeech.t2s.exps.voice_cloning.tacotron2_ge2e.aishell3 import voc_phones from paddlespeech.t2s.exps.voice_cloning.tacotron2_ge2e.aishell3 import voc_tones from paddlespeech.t2s.exps.voice_cloning.tacotron2_ge2e.chinese_g2p import convert_sentence -from paddlespeech.t2s.models.lstm_speaker_encoder import LSTMSpeakerEncoder from paddlespeech.t2s.models.tacotron2 import Tacotron2 from paddlespeech.t2s.models.waveflow import ConditionalWaveFlow from paddlespeech.t2s.utils import display +from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor +from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder def voice_cloning(args): diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index e650e2550..cf957978e 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -98,7 +98,7 @@ class FastSpeech2(nn.Layer): pitch_embed_dropout: float=0.5, stop_gradient_from_pitch_predictor: bool=False, # spk emb - num_speakers: int=None, + spk_num: int=None, spk_embed_dim: int=None, spk_embed_integration_type: str="add", # tone emb @@ -148,9 +148,9 @@ class FastSpeech2(nn.Layer): # initialize parameters initialize(self, init_type) - if self.spk_embed_dim is not None: + if spk_num and self.spk_embed_dim: self.spk_embedding_table = nn.Embedding( - num_embeddings=num_speakers, + num_embeddings=spk_num, embedding_dim=self.spk_embed_dim, padding_idx=self.padding_idx) @@ -299,7 +299,7 @@ class FastSpeech2(nn.Layer): pitch: paddle.Tensor, energy: paddle.Tensor, tone_id: paddle.Tensor=None, - spembs: paddle.Tensor=None, + spk_emb: paddle.Tensor=None, spk_id: paddle.Tensor=None ) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]: """Calculate forward propagation. @@ -322,7 +322,7 @@ class FastSpeech2(nn.Layer): Batch of padded token-averaged energy (B, Tmax, 1). tone_id : Tensor, optional(int64) Batch of padded tone ids (B, Tmax). - spembs : Tensor, optional + spk_emb : Tensor, optional Batch of speaker embeddings (B, spk_embed_dim). spk_id : Tnesor, optional(int64) Batch of speaker ids (B,) @@ -366,7 +366,7 @@ class FastSpeech2(nn.Layer): ps, es, is_inference=False, - spembs=spembs, + spk_emb=spk_emb, spk_id=spk_id, tone_id=tone_id) # modify mod part of groundtruth @@ -387,7 +387,7 @@ class FastSpeech2(nn.Layer): es: paddle.Tensor=None, is_inference: bool=False, alpha: float=1.0, - spembs=None, + spk_emb=None, spk_id=None, tone_id=None) -> Sequence[paddle.Tensor]: # forward encoder @@ -397,11 +397,12 @@ class FastSpeech2(nn.Layer): # integrate speaker embedding if self.spk_embed_dim is not None: - if spembs is not None: - hs = self._integrate_with_spk_embed(hs, spembs) + # spk_emb has a higher priority than spk_id + if spk_emb is not None: + hs = self._integrate_with_spk_embed(hs, spk_emb) elif spk_id is not None: - spembs = self.spk_embedding_table(spk_id) - hs = self._integrate_with_spk_embed(hs, spembs) + spk_emb = self.spk_embedding_table(spk_id) + hs = self._integrate_with_spk_embed(hs, spk_emb) # integrate tone embedding if self.tone_embed_dim is not None: @@ -489,7 +490,7 @@ class FastSpeech2(nn.Layer): energy: paddle.Tensor=None, alpha: float=1.0, use_teacher_forcing: bool=False, - spembs=None, + spk_emb=None, spk_id=None, tone_id=None, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: @@ -512,7 +513,7 @@ class FastSpeech2(nn.Layer): use_teacher_forcing : bool, optional Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. - spembs : Tensor, optional + spk_emb : Tensor, optional peaker embedding vector (spk_embed_dim,). spk_id : Tensor, optional(int64) Batch of padded spk ids (1,). @@ -527,7 +528,6 @@ class FastSpeech2(nn.Layer): # input of embedding must be int64 x = paddle.cast(text, 'int64') y = speech - spemb = spembs d, p, e = durations, pitch, energy # setup batch axis ilens = paddle.shape(x)[0] @@ -537,8 +537,8 @@ class FastSpeech2(nn.Layer): if y is not None: ys = y.unsqueeze(0) - if spemb is not None: - spembs = spemb.unsqueeze(0) + if spk_emb is not None: + spk_emb = spk_emb.unsqueeze(0) if tone_id is not None: tone_id = tone_id.unsqueeze(0) @@ -548,7 +548,7 @@ class FastSpeech2(nn.Layer): ds = d.unsqueeze(0) if d is not None else None ps = p.unsqueeze(0) if p is not None else None es = e.unsqueeze(0) if e is not None else None - # ds, ps, es = , p.unsqueeze(0), e.unsqueeze(0) + # (1, L, odim) _, outs, d_outs, p_outs, e_outs = self._forward( xs, @@ -557,7 +557,7 @@ class FastSpeech2(nn.Layer): ds=ds, ps=ps, es=es, - spembs=spembs, + spk_emb=spk_emb, spk_id=spk_id, tone_id=tone_id, is_inference=True) @@ -569,19 +569,19 @@ class FastSpeech2(nn.Layer): ys, is_inference=True, alpha=alpha, - spembs=spembs, + spk_emb=spk_emb, spk_id=spk_id, tone_id=tone_id) return outs[0], d_outs[0], p_outs[0], e_outs[0] - def _integrate_with_spk_embed(self, hs, spembs): + def _integrate_with_spk_embed(self, hs, spk_emb): """Integrate speaker embedding with hidden states. Parameters ---------- hs : Tensor Batch of hidden state sequences (B, Tmax, adim). - spembs : Tensor + spk_emb : Tensor Batch of speaker embeddings (B, spk_embed_dim). Returns @@ -591,13 +591,13 @@ class FastSpeech2(nn.Layer): """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states - spembs = self.spk_projection(F.normalize(spembs)) - hs = hs + spembs.unsqueeze(1) + spk_emb = self.spk_projection(F.normalize(spk_emb)) + hs = hs + spk_emb.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection - spembs = F.normalize(spembs).unsqueeze(1).expand( + spk_emb = F.normalize(spk_emb).unsqueeze(1).expand( shape=[-1, hs.shape[1], -1]) - hs = self.spk_projection(paddle.concat([hs, spembs], axis=-1)) + hs = self.spk_projection(paddle.concat([hs, spk_emb], axis=-1)) else: raise NotImplementedError("support only add or concat.") @@ -682,9 +682,9 @@ class FastSpeech2Inference(nn.Layer): self.normalizer = normalizer self.acoustic_model = model - def forward(self, text, spk_id=None): + def forward(self, text, spk_id=None, spk_emb=None): normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference( - text, spk_id=spk_id) + text, spk_id=spk_id, spk_emb=spk_emb) logmel = self.normalizer.inverse(normalized_mel) return logmel diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py index 4297c8b61..0dabf934c 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2_updater.py @@ -54,6 +54,10 @@ class FastSpeech2Updater(StandardUpdater): losses_dict = {} # spk_id!=None in multiple spk fastspeech2 spk_id = batch["spk_id"] if "spk_id" in batch else None + spk_emb = batch["spk_emb"] if "spk_emb" in batch else None + # No explicit speaker identifier labels are used during voice cloning training. + if spk_emb is not None: + spk_id = None before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( text=batch["text"], @@ -63,7 +67,8 @@ class FastSpeech2Updater(StandardUpdater): durations=batch["durations"], pitch=batch["pitch"], energy=batch["energy"], - spk_id=spk_id) + spk_id=spk_id, + spk_emb=spk_emb) l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( after_outs=after_outs, @@ -126,6 +131,9 @@ class FastSpeech2Evaluator(StandardEvaluator): losses_dict = {} # spk_id!=None in multiple spk fastspeech2 spk_id = batch["spk_id"] if "spk_id" in batch else None + spk_emb = batch["spk_emb"] if "spk_emb" in batch else None + if spk_emb is not None: + spk_id = None before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model( text=batch["text"], @@ -135,7 +143,8 @@ class FastSpeech2Evaluator(StandardEvaluator): durations=batch["durations"], pitch=batch["pitch"], energy=batch["energy"], - spk_id=spk_id) + spk_id=spk_id, + spk_emb=spk_emb) l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( after_outs=after_outs, diff --git a/paddlespeech/t2s/models/transformer_tts/transformer_tts.py b/paddlespeech/t2s/models/transformer_tts/transformer_tts.py index 97233c766..5958a1660 100644 --- a/paddlespeech/t2s/models/transformer_tts/transformer_tts.py +++ b/paddlespeech/t2s/models/transformer_tts/transformer_tts.py @@ -391,7 +391,7 @@ class TransformerTTS(nn.Layer): text_lengths: paddle.Tensor, speech: paddle.Tensor, speech_lengths: paddle.Tensor, - spembs: paddle.Tensor=None, + spk_emb: paddle.Tensor=None, ) -> Tuple[paddle.Tensor, Dict[str, paddle.Tensor], paddle.Tensor]: """Calculate forward propagation. @@ -405,7 +405,7 @@ class TransformerTTS(nn.Layer): Batch of padded target features (B, Lmax, odim). speech_lengths : Tensor(int64) Batch of the lengths of each target (B,). - spembs : Tensor, optional + spk_emb : Tensor, optional Batch of speaker embeddings (B, spk_embed_dim). Returns @@ -439,7 +439,7 @@ class TransformerTTS(nn.Layer): # calculate transformer outputs after_outs, before_outs, logits = self._forward(xs, ilens, ys, olens, - spembs) + spk_emb) # modifiy mod part of groundtruth @@ -467,7 +467,7 @@ class TransformerTTS(nn.Layer): ilens: paddle.Tensor, ys: paddle.Tensor, olens: paddle.Tensor, - spembs: paddle.Tensor, + spk_emb: paddle.Tensor, ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: # forward encoder x_masks = self._source_mask(ilens) @@ -480,7 +480,7 @@ class TransformerTTS(nn.Layer): # integrate speaker embedding if self.spk_embed_dim is not None: - hs = self._integrate_with_spk_embed(hs, spembs) + hs = self._integrate_with_spk_embed(hs, spk_emb) # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) if self.reduction_factor > 1: @@ -514,7 +514,7 @@ class TransformerTTS(nn.Layer): self, text: paddle.Tensor, speech: paddle.Tensor=None, - spembs: paddle.Tensor=None, + spk_emb: paddle.Tensor=None, threshold: float=0.5, minlenratio: float=0.0, maxlenratio: float=10.0, @@ -528,7 +528,7 @@ class TransformerTTS(nn.Layer): Input sequence of characters (T,). speech : Tensor, optional Feature sequence to extract style (N, idim). - spembs : Tensor, optional + spk_emb : Tensor, optional Speaker embedding vector (spk_embed_dim,). threshold : float, optional Threshold in inference. @@ -551,7 +551,6 @@ class TransformerTTS(nn.Layer): """ # input of embedding must be int64 y = speech - spemb = spembs # add eos at the last of sequence text = numpy.pad( @@ -564,12 +563,12 @@ class TransformerTTS(nn.Layer): # get teacher forcing outputs xs, ys = x.unsqueeze(0), y.unsqueeze(0) - spembs = None if spemb is None else spemb.unsqueeze(0) + spk_emb = None if spk_emb is None else spk_emb.unsqueeze(0) ilens = paddle.to_tensor( [xs.shape[1]], dtype=paddle.int64, place=xs.place) olens = paddle.to_tensor( [ys.shape[1]], dtype=paddle.int64, place=ys.place) - outs, *_ = self._forward(xs, ilens, ys, olens, spembs) + outs, *_ = self._forward(xs, ilens, ys, olens, spk_emb) # get attention weights att_ws = [] @@ -590,9 +589,9 @@ class TransformerTTS(nn.Layer): hs = hs + style_embs.unsqueeze(1) # integrate speaker embedding - if self.spk_embed_dim is not None: - spembs = spemb.unsqueeze(0) - hs = self._integrate_with_spk_embed(hs, spembs) + if spk_emb is not None: + spk_emb = spk_emb.unsqueeze(0) + hs = self._integrate_with_spk_embed(hs, spk_emb) # set limits of length maxlen = int(hs.shape[1] * maxlenratio / self.reduction_factor) @@ -726,14 +725,14 @@ class TransformerTTS(nn.Layer): def _integrate_with_spk_embed(self, hs: paddle.Tensor, - spembs: paddle.Tensor) -> paddle.Tensor: + spk_emb: paddle.Tensor) -> paddle.Tensor: """Integrate speaker embedding with hidden states. Parameters ---------- hs : Tensor Batch of hidden state sequences (B, Tmax, adim). - spembs : Tensor + spk_emb : Tensor Batch of speaker embeddings (B, spk_embed_dim). Returns @@ -744,13 +743,13 @@ class TransformerTTS(nn.Layer): """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states - spembs = self.projection(F.normalize(spembs)) - hs = hs + spembs.unsqueeze(1) + spk_emb = self.projection(F.normalize(spk_emb)) + hs = hs + spk_emb.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection - spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.shape[1], - -1) - hs = self.projection(paddle.concat([hs, spembs], axis=-1)) + spk_emb = F.normalize(spk_emb).unsqueeze(1).expand(-1, hs.shape[1], + -1) + hs = self.projection(paddle.concat([hs, spk_emb], axis=-1)) else: raise NotImplementedError("support only add or concat.") diff --git a/paddlespeech/xvector/__init__.py b/paddlespeech/vector/__init__.py similarity index 100% rename from paddlespeech/xvector/__init__.py rename to paddlespeech/vector/__init__.py diff --git a/paddlespeech/t2s/exps/ge2e/__init__.py b/paddlespeech/vector/exps/__init__.py similarity index 100% rename from paddlespeech/t2s/exps/ge2e/__init__.py rename to paddlespeech/vector/exps/__init__.py diff --git a/paddlespeech/vector/exps/ge2e/__init__.py b/paddlespeech/vector/exps/ge2e/__init__.py new file mode 100644 index 000000000..abf198b97 --- /dev/null +++ b/paddlespeech/vector/exps/ge2e/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/exps/ge2e/audio_processor.py b/paddlespeech/vector/exps/ge2e/audio_processor.py similarity index 100% rename from paddlespeech/t2s/exps/ge2e/audio_processor.py rename to paddlespeech/vector/exps/ge2e/audio_processor.py diff --git a/paddlespeech/t2s/exps/ge2e/config.py b/paddlespeech/vector/exps/ge2e/config.py similarity index 100% rename from paddlespeech/t2s/exps/ge2e/config.py rename to paddlespeech/vector/exps/ge2e/config.py diff --git a/paddlespeech/t2s/exps/ge2e/dataset_processors.py b/paddlespeech/vector/exps/ge2e/dataset_processors.py similarity index 98% rename from paddlespeech/t2s/exps/ge2e/dataset_processors.py rename to paddlespeech/vector/exps/ge2e/dataset_processors.py index a9320d985..908c852b2 100644 --- a/paddlespeech/t2s/exps/ge2e/dataset_processors.py +++ b/paddlespeech/vector/exps/ge2e/dataset_processors.py @@ -19,7 +19,7 @@ from typing import List import numpy as np from tqdm import tqdm -from paddlespeech.t2s.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor +from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor def _process_utterance(path_pair, processor: SpeakerVerificationPreprocessor): diff --git a/paddlespeech/t2s/exps/ge2e/inference.py b/paddlespeech/vector/exps/ge2e/inference.py similarity index 95% rename from paddlespeech/t2s/exps/ge2e/inference.py rename to paddlespeech/vector/exps/ge2e/inference.py index eed3b7947..7660de5e8 100644 --- a/paddlespeech/t2s/exps/ge2e/inference.py +++ b/paddlespeech/vector/exps/ge2e/inference.py @@ -18,9 +18,9 @@ import numpy as np import paddle import tqdm -from paddlespeech.t2s.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor -from paddlespeech.t2s.exps.ge2e.config import get_cfg_defaults -from paddlespeech.t2s.models.lstm_speaker_encoder import LSTMSpeakerEncoder +from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor +from paddlespeech.vector.exps.ge2e.config import get_cfg_defaults +from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder def embed_utterance(processor, model, fpath_or_wav): diff --git a/paddlespeech/t2s/exps/ge2e/preprocess.py b/paddlespeech/vector/exps/ge2e/preprocess.py similarity index 87% rename from paddlespeech/t2s/exps/ge2e/preprocess.py rename to paddlespeech/vector/exps/ge2e/preprocess.py index 604ff0c67..dabe0ce76 100644 --- a/paddlespeech/t2s/exps/ge2e/preprocess.py +++ b/paddlespeech/vector/exps/ge2e/preprocess.py @@ -14,14 +14,13 @@ import argparse from pathlib import Path -from audio_processor import SpeakerVerificationPreprocessor - -from paddlespeech.t2s.exps.ge2e.config import get_cfg_defaults -from paddlespeech.t2s.exps.ge2e.dataset_processors import process_aidatatang_200zh -from paddlespeech.t2s.exps.ge2e.dataset_processors import process_librispeech -from paddlespeech.t2s.exps.ge2e.dataset_processors import process_magicdata -from paddlespeech.t2s.exps.ge2e.dataset_processors import process_voxceleb1 -from paddlespeech.t2s.exps.ge2e.dataset_processors import process_voxceleb2 +from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor +from paddlespeech.vector.exps.ge2e.config import get_cfg_defaults +from paddlespeech.vector.exps.ge2e.dataset_processors import process_aidatatang_200zh +from paddlespeech.vector.exps.ge2e.dataset_processors import process_librispeech +from paddlespeech.vector.exps.ge2e.dataset_processors import process_magicdata +from paddlespeech.vector.exps.ge2e.dataset_processors import process_voxceleb1 +from paddlespeech.vector.exps.ge2e.dataset_processors import process_voxceleb2 if __name__ == "__main__": parser = argparse.ArgumentParser( diff --git a/paddlespeech/t2s/exps/ge2e/random_cycle.py b/paddlespeech/vector/exps/ge2e/random_cycle.py similarity index 100% rename from paddlespeech/t2s/exps/ge2e/random_cycle.py rename to paddlespeech/vector/exps/ge2e/random_cycle.py diff --git a/paddlespeech/t2s/exps/ge2e/speaker_verification_dataset.py b/paddlespeech/vector/exps/ge2e/speaker_verification_dataset.py similarity index 98% rename from paddlespeech/t2s/exps/ge2e/speaker_verification_dataset.py rename to paddlespeech/vector/exps/ge2e/speaker_verification_dataset.py index a13219969..194eb7f28 100644 --- a/paddlespeech/t2s/exps/ge2e/speaker_verification_dataset.py +++ b/paddlespeech/vector/exps/ge2e/speaker_verification_dataset.py @@ -18,7 +18,7 @@ import numpy as np from paddle.io import BatchSampler from paddle.io import Dataset -from paddlespeech.t2s.exps.ge2e.random_cycle import random_cycle +from paddlespeech.vector.exps.ge2e.random_cycle import random_cycle class MultiSpeakerMelDataset(Dataset): diff --git a/paddlespeech/t2s/exps/ge2e/train.py b/paddlespeech/vector/exps/ge2e/train.py similarity index 91% rename from paddlespeech/t2s/exps/ge2e/train.py rename to paddlespeech/vector/exps/ge2e/train.py index 55c6daf73..bf1cf1074 100644 --- a/paddlespeech/t2s/exps/ge2e/train.py +++ b/paddlespeech/vector/exps/ge2e/train.py @@ -19,13 +19,13 @@ from paddle.io import DataLoader from paddle.nn.clip import ClipGradByGlobalNorm from paddle.optimizer import Adam -from paddlespeech.t2s.exps.ge2e.config import get_cfg_defaults -from paddlespeech.t2s.exps.ge2e.speaker_verification_dataset import Collate -from paddlespeech.t2s.exps.ge2e.speaker_verification_dataset import MultiSpeakerMelDataset -from paddlespeech.t2s.exps.ge2e.speaker_verification_dataset import MultiSpeakerSampler -from paddlespeech.t2s.models.lstm_speaker_encoder import LSTMSpeakerEncoder from paddlespeech.t2s.training import default_argument_parser from paddlespeech.t2s.training import ExperimentBase +from paddlespeech.vector.exps.ge2e.config import get_cfg_defaults +from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import Collate +from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import MultiSpeakerMelDataset +from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import MultiSpeakerSampler +from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder class Ge2eExperiment(ExperimentBase): diff --git a/paddlespeech/vector/models/__init__.py b/paddlespeech/vector/models/__init__.py new file mode 100644 index 000000000..185a92b8d --- /dev/null +++ b/paddlespeech/vector/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/paddlespeech/t2s/models/lstm_speaker_encoder.py b/paddlespeech/vector/models/lstm_speaker_encoder.py similarity index 100% rename from paddlespeech/t2s/models/lstm_speaker_encoder.py rename to paddlespeech/vector/models/lstm_speaker_encoder.py