diff --git a/examples/aishell3/ernie_sat/conf/default.yaml b/examples/aishell3/ernie_sat/conf/default.yaml new file mode 100644 index 000000000..fdc767fb0 --- /dev/null +++ b/examples/aishell3/ernie_sat/conf/default.yaml @@ -0,0 +1,283 @@ +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### + +fs: 24000 # sr +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + +# Only used for feats_type != raw + +fmin: 80 # Minimum frequency of Mel basis. +fmax: 7600 # Maximum frequency of Mel basis. +n_mels: 80 # The number of mel basis. + +mean_phn_span: 8 +mlm_prob: 0.8 + +########################################################### +# DATA SETTING # +########################################################### +batch_size: 20 +num_workers: 2 + +########################################################### +# MODEL SETTING # +########################################################### +model: + text_masking: false + postnet_layers: 5 + postnet_filts: 5 + postnet_chans: 256 + encoder_type: conformer + decoder_type: conformer + enc_input_layer: sega_mlm + enc_pre_speech_layer: 0 + enc_cnn_module_kernel: 7 + enc_attention_dim: 384 + enc_attention_heads: 2 + enc_linear_units: 1536 + enc_num_blocks: 4 + enc_dropout_rate: 0.2 + enc_positional_dropout_rate: 0.2 + enc_attention_dropout_rate: 0.2 + enc_normalize_before: true + enc_macaron_style: true + enc_use_cnn_module: true + enc_selfattention_layer_type: legacy_rel_selfattn + enc_activation_type: swish + enc_pos_enc_layer_type: legacy_rel_pos + enc_positionwise_layer_type: conv1d + enc_positionwise_conv_kernel_size: 3 + dec_cnn_module_kernel: 31 + dec_attention_dim: 384 + dec_attention_heads: 2 + dec_linear_units: 1536 + dec_num_blocks: 4 + dec_dropout_rate: 0.2 + dec_positional_dropout_rate: 0.2 + dec_attention_dropout_rate: 0.2 + dec_macaron_style: true + dec_use_cnn_module: true + dec_selfattention_layer_type: legacy_rel_selfattn + dec_activation_type: swish + dec_pos_enc_layer_type: legacy_rel_pos + dec_positionwise_layer_type: conv1d + dec_positionwise_conv_kernel_size: 3 + +########################################################### +# OPTIMIZER SETTING # +########################################################### +scheduler_params: + d_model: 384 + warmup_steps: 4000 +grad_clip: 1.0 + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 1500 +num_snapshots: 50 + +########################################################### +# OTHER SETTING # +########################################################### +seed: 0 + +token_list: +- +- +- d +- sp +- sh +- ii +- j +- zh +- l +- x +- b +- g +- uu +- e5 +- h +- q +- m +- i1 +- t +- z +- ch +- f +- s +- u4 +- ix4 +- i4 +- n +- i3 +- iu3 +- vv +- ian4 +- ix2 +- r +- e4 +- ai4 +- k +- ing2 +- a1 +- en2 +- ui4 +- ong1 +- uo3 +- u2 +- u3 +- ao4 +- ee +- p +- an1 +- eng2 +- i2 +- in1 +- c +- ai2 +- ian2 +- e2 +- an4 +- ing4 +- v4 +- ai3 +- a5 +- ian3 +- eng1 +- ong4 +- ang4 +- ian1 +- ing1 +- iy4 +- ao3 +- ang1 +- uo4 +- u1 +- iao4 +- iu4 +- a4 +- van2 +- ie4 +- ang2 +- ou4 +- iang4 +- ix1 +- er4 +- iy1 +- e1 +- en1 +- ui2 +- an3 +- ei4 +- ong2 +- uo1 +- ou3 +- uo2 +- iao1 +- ou1 +- an2 +- uan4 +- ia4 +- ia1 +- ang3 +- v3 +- iu2 +- iao3 +- in4 +- a3 +- ei3 +- iang3 +- v2 +- eng4 +- en3 +- aa +- uan1 +- v1 +- ao1 +- ve4 +- ie3 +- ai1 +- ing3 +- iang1 +- a2 +- ui1 +- en4 +- en5 +- in3 +- uan3 +- e3 +- ie1 +- ve2 +- ei2 +- in2 +- ix3 +- uan2 +- iang2 +- ie2 +- ua4 +- ou2 +- uai4 +- er2 +- eng3 +- uang3 +- un1 +- ong3 +- uang4 +- vn4 +- un2 +- iy3 +- iz4 +- ui3 +- iao2 +- iong4 +- un4 +- van4 +- ao2 +- uang1 +- iy5 +- o2 +- ei1 +- ua1 +- iu1 +- uang2 +- er5 +- o1 +- un3 +- vn1 +- vn2 +- o4 +- ve1 +- van3 +- ua2 +- er3 +- iong3 +- van1 +- ia2 +- iy2 +- ia3 +- iong1 +- uo5 +- oo +- ve3 +- ou5 +- uai3 +- ian5 +- iong2 +- uai2 +- uai1 +- ua3 +- vn3 +- ia5 +- ie5 +- ueng1 +- o5 +- o3 +- iang5 +- ei5 +- \ No newline at end of file diff --git a/examples/aishell3/ernie_sat/local/preprocess.sh b/examples/aishell3/ernie_sat/local/preprocess.sh new file mode 100755 index 000000000..d7a91d08f --- /dev/null +++ b/examples/aishell3/ernie_sat/local/preprocess.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./aishell3_alignment_tone \ + --output durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=aishell3 \ + --rootdir=~/datasets/data_aishell3/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="speech" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize and covert phone/speaker to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt +fi diff --git a/examples/aishell3/ernie_sat/local/synthesize.sh b/examples/aishell3/ernie_sat/local/synthesize.sh new file mode 100755 index 000000000..3e907427c --- /dev/null +++ b/examples/aishell3/ernie_sat/local/synthesize.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +stage=1 +stop_stage=1 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --erniesat_config=${config_path} \ + --erniesat_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --erniesat_stat=dump/train/speech_stats.npy \ + --voc=pwgan_aishell3 \ + --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# hifigan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --erniesat_config=${config_path} \ + --erniesat_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --erniesat_stat=dump/train/speech_stats.npy \ + --voc=hifigan_aishell3 \ + --voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \ + --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi diff --git a/examples/aishell3/ernie_sat/local/train.sh b/examples/aishell3/ernie_sat/local/train.sh new file mode 100755 index 000000000..30720e8f5 --- /dev/null +++ b/examples/aishell3/ernie_sat/local/train.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=2 \ + --phones-dict=dump/phone_id_map.txt \ No newline at end of file diff --git a/examples/aishell3/ernie_sat/path.sh b/examples/aishell3/ernie_sat/path.sh new file mode 100755 index 000000000..4ecab0251 --- /dev/null +++ b/examples/aishell3/ernie_sat/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=ernie_sat +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} \ No newline at end of file diff --git a/examples/aishell3/ernie_sat/run.sh b/examples/aishell3/ernie_sat/run.sh new file mode 100755 index 000000000..d75a19f23 --- /dev/null +++ b/examples/aishell3/ernie_sat/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_153.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/examples/aishell3/tts3/conf/conformer.yaml b/examples/aishell3/tts3/conf/conformer.yaml index ea73593d7..0834bfe3f 100644 --- a/examples/aishell3/tts3/conf/conformer.yaml +++ b/examples/aishell3/tts3/conf/conformer.yaml @@ -94,8 +94,8 @@ updater: # OPTIMIZER SETTING # ########################################################### optimizer: - optim: adam # optimizer type - learning_rate: 0.001 # learning rate + optim: adam # optimizer type + learning_rate: 0.001 # learning rate ########################################################### # TRAINING SETTING # diff --git a/examples/aishell3/tts3/conf/default.yaml b/examples/aishell3/tts3/conf/default.yaml index ac4956742..e65b5d0ec 100644 --- a/examples/aishell3/tts3/conf/default.yaml +++ b/examples/aishell3/tts3/conf/default.yaml @@ -88,8 +88,8 @@ updater: # OPTIMIZER SETTING # ########################################################### optimizer: - optim: adam # optimizer type - learning_rate: 0.001 # learning rate + optim: adam # optimizer type + learning_rate: 0.001 # learning rate ########################################################### # TRAINING SETTING # diff --git a/examples/aishell3/tts3/local/synthesize.sh b/examples/aishell3/tts3/local/synthesize.sh index d3978833f..9134e0426 100755 --- a/examples/aishell3/tts3/local/synthesize.sh +++ b/examples/aishell3/tts3/local/synthesize.sh @@ -37,7 +37,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --am_stat=dump/train/speech_stats.npy \ --voc=hifigan_aishell3 \ --voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \ - --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pd \ + --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz \ --voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \ --test_metadata=dump/test/norm/metadata.jsonl \ --output_dir=${train_output_path}/test \ diff --git a/examples/aishell3_vctk/ernie_sat/conf/default.yaml b/examples/aishell3_vctk/ernie_sat/conf/default.yaml new file mode 100644 index 000000000..abb69fcc0 --- /dev/null +++ b/examples/aishell3_vctk/ernie_sat/conf/default.yaml @@ -0,0 +1,352 @@ +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### + +fs: 24000 # sr +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + +# Only used for feats_type != raw + +fmin: 80 # Minimum frequency of Mel basis. +fmax: 7600 # Maximum frequency of Mel basis. +n_mels: 80 # The number of mel basis. + +mean_phn_span: 8 +mlm_prob: 0.8 + +########################################################### +# DATA SETTING # +########################################################### +batch_size: 20 +num_workers: 2 + +########################################################### +# MODEL SETTING # +########################################################### +model: + text_masking: true + postnet_layers: 5 + postnet_filts: 5 + postnet_chans: 256 + encoder_type: conformer + decoder_type: conformer + enc_input_layer: sega_mlm + enc_pre_speech_layer: 0 + enc_cnn_module_kernel: 7 + enc_attention_dim: 384 + enc_attention_heads: 2 + enc_linear_units: 1536 + enc_num_blocks: 4 + enc_dropout_rate: 0.2 + enc_positional_dropout_rate: 0.2 + enc_attention_dropout_rate: 0.2 + enc_normalize_before: true + enc_macaron_style: true + enc_use_cnn_module: true + enc_selfattention_layer_type: legacy_rel_selfattn + enc_activation_type: swish + enc_pos_enc_layer_type: legacy_rel_pos + enc_positionwise_layer_type: conv1d + enc_positionwise_conv_kernel_size: 3 + dec_cnn_module_kernel: 31 + dec_attention_dim: 384 + dec_attention_heads: 2 + dec_linear_units: 1536 + dec_num_blocks: 4 + dec_dropout_rate: 0.2 + dec_positional_dropout_rate: 0.2 + dec_attention_dropout_rate: 0.2 + dec_macaron_style: true + dec_use_cnn_module: true + dec_selfattention_layer_type: legacy_rel_selfattn + dec_activation_type: swish + dec_pos_enc_layer_type: legacy_rel_pos + dec_positionwise_layer_type: conv1d + dec_positionwise_conv_kernel_size: 3 + +########################################################### +# OPTIMIZER SETTING # +########################################################### +scheduler_params: + d_model: 384 + warmup_steps: 4000 +grad_clip: 1.0 + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 700 +num_snapshots: 50 + +########################################################### +# OTHER SETTING # +########################################################### +seed: 0 + +token_list: +- +- +- AH0 +- T +- N +- sp +- S +- R +- D +- L +- Z +- DH +- IH1 +- K +- W +- M +- EH1 +- AE1 +- ER0 +- B +- IY1 +- P +- V +- IY0 +- F +- HH +- AA1 +- AY1 +- AH1 +- EY1 +- IH0 +- AO1 +- OW1 +- UW1 +- G +- NG +- SH +- Y +- TH +- ER1 +- JH +- UH1 +- AW1 +- CH +- IH2 +- OW0 +- OW2 +- EY2 +- EH2 +- UW0 +- OY1 +- ZH +- EH0 +- AY2 +- AW2 +- AA2 +- AE2 +- IY2 +- AH2 +- AE0 +- AO2 +- AY0 +- AO0 +- UW2 +- UH2 +- AA0 +- EY0 +- AW0 +- UH0 +- ER2 +- OY2 +- OY0 +- d +- sh +- ii +- j +- zh +- l +- x +- b +- g +- uu +- e5 +- h +- q +- m +- i1 +- t +- z +- ch +- f +- s +- u4 +- ix4 +- i4 +- n +- i3 +- iu3 +- vv +- ian4 +- ix2 +- r +- e4 +- ai4 +- k +- ing2 +- a1 +- en2 +- ui4 +- ong1 +- uo3 +- u2 +- u3 +- ao4 +- ee +- p +- an1 +- eng2 +- i2 +- in1 +- c +- ai2 +- ian2 +- e2 +- an4 +- ing4 +- v4 +- ai3 +- a5 +- ian3 +- eng1 +- ong4 +- ang4 +- ian1 +- ing1 +- iy4 +- ao3 +- ang1 +- uo4 +- u1 +- iao4 +- iu4 +- a4 +- van2 +- ie4 +- ang2 +- ou4 +- iang4 +- ix1 +- er4 +- iy1 +- e1 +- en1 +- ui2 +- an3 +- ei4 +- ong2 +- uo1 +- ou3 +- uo2 +- iao1 +- ou1 +- an2 +- uan4 +- ia4 +- ia1 +- ang3 +- v3 +- iu2 +- iao3 +- in4 +- a3 +- ei3 +- iang3 +- v2 +- eng4 +- en3 +- aa +- uan1 +- v1 +- ao1 +- ve4 +- ie3 +- ai1 +- ing3 +- iang1 +- a2 +- ui1 +- en4 +- en5 +- in3 +- uan3 +- e3 +- ie1 +- ve2 +- ei2 +- in2 +- ix3 +- uan2 +- iang2 +- ie2 +- ua4 +- ou2 +- uai4 +- er2 +- eng3 +- uang3 +- un1 +- ong3 +- uang4 +- vn4 +- un2 +- iy3 +- iz4 +- ui3 +- iao2 +- iong4 +- un4 +- van4 +- ao2 +- uang1 +- iy5 +- o2 +- ei1 +- ua1 +- iu1 +- uang2 +- er5 +- o1 +- un3 +- vn1 +- vn2 +- o4 +- ve1 +- van3 +- ua2 +- er3 +- iong3 +- van1 +- ia2 +- iy2 +- ia3 +- iong1 +- uo5 +- oo +- ve3 +- ou5 +- uai3 +- ian5 +- iong2 +- uai2 +- uai1 +- ua3 +- vn3 +- ia5 +- ie5 +- ueng1 +- o5 +- o3 +- iang5 +- ei5 +- diff --git a/examples/aishell3_vctk/ernie_sat/local/preprocess.sh b/examples/aishell3_vctk/ernie_sat/local/preprocess.sh new file mode 100755 index 000000000..783fd6333 --- /dev/null +++ b/examples/aishell3_vctk/ernie_sat/local/preprocess.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results for aishell3 ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./aishell3_alignment_tone \ + --output durations_aishell3.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results for vctk ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./vctk_alignment \ + --output durations_vctk.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get durations from MFA's result + echo "concat durations_aishell3.txt and durations_vctk.txt to durations.txt" + cat durations_aishell3.txt durations_vctk.txt > durations.txt +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=aishell3 \ + --rootdir=~/datasets/data_aishell3/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=vctk \ + --rootdir=~/datasets/VCTK-Corpus-0.92/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="speech" +fi + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # normalize and covert phone/speaker to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt +fi diff --git a/examples/aishell3_vctk/ernie_sat/local/synthesize.sh b/examples/aishell3_vctk/ernie_sat/local/synthesize.sh new file mode 100755 index 000000000..3e907427c --- /dev/null +++ b/examples/aishell3_vctk/ernie_sat/local/synthesize.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +stage=1 +stop_stage=1 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --erniesat_config=${config_path} \ + --erniesat_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --erniesat_stat=dump/train/speech_stats.npy \ + --voc=pwgan_aishell3 \ + --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# hifigan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --erniesat_config=${config_path} \ + --erniesat_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --erniesat_stat=dump/train/speech_stats.npy \ + --voc=hifigan_aishell3 \ + --voc_config=hifigan_aishell3_ckpt_0.2.0/default.yaml \ + --voc_ckpt=hifigan_aishell3_ckpt_0.2.0/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_aishell3_ckpt_0.2.0/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi diff --git a/examples/aishell3_vctk/ernie_sat/local/train.sh b/examples/aishell3_vctk/ernie_sat/local/train.sh new file mode 100755 index 000000000..30720e8f5 --- /dev/null +++ b/examples/aishell3_vctk/ernie_sat/local/train.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=2 \ + --phones-dict=dump/phone_id_map.txt \ No newline at end of file diff --git a/examples/aishell3_vctk/ernie_sat/path.sh b/examples/aishell3_vctk/ernie_sat/path.sh new file mode 100755 index 000000000..4ecab0251 --- /dev/null +++ b/examples/aishell3_vctk/ernie_sat/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=ernie_sat +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} \ No newline at end of file diff --git a/examples/aishell3_vctk/ernie_sat/run.sh b/examples/aishell3_vctk/ernie_sat/run.sh new file mode 100755 index 000000000..d75a19f23 --- /dev/null +++ b/examples/aishell3_vctk/ernie_sat/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_153.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/examples/csmsc/tts2/conf/default.yaml b/examples/csmsc/tts2/conf/default.yaml index a3366b8f9..a5a258b7e 100644 --- a/examples/csmsc/tts2/conf/default.yaml +++ b/examples/csmsc/tts2/conf/default.yaml @@ -21,22 +21,22 @@ num_workers: 4 # MODEL SETTING # ########################################################### model: - encoder_hidden_size: 128 - encoder_kernel_size: 3 - encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1] - duration_predictor_hidden_size: 128 - decoder_hidden_size: 128 - decoder_output_size: 80 - decoder_kernel_size: 3 - decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1] + encoder_hidden_size: 128 + encoder_kernel_size: 3 + encoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 1] + duration_predictor_hidden_size: 128 + decoder_hidden_size: 128 + decoder_output_size: 80 + decoder_kernel_size: 3 + decoder_dilations: [1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1] ########################################################### # OPTIMIZER SETTING # ########################################################### optimizer: - optim: adam # optimizer type - learning_rate: 0.002 # learning rate - max_grad_norm: 1 + optim: adam # optimizer type + learning_rate: 0.002 # learning rate + max_grad_norm: 1 ########################################################### # TRAINING SETTING # diff --git a/examples/csmsc/voc3/conf/default.yaml b/examples/csmsc/voc3/conf/default.yaml index fbff54f19..a5ee17808 100644 --- a/examples/csmsc/voc3/conf/default.yaml +++ b/examples/csmsc/voc3/conf/default.yaml @@ -29,7 +29,7 @@ generator_params: out_channels: 4 # Number of output channels. kernel_size: 7 # Kernel size of initial and final conv layers. channels: 384 # Initial number of channels for conv layers. - upsample_scales: [5, 5, 3] # List of Upsampling scales. prod(upsample_scales) == n_shift + upsample_scales: [5, 5, 3] # List of Upsampling scales. prod(upsample_scales) x out_channels == n_shift stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack. stacks: 4 # Number of stacks in a single residual stack module. use_weight_norm: True # Whether to use weight normalization. diff --git a/examples/csmsc/voc3/conf/finetune.yaml b/examples/csmsc/voc3/conf/finetune.yaml index 0a38c2820..8c37ac302 100644 --- a/examples/csmsc/voc3/conf/finetune.yaml +++ b/examples/csmsc/voc3/conf/finetune.yaml @@ -29,7 +29,7 @@ generator_params: out_channels: 4 # Number of output channels. kernel_size: 7 # Kernel size of initial and final conv layers. channels: 384 # Initial number of channels for conv layers. - upsample_scales: [5, 5, 3] # List of Upsampling scales. prod(upsample_scales) == n_shift + upsample_scales: [5, 5, 3] # List of Upsampling scales. prod(upsample_scales) x out_channels == n_shift stack_kernel_size: 3 # Kernel size of dilated conv layers in residual stack. stacks: 4 # Number of stacks in a single residual stack module. use_weight_norm: True # Whether to use weight normalization. diff --git a/examples/ernie_sat/local/align.py b/examples/ernie_sat/local/align.py index 025877ddf..ff47cac5b 100755 --- a/examples/ernie_sat/local/align.py +++ b/examples/ernie_sat/local/align.py @@ -1,3 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Usage: align.py wavfile trsfile outwordfile outphonefile """ diff --git a/examples/ernie_sat/local/inference.py b/examples/ernie_sat/local/inference.py index 196d9c6d0..e6a0788fd 100644 --- a/examples/ernie_sat/local/inference.py +++ b/examples/ernie_sat/local/inference.py @@ -1,4 +1,16 @@ -#!/usr/bin/env python3 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import random from typing import Dict @@ -305,7 +317,6 @@ def get_dur_adj_factor(orig_dur: List[int], def prep_feats_with_dur(wav_path: str, - mlm_model: nn.Layer, source_lang: str="English", target_lang: str="English", old_str: str="", @@ -425,8 +436,7 @@ def prep_feats_with_dur(wav_path: str, return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy -def prep_feats(mlm_model: nn.Layer, - wav_path: str, +def prep_feats(wav_path: str, source_lang: str="english", target_lang: str="english", old_str: str="", @@ -440,7 +450,6 @@ def prep_feats(mlm_model: nn.Layer, wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur( source_lang=source_lang, target_lang=target_lang, - mlm_model=mlm_model, old_str=old_str, new_str=new_str, wav_path=wav_path, @@ -482,7 +491,6 @@ def decode_with_model(mlm_model: nn.Layer, batch, old_span_bdy, new_span_bdy = prep_feats( source_lang=source_lang, target_lang=target_lang, - mlm_model=mlm_model, wav_path=wav_path, old_str=old_str, new_str=new_str, diff --git a/examples/ernie_sat/local/inference_new.py b/examples/ernie_sat/local/inference_new.py new file mode 100644 index 000000000..525967eb1 --- /dev/null +++ b/examples/ernie_sat/local/inference_new.py @@ -0,0 +1,622 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random +from typing import Dict +from typing import List + +import librosa +import numpy as np +import paddle +import soundfile as sf +import yaml +from align import alignment +from align import alignment_zh +from align import words2phns +from align import words2phns_zh +from paddle import nn +from sedit_arg_parser import parse_args +from utils import eval_durs +from utils import get_voc_out +from utils import is_chinese +from utils import load_num_sequence_text +from utils import read_2col_text +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn +from paddlespeech.t2s.models.ernie_sat.ernie_sat import ErnieSAT + +random.seed(0) +np.random.seed(0) + + +def get_wav(wav_path: str, + source_lang: str='english', + target_lang: str='english', + model_name: str="paddle_checkpoint_en", + old_str: str="", + new_str: str="", + non_autoreg: bool=True): + wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output( + source_lang=source_lang, + target_lang=target_lang, + model_name=model_name, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + use_teacher_forcing=non_autoreg) + + masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]] + + alt_wav = get_voc_out(masked_feat) + + old_time_bdy = [hop_length * x for x in old_span_bdy] + + wav_replaced = np.concatenate( + [wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]]) + + data_dict = {"origin": wav_org, "output": wav_replaced} + + return data_dict + + +def load_model(model_name: str="paddle_checkpoint_en"): + config_path = './pretrained_model/{}/default.yaml'.format(model_name) + model_path = './pretrained_model/{}/model.pdparams'.format(model_name) + with open(config_path) as f: + conf = CfgNode(yaml.safe_load(f)) + token_list = list(conf.token_list) + vocab_size = len(token_list) + odim = conf.n_mels + mlm_model = ErnieSAT(idim=vocab_size, odim=odim, **conf["model"]) + state_dict = paddle.load(model_path) + new_state_dict = {} + for key, value in state_dict.items(): + new_key = "model." + key + new_state_dict[new_key] = value + mlm_model.set_state_dict(new_state_dict) + mlm_model.eval() + + return mlm_model, conf + + +def read_data(uid: str, prefix: os.PathLike): + # 获取 uid 对应的文本 + mfa_text = read_2col_text(prefix + '/text')[uid] + # 获取 uid 对应的音频路径 + mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid] + if not os.path.isabs(mfa_wav_path): + mfa_wav_path = prefix + mfa_wav_path + return mfa_text, mfa_wav_path + + +def get_align_data(uid: str, prefix: os.PathLike): + mfa_path = prefix + "mfa_" + mfa_text = read_2col_text(mfa_path + 'text')[uid] + mfa_start = load_num_sequence_text( + mfa_path + 'start', loader_type='text_float')[uid] + mfa_end = load_num_sequence_text( + mfa_path + 'end', loader_type='text_float')[uid] + mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid] + return mfa_text, mfa_start, mfa_end, mfa_wav_path + + +# 获取需要被 mask 的 mel 帧的范围 +def get_masked_mel_bdy(mfa_start: List[float], + mfa_end: List[float], + fs: int, + hop_length: int, + span_to_repl: List[List[int]]): + align_start = np.array(mfa_start) + align_end = np.array(mfa_end) + align_start = np.floor(fs * align_start / hop_length).astype('int') + align_end = np.floor(fs * align_end / hop_length).astype('int') + if span_to_repl[0] >= len(mfa_start): + span_bdy = [align_end[-1], align_end[-1]] + else: + span_bdy = [ + align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1] + ] + return span_bdy, align_start, align_end + + +def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]): + dic = {} + keys_to_del = [] + exist_idx = [] + sp_count = 0 + add_sp_count = 0 + for key in word2phns.keys(): + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + exist_idx.append(int(idx)) + else: + keys_to_del.append(key) + + for key in keys_to_del: + del word2phns[key] + + cur_id = 0 + for key in tp_word2phns.keys(): + if cur_id in exist_idx: + dic[str(cur_id) + "_sp"] = 'sp' + cur_id += 1 + add_sp_count += 1 + idx, wrd = key.split('_') + dic[str(cur_id) + "_" + wrd] = tp_word2phns[key] + cur_id += 1 + + if add_sp_count + 1 == sp_count: + dic[str(cur_id) + "_sp"] = 'sp' + add_sp_count += 1 + + assert add_sp_count == sp_count, "sp are not added in dic" + return dic + + +def get_max_idx(dic): + return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1] + + +def get_phns_and_spans(wav_path: str, + old_str: str="", + new_str: str="", + source_lang: str="english", + target_lang: str="english"): + is_append = (old_str == new_str[:len(old_str)]) + old_phns, mfa_start, mfa_end = [], [], [] + # source + if source_lang == "english": + intervals, word2phns = alignment(wav_path, old_str) + elif source_lang == "chinese": + intervals, word2phns = alignment_zh(wav_path, old_str) + _, tp_word2phns = words2phns_zh(old_str) + + for key, value in tp_word2phns.items(): + idx, wrd = key.split('_') + cur_val = " ".join(value) + tp_word2phns[key] = cur_val + + word2phns = recover_dict(word2phns, tp_word2phns) + else: + assert source_lang == "chinese" or source_lang == "english", \ + "source_lang is wrong..." + + for item in intervals: + old_phns.append(item[0]) + mfa_start.append(float(item[1])) + mfa_end.append(float(item[2])) + # target + if is_append and (source_lang != target_lang): + cross_lingual_clone = True + else: + cross_lingual_clone = False + + if cross_lingual_clone: + str_origin = new_str[:len(old_str)] + str_append = new_str[len(old_str):] + + if target_lang == "chinese": + phns_origin, origin_word2phns = words2phns(str_origin) + phns_append, append_word2phns_tmp = words2phns_zh(str_append) + + elif target_lang == "english": + # 原始句子 + phns_origin, origin_word2phns = words2phns_zh(str_origin) + # clone 句子 + phns_append, append_word2phns_tmp = words2phns(str_append) + else: + assert target_lang == "chinese" or target_lang == "english", \ + "cloning is not support for this language, please check it." + + new_phns = phns_origin + phns_append + + append_word2phns = {} + length = len(origin_word2phns) + for key, value in append_word2phns_tmp.items(): + idx, wrd = key.split('_') + append_word2phns[str(int(idx) + length) + '_' + wrd] = value + new_word2phns = origin_word2phns.copy() + new_word2phns.update(append_word2phns) + + else: + if source_lang == target_lang and target_lang == "english": + new_phns, new_word2phns = words2phns(new_str) + elif source_lang == target_lang and target_lang == "chinese": + new_phns, new_word2phns = words2phns_zh(new_str) + else: + assert source_lang == target_lang, \ + "source language is not same with target language..." + + span_to_repl = [0, len(old_phns) - 1] + span_to_add = [0, len(new_phns) - 1] + left_idx = 0 + new_phns_left = [] + sp_count = 0 + # find the left different index + for key in word2phns.keys(): + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + new_phns_left.append('sp') + else: + idx = str(int(idx) - sp_count) + if idx + '_' + wrd in new_word2phns: + left_idx += len(new_word2phns[idx + '_' + wrd]) + new_phns_left.extend(word2phns[key].split()) + else: + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + break + + # reverse word2phns and new_word2phns + right_idx = 0 + new_phns_right = [] + sp_count = 0 + word2phns_max_idx = get_max_idx(word2phns) + new_word2phns_max_idx = get_max_idx(new_word2phns) + new_phns_mid = [] + if is_append: + new_phns_right = [] + new_phns_mid = new_phns[left_idx:] + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + span_to_repl[1] = len(old_phns) - len(new_phns_right) + # speech edit + else: + for key in list(word2phns.keys())[::-1]: + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + new_phns_right = ['sp'] + new_phns_right + else: + idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx) + - sp_count)) + if idx + '_' + wrd in new_word2phns: + right_idx -= len(new_word2phns[idx + '_' + wrd]) + new_phns_right = word2phns[key].split() + new_phns_right + else: + span_to_repl[1] = len(old_phns) - len(new_phns_right) + new_phns_mid = new_phns[left_idx:right_idx] + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + if len(new_phns_mid) == 0: + span_to_add[1] = min(span_to_add[1] + 1, len(new_phns)) + span_to_add[0] = max(0, span_to_add[0] - 1) + span_to_repl[0] = max(0, span_to_repl[0] - 1) + span_to_repl[1] = min(span_to_repl[1] + 1, + len(old_phns)) + break + new_phns = new_phns_left + new_phns_mid + new_phns_right + ''' + For that reason cover should not be given. + For that reason cover is impossible to be given. + span_to_repl: [17, 23] "should not" + span_to_add: [17, 30] "is impossible to" + ''' + return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add + + +# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同 +# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放 +def get_dur_adj_factor(orig_dur: List[int], + pred_dur: List[int], + phns: List[str]): + length = 0 + factor_list = [] + for orig, pred, phn in zip(orig_dur, pred_dur, phns): + if pred == 0 or phn == 'sp': + continue + else: + factor_list.append(orig / pred) + factor_list = np.array(factor_list) + factor_list.sort() + if len(factor_list) < 5: + return 1 + length = 2 + avg = np.average(factor_list[length:-length]) + return avg + + +def prep_feats_with_dur(wav_path: str, + source_lang: str="English", + target_lang: str="English", + old_str: str="", + new_str: str="", + mask_reconstruct: bool=False, + duration_adjust: bool=True, + start_end_sp: bool=False, + fs: int=24000, + hop_length: int=300): + ''' + Returns: + np.ndarray: new wav, replace the part to be edited in original wav with 0 + List[str]: new phones + List[float]: mfa start of new wav + List[float]: mfa end of new wav + List[int]: masked mel boundary of original wav + List[int]: masked mel boundary of new wav + ''' + wav_org, _ = librosa.load(wav_path, sr=fs) + + mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang) + + if start_end_sp: + if new_phns[-1] != 'sp': + new_phns = new_phns + ['sp'] + # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 + if target_lang == "english" or target_lang == "chinese": + old_durs = eval_durs(old_phns, target_lang=source_lang) + else: + assert target_lang == "chinese" or target_lang == "english", \ + "calculate duration_predict is not support for this language..." + + orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)] + if '[MASK]' in new_str: + new_phns = old_phns + span_to_add = span_to_repl + d_factor_left = get_dur_adj_factor( + orig_dur=orig_old_durs[:span_to_repl[0]], + pred_dur=old_durs[:span_to_repl[0]], + phns=old_phns[:span_to_repl[0]]) + d_factor_right = get_dur_adj_factor( + orig_dur=orig_old_durs[span_to_repl[1]:], + pred_dur=old_durs[span_to_repl[1]:], + phns=old_phns[span_to_repl[1]:]) + d_factor = (d_factor_left + d_factor_right) / 2 + new_durs_adjusted = [d_factor * i for i in old_durs] + else: + if duration_adjust: + d_factor = get_dur_adj_factor( + orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns) + d_factor = d_factor * 1.25 + else: + d_factor = 1 + + if target_lang == "english" or target_lang == "chinese": + new_durs = eval_durs(new_phns, target_lang=target_lang) + else: + assert target_lang == "chinese" or target_lang == "english", \ + "calculate duration_predict is not support for this language..." + + new_durs_adjusted = [d_factor * i for i in new_durs] + + new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]]) + old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]]) + dur_offset = new_span_dur_sum - old_span_dur_sum + new_mfa_start = mfa_start[:span_to_repl[0]] + new_mfa_end = mfa_end[:span_to_repl[0]] + for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]: + if len(new_mfa_end) == 0: + new_mfa_start.append(0) + new_mfa_end.append(i) + else: + new_mfa_start.append(new_mfa_end[-1]) + new_mfa_end.append(new_mfa_end[-1] + i) + new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]] + new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]] + + # 3. get new wav + # 在原始句子后拼接 + if span_to_repl[0] >= len(mfa_start): + left_idx = len(wav_org) + right_idx = left_idx + # 在原始句子中间替换 + else: + left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs)) + right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs)) + blank_wav = np.zeros( + (int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype) + # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定 + new_wav = np.concatenate( + [wav_org[:left_idx], blank_wav, wav_org[right_idx:]]) + + # 4. get old and new mel span to be mask + # [92, 92] + + old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy( + mfa_start=mfa_start, + mfa_end=mfa_end, + fs=fs, + hop_length=hop_length, + span_to_repl=span_to_repl) + # [92, 174] + # new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别 + new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy( + mfa_start=new_mfa_start, + mfa_end=new_mfa_end, + fs=fs, + hop_length=hop_length, + span_to_repl=span_to_add) + + # old_span_bdy, new_span_bdy 是帧级别的范围 + return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy + + +def prep_feats(wav_path: str, + source_lang: str="english", + target_lang: str="english", + old_str: str="", + new_str: str="", + duration_adjust: bool=True, + start_end_sp: bool=False, + mask_reconstruct: bool=False, + fs: int=24000, + hop_length: int=300, + token_list: List[str]=[]): + wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur( + source_lang=source_lang, + target_lang=target_lang, + old_str=old_str, + new_str=new_str, + wav_path=wav_path, + duration_adjust=duration_adjust, + start_end_sp=start_end_sp, + mask_reconstruct=mask_reconstruct, + fs=fs, + hop_length=hop_length) + + token_to_id = {item: i for i, item in enumerate(token_list)} + text = np.array( + list(map(lambda x: token_to_id.get(x, token_to_id['']), phns))) + span_bdy = np.array(new_span_bdy) + + batch = [('1', { + "speech": wav, + "align_start": mfa_start, + "align_end": mfa_end, + "text": text, + "span_bdy": span_bdy + })] + + return batch, old_span_bdy, new_span_bdy + + +def decode_with_model(mlm_model: nn.Layer, + collate_fn, + wav_path: str, + source_lang: str="english", + target_lang: str="english", + old_str: str="", + new_str: str="", + use_teacher_forcing: bool=False, + duration_adjust: bool=True, + start_end_sp: bool=False, + fs: int=24000, + hop_length: int=300, + token_list: List[str]=[]): + batch, old_span_bdy, new_span_bdy = prep_feats( + source_lang=source_lang, + target_lang=target_lang, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + duration_adjust=duration_adjust, + start_end_sp=start_end_sp, + fs=fs, + hop_length=hop_length, + token_list=token_list) + + feats = collate_fn(batch)[1] + + if 'text_masked_pos' in feats.keys(): + feats.pop('text_masked_pos') + + output = mlm_model.inference( + text=feats['text'], + speech=feats['speech'], + masked_pos=feats['masked_pos'], + speech_mask=feats['speech_mask'], + text_mask=feats['text_mask'], + speech_seg_pos=feats['speech_seg_pos'], + text_seg_pos=feats['text_seg_pos'], + span_bdy=new_span_bdy, + use_teacher_forcing=use_teacher_forcing) + + # 拼接音频 + output_feat = paddle.concat(x=output, axis=0) + wav_org, _ = librosa.load(wav_path, sr=fs) + return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length + + +def get_mlm_output(wav_path: str, + model_name: str="paddle_checkpoint_en", + source_lang: str="english", + target_lang: str="english", + old_str: str="", + new_str: str="", + use_teacher_forcing: bool=False, + duration_adjust: bool=True, + start_end_sp: bool=False): + mlm_model, train_conf = load_model(model_name) + + collate_fn = build_mlm_collate_fn( + sr=train_conf.fs, + n_fft=train_conf.n_fft, + hop_length=train_conf.n_shift, + win_length=train_conf.win_length, + n_mels=train_conf.n_mels, + fmin=train_conf.fmin, + fmax=train_conf.fmax, + mlm_prob=train_conf.mlm_prob, + mean_phn_span=train_conf.mean_phn_span, + seg_emb=train_conf.model['enc_input_layer'] == 'sega_mlm') + + return decode_with_model( + source_lang=source_lang, + target_lang=target_lang, + mlm_model=mlm_model, + collate_fn=collate_fn, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + use_teacher_forcing=use_teacher_forcing, + duration_adjust=duration_adjust, + start_end_sp=start_end_sp, + fs=train_conf.fs, + hop_length=train_conf.n_shift, + token_list=train_conf.token_list) + + +def evaluate(uid: str, + source_lang: str="english", + target_lang: str="english", + prefix: os.PathLike="./prompt/dev/", + model_name: str="paddle_checkpoint_en", + new_str: str="", + prompt_decoding: bool=False, + task_name: str=None): + + # get origin text and path of origin wav + old_str, wav_path = read_data(uid=uid, prefix=prefix) + + if task_name == 'edit': + new_str = new_str + elif task_name == 'synthesize': + new_str = old_str + new_str + else: + new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)]) + + print('new_str is ', new_str) + + results_dict = get_wav( + source_lang=source_lang, + target_lang=target_lang, + model_name=model_name, + wav_path=wav_path, + old_str=old_str, + new_str=new_str) + return results_dict + + +if __name__ == "__main__": + # parse config and args + args = parse_args() + + data_dict = evaluate( + uid=args.uid, + source_lang=args.source_lang, + target_lang=args.target_lang, + prefix=args.prefix, + model_name=args.model_name, + new_str=args.new_str, + task_name=args.task_name) + sf.write(args.output_name, data_dict['output'], samplerate=24000) + print("finished...") diff --git a/examples/ernie_sat/local/sedit_arg_parser.py b/examples/ernie_sat/local/sedit_arg_parser.py index 21c6d0b4b..ad7e57191 100644 --- a/examples/ernie_sat/local/sedit_arg_parser.py +++ b/examples/ernie_sat/local/sedit_arg_parser.py @@ -1,3 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse diff --git a/examples/ernie_sat/local/utils.py b/examples/ernie_sat/local/utils.py index 836942a26..f2dce504a 100644 --- a/examples/ernie_sat/local/utils.py +++ b/examples/ernie_sat/local/utils.py @@ -1,3 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from pathlib import Path from typing import Dict from typing import List diff --git a/examples/ernie_sat/run_clone_en_to_zh_new.sh b/examples/ernie_sat/run_clone_en_to_zh_new.sh new file mode 100755 index 000000000..12fdf23f1 --- /dev/null +++ b/examples/ernie_sat/run_clone_en_to_zh_new.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -e +source path.sh + +# en --> zh 的 语音合成 +# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好' +# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。 + +python local/inference_new.py \ + --task_name=cross-lingual_clone \ + --model_name=paddle_checkpoint_dual_mask_enzh \ + --uid=Prompt_003_new \ + --new_str='今天天气很好.' \ + --prefix='./prompt/dev/' \ + --source_lang=english \ + --target_lang=chinese \ + --output_name=pred_clone.wav \ + --voc=pwgan_aishell3 \ + --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --am=fastspeech2_csmsc \ + --am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \ + --am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \ + --am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \ + --phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/run_gen_en_new.sh b/examples/ernie_sat/run_gen_en_new.sh new file mode 100755 index 000000000..d76b00430 --- /dev/null +++ b/examples/ernie_sat/run_gen_en_new.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +set -e +source path.sh + +# 纯英文的语音合成 +# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.' + +python local/inference_new.py \ + --task_name=synthesize \ + --model_name=paddle_checkpoint_en \ + --uid=p299_096 \ + --new_str='I enjoy my life, do you?' \ + --prefix='./prompt/dev/' \ + --source_lang=english \ + --target_lang=english \ + --output_name=pred_gen.wav \ + --voc=pwgan_aishell3 \ + --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --am=fastspeech2_ljspeech \ + --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ + --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ + --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ + --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/run_sedit_en_new.sh b/examples/ernie_sat/run_sedit_en_new.sh new file mode 100755 index 000000000..0952d280c --- /dev/null +++ b/examples/ernie_sat/run_sedit_en_new.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -e +source path.sh + +# 纯英文的语音编辑 +# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音 +# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作 + +python local/inference_new.py \ + --task_name=edit \ + --model_name=paddle_checkpoint_en \ + --uid=p243_new \ + --new_str='for that reason cover is impossible to be given.' \ + --prefix='./prompt/dev/' \ + --source_lang=english \ + --target_lang=english \ + --output_name=pred_edit.wav \ + --voc=pwgan_aishell3 \ + --voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --am=fastspeech2_ljspeech \ + --am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \ + --am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \ + --am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \ + --phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt diff --git a/examples/ernie_sat/test_run_new.sh b/examples/ernie_sat/test_run_new.sh new file mode 100755 index 000000000..bf8a4e02d --- /dev/null +++ b/examples/ernie_sat/test_run_new.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +rm -rf *.wav +./run_sedit_en_new.sh # 语音编辑任务(英文) +./run_gen_en_new.sh # 个性化语音合成任务(英文) +./run_clone_en_to_zh_new.sh # 跨语言语音合成任务(英文到中文的语音克隆) \ No newline at end of file diff --git a/examples/vctk/ernie_sat/conf/default.yaml b/examples/vctk/ernie_sat/conf/default.yaml new file mode 100644 index 000000000..672f937ef --- /dev/null +++ b/examples/vctk/ernie_sat/conf/default.yaml @@ -0,0 +1,163 @@ +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### + +fs: 24000 # sr +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + +# Only used for feats_type != raw + +fmin: 80 # Minimum frequency of Mel basis. +fmax: 7600 # Maximum frequency of Mel basis. +n_mels: 80 # The number of mel basis. + +mean_phn_span: 8 +mlm_prob: 0.8 + +########################################################### +# DATA SETTING # +########################################################### +batch_size: 20 +num_workers: 2 + +########################################################### +# MODEL SETTING # +########################################################### +model: + text_masking: false + postnet_layers: 5 + postnet_filts: 5 + postnet_chans: 256 + encoder_type: conformer + decoder_type: conformer + enc_input_layer: sega_mlm + enc_pre_speech_layer: 0 + enc_cnn_module_kernel: 7 + enc_attention_dim: 384 + enc_attention_heads: 2 + enc_linear_units: 1536 + enc_num_blocks: 4 + enc_dropout_rate: 0.2 + enc_positional_dropout_rate: 0.2 + enc_attention_dropout_rate: 0.2 + enc_normalize_before: true + enc_macaron_style: true + enc_use_cnn_module: true + enc_selfattention_layer_type: legacy_rel_selfattn + enc_activation_type: swish + enc_pos_enc_layer_type: legacy_rel_pos + enc_positionwise_layer_type: conv1d + enc_positionwise_conv_kernel_size: 3 + dec_cnn_module_kernel: 31 + dec_attention_dim: 384 + dec_attention_heads: 2 + dec_linear_units: 1536 + dec_num_blocks: 4 + dec_dropout_rate: 0.2 + dec_positional_dropout_rate: 0.2 + dec_attention_dropout_rate: 0.2 + dec_macaron_style: true + dec_use_cnn_module: true + dec_selfattention_layer_type: legacy_rel_selfattn + dec_activation_type: swish + dec_pos_enc_layer_type: legacy_rel_pos + dec_positionwise_layer_type: conv1d + dec_positionwise_conv_kernel_size: 3 + +########################################################### +# OPTIMIZER SETTING # +########################################################### +scheduler_params: + d_model: 384 + warmup_steps: 4000 +grad_clip: 1.0 + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 1500 +num_snapshots: 50 + +########################################################### +# OTHER SETTING # +########################################################### +seed: 0 + +token_list: +- +- +- AH0 +- T +- N +- sp +- D +- S +- R +- L +- IH1 +- DH +- AE1 +- M +- EH1 +- K +- Z +- W +- HH +- ER0 +- AH1 +- IY1 +- P +- V +- F +- B +- AY1 +- IY0 +- EY1 +- AA1 +- AO1 +- UW1 +- IH0 +- OW1 +- NG +- G +- SH +- ER1 +- Y +- TH +- AW1 +- CH +- UH1 +- IH2 +- JH +- OW0 +- EH2 +- OY1 +- AY2 +- EH0 +- EY2 +- UW0 +- AE2 +- AA2 +- OW2 +- AH2 +- ZH +- AO2 +- IY2 +- AE0 +- UW2 +- AY0 +- AA0 +- AO0 +- AW2 +- EY0 +- UH2 +- ER2 +- OY2 +- UH0 +- AW0 +- OY0 +- diff --git a/examples/vctk/ernie_sat/local/preprocess.sh b/examples/vctk/ernie_sat/local/preprocess.sh new file mode 100755 index 000000000..a0a3881f0 --- /dev/null +++ b/examples/vctk/ernie_sat/local/preprocess.sh @@ -0,0 +1,61 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./vctk_alignment \ + --output durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=vctk \ + --rootdir=~/datasets/VCTK-Corpus-0.92/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="speech" +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # normalize and covert phone/speaker to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt +fi diff --git a/examples/vctk/ernie_sat/local/synthesize.sh b/examples/vctk/ernie_sat/local/synthesize.sh new file mode 100755 index 000000000..b24db018a --- /dev/null +++ b/examples/vctk/ernie_sat/local/synthesize.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +stage=1 +stop_stage=1 + +# use am to predict duration here +# 增加 am_phones_dict am_tones_dict 等,也可以用新的方式构造 am, 不需要这么多参数了就 + +# pwgan +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --erniesat_config=${config_path} \ + --erniesat_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --erniesat_stat=dump/train/speech_stats.npy \ + --voc=pwgan_vctk \ + --voc_config=pwg_vctk_ckpt_0.1.1/default.yaml \ + --voc_ckpt=pwg_vctk_ckpt_0.1.1/snapshot_iter_1500000.pdz \ + --voc_stat=pwg_vctk_ckpt_0.1.1/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi + +# hifigan +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + FLAGS_allocator_strategy=naive_best_fit \ + FLAGS_fraction_of_gpu_memory_to_use=0.01 \ + python3 ${BIN_DIR}/synthesize.py \ + --erniesat_config=${config_path} \ + --erniesat_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --erniesat_stat=dump/train/speech_stats.npy \ + --voc=hifigan_vctk \ + --voc_config=hifigan_vctk_ckpt_0.2.0/default.yaml \ + --voc_ckpt=hifigan_vctk_ckpt_0.2.0/snapshot_iter_2500000.pdz \ + --voc_stat=hifigan_vctk_ckpt_0.2.0/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt +fi diff --git a/examples/vctk/ernie_sat/local/train.sh b/examples/vctk/ernie_sat/local/train.sh new file mode 100755 index 000000000..30720e8f5 --- /dev/null +++ b/examples/vctk/ernie_sat/local/train.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=2 \ + --phones-dict=dump/phone_id_map.txt \ No newline at end of file diff --git a/examples/vctk/ernie_sat/path.sh b/examples/vctk/ernie_sat/path.sh new file mode 100755 index 000000000..4ecab0251 --- /dev/null +++ b/examples/vctk/ernie_sat/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=ernie_sat +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} \ No newline at end of file diff --git a/examples/vctk/ernie_sat/run.sh b/examples/vctk/ernie_sat/run.sh new file mode 100755 index 000000000..d75a19f23 --- /dev/null +++ b/examples/vctk/ernie_sat/run.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_153.pdz + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi diff --git a/examples/vctk/tts3/conf/default.yaml b/examples/vctk/tts3/conf/default.yaml index 1bca9107b..a75658d3d 100644 --- a/examples/vctk/tts3/conf/default.yaml +++ b/examples/vctk/tts3/conf/default.yaml @@ -24,7 +24,7 @@ f0max: 400 # Maximum f0 for pitch extraction. # DATA SETTING # ########################################################### batch_size: 64 -num_workers: 4 +num_workers: 2 ########################################################### @@ -88,8 +88,8 @@ updater: # OPTIMIZER SETTING # ########################################################### optimizer: - optim: adam # optimizer type - learning_rate: 0.001 # learning rate + optim: adam # optimizer type + learning_rate: 0.001 # learning rate ########################################################### # TRAINING SETTING # diff --git a/paddlespeech/t2s/datasets/am_batch_fn.py b/paddlespeech/t2s/datasets/am_batch_fn.py index 1c70b1cdc..2cb7a11a2 100644 --- a/paddlespeech/t2s/datasets/am_batch_fn.py +++ b/paddlespeech/t2s/datasets/am_batch_fn.py @@ -28,6 +28,150 @@ from paddlespeech.t2s.modules.nets_utils import phones_masking from paddlespeech.t2s.modules.nets_utils import phones_text_masking +# 因为要传参数,所以需要额外构建 +def build_erniesat_collate_fn(mlm_prob: float=0.8, + mean_phn_span: int=8, + seg_emb: bool=False, + text_masking: bool=False): + + return ErnieSATCollateFn( + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span, + seg_emb=seg_emb, + text_masking=text_masking) + + +class ErnieSATCollateFn: + """Functor class of common_collate_fn()""" + + def __init__(self, + mlm_prob: float=0.8, + mean_phn_span: int=8, + seg_emb: bool=False, + text_masking: bool=False): + self.mlm_prob = mlm_prob + self.mean_phn_span = mean_phn_span + self.seg_emb = seg_emb + self.text_masking = text_masking + + def __call__(self, exmaples): + return erniesat_batch_fn( + exmaples, + mlm_prob=self.mlm_prob, + mean_phn_span=self.mean_phn_span, + seg_emb=self.seg_emb, + text_masking=self.text_masking) + + +def erniesat_batch_fn(examples, + mlm_prob: float=0.8, + mean_phn_span: int=8, + seg_emb: bool=False, + text_masking: bool=False): + # fields = ["text", "text_lengths", "speech", "speech_lengths", "align_start", "align_end"] + text = [np.array(item["text"], dtype=np.int64) for item in examples] + speech = [np.array(item["speech"], dtype=np.float32) for item in examples] + + text_lengths = [ + np.array(item["text_lengths"], dtype=np.int64) for item in examples + ] + speech_lengths = [ + np.array(item["speech_lengths"], dtype=np.int64) for item in examples + ] + + align_start = [ + np.array(item["align_start"], dtype=np.int64) for item in examples + ] + + align_end = [ + np.array(item["align_end"], dtype=np.int64) for item in examples + ] + + align_start_lengths = [ + np.array(len(item["align_start"]), dtype=np.int64) for item in examples + ] + + # add_pad + text = batch_sequences(text) + speech = batch_sequences(speech) + align_start = batch_sequences(align_start) + align_end = batch_sequences(align_end) + + # convert each batch to paddle.Tensor + text = paddle.to_tensor(text) + speech = paddle.to_tensor(speech) + text_lengths = paddle.to_tensor(text_lengths) + speech_lengths = paddle.to_tensor(speech_lengths) + align_start_lengths = paddle.to_tensor(align_start_lengths) + + speech_pad = speech + text_pad = text + + text_mask = make_non_pad_mask( + text_lengths, text_pad, length_dim=1).unsqueeze(-2) + speech_mask = make_non_pad_mask( + speech_lengths, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) + + # for training + span_bdy = None + # for inference + if 'span_bdy' in examples[0].keys(): + span_bdy = [ + np.array(item["span_bdy"], dtype=np.int64) for item in examples + ] + span_bdy = paddle.to_tensor(span_bdy) + + # dual_mask 的是混合中英时候同时 mask 语音和文本 + # ernie sat 在实现跨语言的时候都 mask 了 + if text_masking: + masked_pos, text_masked_pos = phones_text_masking( + xs_pad=speech_pad, + src_mask=speech_mask, + text_pad=text_pad, + text_mask=text_mask, + align_start=align_start, + align_end=align_end, + align_start_lens=align_start_lengths, + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span, + span_bdy=span_bdy) + # 训练纯中文和纯英文的 -> a3t 没有对 phoneme 做 mask, 只对语音 mask 了 + # a3t 和 ernie sat 的区别主要在于做 mask 的时候 + else: + masked_pos = phones_masking( + xs_pad=speech_pad, + src_mask=speech_mask, + align_start=align_start, + align_end=align_end, + align_start_lens=align_start_lengths, + mlm_prob=mlm_prob, + mean_phn_span=mean_phn_span, + span_bdy=span_bdy) + text_masked_pos = paddle.zeros(paddle.shape(text_pad)) + + speech_seg_pos, text_seg_pos = get_seg_pos( + speech_pad=speech_pad, + text_pad=text_pad, + align_start=align_start, + align_end=align_end, + align_start_lens=align_start_lengths, + seg_emb=seg_emb) + + batch = { + "text": text, + "speech": speech, + # need to generate + "masked_pos": masked_pos, + "speech_mask": speech_mask, + "text_mask": text_mask, + "speech_seg_pos": speech_seg_pos, + "text_seg_pos": text_seg_pos, + "text_masked_pos": text_masked_pos + } + + return batch + + def tacotron2_single_spk_batch_fn(examples): # fields = ["text", "text_lengths", "speech", "speech_lengths"] text = [np.array(item["text"], dtype=np.int64) for item in examples] @@ -378,7 +522,6 @@ class MLMCollateFn: mean_phn_span=self.mean_phn_span, seg_emb=self.seg_emb, text_masking=self.text_masking, - attention_window=self.attention_window, not_sequence=self.not_sequence) @@ -389,7 +532,6 @@ def mlm_collate_fn( mean_phn_span: int=8, seg_emb: bool=False, text_masking: bool=False, - attention_window: int=0, pad_value: int=0, not_sequence: Collection[str]=(), ) -> Tuple[List[str], Dict[str, paddle.Tensor]]: @@ -420,6 +562,7 @@ def mlm_collate_fn( feats = feats_extract.get_log_mel_fbank(np.array(output["speech"][0])) feats = paddle.to_tensor(feats) + print("feats.shape:", feats.shape) feats_lens = paddle.shape(feats)[0] feats = paddle.unsqueeze(feats, 0) @@ -439,6 +582,7 @@ def mlm_collate_fn( text_lens, text_pad, length_dim=1).unsqueeze(-2) speech_mask = make_non_pad_mask( feats_lens, speech_pad[:, :, 0], length_dim=1).unsqueeze(-2) + span_bdy = None if 'span_bdy' in output.keys(): span_bdy = output['span_bdy'] diff --git a/paddlespeech/t2s/exps/ernie_sat/__init__.py b/paddlespeech/t2s/exps/ernie_sat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlespeech/t2s/exps/ernie_sat/align.py b/paddlespeech/t2s/exps/ernie_sat/align.py new file mode 100755 index 000000000..529a8221c --- /dev/null +++ b/paddlespeech/t2s/exps/ernie_sat/align.py @@ -0,0 +1,386 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import shutil +from pathlib import Path + +import librosa +import numpy as np +import pypinyin +from praatio import textgrid +from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name +from paddlespeech.t2s.exps.ernie_sat.utils import get_dict + + +DICT_EN = 'tools/aligner/cmudict-0.7b' +DICT_ZH = 'tools/aligner/simple.lexicon' +MODEL_DIR_EN = 'tools/aligner/vctk_model.zip' +MODEL_DIR_ZH = 'tools/aligner/aishell3_model.zip' +MFA_PATH = 'tools/montreal-forced-aligner/bin' +os.environ['PATH'] = MFA_PATH + '/:' + os.environ['PATH'] + +def _get_max_idx(dic): + return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1] + + +def _readtg(tg_path: str, lang: str='en', fs: int=24000, n_shift: int=300): + alignment = textgrid.openTextgrid(tg_path, includeEmptyIntervals=True) + phones = [] + ends = [] + words = [] + + for interval in alignment.tierDict['words'].entryList: + word = interval.label + if word: + words.append(word) + for interval in alignment.tierDict['phones'].entryList: + phone = interval.label + phones.append(phone) + ends.append(interval.end) + frame_pos = librosa.time_to_frames(ends, sr=fs, hop_length=n_shift) + durations = np.diff(frame_pos, prepend=0) + assert len(durations) == len(phones) + # merge '' and sp in the end + if phones[-1] == '' and len(phones) > 1 and phones[-2] == 'sp': + phones = phones[:-1] + durations[-2] += durations[-1] + durations = durations[:-1] + + # replace ' and 'sil' with 'sp' + phones = ['sp' if (phn == '' or phn == 'sil') else phn for phn in phones] + + if lang == 'en': + DICT = DICT_EN + + elif lang == 'zh': + DICT = DICT_ZH + + word2phns_dict = get_dict(DICT) + + phn2word_dict = [] + for word in words: + if lang == 'en': + word = word.upper() + phn2word_dict.append([word2phns_dict[word].split(), word]) + + non_sp_idx = 0 + word_idx = 0 + i = 0 + word2phns = {} + while i < len(phones): + phn = phones[i] + if phn == 'sp': + word2phns[str(word_idx) + '_sp'] = ['sp'] + i += 1 + else: + phns, word = phn2word_dict[non_sp_idx] + word2phns[str(word_idx) + '_' + word] = phns + non_sp_idx += 1 + i += len(phns) + word_idx += 1 + sum_phn = sum(len(word2phns[k]) for k in word2phns) + assert sum_phn == len(phones) + + results = '' + for (p, d) in zip(phones, durations): + results += p + ' ' + str(d) + ' ' + return results.strip(), word2phns + + +def alignment(wav_path: str, + text: str, + fs: int=24000, + lang='en', + n_shift: int=300): + wav_name = os.path.basename(wav_path) + utt = wav_name.split('.')[0] + # prepare data for MFA + tmp_name = get_tmp_name(text=text) + tmpbase = './tmp_dir/' + tmp_name + tmpbase = Path(tmpbase) + tmpbase.mkdir(parents=True, exist_ok=True) + print("tmp_name in alignment:",tmp_name) + + shutil.copyfile(wav_path, tmpbase / wav_name) + txt_name = utt + '.txt' + txt_path = tmpbase / txt_name + with open(txt_path, 'w') as wf: + wf.write(text + '\n') + # MFA + if lang == 'en': + DICT = DICT_EN + MODEL_DIR = MODEL_DIR_EN + + elif lang == 'zh': + DICT = DICT_ZH + MODEL_DIR = MODEL_DIR_ZH + else: + print('please input right lang!!') + + CMD = 'mfa_align' + ' ' + str( + tmpbase) + ' ' + DICT + ' ' + MODEL_DIR + ' ' + str(tmpbase) + os.system(CMD) + tg_path = str(tmpbase) + '/' + tmp_name + '/' + utt + '.TextGrid' + phn_dur, word2phns = _readtg(tg_path, lang=lang) + phn_dur = phn_dur.split() + phns = phn_dur[::2] + durs = phn_dur[1::2] + durs = [int(d) for d in durs] + assert len(phns) == len(durs) + return phns, durs, word2phns + + +def words2phns(text: str, lang='en'): + ''' + Args: + text (str): + input text. + eg: for that reason cover is impossible to be given. + lang (str): + 'en' or 'zh' + Returns: + List[str]: phones of input text. + eg: + ['F', 'AO1', 'R', 'DH', 'AE1', 'T', 'R', 'IY1', 'Z', 'AH0', 'N', 'K', 'AH1', 'V', 'ER0', + 'IH1', 'Z', 'IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L', 'T', 'UW1', 'B', 'IY1', + 'G', 'IH1', 'V', 'AH0', 'N'] + + Dict(str, str): key - idx_word + value - phones + eg: + {'0_FOR': ['F', 'AO1', 'R'], '1_THAT': ['DH', 'AE1', 'T'], + '2_REASON': ['R', 'IY1', 'Z', 'AH0', 'N'],'3_COVER': ['K', 'AH1', 'V', 'ER0'], '4_IS': ['IH1', 'Z'], + '5_IMPOSSIBLE': ['IH2', 'M', 'P', 'AA1', 'S', 'AH0', 'B', 'AH0', 'L'], + '6_TO': ['T', 'UW1'], '7_BE': ['B', 'IY1'], '8_GIVEN': ['G', 'IH1', 'V', 'AH0', 'N']} + ''' + text = text.strip() + words = [] + for pun in [ + ',', '.', ':', ';', '!', '?', '"', '(', ')', '--', '---', u',', + u'。', u':', u';', u'!', u'?', u'(', u')' + ]: + text = text.replace(pun, ' ') + for wrd in text.split(): + if (wrd[-1] == '-'): + wrd = wrd[:-1] + if (wrd[0] == "'"): + wrd = wrd[1:] + if wrd: + words.append(wrd) + if lang == 'en': + dictfile = DICT_EN + elif lang == 'zh': + dictfile = DICT_ZH + else: + print('please input right lang!!') + + word2phns_dict = get_dict(dictfile) + ds = word2phns_dict.keys() + phns = [] + wrd2phns = {} + for index, wrd in enumerate(words): + if lang == 'en': + wrd = wrd.upper() + if (wrd not in ds): + wrd2phns[str(index) + '_' + wrd] = 'spn' + phns.extend('spn') + else: + wrd2phns[str(index) + '_' + wrd] = word2phns_dict[wrd].split() + phns.extend(word2phns_dict[wrd].split()) + return phns, wrd2phns + + +def get_phns_spans(wav_path: str, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + fs: int=24000, + n_shift: int=300): + is_append = (old_str == new_str[:len(old_str)]) + old_phns, mfa_start, mfa_end = [], [], [] + # source + lang = source_lang + phn, dur, w2p = alignment( + wav_path=wav_path, text=old_str, lang=lang, fs=fs, n_shift=n_shift) + + new_d_cumsum = np.pad(np.array(dur).cumsum(0), (1, 0), 'constant').tolist() + mfa_start = new_d_cumsum[:-1] + mfa_end = new_d_cumsum[1:] + old_phns = phn + + # target + if is_append and (source_lang != target_lang): + cross_lingual_clone = True + else: + cross_lingual_clone = False + + if cross_lingual_clone: + str_origin = new_str[:len(old_str)] + str_append = new_str[len(old_str):] + + if target_lang == 'zh': + phns_origin, origin_w2p = words2phns(str_origin, lang='en') + phns_append, append_w2p_tmp = words2phns(str_append, lang='zh') + elif target_lang == 'en': + # 原始句子 + phns_origin, origin_w2p = words2phns(str_origin, lang='zh') + # clone 句子 + phns_append, append_w2p_tmp = words2phns(str_append, lang='en') + else: + assert target_lang == 'zh' or target_lang == 'en', \ + 'cloning is not support for this language, please check it.' + + new_phns = phns_origin + phns_append + + append_w2p = {} + length = len(origin_w2p) + for key, value in append_w2p_tmp.items(): + idx, wrd = key.split('_') + append_w2p[str(int(idx) + length) + '_' + wrd] = value + new_w2p = origin_w2p.copy() + new_w2p.update(append_w2p) + + else: + if source_lang == target_lang: + new_phns, new_w2p = words2phns(new_str, lang=source_lang) + else: + assert source_lang == target_lang, \ + 'source language is not same with target language...' + + span_to_repl = [0, len(old_phns) - 1] + span_to_add = [0, len(new_phns) - 1] + left_idx = 0 + new_phns_left = [] + sp_count = 0 + # find the left different index + # 因为可能 align 时候的 words2phns 和直接 words2phns, 前者会有 sp? + for key in w2p.keys(): + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + new_phns_left.append('sp') + else: + idx = str(int(idx) - sp_count) + if idx + '_' + wrd in new_w2p: + # 是 new_str phn 序列的 index + left_idx += len(new_w2p[idx + '_' + wrd]) + # old phn 序列 + new_phns_left.extend(w2p[key]) + else: + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + break + + # reverse w2p and new_w2p + right_idx = 0 + new_phns_right = [] + sp_count = 0 + w2p_max_idx = _get_max_idx(w2p) + new_w2p_max_idx = _get_max_idx(new_w2p) + new_phns_mid = [] + if is_append: + new_phns_right = [] + new_phns_mid = new_phns[left_idx:] + span_to_repl[0] = len(new_phns_left) + span_to_add[0] = len(new_phns_left) + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + span_to_repl[1] = len(old_phns) - len(new_phns_right) + # speech edit + else: + for key in list(w2p.keys())[::-1]: + idx, wrd = key.split('_') + if wrd == 'sp': + sp_count += 1 + new_phns_right = ['sp'] + new_phns_right + else: + idx = str(new_w2p_max_idx - (w2p_max_idx - int(idx) - sp_count)) + if idx + '_' + wrd in new_w2p: + right_idx -= len(new_w2p[idx + '_' + wrd]) + new_phns_right = w2p[key] + new_phns_right + else: + span_to_repl[1] = len(old_phns) - len(new_phns_right) + new_phns_mid = new_phns[left_idx:right_idx] + span_to_add[1] = len(new_phns_left) + len(new_phns_mid) + if len(new_phns_mid) == 0: + span_to_add[1] = min(span_to_add[1] + 1, len(new_phns)) + span_to_add[0] = max(0, span_to_add[0] - 1) + span_to_repl[0] = max(0, span_to_repl[0] - 1) + span_to_repl[1] = min(span_to_repl[1] + 1, + len(old_phns)) + break + new_phns = new_phns_left + new_phns_mid + new_phns_right + ''' + For that reason cover should not be given. + For that reason cover is impossible to be given. + span_to_repl: [17, 23] "should not" + span_to_add: [17, 30] "is impossible to" + ''' + outs = {} + outs['mfa_start'] = mfa_start + outs['mfa_end'] = mfa_end + outs['old_phns'] = old_phns + outs['new_phns'] = new_phns + outs['span_to_repl'] = span_to_repl + outs['span_to_add'] = span_to_add + + return outs + + +if __name__ == '__main__': + text = "For that reason cover should not be given." + phn, dur, word2phns = alignment("exp/p243_313.wav", text, lang='en') + print(phn, dur) + print(word2phns) + print("---------------------------------") + # 这里可以用我们的中文前端得到 pinyin 序列 + text_zh = "卡尔普陪外孙玩滑梯。" + text_zh = pypinyin.lazy_pinyin( + text_zh, + neutral_tone_with_five=True, + style=pypinyin.Style.TONE3, + tone_sandhi=True) + text_zh = " ".join(text_zh) + phn, dur, word2phns = alignment("exp/000001.wav", text_zh, lang='zh') + print(phn, dur) + print(word2phns) + print("---------------------------------") + phns, wrd2phns = words2phns(text, lang='en') + print("phns:", phns) + print("wrd2phns:", wrd2phns) + print("---------------------------------") + + phns, wrd2phns = words2phns(text_zh, lang='zh') + print("phns:", phns) + print("wrd2phns:", wrd2phns) + print("---------------------------------") + + outs = get_phns_spans( + wav_path="exp/p243_313.wav", + old_str="For that reason cover should not be given.", + new_str="for that reason cover is impossible to be given.") + + mfa_start = outs["mfa_start"] + mfa_end = outs["mfa_end"] + old_phns = outs["old_phns"] + new_phns = outs["new_phns"] + span_to_repl = outs["span_to_repl"] + span_to_add = outs["span_to_add"] + print("mfa_start:", mfa_start) + print("mfa_end:", mfa_end) + print("old_phns:", old_phns) + print("new_phns:", new_phns) + print("span_to_repl:", span_to_repl) + print("span_to_add:", span_to_add) + print("---------------------------------") diff --git a/paddlespeech/t2s/exps/ernie_sat/normalize.py b/paddlespeech/t2s/exps/ernie_sat/normalize.py new file mode 100644 index 000000000..74cdae2a6 --- /dev/null +++ b/paddlespeech/t2s/exps/ernie_sat/normalize.py @@ -0,0 +1,130 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Normalize feature files and dump them.""" +import argparse +import logging +from operator import itemgetter +from pathlib import Path + +import jsonlines +import numpy as np +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from paddlespeech.t2s.datasets.data_table import DataTable + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." + ) + parser.add_argument( + "--metadata", + type=str, + required=True, + help="directory including feature files to be normalized. " + "you need to specify either *-scp or rootdir.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump normalized feature files.") + parser.add_argument( + "--speech-stats", + type=str, + required=True, + help="speech statistics file.") + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + parser.add_argument( + "--speaker-dict", type=str, default=None, help="speaker id map file.") + + args = parser.parse_args() + + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, converters={ + "speech": np.load, + }) + logging.info(f"The number of files = {len(dataset)}.") + + # restore scaler + speech_scaler = StandardScaler() + speech_scaler.mean_ = np.load(args.speech_stats)[0] + speech_scaler.scale_ = np.load(args.speech_stats)[1] + speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0] + + vocab_phones = {} + with open(args.phones_dict, 'rt') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + vocab_phones[phn] = int(id) + + vocab_speaker = {} + with open(args.speaker_dict, 'rt') as f: + spk_id = [line.strip().split() for line in f.readlines()] + for spk, id in spk_id: + vocab_speaker[spk] = int(id) + + # process each file + output_metadata = [] + + for item in tqdm(dataset): + utt_id = item['utt_id'] + speech = item['speech'] + + # normalize + speech = speech_scaler.transform(speech) + speech_dir = dumpdir / "data_speech" + speech_dir.mkdir(parents=True, exist_ok=True) + speech_path = speech_dir / f"{utt_id}_speech.npy" + np.save(speech_path, speech.astype(np.float32), allow_pickle=False) + + phone_ids = [vocab_phones[p] for p in item['phones']] + spk_id = vocab_speaker[item["speaker"]] + record = { + "utt_id": item['utt_id'], + "spk_id": spk_id, + "text": phone_ids, + "text_lengths": item['text_lengths'], + "speech_lengths": item['speech_lengths'], + "durations": item['durations'], + "speech": str(speech_path), + "align_start": item['align_start'], + "align_end": item['align_end'], + } + # add spk_emb for voice cloning + if "spk_emb" in item: + record["spk_emb"] = str(item["spk_emb"]) + + output_metadata.append(record) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/ernie_sat/preprocess.py b/paddlespeech/t2s/exps/ernie_sat/preprocess.py new file mode 100644 index 000000000..fc9e0888b --- /dev/null +++ b/paddlespeech/t2s/exps/ernie_sat/preprocess.py @@ -0,0 +1,342 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import os +from concurrent.futures import ThreadPoolExecutor +from operator import itemgetter +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List + +import jsonlines +import librosa +import numpy as np +import tqdm +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length +from paddlespeech.t2s.datasets.preprocess_utils import get_input_token +from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur +from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map +from paddlespeech.t2s.datasets.preprocess_utils import merge_silence +from paddlespeech.t2s.utils import str2bool + + +def process_sentence(config: Dict[str, Any], + fp: Path, + sentences: Dict, + output_dir: Path, + mel_extractor=None, + cut_sil: bool=True, + spk_emb_dir: Path=None): + utt_id = fp.stem + # for vctk + if utt_id.endswith("_mic2"): + utt_id = utt_id[:-5] + record = None + if utt_id in sentences: + # reading, resampling may occur + wav, _ = librosa.load(str(fp), sr=config.fs) + if len(wav.shape) != 1: + return record + max_value = np.abs(wav).max() + if max_value > 1.0: + wav = wav / max_value + assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs(wav).max( + ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + speaker = sentences[utt_id][2] + d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant') + + # little imprecise than use *.TextGrid directly + times = librosa.frames_to_time( + d_cumsum, sr=config.fs, hop_length=config.n_shift) + if cut_sil: + start = 0 + end = d_cumsum[-1] + if phones[0] == "sil" and len(durations) > 1: + start = times[1] + durations = durations[1:] + phones = phones[1:] + if phones[-1] == 'sil' and len(durations) > 1: + end = times[-2] + durations = durations[:-1] + phones = phones[:-1] + sentences[utt_id][0] = phones + sentences[utt_id][1] = durations + start, end = librosa.time_to_samples([start, end], sr=config.fs) + wav = wav[start:end] + + # extract mel feats + logmel = mel_extractor.get_log_mel_fbank(wav) + # change duration according to mel_length + compare_duration_and_mel_length(sentences, utt_id, logmel) + # utt_id may be popped in compare_duration_and_mel_length + if utt_id not in sentences: + return None + phones = sentences[utt_id][0] + durations = sentences[utt_id][1] + num_frames = logmel.shape[0] + assert sum(durations) == num_frames + + new_d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant') + align_start = new_d_cumsum[:-1] + align_end = new_d_cumsum[1:] + assert len(align_start) == len(align_end) == len(durations) + + mel_dir = output_dir / "data_speech" + mel_dir.mkdir(parents=True, exist_ok=True) + mel_path = mel_dir / (utt_id + "_speech.npy") + np.save(mel_path, logmel) + # align_start_lengths == text_lengths + record = { + "utt_id": utt_id, + "phones": phones, + "text_lengths": len(phones), + "speech_lengths": num_frames, + "durations": durations, + "speech": str(mel_path), + "speaker": speaker, + "align_start": align_start.tolist(), + "align_end": align_end.tolist(), + } + if spk_emb_dir: + if speaker in os.listdir(spk_emb_dir): + embed_name = utt_id + ".npy" + embed_path = spk_emb_dir / speaker / embed_name + if embed_path.is_file(): + record["spk_emb"] = str(embed_path) + else: + return None + return record + + +def process_sentences(config, + fps: List[Path], + sentences: Dict, + output_dir: Path, + mel_extractor=None, + nprocs: int=1, + cut_sil: bool=True, + spk_emb_dir: Path=None): + if nprocs == 1: + results = [] + for fp in tqdm.tqdm(fps, total=len(fps)): + record = process_sentence( + config=config, + fp=fp, + sentences=sentences, + output_dir=output_dir, + mel_extractor=mel_extractor, + cut_sil=cut_sil, + spk_emb_dir=spk_emb_dir) + if record: + results.append(record) + else: + with ThreadPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp in fps: + future = pool.submit(process_sentence, config, fp, + sentences, output_dir, mel_extractor, + cut_sil, spk_emb_dir) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + record = ft.result() + if record: + results.append(record) + + results.sort(key=itemgetter("utt_id")) + # replace 'w' with 'a' to write from the end of file + with jsonlines.open(output_dir / "metadata.jsonl", 'a') as writer: + for item in results: + writer.write(item) + print("Done") + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features.") + + parser.add_argument( + "--dataset", + default="baker", + type=str, + help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now") + + parser.add_argument( + "--rootdir", default=None, type=str, help="directory to dataset.") + + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.") + parser.add_argument( + "--dur-file", default=None, type=str, help="path to durations.txt.") + + parser.add_argument("--config", type=str, help="fastspeech2 config file.") + + parser.add_argument( + "--num-cpu", type=int, default=1, help="number of process.") + + parser.add_argument( + "--cut-sil", + type=str2bool, + default=True, + help="whether cut sil in the edge of audio") + + parser.add_argument( + "--spk_emb_dir", + default=None, + type=str, + help="directory to speaker embedding files.") + args = parser.parse_args() + + rootdir = Path(args.rootdir).expanduser() + dumpdir = Path(args.dumpdir).expanduser() + # use absolute path + dumpdir = dumpdir.resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + dur_file = Path(args.dur_file).expanduser() + + if args.spk_emb_dir: + spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve() + else: + spk_emb_dir = None + + assert rootdir.is_dir() + assert dur_file.is_file() + + with open(args.config, 'rt') as f: + config = CfgNode(yaml.safe_load(f)) + + sentences, speaker_set = get_phn_dur(dur_file) + + merge_silence(sentences) + phone_id_map_path = dumpdir / "phone_id_map.txt" + speaker_id_map_path = dumpdir / "speaker_id_map.txt" + get_input_token(sentences, phone_id_map_path, args.dataset) + get_spk_id_map(speaker_set, speaker_id_map_path) + + if args.dataset == "baker": + wav_files = sorted(list((rootdir / "Wave").rglob("*.wav"))) + # split data into 3 sections + num_train = 9800 + num_dev = 100 + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + elif args.dataset == "aishell3": + sub_num_dev = 5 + wav_dir = rootdir / "train" / "wav" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*.wav"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + elif args.dataset == "ljspeech": + wav_files = sorted(list((rootdir / "wavs").rglob("*.wav"))) + # split data into 3 sections + num_train = 12900 + num_dev = 100 + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + elif args.dataset == "vctk": + sub_num_dev = 5 + wav_dir = rootdir / "wav48_silence_trimmed" + train_wav_files = [] + dev_wav_files = [] + test_wav_files = [] + for speaker in os.listdir(wav_dir): + wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac"))) + if len(wav_files) > 100: + train_wav_files += wav_files[:-sub_num_dev * 2] + dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev] + test_wav_files += wav_files[-sub_num_dev:] + else: + train_wav_files += wav_files + + else: + print("dataset should in {baker, aishell3, ljspeech, vctk} now!") + + train_dump_dir = dumpdir / "train" / "raw" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dumpdir / "dev" / "raw" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dumpdir / "test" / "raw" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # Extractor + mel_extractor = LogMelFBank( + sr=config.fs, + n_fft=config.n_fft, + hop_length=config.n_shift, + win_length=config.win_length, + window=config.window, + n_mels=config.n_mels, + fmin=config.fmin, + fmax=config.fmax) + + # process for the 3 sections + if train_wav_files: + process_sentences( + config=config, + fps=train_wav_files, + sentences=sentences, + output_dir=train_dump_dir, + mel_extractor=mel_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) + if dev_wav_files: + process_sentences( + config=config, + fps=dev_wav_files, + sentences=sentences, + output_dir=dev_dump_dir, + mel_extractor=mel_extractor, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) + if test_wav_files: + process_sentences( + config=config, + fps=test_wav_files, + sentences=sentences, + output_dir=test_dump_dir, + mel_extractor=mel_extractor, + nprocs=args.num_cpu, + cut_sil=args.cut_sil, + spk_emb_dir=spk_emb_dir) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/ernie_sat/synthesize.py b/paddlespeech/t2s/exps/ernie_sat/synthesize.py new file mode 100644 index 000000000..2e3582948 --- /dev/null +++ b/paddlespeech/t2s/exps/ernie_sat/synthesize.py @@ -0,0 +1,200 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import soundfile as sf +import yaml +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn +from paddlespeech.t2s.exps.syn_utils import denorm +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_test_dataset +from paddlespeech.t2s.exps.syn_utils import get_voc_inference + + +def evaluate(args): + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for evaluation + with jsonlines.open(args.test_metadata, 'r') as reader: + test_metadata = list(reader) + + # Init body. + with open(args.erniesat_config) as f: + erniesat_config = CfgNode(yaml.safe_load(f)) + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(erniesat_config) + print(voc_config) + + # ernie sat model + erniesat_inference = get_am_inference( + am='erniesat_dataset', + am_config=erniesat_config, + am_ckpt=args.erniesat_ckpt, + am_stat=args.erniesat_stat, + phones_dict=args.phones_dict) + + test_dataset = get_test_dataset( + test_metadata=test_metadata, am='erniesat_dataset') + + # vocoder + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + collate_fn = build_erniesat_collate_fn( + mlm_prob=erniesat_config.mlm_prob, + mean_phn_span=erniesat_config.mean_phn_span, + seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', + text_masking=False) + + gen_raw = True + erniesat_mu, erniesat_std = np.load(args.erniesat_stat) + + for datum in test_dataset: + # collate function and dataloader + utt_id = datum["utt_id"] + speech_len = datum["speech_lengths"] + + # mask the middle 1/3 speech + left_bdy, right_bdy = speech_len // 3, 2 * speech_len // 3 + span_bdy = [left_bdy, right_bdy] + datum.update({"span_bdy": span_bdy}) + + batch = collate_fn([datum]) + with paddle.no_grad(): + out_mels = erniesat_inference( + speech=batch["speech"], + text=batch["text"], + masked_pos=batch["masked_pos"], + speech_mask=batch["speech_mask"], + text_mask=batch["text_mask"], + speech_seg_pos=batch["speech_seg_pos"], + text_seg_pos=batch["text_seg_pos"], + span_bdy=span_bdy) + + # vocoder + wav_list = [] + for mel in out_mels: + part_wav = voc_inference(mel) + wav_list.append(part_wav) + wav = paddle.concat(wav_list) + wav = wav.numpy() + if gen_raw: + speech = datum['speech'] + denorm_mel = denorm(speech, erniesat_mu, erniesat_std) + denorm_mel = paddle.to_tensor(denorm_mel) + wav_raw = voc_inference(denorm_mel) + wav_raw = wav_raw.numpy() + + sf.write( + str(output_dir / (utt_id + ".wav")), + wav, + samplerate=erniesat_config.fs) + if gen_raw: + sf.write( + str(output_dir / (utt_id + "_raw" + ".wav")), + wav_raw, + samplerate=erniesat_config.fs) + + print(f"{utt_id} done!") + + +def parse_args(): + # parse args and config + parser = argparse.ArgumentParser( + description="Synthesize with acoustic model & vocoder") + # ernie sat + + parser.add_argument( + '--erniesat_config', + type=str, + default=None, + help='Config of acoustic model.') + parser.add_argument( + '--erniesat_ckpt', + type=str, + default=None, + help='Checkpoint file of acoustic model.') + parser.add_argument( + "--erniesat_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training acoustic model." + ) + parser.add_argument( + "--phones_dict", type=str, default=None, help="phone vocabulary file.") + # vocoder + parser.add_argument( + '--voc', + type=str, + default='pwgan_csmsc', + choices=[ + 'pwgan_aishell3', + 'pwgan_vctk', + 'hifigan_aishell3', + 'hifigan_vctk', + ], + help='Choose vocoder type of tts task.') + parser.add_argument( + '--voc_config', type=str, default=None, help='Config of voc.') + parser.add_argument( + '--voc_ckpt', type=str, default=None, help='Checkpoint file of voc.') + parser.add_argument( + "--voc_stat", + type=str, + default=None, + help="mean and standard deviation used to normalize spectrogram when training voc." + ) + # other + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") + parser.add_argument("--test_metadata", type=str, help="test metadata.") + parser.add_argument("--output_dir", type=str, help="output dir.") + + args = parser.parse_args() + return args + + +def main(): + + args = parse_args() + if args.ngpu == 0: + paddle.set_device("cpu") + elif args.ngpu > 0: + paddle.set_device("gpu") + else: + print("ngpu should >= 0 !") + + evaluate(args) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py b/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py new file mode 100644 index 000000000..95b07367c --- /dev/null +++ b/paddlespeech/t2s/exps/ernie_sat/synthesize_e2e.py @@ -0,0 +1,346 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import librosa +import numpy as np +import soundfile as sf + +from paddlespeech.t2s.exps.ernie_sat.align import get_phns_spans +from paddlespeech.t2s.exps.ernie_sat.utils import eval_durs +from paddlespeech.t2s.exps.ernie_sat.utils import get_dur_adj_factor +from paddlespeech.t2s.exps.ernie_sat.utils import get_span_bdy +from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn +from paddlespeech.t2s.exps.syn_utils import get_frontend +from paddlespeech.t2s.datasets.get_feats import LogMelFBank +from paddlespeech.t2s.exps.syn_utils import norm +from paddlespeech.t2s.exps.ernie_sat.utils import get_tmp_name + + + + + + +def _p2id(self, phonemes: List[str]) -> np.ndarray: + # replace unk phone with sp + phonemes = [ + phn if phn in vocab_phones else "sp" for phn in phonemes + ] + phone_ids = [vocab_phones[item] for item in phonemes] + return np.array(phone_ids, np.int64) + + + +def prep_feats_with_dur(wav_path: str, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + duration_adjust: bool=True, + fs: int=24000, + n_shift: int=300): + ''' + Returns: + np.ndarray: new wav, replace the part to be edited in original wav with 0 + List[str]: new phones + List[float]: mfa start of new wav + List[float]: mfa end of new wav + List[int]: masked mel boundary of original wav + List[int]: masked mel boundary of new wav + ''' + wav_org, _ = librosa.load(wav_path, sr=fs) + phns_spans_outs = get_phns_spans( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang, + fs=fs, + n_shift=n_shift) + + mfa_start = phns_spans_outs["mfa_start"] + mfa_end = phns_spans_outs["mfa_end"] + old_phns = phns_spans_outs["old_phns"] + new_phns = phns_spans_outs["new_phns"] + span_to_repl = phns_spans_outs["span_to_repl"] + span_to_add = phns_spans_outs["span_to_add"] + + # 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替 + if target_lang in {'en', 'zh'}: + old_durs = eval_durs(old_phns, target_lang=source_lang) + else: + assert target_lang in {'en', 'zh'}, \ + "calculate duration_predict is not support for this language..." + + orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)] + + if duration_adjust: + d_factor = get_dur_adj_factor( + orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns) + d_factor = d_factor * 1.25 + else: + d_factor = 1 + + if target_lang in {'en', 'zh'}: + new_durs = eval_durs(new_phns, target_lang=target_lang) + else: + assert target_lang == "zh" or target_lang == "en", \ + "calculate duration_predict is not support for this language..." + + # duration 要是整数 + new_durs_adjusted = [int(np.ceil(d_factor * i)) for i in new_durs] + + new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]]) + old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]]) + dur_offset = new_span_dur_sum - old_span_dur_sum + new_mfa_start = mfa_start[:span_to_repl[0]] + new_mfa_end = mfa_end[:span_to_repl[0]] + + for dur in new_durs_adjusted[span_to_add[0]:span_to_add[1]]: + if len(new_mfa_end) == 0: + new_mfa_start.append(0) + new_mfa_end.append(dur) + else: + new_mfa_start.append(new_mfa_end[-1]) + new_mfa_end.append(new_mfa_end[-1] + dur) + + new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]] + new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]] + + # 3. get new wav + # 在原始句子后拼接 + if span_to_repl[0] >= len(mfa_start): + wav_left_idx = len(wav_org) + wav_right_idx = wav_left_idx + # 在原始句子中间替换 + else: + wav_left_idx = int(np.floor(mfa_start[span_to_repl[0]] * n_shift)) + wav_right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * n_shift)) + blank_wav = np.zeros( + (int(np.ceil(new_span_dur_sum * n_shift)), ), dtype=wav_org.dtype) + # 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定 + new_wav = np.concatenate( + [wav_org[:wav_left_idx], blank_wav, wav_org[wav_right_idx:]]) + + # 音频是正常遮住了 + sf.write(str("new_wav.wav"), new_wav, samplerate=fs) + + # 4. get old and new mel span to be mask + old_span_bdy = get_span_bdy( + mfa_start=mfa_start, mfa_end=mfa_end, span_to_repl=span_to_repl) + + new_span_bdy = get_span_bdy( + mfa_start=new_mfa_start, mfa_end=new_mfa_end, span_to_repl=span_to_add) + + # old_span_bdy, new_span_bdy 是帧级别的范围 + outs = {} + outs['new_wav'] = new_wav + outs['new_phns'] = new_phns + outs['new_mfa_start'] = new_mfa_start + outs['new_mfa_end'] = new_mfa_end + outs['old_span_bdy'] = old_span_bdy + outs['new_span_bdy'] = new_span_bdy + return outs + + + + +def prep_feats(wav_path: str, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + duration_adjust: bool=True, + fs: int=24000, + n_shift: int=300): + + outs = prep_feats_with_dur( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + source_lang=source_lang, + target_lang=target_lang, + duration_adjust=duration_adjust, + fs=fs, + n_shift=n_shift) + + wav_name = os.path.basename(wav_path) + utt_id = wav_name.split('.')[0] + + wav = outs['new_wav'] + phns = outs['new_phns'] + mfa_start = outs['new_mfa_start'] + mfa_end = outs['new_mfa_end'] + old_span_bdy = outs['old_span_bdy'] + new_span_bdy = outs['new_span_bdy'] + span_bdy = np.array(new_span_bdy) + + text = _p2id(phns) + mel = mel_extractor.get_log_mel_fbank(wav) + erniesat_mean, erniesat_std = np.load(erniesat_stat) + normed_mel = norm(mel, erniesat_mean, erniesat_std) + tmp_name = get_tmp_name(text=old_str) + tmpbase = './tmp_dir/' + tmp_name + tmpbase = Path(tmpbase) + tmpbase.mkdir(parents=True, exist_ok=True) + print("tmp_name in synthesize_e2e:",tmp_name) + + mel_path = tmpbase / 'mel.npy' + print("mel_path:",mel_path) + np.save(mel_path, logmel) + durations = [e - s for e, s in zip(mfa_end, mfa_start)] + + datum={ + "utt_id": utt_id, + "spk_id": 0, + "text": text, + "text_lengths": len(text), + "speech_lengths": 115, + "durations": durations, + "speech": mel_path, + "align_start": mfa_start, + "align_end": mfa_end, + "span_bdy": span_bdy + } + + batch = collate_fn([datum]) + print("batch:",batch) + + return batch, old_span_bdy, new_span_bdy + + +def decode_with_model(mlm_model: nn.Layer, + collate_fn, + wav_path: str, + old_str: str='', + new_str: str='', + source_lang: str='en', + target_lang: str='en', + use_teacher_forcing: bool=False, + duration_adjust: bool=True, + fs: int=24000, + n_shift: int=300, + token_list: List[str]=[]): + batch, old_span_bdy, new_span_bdy = prep_feats( + source_lang=source_lang, + target_lang=target_lang, + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + duration_adjust=duration_adjust, + fs=fs, + n_shift=n_shift, + token_list=token_list) + + + + feats = collate_fn(batch)[1] + + if 'text_masked_pos' in feats.keys(): + feats.pop('text_masked_pos') + + output = mlm_model.inference( + text=feats['text'], + speech=feats['speech'], + masked_pos=feats['masked_pos'], + speech_mask=feats['speech_mask'], + text_mask=feats['text_mask'], + speech_seg_pos=feats['speech_seg_pos'], + text_seg_pos=feats['text_seg_pos'], + span_bdy=new_span_bdy, + use_teacher_forcing=use_teacher_forcing) + + # 拼接音频 + output_feat = paddle.concat(x=output, axis=0) + wav_org, _ = librosa.load(wav_path, sr=fs) + return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length + + +if __name__ == '__main__': + fs = 24000 + n_shift = 300 + wav_path = "exp/p243_313.wav" + old_str = "For that reason cover should not be given." + # for edit + # new_str = "for that reason cover is impossible to be given." + # for synthesize + append_str = "do you love me i love you so much" + new_str = old_str + append_str + + ''' + outs = prep_feats_with_dur( + wav_path=wav_path, + old_str=old_str, + new_str=new_str, + fs=fs, + n_shift=n_shift) + + new_wav = outs['new_wav'] + new_phns = outs['new_phns'] + new_mfa_start = outs['new_mfa_start'] + new_mfa_end = outs['new_mfa_end'] + old_span_bdy = outs['old_span_bdy'] + new_span_bdy = outs['new_span_bdy'] + + print("---------------------------------") + + print("new_wav:", new_wav) + print("new_phns:", new_phns) + print("new_mfa_start:", new_mfa_start) + print("new_mfa_end:", new_mfa_end) + print("old_span_bdy:", old_span_bdy) + print("new_span_bdy:", new_span_bdy) + print("---------------------------------") + ''' + + erniesat_config = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/local/default.yaml" + + with open(erniesat_config) as f: + erniesat_config = CfgNode(yaml.safe_load(f)) + + erniesat_stat = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/train/speech_stats.npy" + + # Extractor + mel_extractor = LogMelFBank( + sr=erniesat_config.fs, + n_fft=erniesat_config.n_fft, + hop_length=erniesat_config.n_shift, + win_length=erniesat_config.win_length, + window=erniesat_config.window, + n_mels=erniesat_config.n_mels, + fmin=erniesat_config.fmin, + fmax=erniesat_config.fmax) + + + + collate_fn = build_erniesat_collate_fn( + mlm_prob=erniesat_config.mlm_prob, + mean_phn_span=erniesat_config.mean_phn_span, + seg_emb=erniesat_config.model['enc_input_layer'] == 'sega_mlm', + text_masking=False) + + phones_dict='/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/vctk/ernie_sat/dump/phone_id_map.txt' + vocab_phones = {} + + with open(phones_dict, 'rt') as f: + phn_id = [line.strip().split() for line in f.readlines()] + for phn, id in phn_id: + vocab_phones[phn] = int(id) + + prep_feats(wav_path=wav_path, + old_str=old_str, + new_str=new_str, + fs=fs, + n_shift=n_shift) + + + diff --git a/paddlespeech/t2s/exps/ernie_sat/train.py b/paddlespeech/t2s/exps/ernie_sat/train.py new file mode 100644 index 000000000..ccd1245e1 --- /dev/null +++ b/paddlespeech/t2s/exps/ernie_sat/train.py @@ -0,0 +1,203 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +import os +import shutil +from pathlib import Path + +import jsonlines +import numpy as np +import paddle +import yaml +from paddle import DataParallel +from paddle import distributed as dist +from paddle import nn +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from paddle.optimizer import Adam +from yacs.config import CfgNode + +from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn +from paddlespeech.t2s.datasets.data_table import DataTable +from paddlespeech.t2s.models.ernie_sat import ErnieSAT +from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator +from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater +from paddlespeech.t2s.training.extensions.snapshot import Snapshot +from paddlespeech.t2s.training.extensions.visualizer import VisualDL +from paddlespeech.t2s.training.seeding import seed_everything +from paddlespeech.t2s.training.trainer import Trainer + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + world_size = paddle.distributed.get_world_size() + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + fields = [ + "text", "text_lengths", "speech", "speech_lengths", "align_start", + "align_end" + ] + converters = {"speech": np.load} + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=fields, + converters=converters, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=fields, + converters=converters, ) + + # collate function and dataloader + collate_fn = build_erniesat_collate_fn( + mlm_prob=config.mlm_prob, + mean_phn_span=config.mean_phn_span, + seg_emb=config.model['enc_input_layer'] == 'sega_mlm', + text_masking=config["model"]["text_masking"]) + + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + + print("samplers done!") + + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=collate_fn, + num_workers=config.num_workers) + + dev_dataloader = DataLoader( + dev_dataset, + shuffle=False, + drop_last=False, + batch_size=config.batch_size, + collate_fn=collate_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + with open(args.phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + vocab_size = len(phn_id) + print("vocab_size:", vocab_size) + + odim = config.n_mels + model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"]) + + if world_size > 1: + model = DataParallel(model) + print("model done!") + + scheduler = paddle.optimizer.lr.NoamDecay( + d_model=config["scheduler_params"]["d_model"], + warmup_steps=config["scheduler_params"]["warmup_steps"]) + grad_clip = nn.ClipGradByGlobalNorm(config["grad_clip"]) + optimizer = Adam( + learning_rate=scheduler, + grad_clip=grad_clip, + parameters=model.parameters()) + + print("optimizer done!") + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + if dist.get_rank() == 0: + config_name = args.config.split("/")[-1] + # copy conf to output_dir + shutil.copyfile(args.config, output_dir / config_name) + + updater = ErnieSATUpdater( + model=model, + optimizer=optimizer, + scheduler=scheduler, + dataloader=train_dataloader, + text_masking=config["model"]["text_masking"], + odim=odim, + vocab_size=vocab_size, + output_dir=output_dir) + + trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir) + + evaluator = ErnieSATEvaluator( + model=model, + dataloader=dev_dataloader, + text_masking=config["model"]["text_masking"], + odim=odim, + vocab_size=vocab_size, + output_dir=output_dir, ) + + if dist.get_rank() == 0: + trainer.extend(evaluator, trigger=(1, "epoch")) + trainer.extend(VisualDL(output_dir), trigger=(1, "iteration")) + trainer.extend( + Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch')) + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser(description="Train an ErnieSAT model.") + parser.add_argument("--config", type=str, help="ErnieSAT config file.") + parser.add_argument("--train-metadata", type=str, help="training data.") + parser.add_argument("--dev-metadata", type=str, help="dev data.") + parser.add_argument("--output-dir", type=str, help="output dir.") + parser.add_argument( + "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.") + parser.add_argument( + "--phones-dict", type=str, default=None, help="phone vocabulary file.") + + args = parser.parse_args() + + with open(args.config) as f: + config = CfgNode(yaml.safe_load(f)) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.ngpu > 1: + dist.spawn(train_sp, (args, config), nprocs=args.ngpu) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/paddlespeech/t2s/exps/ernie_sat/utils.py b/paddlespeech/t2s/exps/ernie_sat/utils.py new file mode 100644 index 000000000..9169efa36 --- /dev/null +++ b/paddlespeech/t2s/exps/ernie_sat/utils.py @@ -0,0 +1,216 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Dict +from typing import List +from typing import Union +import os + +import numpy as np +import paddle +import yaml +from yacs.config import CfgNode +import hashlib + + +from paddlespeech.t2s.exps.syn_utils import get_am_inference +from paddlespeech.t2s.exps.syn_utils import get_voc_inference + +def _get_user(): + return os.path.expanduser('~').split('/')[-1] + +def str2md5(string): + md5_val = hashlib.md5(string.encode('utf8')).hexdigest() + return md5_val + +def get_tmp_name(text:str): + return _get_user() + '_' + str(os.getpid()) + '_' + str2md5(text) + +def get_dict(dictfile: str): + word2phns_dict = {} + with open(dictfile, 'r') as fid: + for line in fid: + line_lst = line.split() + word, phn_lst = line_lst[0], line.split()[1:] + if word not in word2phns_dict.keys(): + word2phns_dict[word] = ' '.join(phn_lst) + return word2phns_dict + + +# 获取需要被 mask 的 mel 帧的范围 +def get_span_bdy(mfa_start: List[float], + mfa_end: List[float], + span_to_repl: List[List[int]]): + if span_to_repl[0] >= len(mfa_start): + span_bdy = [mfa_end[-1], mfa_end[-1]] + else: + span_bdy = [mfa_start[span_to_repl[0]], mfa_end[span_to_repl[1] - 1]] + return span_bdy + + +# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同 +# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放 +def get_dur_adj_factor(orig_dur: List[int], + pred_dur: List[int], + phns: List[str]): + length = 0 + factor_list = [] + for orig, pred, phn in zip(orig_dur, pred_dur, phns): + if pred == 0 or phn == 'sp': + continue + else: + factor_list.append(orig / pred) + factor_list = np.array(factor_list) + factor_list.sort() + if len(factor_list) < 5: + return 1 + length = 2 + avg = np.average(factor_list[length:-length]) + return avg + + +def read_2col_text(path: Union[Path, str]) -> Dict[str, str]: + """Read a text file having 2 column as dict object. + + Examples: + wav.scp: + key1 /some/path/a.wav + key2 /some/path/b.wav + + >>> read_2col_text('wav.scp') + {'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'} + + """ + + data = {} + with Path(path).open("r", encoding="utf-8") as f: + for linenum, line in enumerate(f, 1): + sps = line.rstrip().split(maxsplit=1) + if len(sps) == 1: + k, v = sps[0], "" + else: + k, v = sps + if k in data: + raise RuntimeError(f"{k} is duplicated ({path}:{linenum})") + data[k] = v + return data + + +def load_num_sequence_text(path: Union[Path, str], loader_type: str="csv_int" + ) -> Dict[str, List[Union[float, int]]]: + """Read a text file indicating sequences of number + + Examples: + key1 1 2 3 + key2 34 5 6 + + >>> d = load_num_sequence_text('text') + >>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3])) + """ + if loader_type == "text_int": + delimiter = " " + dtype = int + elif loader_type == "text_float": + delimiter = " " + dtype = float + elif loader_type == "csv_int": + delimiter = "," + dtype = int + elif loader_type == "csv_float": + delimiter = "," + dtype = float + else: + raise ValueError(f"Not supported loader_type={loader_type}") + + # path looks like: + # utta 1,0 + # uttb 3,4,5 + # -> return {'utta': np.ndarray([1, 0]), + # 'uttb': np.ndarray([3, 4, 5])} + d = read_2column_text(path) + # Using for-loop instead of dict-comprehension for debuggability + retval = {} + for k, v in d.items(): + try: + retval[k] = [dtype(i) for i in v.split(delimiter)] + except TypeError: + print(f'Error happened with path="{path}", id="{k}", value="{v}"') + raise + return retval + + +def is_chinese(ch): + if u'\u4e00' <= ch <= u'\u9fff': + return True + else: + return False + + +def get_voc_out(mel): + # vocoder + args = parse_args() + with open(args.voc_config) as f: + voc_config = CfgNode(yaml.safe_load(f)) + voc_inference = get_voc_inference( + voc=args.voc, + voc_config=voc_config, + voc_ckpt=args.voc_ckpt, + voc_stat=args.voc_stat) + + with paddle.no_grad(): + wav = voc_inference(mel) + return np.squeeze(wav) + + +def eval_durs(phns, target_lang: str='zh', fs: int=24000, n_shift: int=300): + + if target_lang == 'en': + am = "fastspeech2_ljspeech" + am_config = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml" + am_ckpt = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz" + am_stat = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy" + phones_dict = "download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt" + + elif target_lang == 'zh': + am = "fastspeech2_csmsc" + am_config = "download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml" + am_ckpt = "download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz" + am_stat = "download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy" + phones_dict = "download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt" + + # Init body. + with open(am_config) as f: + am_config = CfgNode(yaml.safe_load(f)) + + am_inference, am = get_am_inference( + am=am, + am_config=am_config, + am_ckpt=am_ckpt, + am_stat=am_stat, + phones_dict=phones_dict, + return_am=True) + + vocab_phones = {} + with open(phones_dict, "r") as f: + phn_id = [line.strip().split() for line in f.readlines()] + for tone, id in phn_id: + vocab_phones[tone] = int(id) + vocab_size = len(vocab_phones) + phonemes = [phn if phn in vocab_phones else "sp" for phn in phns] + + phone_ids = [vocab_phones[item] for item in phonemes] + phone_ids = paddle.to_tensor(np.array(phone_ids, np.int64)) + _, d_outs, _, _ = am.inference(phone_ids) + d_outs = d_outs.tolist() + return d_outs diff --git a/paddlespeech/t2s/exps/syn_utils.py b/paddlespeech/t2s/exps/syn_utils.py index 77abf97d9..bade62aca 100644 --- a/paddlespeech/t2s/exps/syn_utils.py +++ b/paddlespeech/t2s/exps/syn_utils.py @@ -69,6 +69,10 @@ model_alias = { "paddlespeech.t2s.models.wavernn:WaveRNN", "wavernn_inference": "paddlespeech.t2s.models.wavernn:WaveRNNInference", + "erniesat": + "paddlespeech.t2s.models.ernie_sat:ErnieSAT", + "erniesat_inference": + "paddlespeech.t2s.models.ernie_sat:ErnieSATInference", } @@ -112,6 +116,7 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], # model: {model_name}_{dataset} am_name = am[:am.rindex('_')] am_dataset = am[am.rindex('_') + 1:] + converters = {} if am_name == 'fastspeech2': fields = ["utt_id", "text"] if am_dataset in {"aishell3", "vctk", @@ -130,8 +135,17 @@ def get_test_dataset(test_metadata: List[Dict[str, Any]], if voice_cloning: print("voice cloning!") fields += ["spk_emb"] + elif am_name == 'erniesat': + fields = [ + "utt_id", "text", "text_lengths", "speech", "speech_lengths", + "align_start", "align_end" + ] + converters = {"speech": np.load} + else: + print("wrong am, please input right am!!!") - test_dataset = DataTable(data=test_metadata, fields=fields) + test_dataset = DataTable( + data=test_metadata, fields=fields, converters=converters) return test_dataset @@ -201,6 +215,10 @@ def get_am_inference(am: str='fastspeech2_csmsc', **am_config["model"]) elif am_name == 'tacotron2': am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) + elif am_name == 'erniesat': + am = am_class(idim=vocab_size, odim=odim, **am_config["model"]) + else: + print("wrong am, please input right am!!!") am.set_state_dict(paddle.load(am_ckpt)["main_params"]) am.eval() diff --git a/paddlespeech/t2s/models/ernie_sat/__init__.py b/paddlespeech/t2s/models/ernie_sat/__init__.py index dc86fa514..7e795370e 100644 --- a/paddlespeech/t2s/models/ernie_sat/__init__.py +++ b/paddlespeech/t2s/models/ernie_sat/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,4 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .ernie_sat import * +from .ernie_sat_updater import * from .mlm import * diff --git a/paddlespeech/t2s/models/ernie_sat/ernie_sat.py b/paddlespeech/t2s/models/ernie_sat/ernie_sat.py new file mode 100644 index 000000000..54f5d542d --- /dev/null +++ b/paddlespeech/t2s/models/ernie_sat/ernie_sat.py @@ -0,0 +1,705 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict +from typing import List +from typing import Optional + +import paddle +from paddle import nn + +from paddlespeech.t2s.modules.activation import get_activation +from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule +from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer +from paddlespeech.t2s.modules.layer_norm import LayerNorm +from paddlespeech.t2s.modules.masked_fill import masked_fill +from paddlespeech.t2s.modules.nets_utils import initialize +from paddlespeech.t2s.modules.tacotron2.decoder import Postnet +from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention +from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention +from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention +from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding +from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding +from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding +from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding +from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear +from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d +from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward +from paddlespeech.t2s.modules.transformer.repeat import repeat +from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling + + +# MLM -> Mask Language Model +class mySequential(nn.Sequential): + def forward(self, *inputs): + for module in self._sub_layers.values(): + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + + +class MaskInputLayer(nn.Layer): + def __init__(self, out_features: int) -> None: + super().__init__() + self.mask_feature = paddle.create_parameter( + shape=(1, 1, out_features), + dtype=paddle.float32, + default_initializer=paddle.nn.initializer.Assign( + paddle.normal(shape=(1, 1, out_features)))) + + def forward(self, input: paddle.Tensor, + masked_pos: paddle.Tensor=None) -> paddle.Tensor: + masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input) + masked_input = masked_fill(input, masked_pos, 0) + masked_fill( + paddle.expand_as(self.mask_feature, input), ~masked_pos, 0) + return masked_input + + +class MLMEncoder(nn.Layer): + """Conformer encoder module. + + Args: + idim (int): Input dimension. + attention_dim (int): Dimension of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + attention_dropout_rate (float): Dropout rate in attention. + input_layer (Union[str, paddle.nn.Layer]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + macaron_style (bool): Whether to use macaron style for positionwise layer. + pos_enc_layer_type (str): Encoder positional encoding layer type. + selfattention_layer_type (str): Encoder attention layer type. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + stochastic_depth_rate (float): Maximum probability to skip the encoder layer. + + """ + + def __init__(self, + idim: int, + vocab_size: int=0, + pre_speech_layer: int=0, + attention_dim: int=256, + attention_heads: int=4, + linear_units: int=2048, + num_blocks: int=6, + dropout_rate: float=0.1, + positional_dropout_rate: float=0.1, + attention_dropout_rate: float=0.0, + input_layer: str="conv2d", + normalize_before: bool=True, + concat_after: bool=False, + positionwise_layer_type: str="linear", + positionwise_conv_kernel_size: int=1, + macaron_style: bool=False, + pos_enc_layer_type: str="abs_pos", + pos_enc_class=None, + selfattention_layer_type: str="selfattn", + activation_type: str="swish", + use_cnn_module: bool=False, + zero_triu: bool=False, + cnn_module_kernel: int=31, + padding_idx: int=-1, + stochastic_depth_rate: float=0.0, + text_masking: bool=False): + """Construct an Encoder object.""" + super().__init__() + self._output_size = attention_dim + self.text_masking = text_masking + if self.text_masking: + self.text_masking_layer = MaskInputLayer(attention_dim) + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "legacy_rel_pos": + pos_enc_class = LegacyRelPositionalEncoding + assert selfattention_layer_type == "legacy_rel_selfattn" + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + self.conv_subsampling_factor = 1 + if input_layer == "linear": + self.embed = nn.Sequential( + nn.Linear(idim, attention_dim), + nn.LayerNorm(attention_dim), + nn.Dropout(dropout_rate), + nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), ) + self.conv_subsampling_factor = 4 + elif input_layer == "embed": + self.embed = nn.Sequential( + nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer == "mlm": + self.segment_emb = None + self.speech_embed = mySequential( + MaskInputLayer(idim), + nn.Linear(idim, attention_dim), + nn.LayerNorm(attention_dim), + nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate)) + self.text_embed = nn.Sequential( + nn.Embedding( + vocab_size, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer == "sega_mlm": + self.segment_emb = nn.Embedding( + 500, attention_dim, padding_idx=padding_idx) + self.speech_embed = mySequential( + MaskInputLayer(idim), + nn.Linear(idim, attention_dim), + nn.LayerNorm(attention_dim), + nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate)) + self.text_embed = nn.Sequential( + nn.Embedding( + vocab_size, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif isinstance(input_layer, nn.Layer): + self.embed = nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), ) + elif input_layer is None: + self.embed = nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate)) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + + # self-attention module definition + if selfattention_layer_type == "selfattn": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, attention_dim, + attention_dropout_rate, ) + elif selfattention_layer_type == "legacy_rel_selfattn": + assert pos_enc_layer_type == "legacy_rel_pos" + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, attention_dim, + attention_dropout_rate, ) + elif selfattention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = (attention_heads, attention_dim, + attention_dropout_rate, zero_triu, ) + else: + raise ValueError("unknown encoder_attn_layer: " + + selfattention_layer_type) + + # feed-forward module definition + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = (attention_dim, linear_units, + dropout_rate, activation, ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = (attention_dim, linear_units, + positionwise_conv_kernel_size, + dropout_rate, ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = (attention_dim, linear_units, + positionwise_conv_kernel_size, + dropout_rate, ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + stochastic_depth_rate * float(1 + lnum) / num_blocks, ), ) + self.pre_speech_layer = pre_speech_layer + self.pre_speech_encoders = repeat( + self.pre_speech_layer, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def forward(self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor=None, + text_mask: paddle.Tensor=None, + speech_seg_pos: paddle.Tensor=None, + text_seg_pos: paddle.Tensor=None): + """Encode input sequence. + + """ + if masked_pos is not None: + speech = self.speech_embed(speech, masked_pos) + else: + speech = self.speech_embed(speech) + if text is not None: + text = self.text_embed(text) + if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb: + speech_seg_emb = self.segment_emb(speech_seg_pos) + text_seg_emb = self.segment_emb(text_seg_pos) + text = (text[0] + text_seg_emb, text[1]) + speech = (speech[0] + speech_seg_emb, speech[1]) + if self.pre_speech_encoders: + speech, _ = self.pre_speech_encoders(speech, speech_mask) + + if text is not None: + xs = paddle.concat([speech[0], text[0]], axis=1) + xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1) + masks = paddle.concat([speech_mask, text_mask], axis=-1) + else: + xs = speech[0] + xs_pos_emb = speech[1] + masks = speech_mask + + xs, masks = self.encoders((xs, xs_pos_emb), masks) + + if isinstance(xs, tuple): + xs = xs[0] + if self.normalize_before: + xs = self.after_norm(xs) + + return xs, masks + + +class MLMDecoder(MLMEncoder): + def forward(self, xs: paddle.Tensor, masks: paddle.Tensor): + """Encode input sequence. + + Args: + xs (paddle.Tensor): Input tensor (#batch, time, idim). + masks (paddle.Tensor): Mask tensor (#batch, time). + + Returns: + paddle.Tensor: Output tensor (#batch, time, attention_dim). + paddle.Tensor: Mask tensor (#batch, time). + + """ + xs = self.embed(xs) + xs, masks = self.encoders(xs, masks) + + if isinstance(xs, tuple): + xs = xs[0] + if self.normalize_before: + xs = self.after_norm(xs) + + return xs, masks + + +# encoder and decoder is nn.Layer, not str +class MLM(nn.Layer): + def __init__(self, + odim: int, + encoder: nn.Layer, + decoder: Optional[nn.Layer], + postnet_layers: int=0, + postnet_chans: int=0, + postnet_filts: int=0, + text_masking: bool=False): + + super().__init__() + self.odim = odim + self.encoder = encoder + self.decoder = decoder + self.vocab_size = encoder.text_embed[0]._num_embeddings + + if self.decoder is None or not (hasattr(self.decoder, + 'output_layer') and + self.decoder.output_layer is not None): + self.sfc = nn.Linear(self.encoder._output_size, odim) + else: + self.sfc = None + if text_masking: + self.text_sfc = nn.Linear( + self.encoder.text_embed[0]._embedding_dim, + self.vocab_size, + weight_attr=self.encoder.text_embed[0]._weight_attr) + else: + self.text_sfc = None + + self.postnet = (None if postnet_layers == 0 else Postnet( + idim=self.encoder._output_size, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=True, + dropout_rate=0.5, )) + + def inference( + self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor, + span_bdy: List[int], + use_teacher_forcing: bool=False, ) -> List[paddle.Tensor]: + ''' + Args: + speech (paddle.Tensor): input speech (1, Tmax, D). + text (paddle.Tensor): input text (1, Tmax2). + masked_pos (paddle.Tensor): masked position of input speech (1, Tmax) + speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax). + text_mask (paddle.Tensor): mask of text (1, 1, Tmax2). + speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax). + text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2). + span_bdy (List[int]): masked mel boundary of input speech (2,) + use_teacher_forcing (bool): whether to use teacher forcing + Returns: + List[Tensor]: + eg: + [Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])] + ''' + + z_cache = None + if use_teacher_forcing: + before_outs, zs, *_ = self.forward( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos) + if zs is None: + zs = before_outs + + speech = speech.squeeze(0) + outs = [speech[:span_bdy[0]]] + outs += [zs[0][span_bdy[0]:span_bdy[1]]] + outs += [speech[span_bdy[1]:]] + return outs + return None + + +class MLMEncAsDecoder(MLM): + def forward(self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor): + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + encoder_out, h_masks = self.encoder( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos) + if self.decoder is not None: + zs, _ = self.decoder(encoder_out, h_masks) + else: + zs = encoder_out + speech_hidden_states = zs[:, :paddle.shape(speech)[1], :] + if self.sfc is not None: + before_outs = paddle.reshape( + self.sfc(speech_hidden_states), + (paddle.shape(speech_hidden_states)[0], -1, self.odim)) + else: + before_outs = speech_hidden_states + if self.postnet is not None: + after_outs = before_outs + paddle.transpose( + self.postnet(paddle.transpose(before_outs, [0, 2, 1])), + [0, 2, 1]) + else: + after_outs = None + return before_outs, after_outs, None + + +class MLMDualMaksing(MLM): + def forward(self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor): + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + encoder_out, h_masks = self.encoder( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos) + if self.decoder is not None: + zs, _ = self.decoder(encoder_out, h_masks) + else: + zs = encoder_out + speech_hidden_states = zs[:, :paddle.shape(speech)[1], :] + if self.text_sfc: + text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :] + text_outs = paddle.reshape( + self.text_sfc(text_hiddent_states), + (paddle.shape(text_hiddent_states)[0], -1, self.vocab_size)) + if self.sfc is not None: + before_outs = paddle.reshape( + self.sfc(speech_hidden_states), + (paddle.shape(speech_hidden_states)[0], -1, self.odim)) + else: + before_outs = speech_hidden_states + if self.postnet is not None: + after_outs = before_outs + paddle.transpose( + self.postnet(paddle.transpose(before_outs, [0, 2, 1])), + [0, 2, 1]) + else: + after_outs = None + return before_outs, after_outs, text_outs + + +class ErnieSAT(nn.Layer): + def __init__( + self, + # network structure related + idim: int, + odim: int, + postnet_layers: int=5, + postnet_filts: int=5, + postnet_chans: int=256, + use_scaled_pos_enc: bool=False, + encoder_type: str='conformer', + decoder_type: str='conformer', + enc_input_layer: str='sega_mlm', + enc_pre_speech_layer: int=0, + enc_cnn_module_kernel: int=7, + enc_attention_dim: int=384, + enc_attention_heads: int=2, + enc_linear_units: int=1536, + enc_num_blocks: int=4, + enc_dropout_rate: float=0.2, + enc_positional_dropout_rate: float=0.2, + enc_attention_dropout_rate: float=0.2, + enc_normalize_before: bool=True, + enc_macaron_style: bool=True, + enc_use_cnn_module: bool=True, + enc_selfattention_layer_type: str='legacy_rel_selfattn', + enc_activation_type: str='swish', + enc_pos_enc_layer_type: str='legacy_rel_pos', + enc_positionwise_layer_type: str='conv1d', + enc_positionwise_conv_kernel_size: int=3, + text_masking: bool=False, + dec_cnn_module_kernel: int=31, + dec_attention_dim: int=384, + dec_attention_heads: int=2, + dec_linear_units: int=1536, + dec_num_blocks: int=4, + dec_dropout_rate: float=0.2, + dec_positional_dropout_rate: float=0.2, + dec_attention_dropout_rate: float=0.2, + dec_macaron_style: bool=True, + dec_use_cnn_module: bool=True, + dec_selfattention_layer_type: str='legacy_rel_selfattn', + dec_activation_type: str='swish', + dec_pos_enc_layer_type: str='legacy_rel_pos', + dec_positionwise_layer_type: str='conv1d', + dec_positionwise_conv_kernel_size: int=3, + init_type: str="xavier_uniform", ): + super().__init__() + # store hyperparameters + self.odim = odim + + self.use_scaled_pos_enc = use_scaled_pos_enc + + # initialize parameters + initialize(self, init_type) + + # Encoder + if encoder_type == "conformer": + encoder = MLMEncoder( + idim=odim, + vocab_size=idim, + pre_speech_layer=enc_pre_speech_layer, + attention_dim=enc_attention_dim, + attention_heads=enc_attention_heads, + linear_units=enc_linear_units, + num_blocks=enc_num_blocks, + dropout_rate=enc_dropout_rate, + positional_dropout_rate=enc_positional_dropout_rate, + attention_dropout_rate=enc_attention_dropout_rate, + input_layer=enc_input_layer, + normalize_before=enc_normalize_before, + positionwise_layer_type=enc_positionwise_layer_type, + positionwise_conv_kernel_size=enc_positionwise_conv_kernel_size, + macaron_style=enc_macaron_style, + pos_enc_layer_type=enc_pos_enc_layer_type, + selfattention_layer_type=enc_selfattention_layer_type, + activation_type=enc_activation_type, + use_cnn_module=enc_use_cnn_module, + cnn_module_kernel=enc_cnn_module_kernel, + text_masking=text_masking) + else: + raise ValueError(f"{encoder_type} is not supported.") + + # Decoder + if decoder_type != 'no_decoder': + decoder = MLMDecoder( + idim=0, + input_layer=None, + cnn_module_kernel=dec_cnn_module_kernel, + attention_dim=dec_attention_dim, + attention_heads=dec_attention_heads, + linear_units=dec_linear_units, + num_blocks=dec_num_blocks, + dropout_rate=dec_dropout_rate, + positional_dropout_rate=dec_positional_dropout_rate, + macaron_style=dec_macaron_style, + use_cnn_module=dec_use_cnn_module, + selfattention_layer_type=dec_selfattention_layer_type, + activation_type=dec_activation_type, + pos_enc_layer_type=dec_pos_enc_layer_type, + positionwise_layer_type=dec_positionwise_layer_type, + positionwise_conv_kernel_size=dec_positionwise_conv_kernel_size) + + else: + decoder = None + + model_class = MLMDualMaksing if text_masking else MLMEncAsDecoder + + self.model = model_class( + odim=odim, + encoder=encoder, + decoder=decoder, + postnet_layers=postnet_layers, + postnet_filts=postnet_filts, + postnet_chans=postnet_chans, + text_masking=text_masking) + + nn.initializer.set_global_initializer(None) + + def forward(self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor): + return self.model( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos) + + def inference( + self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor, + span_bdy: List[int], + use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]: + return self.model.inference( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos, + span_bdy=span_bdy, + use_teacher_forcing=use_teacher_forcing) + + +class ErnieSATInference(nn.Layer): + def __init__(self, normalizer, model): + super().__init__() + self.normalizer = normalizer + self.acoustic_model = model + + def forward( + self, + speech: paddle.Tensor, + text: paddle.Tensor, + masked_pos: paddle.Tensor, + speech_mask: paddle.Tensor, + text_mask: paddle.Tensor, + speech_seg_pos: paddle.Tensor, + text_seg_pos: paddle.Tensor, + span_bdy: List[int], + use_teacher_forcing: bool=True, ): + outs = self.acoustic_model.inference( + speech=speech, + text=text, + masked_pos=masked_pos, + speech_mask=speech_mask, + text_mask=text_mask, + speech_seg_pos=speech_seg_pos, + text_seg_pos=text_seg_pos, + span_bdy=span_bdy, + use_teacher_forcing=use_teacher_forcing) + + normed_mel_pre, normed_mel_masked, normed_mel_post = outs + logmel_pre = self.normalizer.inverse(normed_mel_pre) + logmel_masked = self.normalizer.inverse(normed_mel_masked) + logmel_post = self.normalizer.inverse(normed_mel_post) + return logmel_pre, logmel_masked, logmel_post diff --git a/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py b/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py new file mode 100644 index 000000000..219341c88 --- /dev/null +++ b/paddlespeech/t2s/models/ernie_sat/ernie_sat_updater.py @@ -0,0 +1,158 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from pathlib import Path + +from paddle import distributed as dist +from paddle.io import DataLoader +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler + +from paddlespeech.t2s.modules.losses import MLMLoss +from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator +from paddlespeech.t2s.training.reporter import report +from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater +logging.basicConfig( + format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s', + datefmt='[%Y-%m-%d %H:%M:%S]') +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class ErnieSATUpdater(StandardUpdater): + def __init__(self, + model: Layer, + optimizer: Optimizer, + scheduler: LRScheduler, + dataloader: DataLoader, + init_state=None, + text_masking: bool=False, + odim: int=80, + vocab_size: int=100, + output_dir: Path=None): + super().__init__(model, optimizer, dataloader, init_state=None) + self.scheduler = scheduler + + self.criterion = MLMLoss( + text_masking=text_masking, odim=odim, vocab_size=vocab_size) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + def update_core(self, batch): + self.msg = "Rank: {}, ".format(dist.get_rank()) + losses_dict = {} + + before_outs, after_outs, text_outs = self.model( + speech=batch["speech"], + text=batch["text"], + masked_pos=batch["masked_pos"], + speech_mask=batch["speech_mask"], + text_mask=batch["text_mask"], + speech_seg_pos=batch["speech_seg_pos"], + text_seg_pos=batch["text_seg_pos"]) + + mlm_loss, text_mlm_loss = self.criterion( + speech=batch["speech"], + before_outs=before_outs, + after_outs=after_outs, + masked_pos=batch["masked_pos"], + text=batch["text"], + # maybe None + text_outs=text_outs, + # maybe None + text_masked_pos=batch["text_masked_pos"]) + + loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss + + self.optimizer.clear_grad() + + loss.backward() + self.optimizer.step() + self.scheduler.step() + scheduler_msg = 'lr: {}'.format(self.scheduler.last_lr) + + report("train/loss", float(loss)) + report("train/mlm_loss", float(mlm_loss)) + if text_mlm_loss is not None: + report("train/text_mlm_loss", float(text_mlm_loss)) + losses_dict["text_mlm_loss"] = float(text_mlm_loss) + + losses_dict["mlm_loss"] = float(mlm_loss) + losses_dict["loss"] = float(loss) + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.msg += ', ' + scheduler_msg + + +class ErnieSATEvaluator(StandardEvaluator): + def __init__(self, + model: Layer, + dataloader: DataLoader, + text_masking: bool=False, + odim: int=80, + vocab_size: int=100, + output_dir: Path=None): + super().__init__(model, dataloader) + + log_file = output_dir / 'worker_{}.log'.format(dist.get_rank()) + self.filehandler = logging.FileHandler(str(log_file)) + logger.addHandler(self.filehandler) + self.logger = logger + self.msg = "" + + self.criterion = MLMLoss( + text_masking=text_masking, odim=odim, vocab_size=vocab_size) + + def evaluate_core(self, batch): + self.msg = "Evaluate: " + losses_dict = {} + + before_outs, after_outs, text_outs = self.model( + speech=batch["speech"], + text=batch["text"], + masked_pos=batch["masked_pos"], + speech_mask=batch["speech_mask"], + text_mask=batch["text_mask"], + speech_seg_pos=batch["speech_seg_pos"], + text_seg_pos=batch["text_seg_pos"]) + + mlm_loss, text_mlm_loss = self.criterion( + speech=batch["speech"], + before_outs=before_outs, + after_outs=after_outs, + masked_pos=batch["masked_pos"], + text=batch["text"], + # maybe None + text_outs=text_outs, + # maybe None + text_masked_pos=batch["text_masked_pos"]) + loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss + + report("eval/loss", float(loss)) + report("eval/mlm_loss", float(mlm_loss)) + if text_mlm_loss is not None: + report("eval/text_mlm_loss", float(text_mlm_loss)) + losses_dict["text_mlm_loss"] = float(text_mlm_loss) + + losses_dict["mlm_loss"] = float(mlm_loss) + losses_dict["loss"] = float(loss) + + self.msg += ', '.join('{}: {:>.6f}'.format(k, v) + for k, v in losses_dict.items()) + self.logger.info(self.msg) diff --git a/paddlespeech/t2s/models/ernie_sat/mlm.py b/paddlespeech/t2s/models/ernie_sat/mlm.py index c9c3d67a6..647fdd9b4 100644 --- a/paddlespeech/t2s/models/ernie_sat/mlm.py +++ b/paddlespeech/t2s/models/ernie_sat/mlm.py @@ -1,9 +1,20 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse from typing import Dict from typing import List from typing import Optional -from typing import Tuple -from typing import Union import paddle import yaml @@ -109,7 +120,6 @@ class MLMEncoder(nn.Layer): positionwise_conv_kernel_size: int=1, macaron_style: bool=False, pos_enc_layer_type: str="abs_pos", - pos_enc_class=None, selfattention_layer_type: str="selfattn", activation_type: str="swish", use_cnn_module: bool=False, @@ -334,7 +344,6 @@ class MLMDecoder(MLMEncoder): # encoder and decoder is nn.Layer, not str class MLM(nn.Layer): def __init__(self, - token_list: Union[Tuple[str, ...], List[str]], odim: int, encoder: nn.Layer, decoder: Optional[nn.Layer], @@ -345,7 +354,6 @@ class MLM(nn.Layer): super().__init__() self.odim = odim - self.token_list = token_list.copy() self.encoder = encoder self.decoder = decoder self.vocab_size = encoder.text_embed[0]._num_embeddings @@ -535,32 +543,6 @@ def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM: vocab_size = len(token_list) odim = 80 - pos_enc_class = ScaledPositionalEncoding if args.use_scaled_pos_enc else PositionalEncoding - - if "conformer" == args.encoder: - conformer_self_attn_layer_type = args.encoder_conf[ - 'selfattention_layer_type'] - conformer_pos_enc_layer_type = args.encoder_conf['pos_enc_layer_type'] - conformer_rel_pos_type = "legacy" - if conformer_rel_pos_type == "legacy": - if conformer_pos_enc_layer_type == "rel_pos": - conformer_pos_enc_layer_type = "legacy_rel_pos" - if conformer_self_attn_layer_type == "rel_selfattn": - conformer_self_attn_layer_type = "legacy_rel_selfattn" - elif conformer_rel_pos_type == "latest": - assert conformer_pos_enc_layer_type != "legacy_rel_pos" - assert conformer_self_attn_layer_type != "legacy_rel_selfattn" - else: - raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}") - args.encoder_conf[ - 'selfattention_layer_type'] = conformer_self_attn_layer_type - args.encoder_conf['pos_enc_layer_type'] = conformer_pos_enc_layer_type - if "conformer" == args.decoder: - args.decoder_conf[ - 'selfattention_layer_type'] = conformer_self_attn_layer_type - args.decoder_conf[ - 'pos_enc_layer_type'] = conformer_pos_enc_layer_type - # Encoder encoder_class = MLMEncoder @@ -571,10 +553,7 @@ def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM: args.encoder_conf['text_masking'] = False encoder = encoder_class( - args.input_size, - vocab_size=vocab_size, - pos_enc_class=pos_enc_class, - **args.encoder_conf) + args.input_size, vocab_size=vocab_size, **args.encoder_conf) # Decoder if args.decoder != 'no_decoder': @@ -591,7 +570,6 @@ def build_model(args: argparse.Namespace, model_class=MLMEncAsDecoder) -> MLM: odim=odim, encoder=encoder, decoder=decoder, - token_list=token_list, **args.model_conf, ) # Initialize diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index 347a10e90..9905765db 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -274,9 +274,7 @@ class FastSpeech2(nn.Layer): super().__init__() # store hyperparameters - self.idim = idim self.odim = odim - self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type diff --git a/paddlespeech/t2s/modules/losses.py b/paddlespeech/t2s/modules/losses.py index b2a31a321..1a43f5ef3 100644 --- a/paddlespeech/t2s/modules/losses.py +++ b/paddlespeech/t2s/modules/losses.py @@ -1068,6 +1068,8 @@ class KLDivergenceLoss(nn.Layer): # loss for ERNIE SAT class MLMLoss(nn.Layer): def __init__(self, + odim: int, + vocab_size: int=0, lsm_weight: float=0.1, ignore_id: int=-1, text_masking: bool=False): @@ -1079,15 +1081,19 @@ class MLMLoss(nn.Layer): else: self.l1_loss_func = nn.L1Loss(reduction='none') self.text_masking = text_masking + self.odim = odim + self.vocab_size = vocab_size - def forward(self, - speech: paddle.Tensor, - before_outs: paddle.Tensor, - after_outs: paddle.Tensor, - masked_pos: paddle.Tensor, - text: paddle.Tensor=None, - text_outs: paddle.Tensor=None, - text_masked_pos: paddle.Tensor=None): + def forward( + self, + speech: paddle.Tensor, + before_outs: paddle.Tensor, + after_outs: paddle.Tensor, + masked_pos: paddle.Tensor, + # for text_loss when text_masking == True + text: paddle.Tensor=None, + text_outs: paddle.Tensor=None, + text_masked_pos: paddle.Tensor=None): xs_pad = speech mlm_loss_pos = masked_pos > 0 @@ -1102,16 +1108,21 @@ class MLMLoss(nn.Layer): paddle.reshape(after_outs, (-1, self.odim)), paddle.reshape(xs_pad, (-1, self.odim))), axis=-1) - loss_mlm = paddle.sum((loss * paddle.reshape( + mlm_loss = paddle.sum((loss * paddle.reshape( mlm_loss_pos, [-1]))) / paddle.sum((mlm_loss_pos) + 1e-10) - if self.text_masking: - loss_text = paddle.sum((self.text_mlm_loss( - paddle.reshape(text_outs, (-1, self.vocab_size)), - paddle.reshape(text, (-1))) * paddle.reshape( - text_masked_pos, - (-1)))) / paddle.sum((text_masked_pos) + 1e-10) + text_mlm_loss = None - return loss_mlm, loss_text - - return loss_mlm + if self.text_masking: + assert text is not None + assert text_outs is not None + assert text_masked_pos is not None + text_outs = paddle.reshape(text_outs, [-1, self.vocab_size]) + text = paddle.reshape(text, [-1]) + text_mlm_loss = self.text_mlm_loss(text_outs, text) + text_masked_pos_reshape = paddle.reshape(text_masked_pos, [-1]) + text_mlm_loss = paddle.sum( + text_mlm_loss * + text_masked_pos_reshape) / paddle.sum((text_masked_pos) + 1e-10) + + return mlm_loss, text_mlm_loss diff --git a/paddlespeech/t2s/modules/nets_utils.py b/paddlespeech/t2s/modules/nets_utils.py index a3d5d1354..798e4dee8 100644 --- a/paddlespeech/t2s/modules/nets_utils.py +++ b/paddlespeech/t2s/modules/nets_utils.py @@ -418,7 +418,6 @@ def phones_masking(xs_pad: paddle.Tensor, mean_phn_span=mean_phn_span).nonzero() masked_start = align_start[idx][masked_phn_idxs].tolist() masked_end = align_end[idx][masked_phn_idxs].tolist() - for s, e in zip(masked_start, masked_end): masked_pos[idx, s:e] = 1 non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) @@ -500,14 +499,15 @@ def phones_text_masking(xs_pad: paddle.Tensor, set(range(length)) - set(masked_phn_idxs[0].tolist())) np.random.shuffle(unmasked_phn_idxs) masked_text_idxs = unmasked_phn_idxs[:text_mask_num_lower] - text_masked_pos[idx][masked_text_idxs] = 1 + text_masked_pos[idx, masked_text_idxs] = 1 masked_start = align_start[idx][masked_phn_idxs].tolist() masked_end = align_end[idx][masked_phn_idxs].tolist() for s, e in zip(masked_start, masked_end): masked_pos[idx, s:e] = 1 - non_eos_mask = paddle.reshape(src_mask, paddle.shape(xs_pad)[:2]) + non_eos_mask = paddle.reshape(src_mask, shape=paddle.shape(xs_pad)[:2]) masked_pos = masked_pos * non_eos_mask - non_eos_text_mask = paddle.reshape(text_mask, paddle.shape(xs_pad)[:2]) + non_eos_text_mask = paddle.reshape( + text_mask, shape=paddle.shape(text_pad)[:2]) text_masked_pos = text_masked_pos * non_eos_text_mask masked_pos = paddle.cast(masked_pos, 'bool') text_masked_pos = paddle.cast(text_masked_pos, 'bool')