commit
4370c5cfa6
@ -0,0 +1,105 @@
|
|||||||
|
###########################################################
|
||||||
|
# FEATURE EXTRACTION SETTING #
|
||||||
|
###########################################################
|
||||||
|
|
||||||
|
fs: 24000 # sr
|
||||||
|
n_fft: 2048 # FFT size.
|
||||||
|
n_shift: 300 # Hop size.
|
||||||
|
win_length: 1200 # Window length.
|
||||||
|
# If set to null, it will be the same as fft_size.
|
||||||
|
window: "hann" # Window function.
|
||||||
|
|
||||||
|
# Only used for feats_type != raw
|
||||||
|
|
||||||
|
fmin: 80 # Minimum frequency of Mel basis.
|
||||||
|
fmax: 7600 # Maximum frequency of Mel basis.
|
||||||
|
n_mels: 80 # The number of mel basis.
|
||||||
|
|
||||||
|
# Only used for the model using pitch features (e.g. FastSpeech2)
|
||||||
|
f0min: 80 # Maximum f0 for pitch extraction.
|
||||||
|
f0max: 400 # Minimum f0 for pitch extraction.
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# DATA SETTING #
|
||||||
|
###########################################################
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 2
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# MODEL SETTING #
|
||||||
|
###########################################################
|
||||||
|
model:
|
||||||
|
adim: 384 # attention dimension
|
||||||
|
aheads: 2 # number of attention heads
|
||||||
|
elayers: 4 # number of encoder layers
|
||||||
|
eunits: 1536 # number of encoder ff units
|
||||||
|
dlayers: 4 # number of decoder layers
|
||||||
|
dunits: 1536 # number of decoder ff units
|
||||||
|
positionwise_layer_type: conv1d # type of position-wise layer
|
||||||
|
positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer
|
||||||
|
duration_predictor_layers: 2 # number of layers of duration predictor
|
||||||
|
duration_predictor_chans: 256 # number of channels of duration predictor
|
||||||
|
duration_predictor_kernel_size: 3 # filter size of duration predictor
|
||||||
|
postnet_layers: 5 # number of layers of postnset
|
||||||
|
postnet_filts: 5 # filter size of conv layers in postnet
|
||||||
|
postnet_chans: 256 # number of channels of conv layers in postnet
|
||||||
|
use_masking: True # whether to apply masking for padded part in loss calculation
|
||||||
|
use_scaled_pos_enc: True # whether to use scaled positional encoding
|
||||||
|
encoder_normalize_before: True # whether to perform layer normalization before the input
|
||||||
|
decoder_normalize_before: True # whether to perform layer normalization before the input
|
||||||
|
reduction_factor: 1 # reduction factor
|
||||||
|
init_type: xavier_uniform # initialization type
|
||||||
|
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
|
||||||
|
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
|
||||||
|
transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer
|
||||||
|
transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding
|
||||||
|
transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer
|
||||||
|
transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer
|
||||||
|
transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding
|
||||||
|
transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer
|
||||||
|
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
|
||||||
|
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
|
||||||
|
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
|
||||||
|
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
|
||||||
|
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
|
||||||
|
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
|
||||||
|
stop_gradient_from_pitch_predictor: true # whether to stop the gradient from pitch predictor to encoder
|
||||||
|
energy_predictor_layers: 2 # number of conv layers in energy predictor
|
||||||
|
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
|
||||||
|
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
|
||||||
|
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
|
||||||
|
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
|
||||||
|
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
|
||||||
|
stop_gradient_from_energy_predictor: false # whether to stop the gradient from energy predictor to encoder
|
||||||
|
spk_embed_dim: 256 # speaker embedding dimension
|
||||||
|
spk_embed_integration_type: concat # speaker embedding integration type
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# UPDATER SETTING #
|
||||||
|
###########################################################
|
||||||
|
updater:
|
||||||
|
use_masking: True # whether to apply masking for padded part in loss calculation
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OPTIMIZER SETTING #
|
||||||
|
###########################################################
|
||||||
|
optimizer:
|
||||||
|
optim: adam # optimizer type
|
||||||
|
learning_rate: 0.001 # learning rate
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# TRAINING SETTING #
|
||||||
|
###########################################################
|
||||||
|
max_epoch: 200
|
||||||
|
num_snapshots: 5
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
# OTHER SETTING #
|
||||||
|
###########################################################
|
||||||
|
seed: 10086
|
@ -0,0 +1,86 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
ge2e_ckpt_path=$2
|
||||||
|
|
||||||
|
# gen speaker embedding
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
python3 ${MAIN_ROOT}/paddlespeech/vector/exps/ge2e/inference.py \
|
||||||
|
--input=~/datasets/data_aishell3/train/wav/ \
|
||||||
|
--output=dump/embed \
|
||||||
|
--checkpoint_path=${ge2e_ckpt_path}
|
||||||
|
fi
|
||||||
|
|
||||||
|
# copy from tts3/preprocess
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# get durations from MFA's result
|
||||||
|
echo "Generate durations.txt from MFA results ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
|
||||||
|
--inputdir=./aishell3_alignment_tone \
|
||||||
|
--output durations.txt \
|
||||||
|
--config=${config_path}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# extract features
|
||||||
|
echo "Extract features ..."
|
||||||
|
python3 ${BIN_DIR}/preprocess.py \
|
||||||
|
--dataset=aishell3 \
|
||||||
|
--rootdir=~/datasets/data_aishell3/ \
|
||||||
|
--dumpdir=dump \
|
||||||
|
--dur-file=durations.txt \
|
||||||
|
--config=${config_path} \
|
||||||
|
--num-cpu=20 \
|
||||||
|
--cut-sil=True \
|
||||||
|
--spk_emb_dir=dump/embed
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# get features' stats(mean and std)
|
||||||
|
echo "Get features' stats ..."
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="speech"
|
||||||
|
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="pitch"
|
||||||
|
|
||||||
|
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--field-name="energy"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
||||||
|
# normalize and covert phone/speaker to id, dev and test should use train's stats
|
||||||
|
echo "Normalize ..."
|
||||||
|
python3 ${BIN_DIR}/normalize.py \
|
||||||
|
--metadata=dump/train/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/train/norm \
|
||||||
|
--speech-stats=dump/train/speech_stats.npy \
|
||||||
|
--pitch-stats=dump/train/pitch_stats.npy \
|
||||||
|
--energy-stats=dump/train/energy_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/normalize.py \
|
||||||
|
--metadata=dump/dev/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/dev/norm \
|
||||||
|
--speech-stats=dump/train/speech_stats.npy \
|
||||||
|
--pitch-stats=dump/train/pitch_stats.npy \
|
||||||
|
--energy-stats=dump/train/energy_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/normalize.py \
|
||||||
|
--metadata=dump/test/raw/metadata.jsonl \
|
||||||
|
--dumpdir=dump/test/norm \
|
||||||
|
--speech-stats=dump/train/speech_stats.npy \
|
||||||
|
--pitch-stats=dump/train/pitch_stats.npy \
|
||||||
|
--energy-stats=dump/train/energy_stats.npy \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--speaker-dict=dump/speaker_id_map.txt
|
||||||
|
fi
|
@ -0,0 +1,19 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
|
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/synthesize.py \
|
||||||
|
--fastspeech2-config=${config_path} \
|
||||||
|
--fastspeech2-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--fastspeech2-stat=dump/train/speech_stats.npy \
|
||||||
|
--pwg-config=pwg_aishell3_ckpt_0.5/default.yaml \
|
||||||
|
--pwg-checkpoint=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||||
|
--pwg-stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||||
|
--test-metadata=dump/test/norm/metadata.jsonl \
|
||||||
|
--output-dir=${train_output_path}/test \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--voice-cloning=True
|
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
|
||||||
|
python3 ${BIN_DIR}/train.py \
|
||||||
|
--train-metadata=dump/train/norm/metadata.jsonl \
|
||||||
|
--dev-metadata=dump/dev/norm/metadata.jsonl \
|
||||||
|
--config=${config_path} \
|
||||||
|
--output-dir=${train_output_path} \
|
||||||
|
--ngpu=2 \
|
||||||
|
--phones-dict=dump/phone_id_map.txt \
|
||||||
|
--voice-cloning=True
|
@ -0,0 +1,22 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
config_path=$1
|
||||||
|
train_output_path=$2
|
||||||
|
ckpt_name=$3
|
||||||
|
ge2e_params_path=$4
|
||||||
|
ref_audio_dir=$5
|
||||||
|
|
||||||
|
FLAGS_allocator_strategy=naive_best_fit \
|
||||||
|
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
|
||||||
|
python3 ${BIN_DIR}/voice_cloning.py \
|
||||||
|
--fastspeech2-config=${config_path} \
|
||||||
|
--fastspeech2-checkpoint=${train_output_path}/checkpoints/${ckpt_name} \
|
||||||
|
--fastspeech2-stat=dump/train/speech_stats.npy \
|
||||||
|
--pwg-config=pwg_aishell3_ckpt_0.5/default.yaml \
|
||||||
|
--pwg-checkpoint=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
|
||||||
|
--pwg-stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
|
||||||
|
--ge2e_params_path=${ge2e_params_path} \
|
||||||
|
--text="凯莫瑞安联合体的经济崩溃迫在眉睫。" \
|
||||||
|
--input-dir=${ref_audio_dir} \
|
||||||
|
--output-dir=${train_output_path}/vc_syn \
|
||||||
|
--phones-dict=dump/phone_id_map.txt
|
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
export MAIN_ROOT=`realpath ${PWD}/../../../`
|
||||||
|
|
||||||
|
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
|
||||||
|
export LC_ALL=C
|
||||||
|
|
||||||
|
export PYTHONDONTWRITEBYTECODE=1
|
||||||
|
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||||
|
export PYTHONIOENCODING=UTF-8
|
||||||
|
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
|
||||||
|
|
||||||
|
MODEL=fastspeech2
|
||||||
|
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}
|
@ -0,0 +1,44 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
source path.sh
|
||||||
|
|
||||||
|
gpus=0,1
|
||||||
|
stage=0
|
||||||
|
stop_stage=100
|
||||||
|
|
||||||
|
conf_path=conf/default.yaml
|
||||||
|
train_output_path=exp/default
|
||||||
|
ckpt_name=snapshot_iter_482.pdz
|
||||||
|
ref_audio_dir=ref_audio
|
||||||
|
|
||||||
|
# not include ".pdparams" here
|
||||||
|
ge2e_ckpt_path=./ge2e_ckpt_0.3/step-3000000
|
||||||
|
|
||||||
|
# include ".pdparams" here
|
||||||
|
ge2e_params_path=${ge2e_ckpt_path}.pdparams
|
||||||
|
|
||||||
|
# with the following command, you can choice the stage range you want to run
|
||||||
|
# such as `./run.sh --stage 0 --stop-stage 0`
|
||||||
|
# this can not be mixed use with `$1`, `$2` ...
|
||||||
|
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
||||||
|
# prepare data
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${conf_path} ${ge2e_ckpt_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
||||||
|
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
||||||
|
# synthesize, vocoder is pwgan
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
||||||
|
# synthesize, vocoder is pwgan
|
||||||
|
CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ge2e_params_path} ${ref_audio_dir} || exit -1
|
||||||
|
fi
|
@ -0,0 +1,208 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import soundfile as sf
|
||||||
|
import yaml
|
||||||
|
from yacs.config import CfgNode
|
||||||
|
|
||||||
|
from paddlespeech.t2s.frontend.zh_frontend import Frontend
|
||||||
|
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2
|
||||||
|
from paddlespeech.t2s.models.fastspeech2 import FastSpeech2Inference
|
||||||
|
from paddlespeech.t2s.models.parallel_wavegan import PWGGenerator
|
||||||
|
from paddlespeech.t2s.models.parallel_wavegan import PWGInference
|
||||||
|
from paddlespeech.t2s.modules.normalizer import ZScore
|
||||||
|
from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
|
||||||
|
from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def voice_cloning(args, fastspeech2_config, pwg_config):
|
||||||
|
# speaker encoder
|
||||||
|
p = SpeakerVerificationPreprocessor(
|
||||||
|
sampling_rate=16000,
|
||||||
|
audio_norm_target_dBFS=-30,
|
||||||
|
vad_window_length=30,
|
||||||
|
vad_moving_average_width=8,
|
||||||
|
vad_max_silence_length=6,
|
||||||
|
mel_window_length=25,
|
||||||
|
mel_window_step=10,
|
||||||
|
n_mels=40,
|
||||||
|
partial_n_frames=160,
|
||||||
|
min_pad_coverage=0.75,
|
||||||
|
partial_overlap_ratio=0.5)
|
||||||
|
print("Audio Processor Done!")
|
||||||
|
|
||||||
|
speaker_encoder = LSTMSpeakerEncoder(
|
||||||
|
n_mels=40, num_layers=3, hidden_size=256, output_size=256)
|
||||||
|
speaker_encoder.set_state_dict(paddle.load(args.ge2e_params_path))
|
||||||
|
speaker_encoder.eval()
|
||||||
|
print("GE2E Done!")
|
||||||
|
|
||||||
|
with open(args.phones_dict, "r") as f:
|
||||||
|
phn_id = [line.strip().split() for line in f.readlines()]
|
||||||
|
vocab_size = len(phn_id)
|
||||||
|
print("vocab_size:", vocab_size)
|
||||||
|
odim = fastspeech2_config.n_mels
|
||||||
|
model = FastSpeech2(
|
||||||
|
idim=vocab_size, odim=odim, **fastspeech2_config["model"])
|
||||||
|
|
||||||
|
model.set_state_dict(
|
||||||
|
paddle.load(args.fastspeech2_checkpoint)["main_params"])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
vocoder = PWGGenerator(**pwg_config["generator_params"])
|
||||||
|
vocoder.set_state_dict(paddle.load(args.pwg_checkpoint)["generator_params"])
|
||||||
|
vocoder.remove_weight_norm()
|
||||||
|
vocoder.eval()
|
||||||
|
print("model done!")
|
||||||
|
|
||||||
|
frontend = Frontend(phone_vocab_path=args.phones_dict)
|
||||||
|
print("frontend done!")
|
||||||
|
|
||||||
|
stat = np.load(args.fastspeech2_stat)
|
||||||
|
mu, std = stat
|
||||||
|
mu = paddle.to_tensor(mu)
|
||||||
|
std = paddle.to_tensor(std)
|
||||||
|
fastspeech2_normalizer = ZScore(mu, std)
|
||||||
|
|
||||||
|
stat = np.load(args.pwg_stat)
|
||||||
|
mu, std = stat
|
||||||
|
mu = paddle.to_tensor(mu)
|
||||||
|
std = paddle.to_tensor(std)
|
||||||
|
pwg_normalizer = ZScore(mu, std)
|
||||||
|
|
||||||
|
fastspeech2_inference = FastSpeech2Inference(fastspeech2_normalizer, model)
|
||||||
|
fastspeech2_inference.eval()
|
||||||
|
pwg_inference = PWGInference(pwg_normalizer, vocoder)
|
||||||
|
pwg_inference.eval()
|
||||||
|
|
||||||
|
output_dir = Path(args.output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
input_dir = Path(args.input_dir)
|
||||||
|
|
||||||
|
sentence = args.text
|
||||||
|
|
||||||
|
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
|
||||||
|
phone_ids = input_ids["phone_ids"][0]
|
||||||
|
|
||||||
|
for name in os.listdir(input_dir):
|
||||||
|
utt_id = name.split(".")[0]
|
||||||
|
ref_audio_path = input_dir / name
|
||||||
|
mel_sequences = p.extract_mel_partials(p.preprocess_wav(ref_audio_path))
|
||||||
|
# print("mel_sequences: ", mel_sequences.shape)
|
||||||
|
with paddle.no_grad():
|
||||||
|
spk_emb = speaker_encoder.embed_utterance(
|
||||||
|
paddle.to_tensor(mel_sequences))
|
||||||
|
# print("spk_emb shape: ", spk_emb.shape)
|
||||||
|
|
||||||
|
with paddle.no_grad():
|
||||||
|
wav = pwg_inference(
|
||||||
|
fastspeech2_inference(phone_ids, spk_emb=spk_emb))
|
||||||
|
|
||||||
|
sf.write(
|
||||||
|
str(output_dir / (utt_id + ".wav")),
|
||||||
|
wav.numpy(),
|
||||||
|
samplerate=fastspeech2_config.fs)
|
||||||
|
print(f"{utt_id} done!")
|
||||||
|
# Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb
|
||||||
|
random_spk_emb = np.random.rand(256) * 0.2
|
||||||
|
random_spk_emb = paddle.to_tensor(random_spk_emb)
|
||||||
|
utt_id = "random_spk_emb"
|
||||||
|
with paddle.no_grad():
|
||||||
|
wav = pwg_inference(fastspeech2_inference(phone_ids, spk_emb=spk_emb))
|
||||||
|
sf.write(
|
||||||
|
str(output_dir / (utt_id + ".wav")),
|
||||||
|
wav.numpy(),
|
||||||
|
samplerate=fastspeech2_config.fs)
|
||||||
|
print(f"{utt_id} done!")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# parse args and config and redirect to train_sp
|
||||||
|
parser = argparse.ArgumentParser(description="")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fastspeech2-config", type=str, help="fastspeech2 config file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fastspeech2-checkpoint",
|
||||||
|
type=str,
|
||||||
|
help="fastspeech2 checkpoint to load.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fastspeech2-stat",
|
||||||
|
type=str,
|
||||||
|
help="mean and standard deviation used to normalize spectrogram when training fastspeech2."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-config", type=str, help="parallel wavegan config file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-checkpoint",
|
||||||
|
type=str,
|
||||||
|
help="parallel wavegan generator parameters to load.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--pwg-stat",
|
||||||
|
type=str,
|
||||||
|
help="mean and standard deviation used to normalize spectrogram when training parallel wavegan."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--phones-dict",
|
||||||
|
type=str,
|
||||||
|
default="phone_id_map.txt",
|
||||||
|
help="phone vocabulary file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
default="每当你觉得,想要批评什么人的时候,你切要记着,这个世界上的人,并非都具备你禀有的条件。",
|
||||||
|
help="text to synthesize, a line")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ge2e_params_path", type=str, help="ge2e params path.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-dir",
|
||||||
|
type=str,
|
||||||
|
help="input dir of *.wav, the sample rate will be resample to 16k.")
|
||||||
|
parser.add_argument("--output-dir", type=str, help="output dir.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.ngpu == 0:
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
elif args.ngpu > 0:
|
||||||
|
paddle.set_device("gpu")
|
||||||
|
else:
|
||||||
|
print("ngpu should >= 0 !")
|
||||||
|
|
||||||
|
with open(args.fastspeech2_config) as f:
|
||||||
|
fastspeech2_config = CfgNode(yaml.safe_load(f))
|
||||||
|
with open(args.pwg_config) as f:
|
||||||
|
pwg_config = CfgNode(yaml.safe_load(f))
|
||||||
|
|
||||||
|
print("========Args========")
|
||||||
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
print("========Config========")
|
||||||
|
print(fastspeech2_config)
|
||||||
|
print(pwg_config)
|
||||||
|
|
||||||
|
voice_cloning(args, fastspeech2_config, pwg_config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -0,0 +1,348 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0,
|
||||||
|
training=True):
|
||||||
|
r"""Scaled dot product attention with masking.
|
||||||
|
|
||||||
|
Assume that q, k, v all have the same leading dimensions (denoted as * in
|
||||||
|
descriptions below). Dropout is applied to attention weights before
|
||||||
|
weighted sum of values.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
q : Tensor [shape=(\*, T_q, d)]
|
||||||
|
the query tensor.
|
||||||
|
k : Tensor [shape=(\*, T_k, d)]
|
||||||
|
the key tensor.
|
||||||
|
v : Tensor [shape=(\*, T_k, d_v)]
|
||||||
|
the value tensor.
|
||||||
|
mask : Tensor, [shape=(\*, T_q, T_k) or broadcastable shape], optional
|
||||||
|
the mask tensor, zeros correspond to paddings. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
out : Tensor [shape=(\*, T_q, d_v)]
|
||||||
|
the context vector.
|
||||||
|
attn_weights : Tensor [shape=(\*, T_q, T_k)]
|
||||||
|
the attention weights.
|
||||||
|
"""
|
||||||
|
d = q.shape[-1] # we only support imperative execution
|
||||||
|
qk = paddle.matmul(q, k, transpose_y=True)
|
||||||
|
scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d))
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here
|
||||||
|
|
||||||
|
attn_weights = F.softmax(scaled_logit, axis=-1)
|
||||||
|
attn_weights = F.dropout(attn_weights, dropout, training=training)
|
||||||
|
out = paddle.matmul(attn_weights, v)
|
||||||
|
return out, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
def drop_head(x, drop_n_heads, training=True):
|
||||||
|
"""Drop n context vectors from multiple ones.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : Tensor [shape=(batch_size, num_heads, time_steps, channels)]
|
||||||
|
The input, multiple context vectors.
|
||||||
|
drop_n_heads : int [0<= drop_n_heads <= num_heads]
|
||||||
|
Number of vectors to drop.
|
||||||
|
training : bool
|
||||||
|
A flag indicating whether it is in training. If `False`, no dropout is
|
||||||
|
applied.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tensor
|
||||||
|
The output.
|
||||||
|
"""
|
||||||
|
if not training or (drop_n_heads == 0):
|
||||||
|
return x
|
||||||
|
|
||||||
|
batch_size, num_heads, _, _ = x.shape
|
||||||
|
# drop all heads
|
||||||
|
if num_heads == drop_n_heads:
|
||||||
|
return paddle.zeros_like(x)
|
||||||
|
|
||||||
|
mask = np.ones([batch_size, num_heads])
|
||||||
|
mask[:, :drop_n_heads] = 0
|
||||||
|
for subarray in mask:
|
||||||
|
np.random.shuffle(subarray)
|
||||||
|
scale = float(num_heads) / (num_heads - drop_n_heads)
|
||||||
|
mask = scale * np.reshape(mask, [batch_size, num_heads, 1, 1])
|
||||||
|
out = x * paddle.to_tensor(mask)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _split_heads(x, num_heads):
|
||||||
|
batch_size, time_steps, _ = x.shape
|
||||||
|
x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1])
|
||||||
|
x = paddle.transpose(x, [0, 2, 1, 3])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_heads(x):
|
||||||
|
batch_size, _, time_steps, _ = x.shape
|
||||||
|
x = paddle.transpose(x, [0, 2, 1, 3])
|
||||||
|
x = paddle.reshape(x, [batch_size, time_steps, -1])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Standard implementations of Monohead Attention & Multihead Attention
|
||||||
|
class MonoheadAttention(nn.Layer):
|
||||||
|
"""Monohead Attention module.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_dim : int
|
||||||
|
Feature size of the query.
|
||||||
|
dropout : float, optional
|
||||||
|
Dropout probability of scaled dot product attention and final context
|
||||||
|
vector. Defaults to 0.0.
|
||||||
|
k_dim : int, optional
|
||||||
|
Feature size of the key of each scaled dot product attention. If not
|
||||||
|
provided, it is set to `model_dim / num_heads`. Defaults to None.
|
||||||
|
v_dim : int, optional
|
||||||
|
Feature size of the key of each scaled dot product attention. If not
|
||||||
|
provided, it is set to `model_dim / num_heads`. Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_dim: int,
|
||||||
|
dropout: float=0.0,
|
||||||
|
k_dim: int=None,
|
||||||
|
v_dim: int=None):
|
||||||
|
super(MonoheadAttention, self).__init__()
|
||||||
|
k_dim = k_dim or model_dim
|
||||||
|
v_dim = v_dim or model_dim
|
||||||
|
self.affine_q = nn.Linear(model_dim, k_dim)
|
||||||
|
self.affine_k = nn.Linear(model_dim, k_dim)
|
||||||
|
self.affine_v = nn.Linear(model_dim, v_dim)
|
||||||
|
self.affine_o = nn.Linear(v_dim, model_dim)
|
||||||
|
|
||||||
|
self.model_dim = model_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
def forward(self, q, k, v, mask):
|
||||||
|
"""Compute context vector and attention weights.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
q : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||||
|
The queries.
|
||||||
|
k : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||||
|
The keys.
|
||||||
|
v : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||||
|
The values.
|
||||||
|
mask : Tensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape
|
||||||
|
The mask.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
out : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||||
|
The context vector.
|
||||||
|
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
|
||||||
|
The attention weights.
|
||||||
|
"""
|
||||||
|
q = self.affine_q(q) # (B, T, C)
|
||||||
|
k = self.affine_k(k)
|
||||||
|
v = self.affine_v(v)
|
||||||
|
|
||||||
|
context_vectors, attention_weights = scaled_dot_product_attention(
|
||||||
|
q, k, v, mask, self.dropout, self.training)
|
||||||
|
|
||||||
|
out = self.affine_o(context_vectors)
|
||||||
|
return out, attention_weights
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Layer):
|
||||||
|
"""Multihead Attention module.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
model_dim: int
|
||||||
|
The feature size of query.
|
||||||
|
num_heads : int
|
||||||
|
The number of attention heads.
|
||||||
|
dropout : float, optional
|
||||||
|
Dropout probability of scaled dot product attention and final context
|
||||||
|
vector. Defaults to 0.0.
|
||||||
|
k_dim : int, optional
|
||||||
|
Feature size of the key of each scaled dot product attention. If not
|
||||||
|
provided, it is set to ``model_dim / num_heads``. Defaults to None.
|
||||||
|
v_dim : int, optional
|
||||||
|
Feature size of the key of each scaled dot product attention. If not
|
||||||
|
provided, it is set to ``model_dim / num_heads``. Defaults to None.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
---------
|
||||||
|
ValueError
|
||||||
|
If ``model_dim`` is not divisible by ``num_heads``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
dropout: float=0.0,
|
||||||
|
k_dim: int=None,
|
||||||
|
v_dim: int=None):
|
||||||
|
super(MultiheadAttention, self).__init__()
|
||||||
|
if model_dim % num_heads != 0:
|
||||||
|
raise ValueError("model_dim must be divisible by num_heads")
|
||||||
|
depth = model_dim // num_heads
|
||||||
|
k_dim = k_dim or depth
|
||||||
|
v_dim = v_dim or depth
|
||||||
|
self.affine_q = nn.Linear(model_dim, num_heads * k_dim)
|
||||||
|
self.affine_k = nn.Linear(model_dim, num_heads * k_dim)
|
||||||
|
self.affine_v = nn.Linear(model_dim, num_heads * v_dim)
|
||||||
|
self.affine_o = nn.Linear(num_heads * v_dim, model_dim)
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.model_dim = model_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
def forward(self, q, k, v, mask):
|
||||||
|
"""Compute context vector and attention weights.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
q : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||||
|
The queries.
|
||||||
|
k : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||||
|
The keys.
|
||||||
|
v : Tensor [shape=(batch_size, time_steps_k, model_dim)]
|
||||||
|
The values.
|
||||||
|
mask : Tensor [shape=(batch_size, times_steps_q, time_steps_k] or broadcastable shape
|
||||||
|
The mask.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
out : Tensor [shape=(batch_size, time_steps_q, model_dim)]
|
||||||
|
The context vector.
|
||||||
|
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
|
||||||
|
The attention weights.
|
||||||
|
"""
|
||||||
|
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
|
||||||
|
k = _split_heads(self.affine_k(k), self.num_heads)
|
||||||
|
v = _split_heads(self.affine_v(v), self.num_heads)
|
||||||
|
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
|
||||||
|
|
||||||
|
context_vectors, attention_weights = scaled_dot_product_attention(
|
||||||
|
q, k, v, mask, self.dropout, self.training)
|
||||||
|
# NOTE: there is more sophisticated implementation: Scheduled DropHead
|
||||||
|
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
|
||||||
|
out = self.affine_o(context_vectors)
|
||||||
|
return out, attention_weights
|
||||||
|
|
||||||
|
|
||||||
|
class LocationSensitiveAttention(nn.Layer):
|
||||||
|
"""Location Sensitive Attention module.
|
||||||
|
|
||||||
|
Reference: `Attention-Based Models for Speech Recognition <https://arxiv.org/pdf/1506.07503.pdf>`_
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
d_query: int
|
||||||
|
The feature size of query.
|
||||||
|
d_key : int
|
||||||
|
The feature size of key.
|
||||||
|
d_attention : int
|
||||||
|
The feature size of dimension.
|
||||||
|
location_filters : int
|
||||||
|
Filter size of attention convolution.
|
||||||
|
location_kernel_size : int
|
||||||
|
Kernel size of attention convolution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
d_query: int,
|
||||||
|
d_key: int,
|
||||||
|
d_attention: int,
|
||||||
|
location_filters: int,
|
||||||
|
location_kernel_size: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.query_layer = nn.Linear(d_query, d_attention, bias_attr=False)
|
||||||
|
self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False)
|
||||||
|
self.value = nn.Linear(d_attention, 1, bias_attr=False)
|
||||||
|
|
||||||
|
# Location Layer
|
||||||
|
self.location_conv = nn.Conv1D(
|
||||||
|
2,
|
||||||
|
location_filters,
|
||||||
|
kernel_size=location_kernel_size,
|
||||||
|
padding=int((location_kernel_size - 1) / 2),
|
||||||
|
bias_attr=False,
|
||||||
|
data_format='NLC')
|
||||||
|
self.location_layer = nn.Linear(
|
||||||
|
location_filters, d_attention, bias_attr=False)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
query,
|
||||||
|
processed_key,
|
||||||
|
value,
|
||||||
|
attention_weights_cat,
|
||||||
|
mask=None):
|
||||||
|
"""Compute context vector and attention weights.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
query : Tensor [shape=(batch_size, d_query)]
|
||||||
|
The queries.
|
||||||
|
processed_key : Tensor [shape=(batch_size, time_steps_k, d_attention)]
|
||||||
|
The keys after linear layer.
|
||||||
|
value : Tensor [shape=(batch_size, time_steps_k, d_key)]
|
||||||
|
The values.
|
||||||
|
attention_weights_cat : Tensor [shape=(batch_size, time_step_k, 2)]
|
||||||
|
Attention weights concat.
|
||||||
|
mask : Tensor, optional
|
||||||
|
The mask. Shape should be (batch_size, times_steps_k, 1).
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
----------
|
||||||
|
attention_context : Tensor [shape=(batch_size, d_attention)]
|
||||||
|
The context vector.
|
||||||
|
attention_weights : Tensor [shape=(batch_size, time_steps_k)]
|
||||||
|
The attention weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
|
||||||
|
processed_attention_weights = self.location_layer(
|
||||||
|
self.location_conv(attention_weights_cat))
|
||||||
|
# (B, T_enc, 1)
|
||||||
|
alignment = self.value(
|
||||||
|
paddle.tanh(processed_attention_weights + processed_key +
|
||||||
|
processed_query))
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
alignment = alignment + (1.0 - mask) * -1e9
|
||||||
|
|
||||||
|
attention_weights = F.softmax(alignment, axis=1)
|
||||||
|
attention_context = paddle.matmul(
|
||||||
|
attention_weights, value, transpose_x=True)
|
||||||
|
|
||||||
|
attention_weights = paddle.squeeze(attention_weights, axis=-1)
|
||||||
|
attention_context = paddle.squeeze(attention_context, axis=1)
|
||||||
|
|
||||||
|
return attention_context, attention_weights
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
@ -0,0 +1,38 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def cycle(iterable):
|
||||||
|
# cycle('ABCD') --> A B C D A B C D A B C D ...
|
||||||
|
saved = []
|
||||||
|
for element in iterable:
|
||||||
|
yield element
|
||||||
|
saved.append(element)
|
||||||
|
while saved:
|
||||||
|
for element in saved:
|
||||||
|
yield element
|
||||||
|
|
||||||
|
|
||||||
|
def random_cycle(iterable):
|
||||||
|
# cycle('ABCD') --> A B C D B C D A A D B C ...
|
||||||
|
saved = []
|
||||||
|
for element in iterable:
|
||||||
|
yield element
|
||||||
|
saved.append(element)
|
||||||
|
random.shuffle(saved)
|
||||||
|
while saved:
|
||||||
|
for element in saved:
|
||||||
|
yield element
|
||||||
|
random.shuffle(saved)
|
@ -0,0 +1,131 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from paddle.io import BatchSampler
|
||||||
|
from paddle.io import Dataset
|
||||||
|
|
||||||
|
from paddlespeech.vector.exps.ge2e.random_cycle import random_cycle
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSpeakerMelDataset(Dataset):
|
||||||
|
"""A 2 layer directory thatn contains mel spectrograms in *.npy format.
|
||||||
|
An Example file structure tree is shown below. We prefer to preprocess
|
||||||
|
raw datasets and organized them like this.
|
||||||
|
|
||||||
|
dataset_root/
|
||||||
|
speaker1/
|
||||||
|
utterance1.npy
|
||||||
|
utterance2.npy
|
||||||
|
utterance3.npy
|
||||||
|
speaker2/
|
||||||
|
utterance1.npy
|
||||||
|
utterance2.npy
|
||||||
|
utterance3.npy
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset_root: Path):
|
||||||
|
self.root = Path(dataset_root).expanduser()
|
||||||
|
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
|
||||||
|
|
||||||
|
speaker_utterances = {
|
||||||
|
speaker_dir: list(speaker_dir.glob("*.npy"))
|
||||||
|
for speaker_dir in speaker_dirs
|
||||||
|
}
|
||||||
|
|
||||||
|
self.speaker_dirs = speaker_dirs
|
||||||
|
self.speaker_to_utterances = speaker_utterances
|
||||||
|
|
||||||
|
# meta data
|
||||||
|
self.num_speakers = len(self.speaker_dirs)
|
||||||
|
self.num_utterances = np.sum(
|
||||||
|
len(utterances)
|
||||||
|
for speaker, utterances in self.speaker_to_utterances.items())
|
||||||
|
|
||||||
|
def get_example_by_index(self, speaker_index, utterance_index):
|
||||||
|
speaker_dir = self.speaker_dirs[speaker_index]
|
||||||
|
fpath = self.speaker_to_utterances[speaker_dir][utterance_index]
|
||||||
|
return self[fpath]
|
||||||
|
|
||||||
|
def __getitem__(self, fpath):
|
||||||
|
return np.load(fpath)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return int(self.num_utterances)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSpeakerSampler(BatchSampler):
|
||||||
|
"""A multi-stratal sampler designed for speaker verification task.
|
||||||
|
First, N speakers from all speakers are sampled randomly. Then, for each
|
||||||
|
speaker, randomly sample M utterances from their corresponding utterances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dataset: MultiSpeakerMelDataset,
|
||||||
|
speakers_per_batch: int,
|
||||||
|
utterances_per_speaker: int):
|
||||||
|
self._speakers = list(dataset.speaker_dirs)
|
||||||
|
self._speaker_to_utterances = dataset.speaker_to_utterances
|
||||||
|
|
||||||
|
self.speakers_per_batch = speakers_per_batch
|
||||||
|
self.utterances_per_speaker = utterances_per_speaker
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
# yield list of Paths
|
||||||
|
speaker_generator = iter(random_cycle(self._speakers))
|
||||||
|
speaker_utterances_generator = {
|
||||||
|
s: iter(random_cycle(us))
|
||||||
|
for s, us in self._speaker_to_utterances.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
while True:
|
||||||
|
speakers = []
|
||||||
|
for _ in range(self.speakers_per_batch):
|
||||||
|
speakers.append(next(speaker_generator))
|
||||||
|
|
||||||
|
utterances = []
|
||||||
|
for s in speakers:
|
||||||
|
us = speaker_utterances_generator[s]
|
||||||
|
for _ in range(self.utterances_per_speaker):
|
||||||
|
utterances.append(next(us))
|
||||||
|
yield utterances
|
||||||
|
|
||||||
|
|
||||||
|
class RandomClip(object):
|
||||||
|
def __init__(self, frames):
|
||||||
|
self.frames = frames
|
||||||
|
|
||||||
|
def __call__(self, spec):
|
||||||
|
# spec [T, C]
|
||||||
|
T = spec.shape[0]
|
||||||
|
start = random.randint(0, T - self.frames)
|
||||||
|
return spec[start:start + self.frames, :]
|
||||||
|
|
||||||
|
|
||||||
|
class Collate(object):
|
||||||
|
def __init__(self, num_frames):
|
||||||
|
self.random_crop = RandomClip(num_frames)
|
||||||
|
|
||||||
|
def __call__(self, examples):
|
||||||
|
frame_clips = [self.random_crop(mel) for mel in examples]
|
||||||
|
batced_clips = np.stack(frame_clips)
|
||||||
|
return batced_clips
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mydataset = MultiSpeakerMelDataset(
|
||||||
|
Path("/home/chenfeiyu/datasets/SV2TTS/encoder"))
|
||||||
|
print(mydataset.get_example_by_index(0, 10))
|
@ -0,0 +1,123 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import time
|
||||||
|
|
||||||
|
from paddle import DataParallel
|
||||||
|
from paddle import distributed as dist
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.nn.clip import ClipGradByGlobalNorm
|
||||||
|
from paddle.optimizer import Adam
|
||||||
|
|
||||||
|
from paddlespeech.t2s.training import default_argument_parser
|
||||||
|
from paddlespeech.t2s.training import ExperimentBase
|
||||||
|
from paddlespeech.vector.exps.ge2e.config import get_cfg_defaults
|
||||||
|
from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import Collate
|
||||||
|
from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import MultiSpeakerMelDataset
|
||||||
|
from paddlespeech.vector.exps.ge2e.speaker_verification_dataset import MultiSpeakerSampler
|
||||||
|
from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder
|
||||||
|
|
||||||
|
|
||||||
|
class Ge2eExperiment(ExperimentBase):
|
||||||
|
def setup_model(self):
|
||||||
|
config = self.config
|
||||||
|
model = LSTMSpeakerEncoder(config.data.n_mels, config.model.num_layers,
|
||||||
|
config.model.hidden_size,
|
||||||
|
config.model.embedding_size)
|
||||||
|
optimizer = Adam(
|
||||||
|
config.training.learning_rate_init,
|
||||||
|
parameters=model.parameters(),
|
||||||
|
grad_clip=ClipGradByGlobalNorm(3))
|
||||||
|
self.model = DataParallel(model) if self.parallel else model
|
||||||
|
self.model_core = model
|
||||||
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
def setup_dataloader(self):
|
||||||
|
config = self.config
|
||||||
|
train_dataset = MultiSpeakerMelDataset(self.args.data)
|
||||||
|
sampler = MultiSpeakerSampler(train_dataset,
|
||||||
|
config.training.speakers_per_batch,
|
||||||
|
config.training.utterances_per_speaker)
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_sampler=sampler,
|
||||||
|
collate_fn=Collate(config.data.partial_n_frames),
|
||||||
|
num_workers=16)
|
||||||
|
|
||||||
|
self.train_dataset = train_dataset
|
||||||
|
self.train_loader = train_loader
|
||||||
|
|
||||||
|
def train_batch(self):
|
||||||
|
start = time.time()
|
||||||
|
batch = self.read_batch()
|
||||||
|
data_loader_time = time.time() - start
|
||||||
|
|
||||||
|
self.optimizer.clear_grad()
|
||||||
|
self.model.train()
|
||||||
|
specs = batch
|
||||||
|
loss, eer = self.model(specs, self.config.training.speakers_per_batch)
|
||||||
|
loss.backward()
|
||||||
|
self.model_core.do_gradient_ops()
|
||||||
|
self.optimizer.step()
|
||||||
|
iteration_time = time.time() - start
|
||||||
|
|
||||||
|
# logging
|
||||||
|
loss_value = float(loss)
|
||||||
|
msg = "Rank: {}, ".format(dist.get_rank())
|
||||||
|
msg += "step: {}, ".format(self.iteration)
|
||||||
|
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
|
||||||
|
iteration_time)
|
||||||
|
msg += 'loss: {:>.6f} err: {:>.6f}'.format(loss_value, eer)
|
||||||
|
self.logger.info(msg)
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
self.visualizer.add_scalar("train/loss", loss_value, self.iteration)
|
||||||
|
self.visualizer.add_scalar("train/eer", eer, self.iteration)
|
||||||
|
self.visualizer.add_scalar("param/w",
|
||||||
|
float(self.model_core.similarity_weight),
|
||||||
|
self.iteration)
|
||||||
|
self.visualizer.add_scalar("param/b",
|
||||||
|
float(self.model_core.similarity_bias),
|
||||||
|
self.iteration)
|
||||||
|
|
||||||
|
def valid(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def main_sp(config, args):
|
||||||
|
exp = Ge2eExperiment(config, args)
|
||||||
|
exp.setup()
|
||||||
|
exp.resume_or_load()
|
||||||
|
exp.run()
|
||||||
|
|
||||||
|
|
||||||
|
def main(config, args):
|
||||||
|
if args.ngpu > 1:
|
||||||
|
dist.spawn(main_sp, args=(config, args), nprocs=args.ngpu)
|
||||||
|
else:
|
||||||
|
main_sp(config, args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config = get_cfg_defaults()
|
||||||
|
parser = default_argument_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
if args.opts:
|
||||||
|
config.merge_from_list(args.opts)
|
||||||
|
config.freeze()
|
||||||
|
print(config)
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
main(config, args)
|
@ -0,0 +1,13 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
Loading…
Reference in new issue