parent
0ea9def0b8
commit
94688264c7
@ -0,0 +1,282 @@
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
|
||||
fs: 24000 # sr
|
||||
n_fft: 2048 # FFT size (samples).
|
||||
n_shift: 300 # Hop size (samples). 12.5ms
|
||||
win_length: 1200 # Window length (samples). 50ms
|
||||
# If set to null, it will be the same as fft_size.
|
||||
window: "hann" # Window function.
|
||||
|
||||
# Only used for feats_type != raw
|
||||
|
||||
fmin: 80 # Minimum frequency of Mel basis.
|
||||
fmax: 7600 # Maximum frequency of Mel basis.
|
||||
n_mels: 80 # The number of mel basis.
|
||||
|
||||
mean_phn_span: 8
|
||||
mlm_prob: 0.8
|
||||
|
||||
###########################################################
|
||||
# DATA SETTING #
|
||||
###########################################################
|
||||
batch_size: 64
|
||||
num_workers: 2
|
||||
|
||||
###########################################################
|
||||
# MODEL SETTING #
|
||||
###########################################################
|
||||
model:
|
||||
text_masking: false
|
||||
postnet_layers: 5
|
||||
postnet_filts: 5
|
||||
postnet_chans: 256
|
||||
encoder_type: conformer
|
||||
decoder_type: conformer
|
||||
enc_input_layer: sega_mlm
|
||||
enc_pre_speech_layer: 0
|
||||
enc_cnn_module_kernel: 7
|
||||
enc_attention_dim: 384
|
||||
enc_attention_heads: 2
|
||||
enc_linear_units: 1536
|
||||
enc_num_blocks: 4
|
||||
enc_dropout_rate: 0.2
|
||||
enc_positional_dropout_rate: 0.2
|
||||
enc_attention_dropout_rate: 0.2
|
||||
enc_normalize_before: true
|
||||
enc_macaron_style: true
|
||||
enc_use_cnn_module: true
|
||||
enc_selfattention_layer_type: legacy_rel_selfattn
|
||||
enc_activation_type: swish
|
||||
enc_pos_enc_layer_type: legacy_rel_pos
|
||||
enc_positionwise_layer_type: conv1d
|
||||
enc_positionwise_conv_kernel_size: 3
|
||||
dec_cnn_module_kernel: 31
|
||||
dec_attention_dim: 384
|
||||
dec_attention_heads: 2
|
||||
dec_linear_units: 1536
|
||||
dec_num_blocks: 4
|
||||
dec_dropout_rate: 0.2
|
||||
dec_positional_dropout_rate: 0.2
|
||||
dec_attention_dropout_rate: 0.2
|
||||
dec_macaron_style: true
|
||||
dec_use_cnn_module: true
|
||||
dec_selfattention_layer_type: legacy_rel_selfattn
|
||||
dec_activation_type: swish
|
||||
dec_pos_enc_layer_type: legacy_rel_pos
|
||||
dec_positionwise_layer_type: conv1d
|
||||
dec_positionwise_conv_kernel_size: 3
|
||||
|
||||
###########################################################
|
||||
# OPTIMIZER SETTING #
|
||||
###########################################################
|
||||
optimizer:
|
||||
optim: adam # optimizer type
|
||||
learning_rate: 0.001 # learning rate
|
||||
|
||||
###########################################################
|
||||
# TRAINING SETTING #
|
||||
###########################################################
|
||||
max_epoch: 200
|
||||
num_snapshots: 5
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
###########################################################
|
||||
seed: 10086
|
||||
|
||||
token_list:
|
||||
- <blank>
|
||||
- <unk>
|
||||
- d
|
||||
- sp
|
||||
- sh
|
||||
- ii
|
||||
- j
|
||||
- zh
|
||||
- l
|
||||
- x
|
||||
- b
|
||||
- g
|
||||
- uu
|
||||
- e5
|
||||
- h
|
||||
- q
|
||||
- m
|
||||
- i1
|
||||
- t
|
||||
- z
|
||||
- ch
|
||||
- f
|
||||
- s
|
||||
- u4
|
||||
- ix4
|
||||
- i4
|
||||
- n
|
||||
- i3
|
||||
- iu3
|
||||
- vv
|
||||
- ian4
|
||||
- ix2
|
||||
- r
|
||||
- e4
|
||||
- ai4
|
||||
- k
|
||||
- ing2
|
||||
- a1
|
||||
- en2
|
||||
- ui4
|
||||
- ong1
|
||||
- uo3
|
||||
- u2
|
||||
- u3
|
||||
- ao4
|
||||
- ee
|
||||
- p
|
||||
- an1
|
||||
- eng2
|
||||
- i2
|
||||
- in1
|
||||
- c
|
||||
- ai2
|
||||
- ian2
|
||||
- e2
|
||||
- an4
|
||||
- ing4
|
||||
- v4
|
||||
- ai3
|
||||
- a5
|
||||
- ian3
|
||||
- eng1
|
||||
- ong4
|
||||
- ang4
|
||||
- ian1
|
||||
- ing1
|
||||
- iy4
|
||||
- ao3
|
||||
- ang1
|
||||
- uo4
|
||||
- u1
|
||||
- iao4
|
||||
- iu4
|
||||
- a4
|
||||
- van2
|
||||
- ie4
|
||||
- ang2
|
||||
- ou4
|
||||
- iang4
|
||||
- ix1
|
||||
- er4
|
||||
- iy1
|
||||
- e1
|
||||
- en1
|
||||
- ui2
|
||||
- an3
|
||||
- ei4
|
||||
- ong2
|
||||
- uo1
|
||||
- ou3
|
||||
- uo2
|
||||
- iao1
|
||||
- ou1
|
||||
- an2
|
||||
- uan4
|
||||
- ia4
|
||||
- ia1
|
||||
- ang3
|
||||
- v3
|
||||
- iu2
|
||||
- iao3
|
||||
- in4
|
||||
- a3
|
||||
- ei3
|
||||
- iang3
|
||||
- v2
|
||||
- eng4
|
||||
- en3
|
||||
- aa
|
||||
- uan1
|
||||
- v1
|
||||
- ao1
|
||||
- ve4
|
||||
- ie3
|
||||
- ai1
|
||||
- ing3
|
||||
- iang1
|
||||
- a2
|
||||
- ui1
|
||||
- en4
|
||||
- en5
|
||||
- in3
|
||||
- uan3
|
||||
- e3
|
||||
- ie1
|
||||
- ve2
|
||||
- ei2
|
||||
- in2
|
||||
- ix3
|
||||
- uan2
|
||||
- iang2
|
||||
- ie2
|
||||
- ua4
|
||||
- ou2
|
||||
- uai4
|
||||
- er2
|
||||
- eng3
|
||||
- uang3
|
||||
- un1
|
||||
- ong3
|
||||
- uang4
|
||||
- vn4
|
||||
- un2
|
||||
- iy3
|
||||
- iz4
|
||||
- ui3
|
||||
- iao2
|
||||
- iong4
|
||||
- un4
|
||||
- van4
|
||||
- ao2
|
||||
- uang1
|
||||
- iy5
|
||||
- o2
|
||||
- ei1
|
||||
- ua1
|
||||
- iu1
|
||||
- uang2
|
||||
- er5
|
||||
- o1
|
||||
- un3
|
||||
- vn1
|
||||
- vn2
|
||||
- o4
|
||||
- ve1
|
||||
- van3
|
||||
- ua2
|
||||
- er3
|
||||
- iong3
|
||||
- van1
|
||||
- ia2
|
||||
- iy2
|
||||
- ia3
|
||||
- iong1
|
||||
- uo5
|
||||
- oo
|
||||
- ve3
|
||||
- ou5
|
||||
- uai3
|
||||
- ian5
|
||||
- iong2
|
||||
- uai2
|
||||
- uai1
|
||||
- ua3
|
||||
- vn3
|
||||
- ia5
|
||||
- ie5
|
||||
- ueng1
|
||||
- o5
|
||||
- o3
|
||||
- iang5
|
||||
- ei5
|
||||
- <sos/eos>
|
@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
config_path=$1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# get durations from MFA's result
|
||||
echo "Generate durations.txt from MFA results ..."
|
||||
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||
--inputdir=./aishell3_alignment_tone \
|
||||
--output durations.txt \
|
||||
--config=${config_path}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# extract features
|
||||
echo "Extract features ..."
|
||||
python3 ${BIN_DIR}/preprocess.py \
|
||||
--dataset=aishell3 \
|
||||
--rootdir=~/datasets/data_aishell3/ \
|
||||
--dumpdir=dump \
|
||||
--dur-file=durations.txt \
|
||||
--config=${config_path} \
|
||||
--num-cpu=20 \
|
||||
--cut-sil=True
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# get features' stats(mean and std)
|
||||
echo "Get features' stats ..."
|
||||
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--field-name="speech"
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# normalize and covert phone/speaker to id, dev and test should use train's stats
|
||||
echo "Normalize ..."
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--dumpdir=dump/train/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/dev/raw/metadata.jsonl \
|
||||
--dumpdir=dump/dev/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/test/raw/metadata.jsonl \
|
||||
--dumpdir=dump/test/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
fi
|
@ -0,0 +1 @@
|
||||
#!/bin/bash
|
@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
|
||||
python3 ${BIN_DIR}/train.py \
|
||||
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--config=${config_path} \
|
||||
--output-dir=${train_output_path} \
|
||||
--ngpu=1 \
|
||||
--phones-dict=dump/phone_id_map.txt
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
MODEL=ernie_sat
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
gpus=0,1
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
conf_path=conf/default.yaml
|
||||
train_output_path=exp/default
|
||||
ckpt_name=snapshot_iter_153.pdz
|
||||
|
||||
# with the following command, you can choose the stage range you want to run
|
||||
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||
# this can not be mixed use with `$1`, `$2` ...
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
./local/preprocess.sh ${conf_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# synthesize, vocoder is pwgan
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||
fi
|
@ -0,0 +1,351 @@
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
|
||||
fs: 24000 # sr
|
||||
n_fft: 2048 # FFT size (samples).
|
||||
n_shift: 300 # Hop size (samples). 12.5ms
|
||||
win_length: 1200 # Window length (samples). 50ms
|
||||
# If set to null, it will be the same as fft_size.
|
||||
window: "hann" # Window function.
|
||||
|
||||
# Only used for feats_type != raw
|
||||
|
||||
fmin: 80 # Minimum frequency of Mel basis.
|
||||
fmax: 7600 # Maximum frequency of Mel basis.
|
||||
n_mels: 80 # The number of mel basis.
|
||||
|
||||
mean_phn_span: 8
|
||||
mlm_prob: 0.8
|
||||
|
||||
###########################################################
|
||||
# DATA SETTING #
|
||||
###########################################################
|
||||
batch_size: 64
|
||||
num_workers: 2
|
||||
|
||||
###########################################################
|
||||
# MODEL SETTING #
|
||||
###########################################################
|
||||
model:
|
||||
text_masking: true
|
||||
postnet_layers: 5
|
||||
postnet_filts: 5
|
||||
postnet_chans: 256
|
||||
encoder_type: conformer
|
||||
decoder_type: conformer
|
||||
enc_input_layer: sega_mlm
|
||||
enc_pre_speech_layer: 0
|
||||
enc_cnn_module_kernel: 7
|
||||
enc_attention_dim: 384
|
||||
enc_attention_heads: 2
|
||||
enc_linear_units: 1536
|
||||
enc_num_blocks: 4
|
||||
enc_dropout_rate: 0.2
|
||||
enc_positional_dropout_rate: 0.2
|
||||
enc_attention_dropout_rate: 0.2
|
||||
enc_normalize_before: true
|
||||
enc_macaron_style: true
|
||||
enc_use_cnn_module: true
|
||||
enc_selfattention_layer_type: legacy_rel_selfattn
|
||||
enc_activation_type: swish
|
||||
enc_pos_enc_layer_type: legacy_rel_pos
|
||||
enc_positionwise_layer_type: conv1d
|
||||
enc_positionwise_conv_kernel_size: 3
|
||||
dec_cnn_module_kernel: 31
|
||||
dec_attention_dim: 384
|
||||
dec_attention_heads: 2
|
||||
dec_linear_units: 1536
|
||||
dec_num_blocks: 4
|
||||
dec_dropout_rate: 0.2
|
||||
dec_positional_dropout_rate: 0.2
|
||||
dec_attention_dropout_rate: 0.2
|
||||
dec_macaron_style: true
|
||||
dec_use_cnn_module: true
|
||||
dec_selfattention_layer_type: legacy_rel_selfattn
|
||||
dec_activation_type: swish
|
||||
dec_pos_enc_layer_type: legacy_rel_pos
|
||||
dec_positionwise_layer_type: conv1d
|
||||
dec_positionwise_conv_kernel_size: 3
|
||||
|
||||
###########################################################
|
||||
# OPTIMIZER SETTING #
|
||||
###########################################################
|
||||
optimizer:
|
||||
optim: adam # optimizer type
|
||||
learning_rate: 0.001 # learning rate
|
||||
|
||||
###########################################################
|
||||
# TRAINING SETTING #
|
||||
###########################################################
|
||||
max_epoch: 100
|
||||
num_snapshots: 5
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
###########################################################
|
||||
seed: 10086
|
||||
|
||||
token_list:
|
||||
- <blank>
|
||||
- <unk>
|
||||
- AH0
|
||||
- T
|
||||
- N
|
||||
- sp
|
||||
- S
|
||||
- R
|
||||
- D
|
||||
- L
|
||||
- Z
|
||||
- DH
|
||||
- IH1
|
||||
- K
|
||||
- W
|
||||
- M
|
||||
- EH1
|
||||
- AE1
|
||||
- ER0
|
||||
- B
|
||||
- IY1
|
||||
- P
|
||||
- V
|
||||
- IY0
|
||||
- F
|
||||
- HH
|
||||
- AA1
|
||||
- AY1
|
||||
- AH1
|
||||
- EY1
|
||||
- IH0
|
||||
- AO1
|
||||
- OW1
|
||||
- UW1
|
||||
- G
|
||||
- NG
|
||||
- SH
|
||||
- Y
|
||||
- TH
|
||||
- ER1
|
||||
- JH
|
||||
- UH1
|
||||
- AW1
|
||||
- CH
|
||||
- IH2
|
||||
- OW0
|
||||
- OW2
|
||||
- EY2
|
||||
- EH2
|
||||
- UW0
|
||||
- OY1
|
||||
- ZH
|
||||
- EH0
|
||||
- AY2
|
||||
- AW2
|
||||
- AA2
|
||||
- AE2
|
||||
- IY2
|
||||
- AH2
|
||||
- AE0
|
||||
- AO2
|
||||
- AY0
|
||||
- AO0
|
||||
- UW2
|
||||
- UH2
|
||||
- AA0
|
||||
- EY0
|
||||
- AW0
|
||||
- UH0
|
||||
- ER2
|
||||
- OY2
|
||||
- OY0
|
||||
- d
|
||||
- sh
|
||||
- ii
|
||||
- j
|
||||
- zh
|
||||
- l
|
||||
- x
|
||||
- b
|
||||
- g
|
||||
- uu
|
||||
- e5
|
||||
- h
|
||||
- q
|
||||
- m
|
||||
- i1
|
||||
- t
|
||||
- z
|
||||
- ch
|
||||
- f
|
||||
- s
|
||||
- u4
|
||||
- ix4
|
||||
- i4
|
||||
- n
|
||||
- i3
|
||||
- iu3
|
||||
- vv
|
||||
- ian4
|
||||
- ix2
|
||||
- r
|
||||
- e4
|
||||
- ai4
|
||||
- k
|
||||
- ing2
|
||||
- a1
|
||||
- en2
|
||||
- ui4
|
||||
- ong1
|
||||
- uo3
|
||||
- u2
|
||||
- u3
|
||||
- ao4
|
||||
- ee
|
||||
- p
|
||||
- an1
|
||||
- eng2
|
||||
- i2
|
||||
- in1
|
||||
- c
|
||||
- ai2
|
||||
- ian2
|
||||
- e2
|
||||
- an4
|
||||
- ing4
|
||||
- v4
|
||||
- ai3
|
||||
- a5
|
||||
- ian3
|
||||
- eng1
|
||||
- ong4
|
||||
- ang4
|
||||
- ian1
|
||||
- ing1
|
||||
- iy4
|
||||
- ao3
|
||||
- ang1
|
||||
- uo4
|
||||
- u1
|
||||
- iao4
|
||||
- iu4
|
||||
- a4
|
||||
- van2
|
||||
- ie4
|
||||
- ang2
|
||||
- ou4
|
||||
- iang4
|
||||
- ix1
|
||||
- er4
|
||||
- iy1
|
||||
- e1
|
||||
- en1
|
||||
- ui2
|
||||
- an3
|
||||
- ei4
|
||||
- ong2
|
||||
- uo1
|
||||
- ou3
|
||||
- uo2
|
||||
- iao1
|
||||
- ou1
|
||||
- an2
|
||||
- uan4
|
||||
- ia4
|
||||
- ia1
|
||||
- ang3
|
||||
- v3
|
||||
- iu2
|
||||
- iao3
|
||||
- in4
|
||||
- a3
|
||||
- ei3
|
||||
- iang3
|
||||
- v2
|
||||
- eng4
|
||||
- en3
|
||||
- aa
|
||||
- uan1
|
||||
- v1
|
||||
- ao1
|
||||
- ve4
|
||||
- ie3
|
||||
- ai1
|
||||
- ing3
|
||||
- iang1
|
||||
- a2
|
||||
- ui1
|
||||
- en4
|
||||
- en5
|
||||
- in3
|
||||
- uan3
|
||||
- e3
|
||||
- ie1
|
||||
- ve2
|
||||
- ei2
|
||||
- in2
|
||||
- ix3
|
||||
- uan2
|
||||
- iang2
|
||||
- ie2
|
||||
- ua4
|
||||
- ou2
|
||||
- uai4
|
||||
- er2
|
||||
- eng3
|
||||
- uang3
|
||||
- un1
|
||||
- ong3
|
||||
- uang4
|
||||
- vn4
|
||||
- un2
|
||||
- iy3
|
||||
- iz4
|
||||
- ui3
|
||||
- iao2
|
||||
- iong4
|
||||
- un4
|
||||
- van4
|
||||
- ao2
|
||||
- uang1
|
||||
- iy5
|
||||
- o2
|
||||
- ei1
|
||||
- ua1
|
||||
- iu1
|
||||
- uang2
|
||||
- er5
|
||||
- o1
|
||||
- un3
|
||||
- vn1
|
||||
- vn2
|
||||
- o4
|
||||
- ve1
|
||||
- van3
|
||||
- ua2
|
||||
- er3
|
||||
- iong3
|
||||
- van1
|
||||
- ia2
|
||||
- iy2
|
||||
- ia3
|
||||
- iong1
|
||||
- uo5
|
||||
- oo
|
||||
- ve3
|
||||
- ou5
|
||||
- uai3
|
||||
- ian5
|
||||
- iong2
|
||||
- uai2
|
||||
- uai1
|
||||
- ua3
|
||||
- vn3
|
||||
- ia5
|
||||
- ie5
|
||||
- ueng1
|
||||
- o5
|
||||
- o3
|
||||
- iang5
|
||||
- ei5
|
||||
- <sos/eos>
|
@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
config_path=$1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# get durations from MFA's result
|
||||
echo "Generate durations.txt from MFA results ..."
|
||||
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||
--inputdir=./aishell3_alignment_tone \
|
||||
--output durations.txt \
|
||||
--config=${config_path}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# extract features
|
||||
echo "Extract features ..."
|
||||
python3 ${BIN_DIR}/preprocess.py \
|
||||
--dataset=aishell3 \
|
||||
--rootdir=~/datasets/data_aishell3/ \
|
||||
--dumpdir=dump \
|
||||
--dur-file=durations.txt \
|
||||
--config=${config_path} \
|
||||
--num-cpu=20 \
|
||||
--cut-sil=True
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# get features' stats(mean and std)
|
||||
echo "Get features' stats ..."
|
||||
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--field-name="speech"
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# normalize and covert phone/speaker to id, dev and test should use train's stats
|
||||
echo "Normalize ..."
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--dumpdir=dump/train/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--pitch-stats=dump/train/pitch_stats.npy \
|
||||
--energy-stats=dump/train/energy_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/dev/raw/metadata.jsonl \
|
||||
--dumpdir=dump/dev/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--pitch-stats=dump/train/pitch_stats.npy \
|
||||
--energy-stats=dump/train/energy_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/test/raw/metadata.jsonl \
|
||||
--dumpdir=dump/test/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--pitch-stats=dump/train/pitch_stats.npy \
|
||||
--energy-stats=dump/train/energy_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
fi
|
@ -0,0 +1 @@
|
||||
#!/bin/bash
|
@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
|
||||
python3 ${BIN_DIR}/train.py \
|
||||
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--config=${config_path} \
|
||||
--output-dir=${train_output_path} \
|
||||
--ngpu=1 \
|
||||
--phones-dict=dump/phone_id_map.txt
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
MODEL=ernie_sat
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
gpus=0,1
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
conf_path=conf/default.yaml
|
||||
train_output_path=exp/default
|
||||
ckpt_name=snapshot_iter_153.pdz
|
||||
|
||||
# with the following command, you can choose the stage range you want to run
|
||||
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||
# this can not be mixed use with `$1`, `$2` ...
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
./local/preprocess.sh ${conf_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# synthesize, vocoder is pwgan
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||
fi
|
@ -0,0 +1,622 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import random
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import paddle
|
||||
import soundfile as sf
|
||||
import yaml
|
||||
from align import alignment
|
||||
from align import alignment_zh
|
||||
from align import words2phns
|
||||
from align import words2phns_zh
|
||||
from paddle import nn
|
||||
from sedit_arg_parser import parse_args
|
||||
from utils import eval_durs
|
||||
from utils import get_voc_out
|
||||
from utils import is_chinese
|
||||
from utils import load_num_sequence_text
|
||||
from utils import read_2col_text
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.am_batch_fn import build_mlm_collate_fn
|
||||
from paddlespeech.t2s.models.ernie_sat.ernie_sat import ErnieSAT
|
||||
|
||||
random.seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
|
||||
def get_wav(wav_path: str,
|
||||
source_lang: str='english',
|
||||
target_lang: str='english',
|
||||
model_name: str="paddle_checkpoint_en",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
non_autoreg: bool=True):
|
||||
wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length = get_mlm_output(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
model_name=model_name,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
use_teacher_forcing=non_autoreg)
|
||||
|
||||
masked_feat = output_feat[new_span_bdy[0]:new_span_bdy[1]]
|
||||
|
||||
alt_wav = get_voc_out(masked_feat)
|
||||
|
||||
old_time_bdy = [hop_length * x for x in old_span_bdy]
|
||||
|
||||
wav_replaced = np.concatenate(
|
||||
[wav_org[:old_time_bdy[0]], alt_wav, wav_org[old_time_bdy[1]:]])
|
||||
|
||||
data_dict = {"origin": wav_org, "output": wav_replaced}
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def load_model(model_name: str="paddle_checkpoint_en"):
|
||||
config_path = './pretrained_model/{}/default.yaml'.format(model_name)
|
||||
model_path = './pretrained_model/{}/model.pdparams'.format(model_name)
|
||||
with open(config_path) as f:
|
||||
conf = CfgNode(yaml.safe_load(f))
|
||||
token_list = list(conf.token_list)
|
||||
vocab_size = len(token_list)
|
||||
odim = conf.n_mels
|
||||
mlm_model = ErnieSAT(idim=vocab_size, odim=odim, **conf["model"])
|
||||
state_dict = paddle.load(model_path)
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = "model." + key
|
||||
new_state_dict[new_key] = value
|
||||
mlm_model.set_state_dict(new_state_dict)
|
||||
mlm_model.eval()
|
||||
|
||||
return mlm_model, conf
|
||||
|
||||
|
||||
def read_data(uid: str, prefix: os.PathLike):
|
||||
# 获取 uid 对应的文本
|
||||
mfa_text = read_2col_text(prefix + '/text')[uid]
|
||||
# 获取 uid 对应的音频路径
|
||||
mfa_wav_path = read_2col_text(prefix + '/wav.scp')[uid]
|
||||
if not os.path.isabs(mfa_wav_path):
|
||||
mfa_wav_path = prefix + mfa_wav_path
|
||||
return mfa_text, mfa_wav_path
|
||||
|
||||
|
||||
def get_align_data(uid: str, prefix: os.PathLike):
|
||||
mfa_path = prefix + "mfa_"
|
||||
mfa_text = read_2col_text(mfa_path + 'text')[uid]
|
||||
mfa_start = load_num_sequence_text(
|
||||
mfa_path + 'start', loader_type='text_float')[uid]
|
||||
mfa_end = load_num_sequence_text(
|
||||
mfa_path + 'end', loader_type='text_float')[uid]
|
||||
mfa_wav_path = read_2col_text(mfa_path + 'wav.scp')[uid]
|
||||
return mfa_text, mfa_start, mfa_end, mfa_wav_path
|
||||
|
||||
|
||||
# 获取需要被 mask 的 mel 帧的范围
|
||||
def get_masked_mel_bdy(mfa_start: List[float],
|
||||
mfa_end: List[float],
|
||||
fs: int,
|
||||
hop_length: int,
|
||||
span_to_repl: List[List[int]]):
|
||||
align_start = np.array(mfa_start)
|
||||
align_end = np.array(mfa_end)
|
||||
align_start = np.floor(fs * align_start / hop_length).astype('int')
|
||||
align_end = np.floor(fs * align_end / hop_length).astype('int')
|
||||
if span_to_repl[0] >= len(mfa_start):
|
||||
span_bdy = [align_end[-1], align_end[-1]]
|
||||
else:
|
||||
span_bdy = [
|
||||
align_start[span_to_repl[0]], align_end[span_to_repl[1] - 1]
|
||||
]
|
||||
return span_bdy, align_start, align_end
|
||||
|
||||
|
||||
def recover_dict(word2phns: Dict[str, str], tp_word2phns: Dict[str, str]):
|
||||
dic = {}
|
||||
keys_to_del = []
|
||||
exist_idx = []
|
||||
sp_count = 0
|
||||
add_sp_count = 0
|
||||
for key in word2phns.keys():
|
||||
idx, wrd = key.split('_')
|
||||
if wrd == 'sp':
|
||||
sp_count += 1
|
||||
exist_idx.append(int(idx))
|
||||
else:
|
||||
keys_to_del.append(key)
|
||||
|
||||
for key in keys_to_del:
|
||||
del word2phns[key]
|
||||
|
||||
cur_id = 0
|
||||
for key in tp_word2phns.keys():
|
||||
if cur_id in exist_idx:
|
||||
dic[str(cur_id) + "_sp"] = 'sp'
|
||||
cur_id += 1
|
||||
add_sp_count += 1
|
||||
idx, wrd = key.split('_')
|
||||
dic[str(cur_id) + "_" + wrd] = tp_word2phns[key]
|
||||
cur_id += 1
|
||||
|
||||
if add_sp_count + 1 == sp_count:
|
||||
dic[str(cur_id) + "_sp"] = 'sp'
|
||||
add_sp_count += 1
|
||||
|
||||
assert add_sp_count == sp_count, "sp are not added in dic"
|
||||
return dic
|
||||
|
||||
|
||||
def get_max_idx(dic):
|
||||
return sorted([int(key.split('_')[0]) for key in dic.keys()])[-1]
|
||||
|
||||
|
||||
def get_phns_and_spans(wav_path: str,
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
source_lang: str="english",
|
||||
target_lang: str="english"):
|
||||
is_append = (old_str == new_str[:len(old_str)])
|
||||
old_phns, mfa_start, mfa_end = [], [], []
|
||||
# source
|
||||
if source_lang == "english":
|
||||
intervals, word2phns = alignment(wav_path, old_str)
|
||||
elif source_lang == "chinese":
|
||||
intervals, word2phns = alignment_zh(wav_path, old_str)
|
||||
_, tp_word2phns = words2phns_zh(old_str)
|
||||
|
||||
for key, value in tp_word2phns.items():
|
||||
idx, wrd = key.split('_')
|
||||
cur_val = " ".join(value)
|
||||
tp_word2phns[key] = cur_val
|
||||
|
||||
word2phns = recover_dict(word2phns, tp_word2phns)
|
||||
else:
|
||||
assert source_lang == "chinese" or source_lang == "english", \
|
||||
"source_lang is wrong..."
|
||||
|
||||
for item in intervals:
|
||||
old_phns.append(item[0])
|
||||
mfa_start.append(float(item[1]))
|
||||
mfa_end.append(float(item[2]))
|
||||
# target
|
||||
if is_append and (source_lang != target_lang):
|
||||
cross_lingual_clone = True
|
||||
else:
|
||||
cross_lingual_clone = False
|
||||
|
||||
if cross_lingual_clone:
|
||||
str_origin = new_str[:len(old_str)]
|
||||
str_append = new_str[len(old_str):]
|
||||
|
||||
if target_lang == "chinese":
|
||||
phns_origin, origin_word2phns = words2phns(str_origin)
|
||||
phns_append, append_word2phns_tmp = words2phns_zh(str_append)
|
||||
|
||||
elif target_lang == "english":
|
||||
# 原始句子
|
||||
phns_origin, origin_word2phns = words2phns_zh(str_origin)
|
||||
# clone 句子
|
||||
phns_append, append_word2phns_tmp = words2phns(str_append)
|
||||
else:
|
||||
assert target_lang == "chinese" or target_lang == "english", \
|
||||
"cloning is not support for this language, please check it."
|
||||
|
||||
new_phns = phns_origin + phns_append
|
||||
|
||||
append_word2phns = {}
|
||||
length = len(origin_word2phns)
|
||||
for key, value in append_word2phns_tmp.items():
|
||||
idx, wrd = key.split('_')
|
||||
append_word2phns[str(int(idx) + length) + '_' + wrd] = value
|
||||
new_word2phns = origin_word2phns.copy()
|
||||
new_word2phns.update(append_word2phns)
|
||||
|
||||
else:
|
||||
if source_lang == target_lang and target_lang == "english":
|
||||
new_phns, new_word2phns = words2phns(new_str)
|
||||
elif source_lang == target_lang and target_lang == "chinese":
|
||||
new_phns, new_word2phns = words2phns_zh(new_str)
|
||||
else:
|
||||
assert source_lang == target_lang, \
|
||||
"source language is not same with target language..."
|
||||
|
||||
span_to_repl = [0, len(old_phns) - 1]
|
||||
span_to_add = [0, len(new_phns) - 1]
|
||||
left_idx = 0
|
||||
new_phns_left = []
|
||||
sp_count = 0
|
||||
# find the left different index
|
||||
for key in word2phns.keys():
|
||||
idx, wrd = key.split('_')
|
||||
if wrd == 'sp':
|
||||
sp_count += 1
|
||||
new_phns_left.append('sp')
|
||||
else:
|
||||
idx = str(int(idx) - sp_count)
|
||||
if idx + '_' + wrd in new_word2phns:
|
||||
left_idx += len(new_word2phns[idx + '_' + wrd])
|
||||
new_phns_left.extend(word2phns[key].split())
|
||||
else:
|
||||
span_to_repl[0] = len(new_phns_left)
|
||||
span_to_add[0] = len(new_phns_left)
|
||||
break
|
||||
|
||||
# reverse word2phns and new_word2phns
|
||||
right_idx = 0
|
||||
new_phns_right = []
|
||||
sp_count = 0
|
||||
word2phns_max_idx = get_max_idx(word2phns)
|
||||
new_word2phns_max_idx = get_max_idx(new_word2phns)
|
||||
new_phns_mid = []
|
||||
if is_append:
|
||||
new_phns_right = []
|
||||
new_phns_mid = new_phns[left_idx:]
|
||||
span_to_repl[0] = len(new_phns_left)
|
||||
span_to_add[0] = len(new_phns_left)
|
||||
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
|
||||
span_to_repl[1] = len(old_phns) - len(new_phns_right)
|
||||
# speech edit
|
||||
else:
|
||||
for key in list(word2phns.keys())[::-1]:
|
||||
idx, wrd = key.split('_')
|
||||
if wrd == 'sp':
|
||||
sp_count += 1
|
||||
new_phns_right = ['sp'] + new_phns_right
|
||||
else:
|
||||
idx = str(new_word2phns_max_idx - (word2phns_max_idx - int(idx)
|
||||
- sp_count))
|
||||
if idx + '_' + wrd in new_word2phns:
|
||||
right_idx -= len(new_word2phns[idx + '_' + wrd])
|
||||
new_phns_right = word2phns[key].split() + new_phns_right
|
||||
else:
|
||||
span_to_repl[1] = len(old_phns) - len(new_phns_right)
|
||||
new_phns_mid = new_phns[left_idx:right_idx]
|
||||
span_to_add[1] = len(new_phns_left) + len(new_phns_mid)
|
||||
if len(new_phns_mid) == 0:
|
||||
span_to_add[1] = min(span_to_add[1] + 1, len(new_phns))
|
||||
span_to_add[0] = max(0, span_to_add[0] - 1)
|
||||
span_to_repl[0] = max(0, span_to_repl[0] - 1)
|
||||
span_to_repl[1] = min(span_to_repl[1] + 1,
|
||||
len(old_phns))
|
||||
break
|
||||
new_phns = new_phns_left + new_phns_mid + new_phns_right
|
||||
'''
|
||||
For that reason cover should not be given.
|
||||
For that reason cover is impossible to be given.
|
||||
span_to_repl: [17, 23] "should not"
|
||||
span_to_add: [17, 30] "is impossible to"
|
||||
'''
|
||||
return mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add
|
||||
|
||||
|
||||
# mfa 获得的 duration 和 fs2 的 duration_predictor 获取的 duration 可能不同
|
||||
# 此处获得一个缩放比例, 用于预测值和真实值之间的缩放
|
||||
def get_dur_adj_factor(orig_dur: List[int],
|
||||
pred_dur: List[int],
|
||||
phns: List[str]):
|
||||
length = 0
|
||||
factor_list = []
|
||||
for orig, pred, phn in zip(orig_dur, pred_dur, phns):
|
||||
if pred == 0 or phn == 'sp':
|
||||
continue
|
||||
else:
|
||||
factor_list.append(orig / pred)
|
||||
factor_list = np.array(factor_list)
|
||||
factor_list.sort()
|
||||
if len(factor_list) < 5:
|
||||
return 1
|
||||
length = 2
|
||||
avg = np.average(factor_list[length:-length])
|
||||
return avg
|
||||
|
||||
|
||||
def prep_feats_with_dur(wav_path: str,
|
||||
source_lang: str="English",
|
||||
target_lang: str="English",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
mask_reconstruct: bool=False,
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False,
|
||||
fs: int=24000,
|
||||
hop_length: int=300):
|
||||
'''
|
||||
Returns:
|
||||
np.ndarray: new wav, replace the part to be edited in original wav with 0
|
||||
List[str]: new phones
|
||||
List[float]: mfa start of new wav
|
||||
List[float]: mfa end of new wav
|
||||
List[int]: masked mel boundary of original wav
|
||||
List[int]: masked mel boundary of new wav
|
||||
'''
|
||||
wav_org, _ = librosa.load(wav_path, sr=fs)
|
||||
|
||||
mfa_start, mfa_end, old_phns, new_phns, span_to_repl, span_to_add = get_phns_and_spans(
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang)
|
||||
|
||||
if start_end_sp:
|
||||
if new_phns[-1] != 'sp':
|
||||
new_phns = new_phns + ['sp']
|
||||
# 中文的 phns 不一定都在 fastspeech2 的字典里, 用 sp 代替
|
||||
if target_lang == "english" or target_lang == "chinese":
|
||||
old_durs = eval_durs(old_phns, target_lang=source_lang)
|
||||
else:
|
||||
assert target_lang == "chinese" or target_lang == "english", \
|
||||
"calculate duration_predict is not support for this language..."
|
||||
|
||||
orig_old_durs = [e - s for e, s in zip(mfa_end, mfa_start)]
|
||||
if '[MASK]' in new_str:
|
||||
new_phns = old_phns
|
||||
span_to_add = span_to_repl
|
||||
d_factor_left = get_dur_adj_factor(
|
||||
orig_dur=orig_old_durs[:span_to_repl[0]],
|
||||
pred_dur=old_durs[:span_to_repl[0]],
|
||||
phns=old_phns[:span_to_repl[0]])
|
||||
d_factor_right = get_dur_adj_factor(
|
||||
orig_dur=orig_old_durs[span_to_repl[1]:],
|
||||
pred_dur=old_durs[span_to_repl[1]:],
|
||||
phns=old_phns[span_to_repl[1]:])
|
||||
d_factor = (d_factor_left + d_factor_right) / 2
|
||||
new_durs_adjusted = [d_factor * i for i in old_durs]
|
||||
else:
|
||||
if duration_adjust:
|
||||
d_factor = get_dur_adj_factor(
|
||||
orig_dur=orig_old_durs, pred_dur=old_durs, phns=old_phns)
|
||||
d_factor = d_factor * 1.25
|
||||
else:
|
||||
d_factor = 1
|
||||
|
||||
if target_lang == "english" or target_lang == "chinese":
|
||||
new_durs = eval_durs(new_phns, target_lang=target_lang)
|
||||
else:
|
||||
assert target_lang == "chinese" or target_lang == "english", \
|
||||
"calculate duration_predict is not support for this language..."
|
||||
|
||||
new_durs_adjusted = [d_factor * i for i in new_durs]
|
||||
|
||||
new_span_dur_sum = sum(new_durs_adjusted[span_to_add[0]:span_to_add[1]])
|
||||
old_span_dur_sum = sum(orig_old_durs[span_to_repl[0]:span_to_repl[1]])
|
||||
dur_offset = new_span_dur_sum - old_span_dur_sum
|
||||
new_mfa_start = mfa_start[:span_to_repl[0]]
|
||||
new_mfa_end = mfa_end[:span_to_repl[0]]
|
||||
for i in new_durs_adjusted[span_to_add[0]:span_to_add[1]]:
|
||||
if len(new_mfa_end) == 0:
|
||||
new_mfa_start.append(0)
|
||||
new_mfa_end.append(i)
|
||||
else:
|
||||
new_mfa_start.append(new_mfa_end[-1])
|
||||
new_mfa_end.append(new_mfa_end[-1] + i)
|
||||
new_mfa_start += [i + dur_offset for i in mfa_start[span_to_repl[1]:]]
|
||||
new_mfa_end += [i + dur_offset for i in mfa_end[span_to_repl[1]:]]
|
||||
|
||||
# 3. get new wav
|
||||
# 在原始句子后拼接
|
||||
if span_to_repl[0] >= len(mfa_start):
|
||||
left_idx = len(wav_org)
|
||||
right_idx = left_idx
|
||||
# 在原始句子中间替换
|
||||
else:
|
||||
left_idx = int(np.floor(mfa_start[span_to_repl[0]] * fs))
|
||||
right_idx = int(np.ceil(mfa_end[span_to_repl[1] - 1] * fs))
|
||||
blank_wav = np.zeros(
|
||||
(int(np.ceil(new_span_dur_sum * fs)), ), dtype=wav_org.dtype)
|
||||
# 原始音频,需要编辑的部分替换成空音频,空音频的时间由 fs2 的 duration_predictor 决定
|
||||
new_wav = np.concatenate(
|
||||
[wav_org[:left_idx], blank_wav, wav_org[right_idx:]])
|
||||
|
||||
# 4. get old and new mel span to be mask
|
||||
# [92, 92]
|
||||
|
||||
old_span_bdy, mfa_start, mfa_end = get_masked_mel_bdy(
|
||||
mfa_start=mfa_start,
|
||||
mfa_end=mfa_end,
|
||||
fs=fs,
|
||||
hop_length=hop_length,
|
||||
span_to_repl=span_to_repl)
|
||||
# [92, 174]
|
||||
# new_mfa_start, new_mfa_end 时间级别的开始和结束时间 -> 帧级别
|
||||
new_span_bdy, new_mfa_start, new_mfa_end = get_masked_mel_bdy(
|
||||
mfa_start=new_mfa_start,
|
||||
mfa_end=new_mfa_end,
|
||||
fs=fs,
|
||||
hop_length=hop_length,
|
||||
span_to_repl=span_to_add)
|
||||
|
||||
# old_span_bdy, new_span_bdy 是帧级别的范围
|
||||
return new_wav, new_phns, new_mfa_start, new_mfa_end, old_span_bdy, new_span_bdy
|
||||
|
||||
|
||||
def prep_feats(wav_path: str,
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False,
|
||||
mask_reconstruct: bool=False,
|
||||
fs: int=24000,
|
||||
hop_length: int=300,
|
||||
token_list: List[str]=[]):
|
||||
wav, phns, mfa_start, mfa_end, old_span_bdy, new_span_bdy = prep_feats_with_dur(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
wav_path=wav_path,
|
||||
duration_adjust=duration_adjust,
|
||||
start_end_sp=start_end_sp,
|
||||
mask_reconstruct=mask_reconstruct,
|
||||
fs=fs,
|
||||
hop_length=hop_length)
|
||||
|
||||
token_to_id = {item: i for i, item in enumerate(token_list)}
|
||||
text = np.array(
|
||||
list(map(lambda x: token_to_id.get(x, token_to_id['<unk>']), phns)))
|
||||
span_bdy = np.array(new_span_bdy)
|
||||
|
||||
batch = [('1', {
|
||||
"speech": wav,
|
||||
"align_start": mfa_start,
|
||||
"align_end": mfa_end,
|
||||
"text": text,
|
||||
"span_bdy": span_bdy
|
||||
})]
|
||||
|
||||
return batch, old_span_bdy, new_span_bdy
|
||||
|
||||
|
||||
def decode_with_model(mlm_model: nn.Layer,
|
||||
collate_fn,
|
||||
wav_path: str,
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
use_teacher_forcing: bool=False,
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False,
|
||||
fs: int=24000,
|
||||
hop_length: int=300,
|
||||
token_list: List[str]=[]):
|
||||
batch, old_span_bdy, new_span_bdy = prep_feats(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
duration_adjust=duration_adjust,
|
||||
start_end_sp=start_end_sp,
|
||||
fs=fs,
|
||||
hop_length=hop_length,
|
||||
token_list=token_list)
|
||||
|
||||
feats = collate_fn(batch)[1]
|
||||
|
||||
if 'text_masked_pos' in feats.keys():
|
||||
feats.pop('text_masked_pos')
|
||||
|
||||
output = mlm_model.inference(
|
||||
text=feats['text'],
|
||||
speech=feats['speech'],
|
||||
masked_pos=feats['masked_pos'],
|
||||
speech_mask=feats['speech_mask'],
|
||||
text_mask=feats['text_mask'],
|
||||
speech_seg_pos=feats['speech_seg_pos'],
|
||||
text_seg_pos=feats['text_seg_pos'],
|
||||
span_bdy=new_span_bdy,
|
||||
use_teacher_forcing=use_teacher_forcing)
|
||||
|
||||
# 拼接音频
|
||||
output_feat = paddle.concat(x=output, axis=0)
|
||||
wav_org, _ = librosa.load(wav_path, sr=fs)
|
||||
return wav_org, output_feat, old_span_bdy, new_span_bdy, fs, hop_length
|
||||
|
||||
|
||||
def get_mlm_output(wav_path: str,
|
||||
model_name: str="paddle_checkpoint_en",
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
old_str: str="",
|
||||
new_str: str="",
|
||||
use_teacher_forcing: bool=False,
|
||||
duration_adjust: bool=True,
|
||||
start_end_sp: bool=False):
|
||||
mlm_model, train_conf = load_model(model_name)
|
||||
|
||||
collate_fn = build_mlm_collate_fn(
|
||||
sr=train_conf.fs,
|
||||
n_fft=train_conf.n_fft,
|
||||
hop_length=train_conf.n_shift,
|
||||
win_length=train_conf.win_length,
|
||||
n_mels=train_conf.n_mels,
|
||||
fmin=train_conf.fmin,
|
||||
fmax=train_conf.fmax,
|
||||
mlm_prob=train_conf.mlm_prob,
|
||||
mean_phn_span=train_conf.mean_phn_span,
|
||||
seg_emb=train_conf.model['enc_input_layer'] == 'sega_mlm')
|
||||
|
||||
return decode_with_model(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
mlm_model=mlm_model,
|
||||
collate_fn=collate_fn,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str,
|
||||
use_teacher_forcing=use_teacher_forcing,
|
||||
duration_adjust=duration_adjust,
|
||||
start_end_sp=start_end_sp,
|
||||
fs=train_conf.fs,
|
||||
hop_length=train_conf.n_shift,
|
||||
token_list=train_conf.token_list)
|
||||
|
||||
|
||||
def evaluate(uid: str,
|
||||
source_lang: str="english",
|
||||
target_lang: str="english",
|
||||
prefix: os.PathLike="./prompt/dev/",
|
||||
model_name: str="paddle_checkpoint_en",
|
||||
new_str: str="",
|
||||
prompt_decoding: bool=False,
|
||||
task_name: str=None):
|
||||
|
||||
# get origin text and path of origin wav
|
||||
old_str, wav_path = read_data(uid=uid, prefix=prefix)
|
||||
|
||||
if task_name == 'edit':
|
||||
new_str = new_str
|
||||
elif task_name == 'synthesize':
|
||||
new_str = old_str + new_str
|
||||
else:
|
||||
new_str = old_str + ' '.join([ch for ch in new_str if is_chinese(ch)])
|
||||
|
||||
print('new_str is ', new_str)
|
||||
|
||||
results_dict = get_wav(
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
model_name=model_name,
|
||||
wav_path=wav_path,
|
||||
old_str=old_str,
|
||||
new_str=new_str)
|
||||
return results_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# parse config and args
|
||||
args = parse_args()
|
||||
|
||||
data_dict = evaluate(
|
||||
uid=args.uid,
|
||||
source_lang=args.source_lang,
|
||||
target_lang=args.target_lang,
|
||||
prefix=args.prefix,
|
||||
model_name=args.model_name,
|
||||
new_str=args.new_str,
|
||||
task_name=args.task_name)
|
||||
sf.write(args.output_name, data_dict['output'], samplerate=24000)
|
||||
print("finished...")
|
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
# en --> zh 的 语音合成
|
||||
# 根据 Prompt_003_new 作为提示语音: This was not the show for me. 来合成: '今天天气很好'
|
||||
# 注: 输入的 new_str 需为中文汉字, 否则会通过预处理只保留中文汉字, 即合成预处理后的中文语音。
|
||||
|
||||
python local/inference_new.py \
|
||||
--task_name=cross-lingual_clone \
|
||||
--model_name=paddle_checkpoint_dual_mask_enzh \
|
||||
--uid=Prompt_003_new \
|
||||
--new_str='今天天气很好.' \
|
||||
--prefix='./prompt/dev/' \
|
||||
--source_lang=english \
|
||||
--target_lang=chinese \
|
||||
--output_name=pred_clone.wav \
|
||||
--voc=pwgan_aishell3 \
|
||||
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
|
||||
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||
--am=fastspeech2_csmsc \
|
||||
--am_config=download/fastspeech2_conformer_baker_ckpt_0.5/conformer.yaml \
|
||||
--am_ckpt=download/fastspeech2_conformer_baker_ckpt_0.5/snapshot_iter_76000.pdz \
|
||||
--am_stat=download/fastspeech2_conformer_baker_ckpt_0.5/speech_stats.npy \
|
||||
--phones_dict=download/fastspeech2_conformer_baker_ckpt_0.5/phone_id_map.txt
|
@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
# 纯英文的语音合成
|
||||
# 样例为根据 p299_096 对应的语音作为提示语音: This was not the show for me. 来合成: 'I enjoy my life.'
|
||||
|
||||
python local/inference_new.py \
|
||||
--task_name=synthesize \
|
||||
--model_name=paddle_checkpoint_en \
|
||||
--uid=p299_096 \
|
||||
--new_str='I enjoy my life, do you?' \
|
||||
--prefix='./prompt/dev/' \
|
||||
--source_lang=english \
|
||||
--target_lang=english \
|
||||
--output_name=pred_gen.wav \
|
||||
--voc=pwgan_aishell3 \
|
||||
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
|
||||
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||
--am=fastspeech2_ljspeech \
|
||||
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
|
||||
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
|
||||
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
|
||||
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
|
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
# 纯英文的语音编辑
|
||||
# 样例为把 p243_new 对应的原始语音: For that reason cover should not be given.编辑成 'for that reason cover is impossible to be given.' 对应的语音
|
||||
# NOTE: 语音编辑任务暂支持句子中 1 个位置的替换或者插入文本操作
|
||||
|
||||
python local/inference_new.py \
|
||||
--task_name=edit \
|
||||
--model_name=paddle_checkpoint_en \
|
||||
--uid=p243_new \
|
||||
--new_str='for that reason cover is impossible to be given.' \
|
||||
--prefix='./prompt/dev/' \
|
||||
--source_lang=english \
|
||||
--target_lang=english \
|
||||
--output_name=pred_edit.wav \
|
||||
--voc=pwgan_aishell3 \
|
||||
--voc_config=download/pwg_aishell3_ckpt_0.5/default.yaml \
|
||||
--voc_ckpt=download/pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||
--voc_stat=download/pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||
--am=fastspeech2_ljspeech \
|
||||
--am_config=download/fastspeech2_nosil_ljspeech_ckpt_0.5/default.yaml \
|
||||
--am_ckpt=download/fastspeech2_nosil_ljspeech_ckpt_0.5/snapshot_iter_100000.pdz \
|
||||
--am_stat=download/fastspeech2_nosil_ljspeech_ckpt_0.5/speech_stats.npy \
|
||||
--phones_dict=download/fastspeech2_nosil_ljspeech_ckpt_0.5/phone_id_map.txt
|
@ -0,0 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
rm -rf *.wav
|
||||
./run_sedit_en_new.sh # 语音编辑任务(英文)
|
||||
./run_gen_en_new.sh # 个性化语音合成任务(英文)
|
||||
./run_clone_en_to_zh_new.sh # 跨语言语音合成任务(英文到中文的语音克隆)
|
@ -0,0 +1,162 @@
|
||||
###########################################################
|
||||
# FEATURE EXTRACTION SETTING #
|
||||
###########################################################
|
||||
|
||||
fs: 24000 # sr
|
||||
n_fft: 2048 # FFT size (samples).
|
||||
n_shift: 300 # Hop size (samples). 12.5ms
|
||||
win_length: 1200 # Window length (samples). 50ms
|
||||
# If set to null, it will be the same as fft_size.
|
||||
window: "hann" # Window function.
|
||||
|
||||
# Only used for feats_type != raw
|
||||
|
||||
fmin: 80 # Minimum frequency of Mel basis.
|
||||
fmax: 7600 # Maximum frequency of Mel basis.
|
||||
n_mels: 80 # The number of mel basis.
|
||||
|
||||
mean_phn_span: 8
|
||||
mlm_prob: 0.8
|
||||
|
||||
###########################################################
|
||||
# DATA SETTING #
|
||||
###########################################################
|
||||
batch_size: 20
|
||||
num_workers: 2
|
||||
|
||||
###########################################################
|
||||
# MODEL SETTING #
|
||||
###########################################################
|
||||
model:
|
||||
text_masking: false
|
||||
postnet_layers: 5
|
||||
postnet_filts: 5
|
||||
postnet_chans: 256
|
||||
encoder_type: conformer
|
||||
decoder_type: conformer
|
||||
enc_input_layer: sega_mlm
|
||||
enc_pre_speech_layer: 0
|
||||
enc_cnn_module_kernel: 7
|
||||
enc_attention_dim: 384
|
||||
enc_attention_heads: 2
|
||||
enc_linear_units: 1536
|
||||
enc_num_blocks: 4
|
||||
enc_dropout_rate: 0.2
|
||||
enc_positional_dropout_rate: 0.2
|
||||
enc_attention_dropout_rate: 0.2
|
||||
enc_normalize_before: true
|
||||
enc_macaron_style: true
|
||||
enc_use_cnn_module: true
|
||||
enc_selfattention_layer_type: legacy_rel_selfattn
|
||||
enc_activation_type: swish
|
||||
enc_pos_enc_layer_type: legacy_rel_pos
|
||||
enc_positionwise_layer_type: conv1d
|
||||
enc_positionwise_conv_kernel_size: 3
|
||||
dec_cnn_module_kernel: 31
|
||||
dec_attention_dim: 384
|
||||
dec_attention_heads: 2
|
||||
dec_linear_units: 1536
|
||||
dec_num_blocks: 4
|
||||
dec_dropout_rate: 0.2
|
||||
dec_positional_dropout_rate: 0.2
|
||||
dec_attention_dropout_rate: 0.2
|
||||
dec_macaron_style: true
|
||||
dec_use_cnn_module: true
|
||||
dec_selfattention_layer_type: legacy_rel_selfattn
|
||||
dec_activation_type: swish
|
||||
dec_pos_enc_layer_type: legacy_rel_pos
|
||||
dec_positionwise_layer_type: conv1d
|
||||
dec_positionwise_conv_kernel_size: 3
|
||||
|
||||
###########################################################
|
||||
# OPTIMIZER SETTING #
|
||||
###########################################################
|
||||
optimizer:
|
||||
optim: adam # optimizer type
|
||||
learning_rate: 0.001 # learning rate
|
||||
|
||||
###########################################################
|
||||
# TRAINING SETTING #
|
||||
###########################################################
|
||||
max_epoch: 200
|
||||
num_snapshots: 5
|
||||
|
||||
###########################################################
|
||||
# OTHER SETTING #
|
||||
###########################################################
|
||||
seed: 10086
|
||||
|
||||
token_list:
|
||||
- <blank>
|
||||
- <unk>
|
||||
- AH0
|
||||
- T
|
||||
- N
|
||||
- sp
|
||||
- D
|
||||
- S
|
||||
- R
|
||||
- L
|
||||
- IH1
|
||||
- DH
|
||||
- AE1
|
||||
- M
|
||||
- EH1
|
||||
- K
|
||||
- Z
|
||||
- W
|
||||
- HH
|
||||
- ER0
|
||||
- AH1
|
||||
- IY1
|
||||
- P
|
||||
- V
|
||||
- F
|
||||
- B
|
||||
- AY1
|
||||
- IY0
|
||||
- EY1
|
||||
- AA1
|
||||
- AO1
|
||||
- UW1
|
||||
- IH0
|
||||
- OW1
|
||||
- NG
|
||||
- G
|
||||
- SH
|
||||
- ER1
|
||||
- Y
|
||||
- TH
|
||||
- AW1
|
||||
- CH
|
||||
- UH1
|
||||
- IH2
|
||||
- JH
|
||||
- OW0
|
||||
- EH2
|
||||
- OY1
|
||||
- AY2
|
||||
- EH0
|
||||
- EY2
|
||||
- UW0
|
||||
- AE2
|
||||
- AA2
|
||||
- OW2
|
||||
- AH2
|
||||
- ZH
|
||||
- AO2
|
||||
- IY2
|
||||
- AE0
|
||||
- UW2
|
||||
- AY0
|
||||
- AA0
|
||||
- AO0
|
||||
- AW2
|
||||
- EY0
|
||||
- UH2
|
||||
- ER2
|
||||
- OY2
|
||||
- UH0
|
||||
- AW0
|
||||
- OY0
|
||||
- <sos/eos>
|
@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
config_path=$1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# get durations from MFA's result
|
||||
echo "Generate durations.txt from MFA results ..."
|
||||
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||
--inputdir=./vctk_alignment \
|
||||
--output durations.txt \
|
||||
--config=${config_path}
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# extract features
|
||||
echo "Extract features ..."
|
||||
python3 ${BIN_DIR}/preprocess.py \
|
||||
--dataset=vctk \
|
||||
--rootdir=~/datasets/VCTK-Corpus-0.92/ \
|
||||
--dumpdir=dump \
|
||||
--dur-file=durations.txt \
|
||||
--config=${config_path} \
|
||||
--num-cpu=20 \
|
||||
--cut-sil=True
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# get features' stats(mean and std)
|
||||
echo "Get features' stats ..."
|
||||
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--field-name="speech"
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||
# normalize and covert phone/speaker to id, dev and test should use train's stats
|
||||
echo "Normalize ..."
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/train/raw/metadata.jsonl \
|
||||
--dumpdir=dump/train/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/dev/raw/metadata.jsonl \
|
||||
--dumpdir=dump/dev/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
|
||||
python3 ${BIN_DIR}/normalize.py \
|
||||
--metadata=dump/test/raw/metadata.jsonl \
|
||||
--dumpdir=dump/test/norm \
|
||||
--speech-stats=dump/train/speech_stats.npy \
|
||||
--phones-dict=dump/phone_id_map.txt \
|
||||
--speaker-dict=dump/speaker_id_map.txt
|
||||
fi
|
@ -0,0 +1 @@
|
||||
#!/bin/bash
|
@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
config_path=$1
|
||||
train_output_path=$2
|
||||
|
||||
python3 ${BIN_DIR}/train.py \
|
||||
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||
--config=${config_path} \
|
||||
--output-dir=${train_output_path} \
|
||||
--ngpu=1 \
|
||||
--phones-dict=dump/phone_id_map.txt
|
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||
|
||||
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||
export LC_ALL=C
|
||||
|
||||
export PYTHONDONTWRITEBYTECODE=1
|
||||
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||
|
||||
MODEL=ernie_sat
|
||||
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
source path.sh
|
||||
|
||||
gpus=0,1
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
conf_path=conf/default.yaml
|
||||
train_output_path=exp/default
|
||||
ckpt_name=snapshot_iter_153.pdz
|
||||
|
||||
# with the following command, you can choose the stage range you want to run
|
||||
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||
# this can not be mixed use with `$1`, `$2` ...
|
||||
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||
|
||||
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||
# prepare data
|
||||
./local/preprocess.sh ${conf_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||
fi
|
||||
|
||||
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||
# synthesize, vocoder is pwgan
|
||||
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||
fi
|
@ -0,0 +1,130 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Normalize feature files and dump them."""
|
||||
import argparse
|
||||
import logging
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from tqdm import tqdm
|
||||
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory including feature files to be normalized. "
|
||||
"you need to specify either *-scp or rootdir.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump normalized feature files.")
|
||||
parser.add_argument(
|
||||
"--speech-stats",
|
||||
type=str,
|
||||
required=True,
|
||||
help="speech statistics file.")
|
||||
parser.add_argument(
|
||||
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
|
||||
parser.add_argument(
|
||||
"--speaker-dict", type=str, default=None, help="speaker id map file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dumpdir = Path(args.dumpdir).expanduser()
|
||||
# use absolute path
|
||||
dumpdir = dumpdir.resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# get dataset
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
dataset = DataTable(
|
||||
metadata, converters={
|
||||
"speech": np.load,
|
||||
})
|
||||
logging.info(f"The number of files = {len(dataset)}.")
|
||||
|
||||
# restore scaler
|
||||
speech_scaler = StandardScaler()
|
||||
speech_scaler.mean_ = np.load(args.speech_stats)[0]
|
||||
speech_scaler.scale_ = np.load(args.speech_stats)[1]
|
||||
speech_scaler.n_features_in_ = speech_scaler.mean_.shape[0]
|
||||
|
||||
vocab_phones = {}
|
||||
with open(args.phones_dict, 'rt') as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
for phn, id in phn_id:
|
||||
vocab_phones[phn] = int(id)
|
||||
|
||||
vocab_speaker = {}
|
||||
with open(args.speaker_dict, 'rt') as f:
|
||||
spk_id = [line.strip().split() for line in f.readlines()]
|
||||
for spk, id in spk_id:
|
||||
vocab_speaker[spk] = int(id)
|
||||
|
||||
# process each file
|
||||
output_metadata = []
|
||||
|
||||
for item in tqdm(dataset):
|
||||
utt_id = item['utt_id']
|
||||
speech = item['speech']
|
||||
|
||||
# normalize
|
||||
speech = speech_scaler.transform(speech)
|
||||
speech_dir = dumpdir / "data_speech"
|
||||
speech_dir.mkdir(parents=True, exist_ok=True)
|
||||
speech_path = speech_dir / f"{utt_id}_speech.npy"
|
||||
np.save(speech_path, speech.astype(np.float32), allow_pickle=False)
|
||||
|
||||
phone_ids = [vocab_phones[p] for p in item['phones']]
|
||||
spk_id = vocab_speaker[item["speaker"]]
|
||||
record = {
|
||||
"utt_id": item['utt_id'],
|
||||
"spk_id": spk_id,
|
||||
"text": phone_ids,
|
||||
"text_lengths": item['text_lengths'],
|
||||
"speech_lengths": item['speech_lengths'],
|
||||
"durations": item['durations'],
|
||||
"speech": str(speech_path),
|
||||
"align_start": item['align_start'],
|
||||
"align_end": item['align_end'],
|
||||
}
|
||||
# add spk_emb for voice cloning
|
||||
if "spk_emb" in item:
|
||||
record["spk_emb"] = str(item["spk_emb"])
|
||||
|
||||
output_metadata.append(record)
|
||||
output_metadata.sort(key=itemgetter('utt_id'))
|
||||
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
||||
with jsonlines.open(output_metadata_path, 'w') as writer:
|
||||
for item in output_metadata:
|
||||
writer.write(item)
|
||||
logging.info(f"metadata dumped into {output_metadata_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,341 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
import jsonlines
|
||||
import librosa
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import yaml
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.get_feats import LogMelFBank
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import compare_duration_and_mel_length
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_input_token
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_phn_dur
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import get_spk_id_map
|
||||
from paddlespeech.t2s.datasets.preprocess_utils import merge_silence
|
||||
from paddlespeech.t2s.utils import str2bool
|
||||
|
||||
|
||||
def process_sentence(config: Dict[str, Any],
|
||||
fp: Path,
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
cut_sil: bool=True,
|
||||
spk_emb_dir: Path=None):
|
||||
utt_id = fp.stem
|
||||
# for vctk
|
||||
if utt_id.endswith("_mic2"):
|
||||
utt_id = utt_id[:-5]
|
||||
record = None
|
||||
if utt_id in sentences:
|
||||
# reading, resampling may occur
|
||||
wav, _ = librosa.load(str(fp), sr=config.fs)
|
||||
if len(wav.shape) != 1:
|
||||
return record
|
||||
max_value = np.abs(wav).max()
|
||||
if max_value > 1.0:
|
||||
wav = wav / max_value
|
||||
assert len(wav.shape) == 1, f"{utt_id} is not a mono-channel audio."
|
||||
assert np.abs(wav).max(
|
||||
) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM."
|
||||
phones = sentences[utt_id][0]
|
||||
durations = sentences[utt_id][1]
|
||||
speaker = sentences[utt_id][2]
|
||||
d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
|
||||
|
||||
# little imprecise than use *.TextGrid directly
|
||||
times = librosa.frames_to_time(
|
||||
d_cumsum, sr=config.fs, hop_length=config.n_shift)
|
||||
if cut_sil:
|
||||
start = 0
|
||||
end = d_cumsum[-1]
|
||||
if phones[0] == "sil" and len(durations) > 1:
|
||||
start = times[1]
|
||||
durations = durations[1:]
|
||||
phones = phones[1:]
|
||||
if phones[-1] == 'sil' and len(durations) > 1:
|
||||
end = times[-2]
|
||||
durations = durations[:-1]
|
||||
phones = phones[:-1]
|
||||
sentences[utt_id][0] = phones
|
||||
sentences[utt_id][1] = durations
|
||||
start, end = librosa.time_to_samples([start, end], sr=config.fs)
|
||||
wav = wav[start:end]
|
||||
|
||||
# extract mel feats
|
||||
logmel = mel_extractor.get_log_mel_fbank(wav)
|
||||
# change duration according to mel_length
|
||||
compare_duration_and_mel_length(sentences, utt_id, logmel)
|
||||
# utt_id may be popped in compare_duration_and_mel_length
|
||||
if utt_id not in sentences:
|
||||
return None
|
||||
phones = sentences[utt_id][0]
|
||||
durations = sentences[utt_id][1]
|
||||
num_frames = logmel.shape[0]
|
||||
assert sum(durations) == num_frames
|
||||
|
||||
new_d_cumsum = np.pad(np.array(durations).cumsum(0), (1, 0), 'constant')
|
||||
align_start = new_d_cumsum[:-1]
|
||||
align_end = new_d_cumsum[1:]
|
||||
assert len(align_start) == len(align_end) == len(durations)
|
||||
|
||||
mel_dir = output_dir / "data_speech"
|
||||
mel_dir.mkdir(parents=True, exist_ok=True)
|
||||
mel_path = mel_dir / (utt_id + "_speech.npy")
|
||||
np.save(mel_path, logmel)
|
||||
# align_start_lengths == text_lengths
|
||||
record = {
|
||||
"utt_id": utt_id,
|
||||
"phones": phones,
|
||||
"text_lengths": len(phones),
|
||||
"speech_lengths": num_frames,
|
||||
"durations": durations,
|
||||
"speech": str(mel_path),
|
||||
"speaker": speaker,
|
||||
"align_start": align_start.tolist(),
|
||||
"align_end": align_end.tolist(),
|
||||
}
|
||||
if spk_emb_dir:
|
||||
if speaker in os.listdir(spk_emb_dir):
|
||||
embed_name = utt_id + ".npy"
|
||||
embed_path = spk_emb_dir / speaker / embed_name
|
||||
if embed_path.is_file():
|
||||
record["spk_emb"] = str(embed_path)
|
||||
else:
|
||||
return None
|
||||
return record
|
||||
|
||||
|
||||
def process_sentences(config,
|
||||
fps: List[Path],
|
||||
sentences: Dict,
|
||||
output_dir: Path,
|
||||
mel_extractor=None,
|
||||
nprocs: int=1,
|
||||
cut_sil: bool=True,
|
||||
spk_emb_dir: Path=None):
|
||||
if nprocs == 1:
|
||||
results = []
|
||||
for fp in tqdm.tqdm(fps, total=len(fps)):
|
||||
record = process_sentence(
|
||||
config=config,
|
||||
fp=fp,
|
||||
sentences=sentences,
|
||||
output_dir=output_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
cut_sil=cut_sil,
|
||||
spk_emb_dir=spk_emb_dir)
|
||||
if record:
|
||||
results.append(record)
|
||||
else:
|
||||
with ThreadPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp in fps:
|
||||
future = pool.submit(process_sentence, config, fp,
|
||||
sentences, output_dir, mel_extractor,
|
||||
cut_sil, spk_emb_dir)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
results = []
|
||||
for ft in futures:
|
||||
record = ft.result()
|
||||
if record:
|
||||
results.append(record)
|
||||
|
||||
results.sort(key=itemgetter("utt_id"))
|
||||
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
||||
for item in results:
|
||||
writer.write(item)
|
||||
print("Done")
|
||||
|
||||
|
||||
def main():
|
||||
# parse config and args
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocess audio and then extract features.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default="baker",
|
||||
type=str,
|
||||
help="name of dataset, should in {baker, aishell3, ljspeech, vctk} now")
|
||||
|
||||
parser.add_argument(
|
||||
"--rootdir", default=None, type=str, help="directory to dataset.")
|
||||
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump feature files.")
|
||||
parser.add_argument(
|
||||
"--dur-file", default=None, type=str, help="path to durations.txt.")
|
||||
|
||||
parser.add_argument("--config", type=str, help="fastspeech2 config file.")
|
||||
|
||||
parser.add_argument(
|
||||
"--num-cpu", type=int, default=1, help="number of process.")
|
||||
|
||||
parser.add_argument(
|
||||
"--cut-sil",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="whether cut sil in the edge of audio")
|
||||
|
||||
parser.add_argument(
|
||||
"--spk_emb_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory to speaker embedding files.")
|
||||
args = parser.parse_args()
|
||||
|
||||
rootdir = Path(args.rootdir).expanduser()
|
||||
dumpdir = Path(args.dumpdir).expanduser()
|
||||
# use absolute path
|
||||
dumpdir = dumpdir.resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
dur_file = Path(args.dur_file).expanduser()
|
||||
|
||||
if args.spk_emb_dir:
|
||||
spk_emb_dir = Path(args.spk_emb_dir).expanduser().resolve()
|
||||
else:
|
||||
spk_emb_dir = None
|
||||
|
||||
assert rootdir.is_dir()
|
||||
assert dur_file.is_file()
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
sentences, speaker_set = get_phn_dur(dur_file)
|
||||
|
||||
merge_silence(sentences)
|
||||
phone_id_map_path = dumpdir / "phone_id_map.txt"
|
||||
speaker_id_map_path = dumpdir / "speaker_id_map.txt"
|
||||
get_input_token(sentences, phone_id_map_path, args.dataset)
|
||||
get_spk_id_map(speaker_set, speaker_id_map_path)
|
||||
|
||||
if args.dataset == "baker":
|
||||
wav_files = sorted(list((rootdir / "Wave").rglob("*.wav")))
|
||||
# split data into 3 sections
|
||||
num_train = 9800
|
||||
num_dev = 100
|
||||
train_wav_files = wav_files[:num_train]
|
||||
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||
test_wav_files = wav_files[num_train + num_dev:]
|
||||
elif args.dataset == "aishell3":
|
||||
sub_num_dev = 5
|
||||
wav_dir = rootdir / "train" / "wav"
|
||||
train_wav_files = []
|
||||
dev_wav_files = []
|
||||
test_wav_files = []
|
||||
for speaker in os.listdir(wav_dir):
|
||||
wav_files = sorted(list((wav_dir / speaker).rglob("*.wav")))
|
||||
if len(wav_files) > 100:
|
||||
train_wav_files += wav_files[:-sub_num_dev * 2]
|
||||
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
|
||||
test_wav_files += wav_files[-sub_num_dev:]
|
||||
else:
|
||||
train_wav_files += wav_files
|
||||
|
||||
elif args.dataset == "ljspeech":
|
||||
wav_files = sorted(list((rootdir / "wavs").rglob("*.wav")))
|
||||
# split data into 3 sections
|
||||
num_train = 12900
|
||||
num_dev = 100
|
||||
train_wav_files = wav_files[:num_train]
|
||||
dev_wav_files = wav_files[num_train:num_train + num_dev]
|
||||
test_wav_files = wav_files[num_train + num_dev:]
|
||||
elif args.dataset == "vctk":
|
||||
sub_num_dev = 5
|
||||
wav_dir = rootdir / "wav48_silence_trimmed"
|
||||
train_wav_files = []
|
||||
dev_wav_files = []
|
||||
test_wav_files = []
|
||||
for speaker in os.listdir(wav_dir):
|
||||
wav_files = sorted(list((wav_dir / speaker).rglob("*_mic2.flac")))
|
||||
if len(wav_files) > 100:
|
||||
train_wav_files += wav_files[:-sub_num_dev * 2]
|
||||
dev_wav_files += wav_files[-sub_num_dev * 2:-sub_num_dev]
|
||||
test_wav_files += wav_files[-sub_num_dev:]
|
||||
else:
|
||||
train_wav_files += wav_files
|
||||
|
||||
else:
|
||||
print("dataset should in {baker, aishell3, ljspeech, vctk} now!")
|
||||
|
||||
train_dump_dir = dumpdir / "train" / "raw"
|
||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
test_dump_dir = dumpdir / "test" / "raw"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extractor
|
||||
mel_extractor = LogMelFBank(
|
||||
sr=config.fs,
|
||||
n_fft=config.n_fft,
|
||||
hop_length=config.n_shift,
|
||||
win_length=config.win_length,
|
||||
window=config.window,
|
||||
n_mels=config.n_mels,
|
||||
fmin=config.fmin,
|
||||
fmax=config.fmax)
|
||||
|
||||
# process for the 3 sections
|
||||
if train_wav_files:
|
||||
process_sentences(
|
||||
config=config,
|
||||
fps=train_wav_files,
|
||||
sentences=sentences,
|
||||
output_dir=train_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir)
|
||||
if dev_wav_files:
|
||||
process_sentences(
|
||||
config=config,
|
||||
fps=dev_wav_files,
|
||||
sentences=sentences,
|
||||
output_dir=dev_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir)
|
||||
if test_wav_files:
|
||||
process_sentences(
|
||||
config=config,
|
||||
fps=test_wav_files,
|
||||
sentences=sentences,
|
||||
output_dir=test_dump_dir,
|
||||
mel_extractor=mel_extractor,
|
||||
nprocs=args.num_cpu,
|
||||
cut_sil=args.cut_sil,
|
||||
spk_emb_dir=spk_emb_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,194 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import paddle
|
||||
import yaml
|
||||
from paddle import DataParallel
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from yacs.config import CfgNode
|
||||
|
||||
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
|
||||
from paddlespeech.t2s.datasets.data_table import DataTable
|
||||
from paddlespeech.t2s.models.ernie_sat import ErnieSAT
|
||||
from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator
|
||||
from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater
|
||||
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
|
||||
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
|
||||
from paddlespeech.t2s.training.optimizer import build_optimizers
|
||||
from paddlespeech.t2s.training.seeding import seed_everything
|
||||
from paddlespeech.t2s.training.trainer import Trainer
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
# decides device type and whether to run in parallel
|
||||
# setup running environment correctly
|
||||
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
|
||||
paddle.set_device("cpu")
|
||||
else:
|
||||
paddle.set_device("gpu")
|
||||
world_size = paddle.distributed.get_world_size()
|
||||
if world_size > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
# set the random seed, it is a must for multiprocess training
|
||||
seed_everything(config.seed)
|
||||
|
||||
print(
|
||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
)
|
||||
fields = [
|
||||
"text", "text_lengths", "speech", "speech_lengths", "align_start",
|
||||
"align_end"
|
||||
]
|
||||
converters = {"speech": np.load}
|
||||
spk_num = None
|
||||
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
# construct dataset for training and validation
|
||||
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||
train_metadata = list(reader)
|
||||
train_dataset = DataTable(
|
||||
data=train_metadata,
|
||||
fields=fields,
|
||||
converters=converters, )
|
||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||
dev_metadata = list(reader)
|
||||
dev_dataset = DataTable(
|
||||
data=dev_metadata,
|
||||
fields=fields,
|
||||
converters=converters, )
|
||||
|
||||
# collate function and dataloader
|
||||
collate_fn = build_erniesat_collate_fn(
|
||||
mlm_prob=config.mlm_prob,
|
||||
mean_phn_span=config.mean_phn_span,
|
||||
seg_emb=config.model['enc_input_layer'] == 'sega_mlm',
|
||||
text_masking=config["model"]["text_masking"],
|
||||
epoch=config["max_epoch"])
|
||||
|
||||
train_sampler = DistributedBatchSampler(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
|
||||
print("samplers done!")
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=config.num_workers)
|
||||
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
batch_size=config.batch_size,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
with open(args.phones_dict, "r") as f:
|
||||
phn_id = [line.strip().split() for line in f.readlines()]
|
||||
vocab_size = len(phn_id)
|
||||
print("vocab_size:", vocab_size)
|
||||
|
||||
odim = config.n_mels
|
||||
model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"])
|
||||
|
||||
if world_size > 1:
|
||||
model = DataParallel(model)
|
||||
print("model done!")
|
||||
|
||||
optimizer = build_optimizers(model, **config["optimizer"])
|
||||
print("optimizer done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if dist.get_rank() == 0:
|
||||
config_name = args.config.split("/")[-1]
|
||||
# copy conf to output_dir
|
||||
shutil.copyfile(args.config, output_dir / config_name)
|
||||
|
||||
updater = ErnieSATUpdater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
dataloader=train_dataloader,
|
||||
text_masking=config["model"]["text_masking"],
|
||||
odim=odim,
|
||||
output_dir=output_dir)
|
||||
|
||||
trainer = Trainer(updater, (config.max_epoch, 'epoch'), output_dir)
|
||||
|
||||
evaluator = ErnieSATEvaluator(
|
||||
model=model,
|
||||
dataloader=dev_dataloader,
|
||||
text_masking=config["model"]["text_masking"],
|
||||
odim=odim,
|
||||
output_dir=output_dir, )
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
trainer.extend(evaluator, trigger=(1, "epoch"))
|
||||
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
|
||||
trainer.extend(
|
||||
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||
trainer.run()
|
||||
|
||||
|
||||
def main():
|
||||
# parse args and config and redirect to train_sp
|
||||
parser = argparse.ArgumentParser(description="Train an ErnieSAT model.")
|
||||
parser.add_argument("--config", type=str, help="ErnieSAT config file.")
|
||||
parser.add_argument("--train-metadata", type=str, help="training data.")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data.")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||
parser.add_argument(
|
||||
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
|
||||
parser.add_argument(
|
||||
"--phones-dict", type=str, default=None, help="phone vocabulary file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.config) as f:
|
||||
config = CfgNode(yaml.safe_load(f))
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
print(config)
|
||||
print(
|
||||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||
)
|
||||
|
||||
# dispatch
|
||||
if args.ngpu > 1:
|
||||
dist.spawn(train_sp, (args, config), nprocs=args.ngpu)
|
||||
else:
|
||||
train_sp(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,670 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
from paddlespeech.t2s.modules.activation import get_activation
|
||||
from paddlespeech.t2s.modules.conformer.convolution import ConvolutionModule
|
||||
from paddlespeech.t2s.modules.conformer.encoder_layer import EncoderLayer
|
||||
from paddlespeech.t2s.modules.layer_norm import LayerNorm
|
||||
from paddlespeech.t2s.modules.masked_fill import masked_fill
|
||||
from paddlespeech.t2s.modules.nets_utils import initialize
|
||||
from paddlespeech.t2s.modules.tacotron2.decoder import Postnet
|
||||
from paddlespeech.t2s.modules.transformer.attention import LegacyRelPositionMultiHeadedAttention
|
||||
from paddlespeech.t2s.modules.transformer.attention import MultiHeadedAttention
|
||||
from paddlespeech.t2s.modules.transformer.attention import RelPositionMultiHeadedAttention
|
||||
from paddlespeech.t2s.modules.transformer.embedding import LegacyRelPositionalEncoding
|
||||
from paddlespeech.t2s.modules.transformer.embedding import PositionalEncoding
|
||||
from paddlespeech.t2s.modules.transformer.embedding import RelPositionalEncoding
|
||||
from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
|
||||
from paddlespeech.t2s.modules.transformer.multi_layer_conv import Conv1dLinear
|
||||
from paddlespeech.t2s.modules.transformer.multi_layer_conv import MultiLayeredConv1d
|
||||
from paddlespeech.t2s.modules.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from paddlespeech.t2s.modules.transformer.repeat import repeat
|
||||
from paddlespeech.t2s.modules.transformer.subsampling import Conv2dSubsampling
|
||||
|
||||
|
||||
# MLM -> Mask Language Model
|
||||
class mySequential(nn.Sequential):
|
||||
def forward(self, *inputs):
|
||||
for module in self._sub_layers.values():
|
||||
if type(inputs) == tuple:
|
||||
inputs = module(*inputs)
|
||||
else:
|
||||
inputs = module(inputs)
|
||||
return inputs
|
||||
|
||||
|
||||
class MaskInputLayer(nn.Layer):
|
||||
def __init__(self, out_features: int) -> None:
|
||||
super().__init__()
|
||||
self.mask_feature = paddle.create_parameter(
|
||||
shape=(1, 1, out_features),
|
||||
dtype=paddle.float32,
|
||||
default_initializer=paddle.nn.initializer.Assign(
|
||||
paddle.normal(shape=(1, 1, out_features))))
|
||||
|
||||
def forward(self, input: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor=None) -> paddle.Tensor:
|
||||
masked_pos = paddle.expand_as(paddle.unsqueeze(masked_pos, -1), input)
|
||||
masked_input = masked_fill(input, masked_pos, 0) + masked_fill(
|
||||
paddle.expand_as(self.mask_feature, input), ~masked_pos, 0)
|
||||
return masked_input
|
||||
|
||||
|
||||
class MLMEncoder(nn.Layer):
|
||||
"""Conformer encoder module.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
attention_dim (int): Dimension of attention.
|
||||
attention_heads (int): The number of heads of multi head attention.
|
||||
linear_units (int): The number of units of position-wise feed forward.
|
||||
num_blocks (int): The number of decoder blocks.
|
||||
dropout_rate (float): Dropout rate.
|
||||
positional_dropout_rate (float): Dropout rate after adding positional encoding.
|
||||
attention_dropout_rate (float): Dropout rate in attention.
|
||||
input_layer (Union[str, paddle.nn.Layer]): Input layer type.
|
||||
normalize_before (bool): Whether to use layer_norm before the first block.
|
||||
concat_after (bool): Whether to concat attention layer's input and output.
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
|
||||
macaron_style (bool): Whether to use macaron style for positionwise layer.
|
||||
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
selfattention_layer_type (str): Encoder attention layer type.
|
||||
activation_type (str): Encoder activation function type.
|
||||
use_cnn_module (bool): Whether to use convolution module.
|
||||
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
|
||||
cnn_module_kernel (int): Kernerl size of convolution module.
|
||||
padding_idx (int): Padding idx for input_layer=embed.
|
||||
stochastic_depth_rate (float): Maximum probability to skip the encoder layer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
idim: int,
|
||||
vocab_size: int=0,
|
||||
pre_speech_layer: int=0,
|
||||
attention_dim: int=256,
|
||||
attention_heads: int=4,
|
||||
linear_units: int=2048,
|
||||
num_blocks: int=6,
|
||||
dropout_rate: float=0.1,
|
||||
positional_dropout_rate: float=0.1,
|
||||
attention_dropout_rate: float=0.0,
|
||||
input_layer: str="conv2d",
|
||||
normalize_before: bool=True,
|
||||
concat_after: bool=False,
|
||||
positionwise_layer_type: str="linear",
|
||||
positionwise_conv_kernel_size: int=1,
|
||||
macaron_style: bool=False,
|
||||
pos_enc_layer_type: str="abs_pos",
|
||||
pos_enc_class=None,
|
||||
selfattention_layer_type: str="selfattn",
|
||||
activation_type: str="swish",
|
||||
use_cnn_module: bool=False,
|
||||
zero_triu: bool=False,
|
||||
cnn_module_kernel: int=31,
|
||||
padding_idx: int=-1,
|
||||
stochastic_depth_rate: float=0.0,
|
||||
text_masking: bool=False):
|
||||
"""Construct an Encoder object."""
|
||||
super().__init__()
|
||||
self._output_size = attention_dim
|
||||
self.text_masking = text_masking
|
||||
if self.text_masking:
|
||||
self.text_masking_layer = MaskInputLayer(attention_dim)
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
elif pos_enc_layer_type == "legacy_rel_pos":
|
||||
pos_enc_class = LegacyRelPositionalEncoding
|
||||
assert selfattention_layer_type == "legacy_rel_selfattn"
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
self.conv_subsampling_factor = 1
|
||||
if input_layer == "linear":
|
||||
self.embed = nn.Sequential(
|
||||
nn.Linear(idim, attention_dim),
|
||||
nn.LayerNorm(attention_dim),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
self.conv_subsampling_factor = 4
|
||||
elif input_layer == "embed":
|
||||
self.embed = nn.Sequential(
|
||||
nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
elif input_layer == "mlm":
|
||||
self.segment_emb = None
|
||||
self.speech_embed = mySequential(
|
||||
MaskInputLayer(idim),
|
||||
nn.Linear(idim, attention_dim),
|
||||
nn.LayerNorm(attention_dim),
|
||||
nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate))
|
||||
self.text_embed = nn.Sequential(
|
||||
nn.Embedding(
|
||||
vocab_size, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
elif input_layer == "sega_mlm":
|
||||
self.segment_emb = nn.Embedding(
|
||||
500, attention_dim, padding_idx=padding_idx)
|
||||
self.speech_embed = mySequential(
|
||||
MaskInputLayer(idim),
|
||||
nn.Linear(idim, attention_dim),
|
||||
nn.LayerNorm(attention_dim),
|
||||
nn.ReLU(),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate))
|
||||
self.text_embed = nn.Sequential(
|
||||
nn.Embedding(
|
||||
vocab_size, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
elif isinstance(input_layer, nn.Layer):
|
||||
self.embed = nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate), )
|
||||
elif input_layer is None:
|
||||
self.embed = nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate))
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
# self-attention module definition
|
||||
if selfattention_layer_type == "selfattn":
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (attention_heads, attention_dim,
|
||||
attention_dropout_rate, )
|
||||
elif selfattention_layer_type == "legacy_rel_selfattn":
|
||||
assert pos_enc_layer_type == "legacy_rel_pos"
|
||||
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (attention_heads, attention_dim,
|
||||
attention_dropout_rate, )
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (attention_heads, attention_dim,
|
||||
attention_dropout_rate, zero_triu, )
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " +
|
||||
selfattention_layer_type)
|
||||
|
||||
# feed-forward module definition
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (attention_dim, linear_units,
|
||||
dropout_rate, activation, )
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (attention_dim, linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate, )
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (attention_dim, linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate, )
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
# convolution module definition
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate * float(1 + lnum) / num_blocks, ), )
|
||||
self.pre_speech_layer = pre_speech_layer
|
||||
self.pre_speech_encoders = repeat(
|
||||
self.pre_speech_layer,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
stochastic_depth_rate * float(1 + lnum) / self.pre_speech_layer, ),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
def forward(self,
|
||||
speech: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor,
|
||||
speech_mask: paddle.Tensor=None,
|
||||
text_mask: paddle.Tensor=None,
|
||||
speech_seg_pos: paddle.Tensor=None,
|
||||
text_seg_pos: paddle.Tensor=None):
|
||||
"""Encode input sequence.
|
||||
|
||||
"""
|
||||
if masked_pos is not None:
|
||||
speech = self.speech_embed(speech, masked_pos)
|
||||
else:
|
||||
speech = self.speech_embed(speech)
|
||||
if text is not None:
|
||||
text = self.text_embed(text)
|
||||
if speech_seg_pos is not None and text_seg_pos is not None and self.segment_emb:
|
||||
speech_seg_emb = self.segment_emb(speech_seg_pos)
|
||||
text_seg_emb = self.segment_emb(text_seg_pos)
|
||||
text = (text[0] + text_seg_emb, text[1])
|
||||
speech = (speech[0] + speech_seg_emb, speech[1])
|
||||
if self.pre_speech_encoders:
|
||||
speech, _ = self.pre_speech_encoders(speech, speech_mask)
|
||||
|
||||
if text is not None:
|
||||
xs = paddle.concat([speech[0], text[0]], axis=1)
|
||||
xs_pos_emb = paddle.concat([speech[1], text[1]], axis=1)
|
||||
masks = paddle.concat([speech_mask, text_mask], axis=-1)
|
||||
else:
|
||||
xs = speech[0]
|
||||
xs_pos_emb = speech[1]
|
||||
masks = speech_mask
|
||||
|
||||
xs, masks = self.encoders((xs, xs_pos_emb), masks)
|
||||
|
||||
if isinstance(xs, tuple):
|
||||
xs = xs[0]
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
|
||||
return xs, masks
|
||||
|
||||
|
||||
class MLMDecoder(MLMEncoder):
|
||||
def forward(self, xs: paddle.Tensor, masks: paddle.Tensor):
|
||||
"""Encode input sequence.
|
||||
|
||||
Args:
|
||||
xs (paddle.Tensor): Input tensor (#batch, time, idim).
|
||||
masks (paddle.Tensor): Mask tensor (#batch, time).
|
||||
|
||||
Returns:
|
||||
paddle.Tensor: Output tensor (#batch, time, attention_dim).
|
||||
paddle.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
xs = self.embed(xs)
|
||||
xs, masks = self.encoders(xs, masks)
|
||||
|
||||
if isinstance(xs, tuple):
|
||||
xs = xs[0]
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
|
||||
return xs, masks
|
||||
|
||||
|
||||
# encoder and decoder is nn.Layer, not str
|
||||
class MLM(nn.Layer):
|
||||
def __init__(self,
|
||||
odim: int,
|
||||
encoder: nn.Layer,
|
||||
decoder: Optional[nn.Layer],
|
||||
postnet_layers: int=0,
|
||||
postnet_chans: int=0,
|
||||
postnet_filts: int=0,
|
||||
text_masking: bool=False):
|
||||
|
||||
super().__init__()
|
||||
self.odim = odim
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.vocab_size = encoder.text_embed[0]._num_embeddings
|
||||
|
||||
if self.decoder is None or not (hasattr(self.decoder,
|
||||
'output_layer') and
|
||||
self.decoder.output_layer is not None):
|
||||
self.sfc = nn.Linear(self.encoder._output_size, odim)
|
||||
else:
|
||||
self.sfc = None
|
||||
if text_masking:
|
||||
self.text_sfc = nn.Linear(
|
||||
self.encoder.text_embed[0]._embedding_dim,
|
||||
self.vocab_size,
|
||||
weight_attr=self.encoder.text_embed[0]._weight_attr)
|
||||
else:
|
||||
self.text_sfc = None
|
||||
|
||||
self.postnet = (None if postnet_layers == 0 else Postnet(
|
||||
idim=self.encoder._output_size,
|
||||
odim=odim,
|
||||
n_layers=postnet_layers,
|
||||
n_chans=postnet_chans,
|
||||
n_filts=postnet_filts,
|
||||
use_batch_norm=True,
|
||||
dropout_rate=0.5, ))
|
||||
|
||||
def inference(
|
||||
self,
|
||||
speech: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor,
|
||||
speech_mask: paddle.Tensor,
|
||||
text_mask: paddle.Tensor,
|
||||
speech_seg_pos: paddle.Tensor,
|
||||
text_seg_pos: paddle.Tensor,
|
||||
span_bdy: List[int],
|
||||
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
|
||||
'''
|
||||
Args:
|
||||
speech (paddle.Tensor): input speech (1, Tmax, D).
|
||||
text (paddle.Tensor): input text (1, Tmax2).
|
||||
masked_pos (paddle.Tensor): masked position of input speech (1, Tmax)
|
||||
speech_mask (paddle.Tensor): mask of speech (1, 1, Tmax).
|
||||
text_mask (paddle.Tensor): mask of text (1, 1, Tmax2).
|
||||
speech_seg_pos (paddle.Tensor): n-th phone of each mel, 0<=n<=Tmax2 (1, Tmax).
|
||||
text_seg_pos (paddle.Tensor): n-th phone of each phone, 0<=n<=Tmax2 (1, Tmax2).
|
||||
span_bdy (List[int]): masked mel boundary of input speech (2,)
|
||||
use_teacher_forcing (bool): whether to use teacher forcing
|
||||
Returns:
|
||||
List[Tensor]:
|
||||
eg:
|
||||
[Tensor(shape=[1, 181, 80]), Tensor(shape=[80, 80]), Tensor(shape=[1, 67, 80])]
|
||||
'''
|
||||
|
||||
z_cache = None
|
||||
if use_teacher_forcing:
|
||||
before_outs, zs, *_ = self.forward(
|
||||
speech=speech,
|
||||
text=text,
|
||||
masked_pos=masked_pos,
|
||||
speech_mask=speech_mask,
|
||||
text_mask=text_mask,
|
||||
speech_seg_pos=speech_seg_pos,
|
||||
text_seg_pos=text_seg_pos)
|
||||
if zs is None:
|
||||
zs = before_outs
|
||||
|
||||
speech = speech.squeeze(0)
|
||||
outs = [speech[:span_bdy[0]]]
|
||||
outs += [zs[0][span_bdy[0]:span_bdy[1]]]
|
||||
outs += [speech[span_bdy[1]:]]
|
||||
return outs
|
||||
return None
|
||||
|
||||
|
||||
class MLMEncAsDecoder(MLM):
|
||||
def forward(self,
|
||||
speech: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor,
|
||||
speech_mask: paddle.Tensor,
|
||||
text_mask: paddle.Tensor,
|
||||
speech_seg_pos: paddle.Tensor,
|
||||
text_seg_pos: paddle.Tensor):
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
encoder_out, h_masks = self.encoder(
|
||||
speech=speech,
|
||||
text=text,
|
||||
masked_pos=masked_pos,
|
||||
speech_mask=speech_mask,
|
||||
text_mask=text_mask,
|
||||
speech_seg_pos=speech_seg_pos,
|
||||
text_seg_pos=text_seg_pos)
|
||||
if self.decoder is not None:
|
||||
zs, _ = self.decoder(encoder_out, h_masks)
|
||||
else:
|
||||
zs = encoder_out
|
||||
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
|
||||
if self.sfc is not None:
|
||||
before_outs = paddle.reshape(
|
||||
self.sfc(speech_hidden_states),
|
||||
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
|
||||
else:
|
||||
before_outs = speech_hidden_states
|
||||
if self.postnet is not None:
|
||||
after_outs = before_outs + paddle.transpose(
|
||||
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
|
||||
[0, 2, 1])
|
||||
else:
|
||||
after_outs = None
|
||||
return before_outs, after_outs, None
|
||||
|
||||
|
||||
class MLMDualMaksing(MLM):
|
||||
def forward(self,
|
||||
speech: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor,
|
||||
speech_mask: paddle.Tensor,
|
||||
text_mask: paddle.Tensor,
|
||||
speech_seg_pos: paddle.Tensor,
|
||||
text_seg_pos: paddle.Tensor):
|
||||
# feats: (Batch, Length, Dim)
|
||||
# -> encoder_out: (Batch, Length2, Dim2)
|
||||
encoder_out, h_masks = self.encoder(
|
||||
speech=speech,
|
||||
text=text,
|
||||
masked_pos=masked_pos,
|
||||
speech_mask=speech_mask,
|
||||
text_mask=text_mask,
|
||||
speech_seg_pos=speech_seg_pos,
|
||||
text_seg_pos=text_seg_pos)
|
||||
if self.decoder is not None:
|
||||
zs, _ = self.decoder(encoder_out, h_masks)
|
||||
else:
|
||||
zs = encoder_out
|
||||
speech_hidden_states = zs[:, :paddle.shape(speech)[1], :]
|
||||
if self.text_sfc:
|
||||
text_hiddent_states = zs[:, paddle.shape(speech)[1]:, :]
|
||||
text_outs = paddle.reshape(
|
||||
self.text_sfc(text_hiddent_states),
|
||||
(paddle.shape(text_hiddent_states)[0], -1, self.vocab_size))
|
||||
if self.sfc is not None:
|
||||
before_outs = paddle.reshape(
|
||||
self.sfc(speech_hidden_states),
|
||||
(paddle.shape(speech_hidden_states)[0], -1, self.odim))
|
||||
else:
|
||||
before_outs = speech_hidden_states
|
||||
if self.postnet is not None:
|
||||
after_outs = before_outs + paddle.transpose(
|
||||
self.postnet(paddle.transpose(before_outs, [0, 2, 1])),
|
||||
[0, 2, 1])
|
||||
else:
|
||||
after_outs = None
|
||||
return before_outs, after_outs, text_outs
|
||||
|
||||
|
||||
class ErnieSAT(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
# network structure related
|
||||
idim: int,
|
||||
odim: int,
|
||||
postnet_layers: int=5,
|
||||
postnet_filts: int=5,
|
||||
postnet_chans: int=256,
|
||||
use_scaled_pos_enc: bool=False,
|
||||
encoder_type: str='conformer',
|
||||
decoder_type: str='conformer',
|
||||
enc_input_layer: str='sega_mlm',
|
||||
enc_pre_speech_layer: int=0,
|
||||
enc_cnn_module_kernel: int=7,
|
||||
enc_attention_dim: int=384,
|
||||
enc_attention_heads: int=2,
|
||||
enc_linear_units: int=1536,
|
||||
enc_num_blocks: int=4,
|
||||
enc_dropout_rate: float=0.2,
|
||||
enc_positional_dropout_rate: float=0.2,
|
||||
enc_attention_dropout_rate: float=0.2,
|
||||
enc_normalize_before: bool=True,
|
||||
enc_macaron_style: bool=True,
|
||||
enc_use_cnn_module: bool=True,
|
||||
enc_selfattention_layer_type: str='legacy_rel_selfattn',
|
||||
enc_activation_type: str='swish',
|
||||
enc_pos_enc_layer_type: str='legacy_rel_pos',
|
||||
enc_positionwise_layer_type: str='conv1d',
|
||||
enc_positionwise_conv_kernel_size: int=3,
|
||||
text_masking: bool=False,
|
||||
dec_cnn_module_kernel: int=31,
|
||||
dec_attention_dim: int=384,
|
||||
dec_attention_heads: int=2,
|
||||
dec_linear_units: int=1536,
|
||||
dec_num_blocks: int=4,
|
||||
dec_dropout_rate: float=0.2,
|
||||
dec_positional_dropout_rate: float=0.2,
|
||||
dec_attention_dropout_rate: float=0.2,
|
||||
dec_macaron_style: bool=True,
|
||||
dec_use_cnn_module: bool=True,
|
||||
dec_selfattention_layer_type: str='legacy_rel_selfattn',
|
||||
dec_activation_type: str='swish',
|
||||
dec_pos_enc_layer_type: str='legacy_rel_pos',
|
||||
dec_positionwise_layer_type: str='conv1d',
|
||||
dec_positionwise_conv_kernel_size: int=3,
|
||||
init_type: str="xavier_uniform", ):
|
||||
super().__init__()
|
||||
# store hyperparameters
|
||||
self.odim = odim
|
||||
|
||||
self.use_scaled_pos_enc = use_scaled_pos_enc
|
||||
|
||||
# initialize parameters
|
||||
initialize(self, init_type)
|
||||
|
||||
# Encoder
|
||||
if encoder_type == "conformer":
|
||||
encoder = MLMEncoder(
|
||||
idim=odim,
|
||||
vocab_size=idim,
|
||||
pre_speech_layer=enc_pre_speech_layer,
|
||||
attention_dim=enc_attention_dim,
|
||||
attention_heads=enc_attention_heads,
|
||||
linear_units=enc_linear_units,
|
||||
num_blocks=enc_num_blocks,
|
||||
dropout_rate=enc_dropout_rate,
|
||||
positional_dropout_rate=enc_positional_dropout_rate,
|
||||
attention_dropout_rate=enc_attention_dropout_rate,
|
||||
input_layer=enc_input_layer,
|
||||
normalize_before=enc_normalize_before,
|
||||
positionwise_layer_type=enc_positionwise_layer_type,
|
||||
positionwise_conv_kernel_size=enc_positionwise_conv_kernel_size,
|
||||
macaron_style=enc_macaron_style,
|
||||
pos_enc_layer_type=enc_pos_enc_layer_type,
|
||||
selfattention_layer_type=enc_selfattention_layer_type,
|
||||
activation_type=enc_activation_type,
|
||||
use_cnn_module=enc_use_cnn_module,
|
||||
cnn_module_kernel=enc_cnn_module_kernel,
|
||||
text_masking=text_masking)
|
||||
else:
|
||||
raise ValueError(f"{encoder_type} is not supported.")
|
||||
|
||||
# Decoder
|
||||
if decoder_type != 'no_decoder':
|
||||
decoder = MLMDecoder(
|
||||
idim=0,
|
||||
input_layer=None,
|
||||
cnn_module_kernel=dec_cnn_module_kernel,
|
||||
attention_dim=dec_attention_dim,
|
||||
attention_heads=dec_attention_heads,
|
||||
linear_units=dec_linear_units,
|
||||
num_blocks=dec_num_blocks,
|
||||
dropout_rate=dec_dropout_rate,
|
||||
positional_dropout_rate=dec_positional_dropout_rate,
|
||||
macaron_style=dec_macaron_style,
|
||||
use_cnn_module=dec_use_cnn_module,
|
||||
selfattention_layer_type=dec_selfattention_layer_type,
|
||||
activation_type=dec_activation_type,
|
||||
pos_enc_layer_type=dec_pos_enc_layer_type,
|
||||
positionwise_layer_type=dec_positionwise_layer_type,
|
||||
positionwise_conv_kernel_size=dec_positionwise_conv_kernel_size)
|
||||
|
||||
else:
|
||||
decoder = None
|
||||
|
||||
model_class = MLMDualMaksing if text_masking else MLMEncAsDecoder
|
||||
|
||||
self.model = model_class(
|
||||
odim=odim,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
postnet_layers=postnet_layers,
|
||||
postnet_filts=postnet_filts,
|
||||
postnet_chans=postnet_chans,
|
||||
text_masking=text_masking)
|
||||
|
||||
nn.initializer.set_global_initializer(None)
|
||||
|
||||
def forward(self,
|
||||
speech: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor,
|
||||
speech_mask: paddle.Tensor,
|
||||
text_mask: paddle.Tensor,
|
||||
speech_seg_pos: paddle.Tensor,
|
||||
text_seg_pos: paddle.Tensor):
|
||||
return self.model(
|
||||
speech=speech,
|
||||
text=text,
|
||||
masked_pos=masked_pos,
|
||||
speech_mask=speech_mask,
|
||||
text_mask=text_mask,
|
||||
speech_seg_pos=speech_seg_pos,
|
||||
text_seg_pos=text_seg_pos)
|
||||
|
||||
def inference(
|
||||
self,
|
||||
speech: paddle.Tensor,
|
||||
text: paddle.Tensor,
|
||||
masked_pos: paddle.Tensor,
|
||||
speech_mask: paddle.Tensor,
|
||||
text_mask: paddle.Tensor,
|
||||
speech_seg_pos: paddle.Tensor,
|
||||
text_seg_pos: paddle.Tensor,
|
||||
span_bdy: List[int],
|
||||
use_teacher_forcing: bool=False, ) -> Dict[str, paddle.Tensor]:
|
||||
return self.model.inference(
|
||||
speech=speech,
|
||||
text=text,
|
||||
masked_pos=masked_pos,
|
||||
speech_mask=speech_mask,
|
||||
text_mask=text_mask,
|
||||
speech_seg_pos=speech_seg_pos,
|
||||
text_seg_pos=text_seg_pos,
|
||||
span_bdy=span_bdy,
|
||||
use_teacher_forcing=use_teacher_forcing)
|
@ -0,0 +1,148 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from paddle import distributed as dist
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
|
||||
from paddlespeech.t2s.modules.losses import MLMLoss
|
||||
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
|
||||
from paddlespeech.t2s.training.reporter import report
|
||||
from paddlespeech.t2s.training.updaters.standard_updater import StandardUpdater
|
||||
logging.basicConfig(
|
||||
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='[%Y-%m-%d %H:%M:%S]')
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class ErnieSATUpdater(StandardUpdater):
|
||||
def __init__(self,
|
||||
model: Layer,
|
||||
optimizer: Optimizer,
|
||||
dataloader: DataLoader,
|
||||
init_state=None,
|
||||
text_masking: bool=False,
|
||||
odim: int=80,
|
||||
output_dir: Path=None):
|
||||
super().__init__(model, optimizer, dataloader, init_state=None)
|
||||
|
||||
self.criterion = MLMLoss(text_masking=text_masking, odim=odim)
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
def update_core(self, batch):
|
||||
self.msg = "Rank: {}, ".format(dist.get_rank())
|
||||
losses_dict = {}
|
||||
|
||||
before_outs, after_outs, text_outs = self.model(
|
||||
speech=batch["speech"],
|
||||
text=batch["text"],
|
||||
masked_pos=batch["masked_pos"],
|
||||
speech_mask=batch["speech_mask"],
|
||||
text_mask=batch["text_mask"],
|
||||
speech_seg_pos=batch["speech_seg_pos"],
|
||||
text_seg_pos=batch["text_seg_pos"])
|
||||
|
||||
mlm_loss, text_mlm_loss = self.criterion(
|
||||
speech=batch["speech"],
|
||||
before_outs=before_outs,
|
||||
after_outs=after_outs,
|
||||
masked_pos=batch["masked_pos"],
|
||||
text=batch["text"],
|
||||
# maybe None
|
||||
text_outs=text_outs,
|
||||
# maybe None
|
||||
text_masked_pos=batch["text_masked_pos"])
|
||||
|
||||
loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss
|
||||
|
||||
optimizer = self.optimizer
|
||||
optimizer.clear_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
report("train/loss", float(loss))
|
||||
report("train/mlm_loss", float(mlm_loss))
|
||||
if text_mlm_loss is not None:
|
||||
report("train/text_mlm_loss", float(text_mlm_loss))
|
||||
losses_dict["text_mlm_loss"] = float(text_mlm_loss)
|
||||
|
||||
losses_dict["mlm_loss"] = float(mlm_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
|
||||
|
||||
class ErnieSATEvaluator(StandardEvaluator):
|
||||
def __init__(self,
|
||||
model: Layer,
|
||||
dataloader: DataLoader,
|
||||
text_masking: bool=False,
|
||||
odim: int=80,
|
||||
output_dir: Path=None):
|
||||
super().__init__(model, dataloader)
|
||||
|
||||
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
|
||||
self.filehandler = logging.FileHandler(str(log_file))
|
||||
logger.addHandler(self.filehandler)
|
||||
self.logger = logger
|
||||
self.msg = ""
|
||||
|
||||
self.criterion = MLMLoss(text_masking=text_masking, odim=odim)
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
self.msg = "Evaluate: "
|
||||
losses_dict = {}
|
||||
|
||||
before_outs, after_outs, text_outs = self.model(
|
||||
speech=batch["speech"],
|
||||
text=batch["text"],
|
||||
masked_pos=batch["masked_pos"],
|
||||
speech_mask=batch["speech_mask"],
|
||||
text_mask=batch["text_mask"],
|
||||
speech_seg_pos=batch["speech_seg_pos"],
|
||||
text_seg_pos=batch["text_seg_pos"])
|
||||
|
||||
mlm_loss, text_mlm_loss = self.criterion(
|
||||
speech=batch["speech"],
|
||||
before_outs=before_outs,
|
||||
after_outs=after_outs,
|
||||
masked_pos=batch["masked_pos"],
|
||||
text=batch["text"],
|
||||
# maybe None
|
||||
text_outs=text_outs,
|
||||
# maybe None
|
||||
text_masked_pos=batch["text_masked_pos"])
|
||||
loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss
|
||||
|
||||
report("eval/loss", float(loss))
|
||||
report("eval/mlm_loss", float(mlm_loss))
|
||||
if text_mlm_loss is not None:
|
||||
report("eval/text_mlm_loss", float(text_mlm_loss))
|
||||
losses_dict["text_mlm_loss"] = float(text_mlm_loss)
|
||||
|
||||
losses_dict["mlm_loss"] = float(mlm_loss)
|
||||
losses_dict["loss"] = float(loss)
|
||||
|
||||
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
|
||||
for k, v in losses_dict.items())
|
||||
self.logger.info(self.msg)
|
Loading…
Reference in new issue