add aishell3 voice cloning with ECAPA-TDNN spk encoder

pull/2359/head
TianYuan 2 years ago
parent ed16f96a9c
commit e622f42d92

@ -10,4 +10,5 @@
* voc3 - MultiBand MelGAN
* vc0 - Tacotron2 Voice Cloning with GE2E
* vc1 - FastSpeech2 Voice Cloning with GE2E
* vc2 - FastSpeech2 Voice Cloning with ECAPA-TDNN
* ernie_sat - ERNIE-SAT

@ -99,7 +99,7 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p
The synthesizing step is very similar to that one of [tts3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3), but we should set `--voice-cloning=True` when calling `${BIN_DIR}/../synthesize.py`.
### Voice Cloning
Assume there are some reference audios in `./ref_audio`
Assume there are some reference audios in `./ref_audio`
```text
ref_audio
├── 001238.wav
@ -116,7 +116,7 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_outpu
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss
:-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------:
default|2(gpu) x 96400|0.99699|0.62013|0.53057|0.11954| 0.20426|
default|2(gpu) x 96400|0.99699|0.62013|0.053057|0.11954| 0.20426|
FastSpeech2 checkpoint contains files listed below.
(There is no need for `speaker_id_map.txt` here )

@ -0,0 +1,126 @@
# FastSpeech2 + AISHELL-3 Voice Cloning (ECAPA-TDNN)
This example contains code used to train a [FastSpeech2](https://arxiv.org/abs/2006.04558) model with [AISHELL-3](http://www.aishelltech.com/aishell_3). The trained model can be used in Voice Cloning Task, We refer to the model structure of [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf). The general steps are as follows:
1. Speaker Encoder: We use Speaker Verification to train a speaker encoder. Datasets used in this task are different from those used in `FastSpeech2` because the transcriptions are not needed, we use more datasets, refer to [ECAPA-TDNN](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0).
2. Synthesizer: We use the trained speaker encoder to generate speaker embedding for each sentence in AISHELL-3. This embedding is an extra input of `FastSpeech2` which will be concated with encoder outputs.
3. Vocoder: We use [Parallel Wave GAN](http://arxiv.org/abs/1910.11480) as the neural Vocoder, refer to [voc1](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1).
## Dataset
### Download and Extract
Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`.
### Get MFA Result and Extract
We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2.
You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo.
## Get Started
Assume the path to the dataset is `~/datasets/data_aishell3`.
Assume the path to the MFA result of AISHELL-3 is `./aishell3_alignment_tone`.
Run the command below to
1. **source path**.
2. preprocess the dataset.
3. train the model.
4. synthesize waveform from `metadata.jsonl`.
5. start a voice cloning inference.
```bash
./run.sh
```
You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset.
```bash
./run.sh --stage 0 --stop-stage 0
```
### Data Preprocessing
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${conf_path}
```
When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below.
```text
dump
├── dev
│ ├── norm
│ └── raw
├── embed
│ ├── SSB0005
│ ├── SSB0009
│ ├── ...
│ └── ...
├── phone_id_map.txt
├── speaker_id_map.txt
├── test
│ ├── norm
│ └── raw
└── train
├── energy_stats.npy
├── norm
├── pitch_stats.npy
├── raw
└── speech_stats.npy
```
The `embed` contains the generated speaker embedding for each sentence in AISHELL-3, which has the same file structure with wav files and the format is `.npy`.
The computing time of utterance embedding can be x hours.
The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains speech、pitch and energy features of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/*_stats.npy`.
Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, speech_lengths, durations, the path of speech features, the path of pitch features, the path of energy features, speaker, and id of each utterance.
The preprocessing step is very similar to that one of [tts3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3), but there is one more `ECAPA-TDNN/inference` step here.
### Model Training
`./local/train.sh` calls `${BIN_DIR}/train.py`.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path}
```
The training step is very similar to that one of [tts3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3), but we should set `--voice-cloning=True` when calling `${BIN_DIR}/train.py`.
### Synthesizing
We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1) as the neural vocoder.
Download pretrained parallel wavegan model from [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip) and unzip it.
```bash
unzip pwg_aishell3_ckpt_0.5.zip
```
Parallel WaveGAN checkpoint contains files listed below.
```text
pwg_aishell3_ckpt_0.5
├── default.yaml # default config used to train parallel wavegan
├── feats_stats.npy # statistics used to normalize spectrogram when training parallel wavegan
└── snapshot_iter_1000000.pdz # generator parameters of parallel wavegan
```
`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`.
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name}
```
The synthesizing step is very similar to that one of [tts3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3), but we should set `--voice-cloning=True` when calling `${BIN_DIR}/../synthesize.py`.
### Voice Cloning
Assume there are some reference audios in `./ref_audio` (the format must be wav here)
```text
ref_audio
├── 001238.wav
├── LJ015-0254.wav
└── audio_self_test.wav
```
`./local/voice_cloning.sh` calls `${BIN_DIR}/../voice_cloning.py`
```bash
CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ref_audio_dir}
```
## Pretrained Model
- [fastspeech2_aishell3_ckpt_vc2_1.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_ckpt_vc2_1.2.0.zip)
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss
:-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------:
default|2(gpu) x 96400|0.991855|0.599517|0.052142|0.094877| 0.245318|
FastSpeech2 checkpoint contains files listed below.
(There is no need for `speaker_id_map.txt` here )
```text
fastspeech2_aishell3_ckpt_vc2_1.2.0
├── default.yaml # default config used to train fastspeech2
├── energy_stats.npy # statistics used to normalize energy when training fastspeech2
├── phone_id_map.txt # phone vocabulary file when training fastspeech2
├── pitch_stats.npy # statistics used to normalize pitch when training fastspeech2
├── snapshot_iter_96400.pdz # model parameters and optimizer states
└── speech_stats.npy # statistics used to normalize spectrogram when training fastspeech2
```

@ -0,0 +1,104 @@
###########################################################
# FEATURE EXTRACTION SETTING #
###########################################################
fs: 24000 # sr
n_fft: 2048 # FFT size (samples).
n_shift: 300 # Hop size (samples). 12.5ms
win_length: 1200 # Window length (samples). 50ms
# If set to null, it will be the same as fft_size.
window: "hann" # Window function.
# Only used for feats_type != raw
fmin: 80 # Minimum frequency of Mel basis.
fmax: 7600 # Maximum frequency of Mel basis.
n_mels: 80 # The number of mel basis.
# Only used for the model using pitch features (e.g. FastSpeech2)
f0min: 80 # Minimum f0 for pitch extraction.
f0max: 400 # Maximum f0 for pitch extraction.
###########################################################
# DATA SETTING #
###########################################################
batch_size: 64
num_workers: 2
###########################################################
# MODEL SETTING #
###########################################################
model:
adim: 384 # attention dimension
aheads: 2 # number of attention heads
elayers: 4 # number of encoder layers
eunits: 1536 # number of encoder ff units
dlayers: 4 # number of decoder layers
dunits: 1536 # number of decoder ff units
positionwise_layer_type: conv1d # type of position-wise layer
positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer
duration_predictor_layers: 2 # number of layers of duration predictor
duration_predictor_chans: 256 # number of channels of duration predictor
duration_predictor_kernel_size: 3 # filter size of duration predictor
postnet_layers: 5 # number of layers of postnset
postnet_filts: 5 # filter size of conv layers in postnet
postnet_chans: 256 # number of channels of conv layers in postnet
use_scaled_pos_enc: True # whether to use scaled positional encoding
encoder_normalize_before: True # whether to perform layer normalization before the input
decoder_normalize_before: True # whether to perform layer normalization before the input
reduction_factor: 1 # reduction factor
init_type: xavier_uniform # initialization type
init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding
init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding
transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer
transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding
transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer
transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer
transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding
transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer
pitch_predictor_layers: 5 # number of conv layers in pitch predictor
pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor
pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor
pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor
pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch
pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch
stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder
energy_predictor_layers: 2 # number of conv layers in energy predictor
energy_predictor_chans: 256 # number of channels of conv layers in energy predictor
energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor
energy_predictor_dropout: 0.5 # dropout rate in energy predictor
energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy
energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy
stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder
spk_embed_dim: 192 # speaker embedding dimension
spk_embed_integration_type: concat # speaker embedding integration type
###########################################################
# UPDATER SETTING #
###########################################################
updater:
use_masking: True # whether to apply masking for padded part in loss calculation
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.001 # learning rate
###########################################################
# TRAINING SETTING #
###########################################################
max_epoch: 200
num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086

@ -0,0 +1,85 @@
#!/bin/bash
stage=0
stop_stage=100
config_path=$1
# gen speaker embedding
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/vc2_infer.py \
--input=~/datasets/data_aishell3/train/wav/ \
--output=dump/embed \
--num-cpu=20
fi
# copy from tts3/preprocess
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# get durations from MFA's result
echo "Generate durations.txt from MFA results ..."
python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \
--inputdir=./aishell3_alignment_tone \
--output durations.txt \
--config=${config_path}
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# extract features
echo "Extract features ..."
python3 ${BIN_DIR}/preprocess.py \
--dataset=aishell3 \
--rootdir=~/datasets/data_aishell3/ \
--dumpdir=dump \
--dur-file=durations.txt \
--config=${config_path} \
--num-cpu=20 \
--cut-sil=True \
--spk_emb_dir=dump/embed
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# get features' stats(mean and std)
echo "Get features' stats ..."
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="speech"
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="pitch"
python3 ${MAIN_ROOT}/utils/compute_statistics.py \
--metadata=dump/train/raw/metadata.jsonl \
--field-name="energy"
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# normalize and covert phone/speaker to id, dev and test should use train's stats
echo "Normalize ..."
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/train/raw/metadata.jsonl \
--dumpdir=dump/train/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/dev/raw/metadata.jsonl \
--dumpdir=dump/dev/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
python3 ${BIN_DIR}/normalize.py \
--metadata=dump/test/raw/metadata.jsonl \
--dumpdir=dump/test/norm \
--speech-stats=dump/train/speech_stats.npy \
--pitch-stats=dump/train/pitch_stats.npy \
--energy-stats=dump/train/energy_stats.npy \
--phones-dict=dump/phone_id_map.txt \
--speaker-dict=dump/speaker_id_map.txt
fi

@ -0,0 +1,22 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize.py \
--am=fastspeech2_aishell3 \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=pwgan_aishell3 \
--voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/test \
--phones_dict=dump/phone_id_map.txt \
--speaker_dict=dump/speaker_id_map.txt \
--voice-cloning=True

@ -0,0 +1,13 @@
#!/bin/bash
config_path=$1
train_output_path=$2
python3 ${BIN_DIR}/train.py \
--train-metadata=dump/train/norm/metadata.jsonl \
--dev-metadata=dump/dev/norm/metadata.jsonl \
--config=${config_path} \
--output-dir=${train_output_path} \
--ngpu=2 \
--phones-dict=dump/phone_id_map.txt \
--voice-cloning=True

@ -0,0 +1,23 @@
#!/bin/bash
config_path=$1
train_output_path=$2
ckpt_name=$3
ref_audio_dir=$4
FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../voice_cloning.py \
--am=fastspeech2_aishell3 \
--am_config=${config_path} \
--am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \
--am_stat=dump/train/speech_stats.npy \
--voc=pwgan_aishell3 \
--voc_config=pwg_aishell3_ckpt_0.5/default.yaml \
--voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \
--voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \
--text="凯莫瑞安联合体的经济崩溃迫在眉睫。" \
--input-dir=${ref_audio_dir} \
--output-dir=${train_output_path}/vc_syn \
--phones-dict=dump/phone_id_map.txt \
--use_ecapa=True

@ -0,0 +1,13 @@
#!/bin/bash
export MAIN_ROOT=`realpath ${PWD}/../../../`
export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH}
export LC_ALL=C
export PYTHONDONTWRITEBYTECODE=1
# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
MODEL=fastspeech2
export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL}

@ -0,0 +1,39 @@
#!/bin/bash
set -e
source path.sh
gpus=0,1
stage=0
stop_stage=100
conf_path=conf/default.yaml
train_output_path=exp/default
ckpt_name=snapshot_iter_96400.pdz
ref_audio_dir=ref_audio
# with the following command, you can choose the stage range you want to run
# such as `./run.sh --stage 0 --stop-stage 0`
# this can not be mixed use with `$1`, `$2` ...
source ${MAIN_ROOT}/utils/parse_options.sh || exit 1
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# prepare data
CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${conf_path} || exit -1
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
# train model, all `ckpt` under `train_output_path/checkpoints/` dir
CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1
fi
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
# synthesize, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1
fi
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# synthesize, vocoder is pwgan
CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ref_audio_dir} || exit -1
fi

@ -99,8 +99,9 @@ class ASRExecutor(BaseExecutor):
'-y',
action="store_true",
default=False,
help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate'
)
help='No additional parameters required. \
Once set this parameter, it means accepting the request of the program by default, \
which includes transforming the audio sample rate')
self.parser.add_argument(
'--rtf',
action="store_true",
@ -340,7 +341,7 @@ class ASRExecutor(BaseExecutor):
audio = np.round(audio).astype("int16")
return audio
def _check(self, audio_file: str, sample_rate: int, force_yes: bool):
def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False):
self.sample_rate = sample_rate
if self.sample_rate != 16000 and self.sample_rate != 8000:
logger.error(
@ -434,8 +435,17 @@ class ASRExecutor(BaseExecutor):
for id_, input_ in task_source.items():
try:
res = self(input_, model, lang, sample_rate, config, ckpt_path,
decode_method, force_yes, rtf, device)
res = self(
audio_file=input_,
model=model,
lang=lang,
sample_rate=sample_rate,
config=config,
ckpt_path=ckpt_path,
decode_method=decode_method,
force_yes=force_yes,
rtf=rtf,
device=device)
task_results[id_] = res
except Exception as e:
has_exceptions = True

@ -70,6 +70,14 @@ class VectorExecutor(BaseExecutor):
type=str,
default=None,
help="Checkpoint file of model.")
self.parser.add_argument(
'--yes',
'-y',
action="store_true",
default=False,
help='No additional parameters required. \
Once set this parameter, it means accepting the request of the program by default, \
which includes transforming the audio sample rate')
self.parser.add_argument(
'--config',
type=str,
@ -109,6 +117,7 @@ class VectorExecutor(BaseExecutor):
sample_rate = parser_args.sample_rate
config = parser_args.config
ckpt_path = parser_args.ckpt_path
force_yes = parser_args.yes
device = parser_args.device
# stage 1: configurate the verbose flag
@ -128,8 +137,14 @@ class VectorExecutor(BaseExecutor):
# extract the speaker audio embedding
if parser_args.task == "spk":
logger.debug("do vector spk task")
res = self(input_, model, sample_rate, config, ckpt_path,
device)
res = self(
audio_file=input_,
model=model,
sample_rate=sample_rate,
config=config,
ckpt_path=ckpt_path,
force_yes=force_yes,
device=device)
task_result[id_] = res
elif parser_args.task == "score":
logger.debug("do vector score task")
@ -145,10 +160,22 @@ class VectorExecutor(BaseExecutor):
logger.debug(
f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}"
)
enroll_embedding = self(enroll_audio, model, sample_rate,
config, ckpt_path, device)
test_embedding = self(test_audio, model, sample_rate,
config, ckpt_path, device)
enroll_embedding = self(
audio_file=enroll_audio,
model=model,
sample_rate=sample_rate,
config=config,
ckpt_path=ckpt_path,
force_yes=force_yes,
device=device)
test_embedding = self(
audio_file=test_audio,
model=model,
sample_rate=sample_rate,
config=config,
ckpt_path=ckpt_path,
force_yes=force_yes,
device=device)
# get the score
res = self.get_embeddings_score(enroll_embedding,
@ -222,6 +249,7 @@ class VectorExecutor(BaseExecutor):
sample_rate: int=16000,
config: os.PathLike=None,
ckpt_path: os.PathLike=None,
force_yes: bool=False,
device=paddle.get_device()):
"""Extract the audio embedding
@ -240,7 +268,7 @@ class VectorExecutor(BaseExecutor):
"""
# stage 0: check the audio format
audio_file = os.path.abspath(audio_file)
if not self._check(audio_file, sample_rate):
if not self._check(audio_file, sample_rate, force_yes):
sys.exit(-1)
# stage 1: set the paddle runtime host device
@ -418,7 +446,7 @@ class VectorExecutor(BaseExecutor):
logger.debug("audio extract the feat success")
def _check(self, audio_file: str, sample_rate: int):
def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False):
"""Check if the model sample match the audio sample rate
Args:
@ -462,13 +490,34 @@ class VectorExecutor(BaseExecutor):
logger.debug(f"The sample rate is {audio_sample_rate}")
if audio_sample_rate != self.sample_rate:
logger.error("The sample rate of the input file is not {}.\n \
logger.debug("The sample rate of the input file is not {}.\n \
The program will resample the wav file to {}.\n \
If the result does not meet your expectations\n \
Please input the 16k 16 bit 1 channel wav file. \
".format(self.sample_rate, self.sample_rate))
sys.exit(-1)
if force_yes is False:
while (True):
logger.debug(
"Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream."
)
content = input("Input(Y/N):")
if content.strip() == "Y" or content.strip(
) == "y" or content.strip() == "yes" or content.strip(
) == "Yes":
logger.debug(
"change the sampele rate, channel to 16k and 1 channel"
)
break
elif content.strip() == "N" or content.strip(
) == "n" or content.strip() == "no" or content.strip(
) == "No":
logger.debug("Exit the program")
return False
else:
logger.warning("Not regular input, please input again")
self.change_format = True
else:
logger.debug("The audio file format is right")
self.change_format = False
return True

@ -105,7 +105,8 @@ class PaddleVectorConnectionHandler:
# we can not reuse the cache io.BytesIO(audio) data,
# because the soundfile will change the io.BytesIO(audio) to the end
# thus we should convert the base64 string to io.BytesIO when we need the audio data
if not self.executor._check(io.BytesIO(audio), sample_rate):
if not self.executor._check(
io.BytesIO(audio), sample_rate, force_yes=True):
logger.debug("check the audio sample rate occurs error")
return np.array([0.0])

@ -0,0 +1,70 @@
import argparse
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import numpy as np
import tqdm
from paddlespeech.cli.vector import VectorExecutor
def _process_utterance(ifpath: Path,
input_dir: Path,
output_dir: Path,
vec_executor):
rel_path = ifpath.relative_to(input_dir)
ofpath = (output_dir / rel_path).with_suffix(".npy")
ofpath.parent.mkdir(parents=True, exist_ok=True)
embed = vec_executor(audio_file=ifpath, force_yes=True)
np.save(ofpath, embed)
return ofpath
def main(args):
# input output preparation
input_dir = Path(args.input).expanduser()
ifpaths = list(input_dir.rglob(args.pattern))
print(f"{len(ifpaths)} utterances in total")
output_dir = Path(args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
vec_executor = VectorExecutor()
nprocs = args.num_cpu
# warm up
vec_executor(audio_file=ifpaths[0], force_yes=True)
if nprocs == 1:
results = []
for ifpath in tqdm.tqdm(ifpaths, total=len(ifpaths)):
_process_utterance(
ifpath=ifpath,
input_dir=input_dir,
output_dir=output_dir,
vec_executor=vec_executor)
else:
with ThreadPoolExecutor(nprocs) as pool:
with tqdm.tqdm(total=len(ifpaths)) as progress:
for ifpath in ifpaths:
future = pool.submit(_process_utterance, ifpath, input_dir,
output_dir, vec_executor)
future.add_done_callback(lambda p: progress.update())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="compute utterance embed.")
parser.add_argument(
"--input", type=str, help="path of the audio_file folder.")
parser.add_argument(
"--pattern",
type=str,
default="*.wav",
help="pattern to filter audio files.")
parser.add_argument(
"--output",
metavar="OUTPUT_DIR",
help="path to save spk embedding results.")
parser.add_argument(
"--num-cpu", type=int, default=1, help="number of process.")
args = parser.parse_args()
main(args)

@ -21,13 +21,28 @@ import soundfile as sf
import yaml
from yacs.config import CfgNode
from paddlespeech.cli.vector import VectorExecutor
from paddlespeech.t2s.exps.syn_utils import get_am_inference
from paddlespeech.t2s.exps.syn_utils import get_voc_inference
from paddlespeech.t2s.frontend.zh_frontend import Frontend
from paddlespeech.t2s.utils import str2bool
from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor
from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder
def gen_random_embed(use_ecapa: bool=False):
if use_ecapa:
# Randomly generate numbers of -25 ~ 25, 192 is the dim of spk_emb
random_spk_emb = (-1 + 2 * np.random.rand(192)) * 25
# GE2E
else:
# Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb
random_spk_emb = np.random.rand(256) * 0.2
random_spk_emb = paddle.to_tensor(random_spk_emb, dtype='float32')
return random_spk_emb
def voice_cloning(args):
# Init body.
with open(args.am_config) as f:
@ -41,30 +56,47 @@ def voice_cloning(args):
print(am_config)
print(voc_config)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
input_dir = Path(args.input_dir)
# speaker encoder
p = SpeakerVerificationPreprocessor(
sampling_rate=16000,
audio_norm_target_dBFS=-30,
vad_window_length=30,
vad_moving_average_width=8,
vad_max_silence_length=6,
mel_window_length=25,
mel_window_step=10,
n_mels=40,
partial_n_frames=160,
min_pad_coverage=0.75,
partial_overlap_ratio=0.5)
print("Audio Processor Done!")
speaker_encoder = LSTMSpeakerEncoder(
n_mels=40, num_layers=3, hidden_size=256, output_size=256)
speaker_encoder.set_state_dict(paddle.load(args.ge2e_params_path))
speaker_encoder.eval()
print("GE2E Done!")
if args.use_ecapa:
vec_executor = VectorExecutor()
# warm up
vec_executor(
audio_file=input_dir / os.listdir(input_dir)[0], force_yes=True)
print("ECAPA-TDNN Done!")
# use GE2E
else:
p = SpeakerVerificationPreprocessor(
sampling_rate=16000,
audio_norm_target_dBFS=-30,
vad_window_length=30,
vad_moving_average_width=8,
vad_max_silence_length=6,
mel_window_length=25,
mel_window_step=10,
n_mels=40,
partial_n_frames=160,
min_pad_coverage=0.75,
partial_overlap_ratio=0.5)
print("Audio Processor Done!")
speaker_encoder = LSTMSpeakerEncoder(
n_mels=40, num_layers=3, hidden_size=256, output_size=256)
speaker_encoder.set_state_dict(paddle.load(args.ge2e_params_path))
speaker_encoder.eval()
print("GE2E Done!")
frontend = Frontend(phone_vocab_path=args.phones_dict)
print("frontend done!")
sentence = args.text
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
phone_ids = input_ids["phone_ids"][0]
# acoustic model
am_inference = get_am_inference(
am=args.am,
@ -80,26 +112,19 @@ def voice_cloning(args):
voc_ckpt=args.voc_ckpt,
voc_stat=args.voc_stat)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
input_dir = Path(args.input_dir)
sentence = args.text
input_ids = frontend.get_input_ids(sentence, merge_sentences=True)
phone_ids = input_ids["phone_ids"][0]
for name in os.listdir(input_dir):
utt_id = name.split(".")[0]
ref_audio_path = input_dir / name
mel_sequences = p.extract_mel_partials(p.preprocess_wav(ref_audio_path))
# print("mel_sequences: ", mel_sequences.shape)
with paddle.no_grad():
spk_emb = speaker_encoder.embed_utterance(
paddle.to_tensor(mel_sequences))
# print("spk_emb shape: ", spk_emb.shape)
if args.use_ecapa:
spk_emb = vec_executor(audio_file=ref_audio_path, force_yes=True)
spk_emb = paddle.to_tensor(spk_emb)
# GE2E
else:
mel_sequences = p.extract_mel_partials(
p.preprocess_wav(ref_audio_path))
with paddle.no_grad():
spk_emb = speaker_encoder.embed_utterance(
paddle.to_tensor(mel_sequences))
with paddle.no_grad():
wav = voc_inference(am_inference(phone_ids, spk_emb=spk_emb))
@ -108,16 +133,17 @@ def voice_cloning(args):
wav.numpy(),
samplerate=am_config.fs)
print(f"{utt_id} done!")
# Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb
random_spk_emb = np.random.rand(256) * 0.2
random_spk_emb = paddle.to_tensor(random_spk_emb, dtype='float32')
utt_id = "random_spk_emb"
with paddle.no_grad():
wav = voc_inference(am_inference(phone_ids, spk_emb=random_spk_emb))
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
samplerate=am_config.fs)
# generate 5 random_spk_emb
for i in range(5):
random_spk_emb = gen_random_embed(args.use_ecapa)
utt_id = "random_spk_emb"
with paddle.no_grad():
wav = voc_inference(am_inference(phone_ids, spk_emb=random_spk_emb))
sf.write(
str(output_dir / (utt_id + "_" + str(i) + ".wav")),
wav.numpy(),
samplerate=am_config.fs)
print(f"{utt_id} done!")
@ -171,13 +197,15 @@ def parse_args():
type=str,
default="每当你觉得,想要批评什么人的时候,你切要记着,这个世界上的人,并非都具备你禀有的条件。",
help="text to synthesize, a line")
parser.add_argument(
"--ge2e_params_path", type=str, help="ge2e params path.")
parser.add_argument(
"--use_ecapa",
type=str2bool,
default=False,
help="whether to use ECAPA-TDNN as speaker encoder.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu=0, use cpu.")
parser.add_argument(
"--input-dir",
type=str,

Loading…
Cancel
Save