add fastspeech2 voice cloning in aishell3

pull/1003/head
TianYuan 3 years ago
parent b9dc017011
commit 8d025451de

@ -24,7 +24,7 @@ f0max: 400 # Minimum f0 for pitch extraction.
# DATA SETTING #
###########################################################
batch_size: 64
num_workers: 4
num_workers: 2
###########################################################

@ -7,7 +7,6 @@ gpus=0,1
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_482.pdz

@ -9,7 +9,7 @@ alignment=$3
ge2e_ckpt_path=$4
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/../../ge2e/inference.py \
python3 ${MAIN_ROOT}/paddlespeech/vector/exps/ge2e/inference.py \
--input=${input}/wav \
--output=${preprocess_path}/embed \
--checkpoint_path=${ge2e_ckpt_path}

@ -0,0 +1,89 @@
# FastSpeech2 + AISHELL-3 Voice Cloning
This example contains code used to train a [Tacotron2 ](https://arxiv.org/abs/1712.05884) model with [AISHELL-3](http://www.aishelltech.com/aishell_3). The trained model can be used in Voice Cloning Task, We refer to the model structure of [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf) . The general steps are as follows:
1. Speaker Encoder: We use a Speaker Verification to train a speaker encoder. Datasets used in this task are different from those used in Tacotron2, because the transcriptions are not needed, we use more datasets, refer to [ge2e](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/ge2e).
2. Synthesizer: Then, we use the trained speaker encoder to generate utterance embedding for each sentence in AISHELL-3. This embedding is a extra input of Tacotron2 which will be concated with encoder outputs.
3. Vocoder: We use WaveFlow as the neural Vocoder, refer to [waveflow](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/voc0).
## Get Started
Assume the path to the dataset is `~/datasets/data_aishell3`.
Assume the path to the MFA result of AISHELL-3 is `./alignment`.
Assume the path to the pretrained ge2e model is `ge2e_ckpt_path=./ge2e_ckpt_0.3/step-3000000`
Run the command below to
1. **source path**.
2. preprocess the dataset,
3. train the model.
4. start a voice cloning inference.
```bash
./run.sh
```
### Preprocess the dataset
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${input} ${preprocess_path} ${alignment} ${ge2e_ckpt_path}
```
#### generate utterance embedding
Use pretrained GE2E (speaker encoder) to generate utterance embedding for each sentence in AISHELL-3, which has the same file structure with wav files and the format is `.npy`.
```bash
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/../ge2e/inference.py \
--input=${input} \
--output=${preprocess_path}/embed \
--ngpu=1 \
--checkpoint_path=${ge2e_ckpt_path}
fi
```
The computing time of utterance embedding can be x hours.
#### process wav
There are silence in the edge of AISHELL-3's wavs, and the audio amplitude is very small, so, we need to remove the silence and normalize the audio. You can the silence remove method based on volume or energy, but the effect is not very good, We use [MFA](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get the alignment of text and speech, then utilize the alignment results to remove the silence.
We use Montreal Force Aligner 1.0. The label in aishell3 include pinyinso the lexicon we provided to MFA is pinyin rather than Chinese characters. And the prosody marks(`$` and `%`) need to be removed. You shoud preprocess the dataset into the format which MFA needs, the texts have the same name with wavs and have the suffix `.lab`.
We use [lexicon.txt](https://github.com/PaddlePaddle/PaddleSpeech/blob/develop/paddlespeech/t2s/exps/voice_cloning/tacotron2_ge2e/lexicon.txt) as the lexicon.
You can download the alignment results from here [alignment_aishell3.tar.gz](https://paddlespeech.bj.bcebos.com/Parakeet/alignment_aishell3.tar.gz), or train your own MFA model reference to [use_mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/use_mfa) (use MFA1.x now) of our repo.
```bash
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Process wav ..."
python3 ${BIN_DIR}/process_wav.py \
--input=${input}/wav \
--output=${preprocess_path}/normalized_wav \
--alignment=${alignment}
fi
```
#### preprocess transcription
We revert the transcription into `phones` and `tones`. It is worth noting that our processing here is different from that used for MFA, we separated the tones. This is a processing method, of course, you can only segment initials and vowels.
```bash
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
python3 ${BIN_DIR}/preprocess_transcription.py \
--input=${input} \
--output=${preprocess_path}
fi
```
The default input is `~/datasets/data_aishell3/train`which contains `label_train-set.txt`, the processed results are `metadata.yaml` and `metadata.pickle`. the former is a text format for easy viewing, and the latter is a binary format for direct reading.
#### extract mel
```python
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
python3 ${BIN_DIR}/extract_mel.py \
--input=${preprocess_path}/normalized_wav \
--output=${preprocess_path}/mel
fi
```
### Train the model
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${preprocess_path} ${train_output_path}
```
Our model remve stop token prediction in Tacotron2, because of the problem of extremely unbalanced proportion of positive and negative samples of stop token prediction, and it's very sensitive to the clip of audio silence. We use the last symbol from the highest point of attention to the encoder side as the termination condition.
In addition, in order to accelerate the convergence of the model, we add `guided attention loss` to induce the alignment between encoder and decoder to show diagonal lines faster.
### Infernece
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${ge2e_params_path} ${tacotron2_params_path} ${waveflow_params_path} ${vc_input} ${vc_output}
```
## Pretrained Model
[tacotron2_aishell3_ckpt_0.3.zip](https://paddlespeech.bj.bcebos.com/Parakeet/tacotron2_aishell3_ckpt_0.3.zip).

@ -0,0 +1,105 @@
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # sr
n_fft: 2048 # FFT size.
n_shift: 300 # Hop size.
win_length: 1200 # Window length.
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
# Only used for feats_type != raw
fmin: 80 # Minimum frequency of Mel basis.
fmax: 7600 # Maximum frequency of Mel basis.
n_mels: 80 # The number of mel basis.
# Only used for the model using pitch features (e.g. FastSpeech2)
f0min: 80 # Maximum f0 for pitch extraction.
f0max: 400 # Minimum f0 for pitch extraction.
###########################################################
# DATA SETTING #
###########################################################
batch_size: 64
num_workers: 2
###########################################################
# MODEL SETTING #
###########################################################
model:
adim: 384 # attention dimension
aheads: 2 # number of attention heads
elayers: 4 # number of encoder layers
eunits: 1536 # number of encoder ff units
dlayers: 4 # number of decoder layers
dunits: 1536 # number of decoder ff units
positionwise_layer_type: conv1d # type of position-wise layer
positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer
duration_predictor_layers: 2 # number of layers of duration predictor
duration_predictor_chans: 256 # number of channels of duration predictor
duration_predictor_kernel_size: 3 # filter size of duration predictor
postnet_layers: 5 # number of layers of postnset
postnet_filts: 5 # filter size of conv layers in postnet
postnet_chans: 256 # number of channels of conv layers in postnet
use_masking: True # whether to apply masking for padded part in loss calculation
use_scaled_pos_enc: True # whether to use scaled positional encoding
encoder_normalize_before: True # whether to perform layer normalization before the input
decoder_normalize_before: True # whether to perform layer normalization before the input
reduction_factor: 1 # reduction factor
init_type: xavier_uniform # initialization type
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer
transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer
transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding
transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder
energy_predictor_layers: 2 # number of conv layers in energy predictor
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder
spk_embed_dim: 256 # speaker embedding dimension
spk_embed_integration_type: concat # speaker embedding integration type
###########################################################
# UPDATER SETTING #
###########################################################
updater:
use_masking: True # whether to apply masking for padded part in loss calculation
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.001 # learning rate
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 200
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086

@ -0,0 +1,86 @@
#!/bin/bash
stage=0
stop_stage=100
config_path=$1
ge2e_ckpt_path=$2
# gen speaker embedding
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${MAIN_ROOT}/paddlespeech/vector/exps/ge2e/inference.py \
--input=~/datasets/data_aishell3/train/wav/ \
--output=dump/embed \
--checkpoint_path=${ge2e_ckpt_path}
fi
# copy from tts3/preprocess
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# get durations from MFA's result
echo "Generate durations.txt from MFA results ..."
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
--inputdir=./aishell3_alignment_tone \
--output durations.txt \
--config=${config_path}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/preprocess.py \
--dataset=aishell3 \
--rootdir=~/datasets/data_aishell3/ \
--dumpdir=dump \
--dur-file=durations.txt \
--config=${config_path} \
--num-cpu=20 \
--cut-sil=True \
--embed-dir=dump/embed
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="speech"
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="pitch"
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="energy"
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# normalize and covert phone/speaker to id, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
fi

@ -0,0 +1,19 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/synthesize.py \
--fastspeech2-config=${config_path} \
--fastspeech2-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--fastspeech2-stat=dump/train/speech_stats.npy \
--pwg-config=pwg_aishell3_ckpt_0.5/default.yaml \
--pwg-checkpoint=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--pwg-stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
--test-metadata=dump/test/norm/metadata.jsonl \
--output-dir=${train_output_path}/test \
--phones-dict=dump/phone_id_map.txt \
--voice-cloning=True

@ -0,0 +1,13 @@
#!/bin/bash
config_path=$1
train_output_path=$2
python3 ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=2 \
--phones-dict=dump/phone_id_map.txt \
--voice-cloning=True

@ -0,0 +1,22 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
ge2e_params_path=$4
ref_audio_dir=$5
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/voice_cloning.py \
--fastspeech2-config=${config_path} \
--fastspeech2-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
--fastspeech2-stat=dump/train/speech_stats.npy \
--pwg-config=pwg_aishell3_ckpt_0.5/default.yaml \
--pwg-checkpoint=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--pwg-stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
--ge2e_params_path=${ge2e_params_path} \
--text="凯莫瑞安联合体的经济崩溃迫在眉睫。" \
--input-dir=${ref_audio_dir} \
--output-dir=${train_output_path}/vc_syn \
--phones-dict=dump/phone_id_map.txt

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=fastspeech2
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}

@ -0,0 +1,44 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_482.pdz
ref_audio_dir=ref_audio
# not include ".pdparams" here
ge2e_ckpt_path=./ge2e_ckpt_0.3/step-3000000
# include ".pdparams" here
ge2e_params_path=${ge2e_ckpt_path}.pdparams
# with the following command, you can choice the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${conf_path} ${ge2e_ckpt_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ge2e_params_path} ${ref_audio_dir} || exit -1
fi

@ -80,7 +80,7 @@ lambda_adv: 4.0 # Loss balancing coefficient.
batch_size: 8 # Batch size.
batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size.
pin_memory: true # Whether to pin memory in Pytorch DataLoader.
num_workers: 4 # Number of workers in Pytorch DataLoader.
num_workers: 2 # Number of workers in Pytorch DataLoader.
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory.

@ -10,4 +10,4 @@ export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=ge2e
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
export BIN_DIR=${MAIN_ROOT}/paddlespeech/vector/exps/${MODEL}

@ -100,7 +100,7 @@ def fastspeech2_single_spk_batch_fn(examples):
def fastspeech2_multi_spk_batch_fn(examples):
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"]
# fields = ["text", "text_lengths", "speech", "speech_lengths", "durations", "pitch", "energy", "spk_id"/"spembs"]
text = [np.array(item["text"], dtype=np.int64) for item in examples]
speech = [np.array(item["speech"], dtype=np.float32) for item in examples]
pitch = [np.array(item["pitch"], dtype=np.float32) for item in examples]
@ -114,7 +114,6 @@ def fastspeech2_multi_spk_batch_fn(examples):
speech_lengths = [
np.array(item["speech_lengths"], dtype=np.int64) for item in examples
]
spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples]
text = batch_sequences(text)
pitch = batch_sequences(pitch)
@ -130,7 +129,6 @@ def fastspeech2_multi_spk_batch_fn(examples):
energy = paddle.to_tensor(energy)
text_lengths = paddle.to_tensor(text_lengths)
speech_lengths = paddle.to_tensor(speech_lengths)
spk_id = paddle.to_tensor(spk_id)
batch = {
"text": text,
@ -139,9 +137,20 @@ def fastspeech2_multi_spk_batch_fn(examples):
"speech": speech,
"speech_lengths": speech_lengths,
"pitch": pitch,
"energy": energy,
"spk_id": spk_id
"energy": energy
}
# spembs has a higher priority than spk_id
if "spembs" in examples[0]:
spembs = [
np.array(item["spembs"], dtype=np.float32) for item in examples
]
spembs = batch_sequences(spembs)
spembs = paddle.to_tensor(spembs)
batch["spembs"] = spembs
elif "spk_id" in examples[0]:
spk_id = [np.array(item["spk_id"], dtype=np.int64) for item in examples]
spk_id = paddle.to_tensor(spk_id)
batch["spk_id"] = spk_id
return batch

@ -167,6 +167,10 @@ def main():
"pitch": str(pitch_path),
"energy": str(energy_path)
}
# add spembs for voice cloning
if "spembs" in item:
record["spembs"] = str(item["spembs"])
output_metadata.append(record)
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"

@ -44,7 +44,8 @@ def process_sentence(config: Dict[str, Any],
mel_extractor=None,
pitch_extractor=None,
energy_extractor=None,
cut_sil: bool=True):
cut_sil: bool=True,
embed_dir: Path=None):
utt_id = fp.stem
# for vctk
if utt_id.endswith("_mic2"):
@ -116,6 +117,14 @@ def process_sentence(config: Dict[str, Any],
"energy": str(energy_path),
"speaker": speaker
}
if embed_dir:
if speaker in os.listdir(embed_dir):
embed_name = utt_id + ".npy"
embed_path = embed_dir / speaker / embed_name
if embed_path.is_file():
record["spembs"] = str(embed_path)
else:
return None
return record
@ -127,13 +136,14 @@ def process_sentences(config,
pitch_extractor=None,
energy_extractor=None,
nprocs: int=1,
cut_sil: bool=True):
cut_sil: bool=True,
embed_dir: Path=None):
if nprocs == 1:
results = []
for fp in fps:
record = process_sentence(config, fp, sentences, output_dir,
mel_extractor, pitch_extractor,
energy_extractor, cut_sil)
energy_extractor, cut_sil, embed_dir)
if record:
results.append(record)
else:
@ -144,7 +154,7 @@ def process_sentences(config,
future = pool.submit(process_sentence, config, fp,
sentences, output_dir, mel_extractor,
pitch_extractor, energy_extractor,
cut_sil)
cut_sil, embed_dir)
future.add_done_callback(lambda p: progress.update())
futures.append(future)
@ -202,6 +212,11 @@ def main():
default=True,
help="whether cut sil in the edge of audio")
parser.add_argument(
"--embed-dir",
default=None,
type=str,
help="directory to speaker embedding files.")
args = parser.parse_args()
rootdir = Path(args.rootdir).expanduser()
@ -211,6 +226,11 @@ def main():
dumpdir.mkdir(parents=True, exist_ok=True)
dur_file = Path(args.dur_file).expanduser()
if args.embed_dir:
embed_dir = Path(args.embed_dir).expanduser().resolve()
else:
embed_dir = None
assert rootdir.is_dir()
assert dur_file.is_file()
@ -251,6 +271,7 @@ def main():
test_wav_files += wav_files[-sub_num_dev:]
else:
train_wav_files += wav_files
elif args.dataset == "ljspeech":
wav_files = sorted(list((rootdir / "wavs").rglob("*.wav")))
# split data into 3 sections
@ -317,7 +338,8 @@ def main():
pitch_extractor,
energy_extractor,
nprocs=args.num_cpu,
cut_sil=args.cut_sil)
cut_sil=args.cut_sil,
embed_dir=embed_dir)
if dev_wav_files:
process_sentences(
config,
@ -327,7 +349,8 @@ def main():
mel_extractor,
pitch_extractor,
energy_extractor,
cut_sil=args.cut_sil)
cut_sil=args.cut_sil,
embed_dir=embed_dir)
if test_wav_files:
process_sentences(
config,
@ -338,7 +361,8 @@ def main():
pitch_extractor,
energy_extractor,
nprocs=args.num_cpu,
cut_sil=args.cut_sil)
cut_sil=args.cut_sil,
embed_dir=embed_dir)
if __name__ == "__main__":

@ -40,15 +40,18 @@ def evaluate(args, fastspeech2_config, pwg_config):
fields = ["utt_id", "text"]
num_speakers = None
if args.speaker_dict is not None:
print("multiple speaker fastspeech2!")
with open(args.speaker_dict, 'rt') as f:
spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id)
fields += ["spk_id"]
elif args.voice_cloning:
print("voice cloning!")
fields += ["spembs"]
else:
print("single speaker fastspeech2!")
num_speakers = None
print("num_speakers:", num_speakers)
test_dataset = DataTable(data=test_metadata, fields=fields)
@ -96,12 +99,15 @@ def evaluate(args, fastspeech2_config, pwg_config):
for datum in test_dataset:
utt_id = datum["utt_id"]
text = paddle.to_tensor(datum["text"])
if "spk_id" in datum:
spembs = None
spk_id = None
if args.voice_cloning and "spembs" in datum:
spembs = paddle.to_tensor(np.load(datum["spembs"]))
elif "spk_id" in datum:
spk_id = paddle.to_tensor(datum["spk_id"])
else:
spk_id = None
with paddle.no_grad():
wav = pwg_inference(fastspeech2_inference(text, spk_id=spk_id))
wav = pwg_inference(
fastspeech2_inference(text, spk_id=spk_id, spembs=spembs))
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
@ -142,6 +148,15 @@ def main():
type=str,
default=None,
help="speaker id map file for multiple speaker model.")
def str2bool(str):
return True if str.lower() == 'true' else False
parser.add_argument(
"--voice-cloning",
type=str2bool,
default=False,
help="whether training voice cloning model.")
parser.add_argument("--test-metadata", type=str, help="test metadata.")
parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument(

@ -61,6 +61,8 @@ def train_sp(args, config):
"text", "text_lengths", "speech", "speech_lengths", "durations",
"pitch", "energy"
]
converters = {"speech": np.load, "pitch": np.load, "energy": np.load}
num_speakers = None
if args.speaker_dict is not None:
print("multiple speaker fastspeech2!")
collate_fn = fastspeech2_multi_spk_batch_fn
@ -68,10 +70,14 @@ def train_sp(args, config):
spk_id = [line.strip().split() for line in f.readlines()]
num_speakers = len(spk_id)
fields += ["spk_id"]
elif args.voice_cloning:
print("Training voice cloning!")
collate_fn = fastspeech2_multi_spk_batch_fn
fields += ["spembs"]
converters["spembs"] = np.load
else:
print("single speaker fastspeech2!")
collate_fn = fastspeech2_single_spk_batch_fn
num_speakers = None
print("num_speakers:", num_speakers)
# dataloader has been too verbose
@ -83,17 +89,13 @@ def train_sp(args, config):
train_dataset = DataTable(
data=train_metadata,
fields=fields,
converters={"speech": np.load,
"pitch": np.load,
"energy": np.load}, )
converters=converters, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
dev_dataset = DataTable(
data=dev_metadata,
fields=fields,
converters={"speech": np.load,
"pitch": np.load,
"energy": np.load}, )
converters=converters, )
# collate function and dataloader
@ -184,6 +186,15 @@ def main():
default=None,
help="speaker id map file for multiple speaker model.")
def str2bool(str):
return True if str.lower() == 'true' else False
parser.add_argument(
"--voice-cloning",
type=str2bool,
default=False,
help="whether training voice cloning model.")
args = parser.parse_args()
with open(args.config) as f:

@ -0,0 +1,207 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
import numpy as np
import paddle
import soundfile as sf
import yaml
from yacs.config import CfgNode
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Inference
from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator
from paddlespeech.t2s.models.parallel_wavegan import PWGInference
from paddlespeech.t2s.modules.normalizer import ZScore
from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder
def voice_cloning(args, fastspeech2_config, pwg_config):
# speaker encoder
p = SpeakerVerificationPreprocessor(
sampling_rate=16000,
audio_norm_target_dBFS=-30,
vad_window_length=30,
vad_moving_average_width=8,
vad_max_silence_length=6,
mel_window_length=25,
mel_window_step=10,
n_mels=40,
partial_n_frames=160,
min_pad_coverage=0.75,
partial_overlap_ratio=0.5)
print("Audio Processor Done!")
speaker_encoder = LSTMSpeakerEncoder(
n_mels=40, num_layers=3, hidden_size=256, output_size=256)
speaker_encoder.set_state_dict(paddle.load(args.ge2e_params_path))
speaker_encoder.eval()
print("GE2E Done!")
with open(args.phones_dict, "r") as f:
phn_id = [line.strip().split() for line in f.readlines()]
vocab_size = len(phn_id)
print("vocab_size:", vocab_size)
odim = fastspeech2_config.n_mels
model = FastSpeech2(
idim=vocab_size, odim=odim, **fastspeech2_config["model"])
model.set_state_dict(
paddle.load(args.fastspeech2_checkpoint)["main_params"])
model.eval()
vocoder = PWGGenerator(**pwg_config["generator_params"])
vocoder.set_state_dict(paddle.load(args.pwg_checkpoint)["generator_params"])
vocoder.remove_weight_norm()
vocoder.eval()
print("model done!")
frontend = Frontend(phone_vocab_path=args.phones_dict)
print("frontend done!")
stat = np.load(args.fastspeech2_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
fastspeech2_normalizer = ZScore(mu, std)
stat = np.load(args.pwg_stat)
mu, std = stat
mu = paddle.to_tensor(mu)
std = paddle.to_tensor(std)
pwg_normalizer = ZScore(mu, std)
fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model)
fastspeech2_inference.eval()
pwg_inference = PWGInference(pwg_normalizer, vocoder)
pwg_inference.eval()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
input_dir = Path(args.input_dir)
sentence = args.text
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
phone_ids = input_ids["phone_ids"][0]
for name in os.listdir(input_dir):
utt_id = name.split(".")[0]
ref_audio_path = input_dir / name
mel_sequences = p.extract_mel_partials(p.preprocess_wav(ref_audio_path))
# print("mel_sequences: ", mel_sequences.shape)
with paddle.no_grad():
spembs = speaker_encoder.embed_utterance(
paddle.to_tensor(mel_sequences))
# print("spembs shape: ", spembs.shape)
with paddle.no_grad():
wav = pwg_inference(fastspeech2_inference(phone_ids, spembs=spembs))
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
samplerate=fastspeech2_config.fs)
print(f"{utt_id} done!")
# Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spembs
random_spembs = np.random.rand(256) * 0.2
random_spembs = paddle.to_tensor(random_spembs)
utt_id = "random_spembs"
with paddle.no_grad():
wav = pwg_inference(fastspeech2_inference(phone_ids, spembs=spembs))
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
samplerate=fastspeech2_config.fs)
print(f"{utt_id} done!")
def main():
# parse args and config and redirect to train_sp
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
parser.add_argument(
"--fastspeech2-checkpoint",
type=str,
help="fastspeech2 checkpoint to load.")
parser.add_argument(
"--fastspeech2-stat",
type=str,
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
)
parser.add_argument(
"--pwg-config", type=str, help="parallel wavegan config file.")
parser.add_argument(
"--pwg-checkpoint",
type=str,
help="parallel wavegan generator parameters to load.")
parser.add_argument(
"--pwg-stat",
type=str,
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
)
parser.add_argument(
"--phones-dict",
type=str,
default="phone_id_map.txt",
help="phone vocabulary file.")
parser.add_argument(
"--text",
type=str,
default="每当你觉得,想要批评什么人的时候,你切要记着,这个世界上的人,并非都具备你禀有的条件。",
help="text to synthesize, a line")
parser.add_argument(
"--ge2e_params_path", type=str, help="ge2e params path.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
parser.add_argument(
"--input-dir",
type=str,
help="input dir of *.wav, the sample rate will be resample to 16k.")
parser.add_argument("--output-dir", type=str, help="output dir.")
args = parser.parse_args()
if args.ngpu == 0:
paddle.set_device("cpu")
elif args.ngpu > 0:
paddle.set_device("gpu")
else:
print("ngpu should >= 0 !")
with open(args.fastspeech2_config) as f:
fastspeech2_config = CfgNode(yaml.safe_load(f))
with open(args.pwg_config) as f:
pwg_config = CfgNode(yaml.safe_load(f))
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
print(fastspeech2_config)
print(pwg_config)
voice_cloning(args, fastspeech2_config, pwg_config)
if __name__ == "__main__":
main()

@ -20,14 +20,14 @@ import paddle
import soundfile as sf
from matplotlib import pyplot as plt
from paddlespeech.t2s.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.t2s.exps.voice_cloning.tacotron2_ge2e.aishell3 import voc_phones
from paddlespeech.t2s.exps.voice_cloning.tacotron2_ge2e.aishell3 import voc_tones
from paddlespeech.t2s.exps.voice_cloning.tacotron2_ge2e.chinese_g2p import convert_sentence
from paddlespeech.t2s.models.lstm_speaker_encoder import LSTMSpeakerEncoder
from paddlespeech.t2s.models.tacotron2 import Tacotron2
from paddlespeech.t2s.models.waveflow import ConditionalWaveFlow
from paddlespeech.t2s.utils import display
from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder
def voice_cloning(args):

@ -146,7 +146,7 @@ class FastSpeech2(nn.Layer):
# initialize parameters
initialize(self, init_type)
if self.spk_embed_dim is not None:
if self.spk_embed_dim and num_speakers:
self.spk_embedding_table = nn.Embedding(
num_embeddings=num_speakers,
embedding_dim=self.spk_embed_dim,
@ -395,6 +395,7 @@ class FastSpeech2(nn.Layer):
# integrate speaker embedding
if self.spk_embed_dim is not None:
# spembs has a higher priority than spk_id
if spembs is not None:
hs = self._integrate_with_spk_embed(hs, spembs)
elif spk_id is not None:
@ -525,7 +526,6 @@ class FastSpeech2(nn.Layer):
# input of embedding must be int64
x = paddle.cast(text, 'int64')
y = speech
spemb = spembs
d, p, e = durations, pitch, energy
# setup batch axis
ilens = paddle.shape(x)[0]
@ -535,8 +535,8 @@ class FastSpeech2(nn.Layer):
if y is not None:
ys = y.unsqueeze(0)
if spemb is not None:
spembs = spemb.unsqueeze(0)
if spembs is not None:
spembs = spembs.unsqueeze(0)
if tone_id is not None:
tone_id = tone_id.unsqueeze(0)
@ -546,7 +546,7 @@ class FastSpeech2(nn.Layer):
ds = d.unsqueeze(0) if d is not None else None
ps = p.unsqueeze(0) if p is not None else None
es = e.unsqueeze(0) if e is not None else None
# ds, ps, es = , p.unsqueeze(0), e.unsqueeze(0)
# (1, L, odim)
_, outs, d_outs, p_outs, e_outs = self._forward(
xs,
@ -680,9 +680,9 @@ class FastSpeech2Inference(nn.Layer):
self.normalizer = normalizer
self.acoustic_model = model
def forward(self, text, spk_id=None):
def forward(self, text, spk_id=None, spembs=None):
normalized_mel, d_outs, p_outs, e_outs = self.acoustic_model.inference(
text, spk_id=spk_id)
text, spk_id=spk_id, spembs=spembs)
logmel = self.normalizer.inverse(normalized_mel)
return logmel

@ -54,6 +54,10 @@ class FastSpeech2Updater(StandardUpdater):
losses_dict = {}
# spk_id!=None in multiple spk fastspeech2
spk_id = batch["spk_id"] if "spk_id" in batch else None
spembs = batch["spembs"] if "spembs" in batch else None
# No explicit speaker identifier labels are used during voice cloning training.
if spembs is not None:
spk_id = None
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"],
@ -63,7 +67,8 @@ class FastSpeech2Updater(StandardUpdater):
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id)
spk_id=spk_id,
spembs=spembs)
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(
after_outs=after_outs,
@ -126,6 +131,9 @@ class FastSpeech2Evaluator(StandardEvaluator):
losses_dict = {}
# spk_id!=None in multiple spk fastspeech2
spk_id = batch["spk_id"] if "spk_id" in batch else None
spembs = batch["spembs"] if "spembs" in batch else None
if spembs is not None:
spk_id = None
before_outs, after_outs, d_outs, p_outs, e_outs, ys, olens = self.model(
text=batch["text"],
@ -135,7 +143,8 @@ class FastSpeech2Evaluator(StandardEvaluator):
durations=batch["durations"],
pitch=batch["pitch"],
energy=batch["energy"],
spk_id=spk_id)
spk_id=spk_id,
spembs=spembs)
l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(
after_outs=after_outs,

@ -0,0 +1,13 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

@ -19,7 +19,7 @@ from typing import List
import numpy as np
from tqdm import tqdm
from paddlespeech.t2s.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
def _process_utterance(path_pair, processor: SpeakerVerificationPreprocessor):

@ -18,9 +18,9 @@ import numpy as np
import paddle
import tqdm
from paddlespeech.t2s.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.t2s.exps.ge2e.config import get_cfg_defaults
from paddlespeech.t2s.models.lstm_speaker_encoder import LSTMSpeakerEncoder
from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.vector.exps.ge2e.config import get_cfg_defaults
from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder
def embed_utterance(processor, model, fpath_or_wav):

@ -14,14 +14,13 @@
import argparse
from pathlib import Path
from audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.t2s.exps.ge2e.config import get_cfg_defaults
from paddlespeech.t2s.exps.ge2e.dataset_processors import process_aidatatang_200zh
from paddlespeech.t2s.exps.ge2e.dataset_processors import process_librispeech
from paddlespeech.t2s.exps.ge2e.dataset_processors import process_magicdata
from paddlespeech.t2s.exps.ge2e.dataset_processors import process_voxceleb1
from paddlespeech.t2s.exps.ge2e.dataset_processors import process_voxceleb2
from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.vector.exps.ge2e.config import get_cfg_defaults
from paddlespeech.vector.exps.ge2e.dataset_processors import process_aidatatang_200zh
from paddlespeech.vector.exps.ge2e.dataset_processors import process_librispeech
from paddlespeech.vector.exps.ge2e.dataset_processors import process_magicdata
from paddlespeech.vector.exps.ge2e.dataset_processors import process_voxceleb1
from paddlespeech.vector.exps.ge2e.dataset_processors import process_voxceleb2
if __name__ == "__main__":
parser = argparse.ArgumentParser(

@ -18,7 +18,7 @@ import numpy as np
from paddle.io import BatchSampler
from paddle.io import Dataset
from paddlespeech.t2s.exps.ge2e.random_cycle import random_cycle
from paddlespeech.vector.exps.ge2e.random_cycle import random_cycle
class MultiSpeakerMelDataset(Dataset):

@ -19,13 +19,13 @@ from paddle.io import DataLoader
from paddle.nn.clip import ClipGradByGlobalNorm
from paddle.optimizer import Adam
from paddlespeech.t2s.exps.ge2e.config import get_cfg_defaults
from paddlespeech.t2s.exps.ge2e.speaker_verification_dataset import Collate
from paddlespeech.t2s.exps.ge2e.speaker_verification_dataset import MultiSpeakerMelDataset
from paddlespeech.t2s.exps.ge2e.speaker_verification_dataset import MultiSpeakerSampler
from paddlespeech.t2s.models.lstm_speaker_encoder import LSTMSpeakerEncoder
from paddlespeech.t2s.training import default_argument_parser
from paddlespeech.t2s.training import ExperimentBase
from paddlespeech.vector.exps.ge2e.config import get_cfg_defaults
from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import Collate
from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import MultiSpeakerMelDataset
from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import MultiSpeakerSampler
from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder
class Ge2eExperiment(ExperimentBase):

@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading…
Cancel
Save