diff --git a/examples/aishell3/README.md b/examples/aishell3/README.md index e022cef4..dd09bdfb 100644 --- a/examples/aishell3/README.md +++ b/examples/aishell3/README.md @@ -10,4 +10,5 @@ * voc3 - MultiBand MelGAN * vc0 - Tacotron2 Voice Cloning with GE2E * vc1 - FastSpeech2 Voice Cloning with GE2E +* vc2 - FastSpeech2 Voice Cloning with ECAPA-TDNN * ernie_sat - ERNIE-SAT diff --git a/examples/aishell3/vc1/README.md b/examples/aishell3/vc1/README.md index aab52510..93e0fd7e 100644 --- a/examples/aishell3/vc1/README.md +++ b/examples/aishell3/vc1/README.md @@ -99,7 +99,7 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_p The synthesizing step is very similar to that one of [tts3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3), but we should set `--voice-cloning=True` when calling `${BIN_DIR}/../synthesize.py`. ### Voice Cloning -Assume there are some reference audios in `./ref_audio` +Assume there are some reference audios in `./ref_audio` ```text ref_audio ├── 001238.wav @@ -116,7 +116,7 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_outpu Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss :-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------: -default|2(gpu) x 96400|0.99699|0.62013|0.53057|0.11954| 0.20426| +default|2(gpu) x 96400|0.99699|0.62013|0.053057|0.11954| 0.20426| FastSpeech2 checkpoint contains files listed below. (There is no need for `speaker_id_map.txt` here ) diff --git a/examples/aishell3/vc2/README.md b/examples/aishell3/vc2/README.md new file mode 100644 index 00000000..77482367 --- /dev/null +++ b/examples/aishell3/vc2/README.md @@ -0,0 +1,126 @@ +# FastSpeech2 + AISHELL-3 Voice Cloning (ECAPA-TDNN) +This example contains code used to train a [FastSpeech2](https://arxiv.org/abs/2006.04558) model with [AISHELL-3](http://www.aishelltech.com/aishell_3). The trained model can be used in Voice Cloning Task, We refer to the model structure of [Transfer Learning from Speaker Verification to Multispeaker Text-To-Speech Synthesis](https://arxiv.org/pdf/1806.04558.pdf). The general steps are as follows: +1. Speaker Encoder: We use Speaker Verification to train a speaker encoder. Datasets used in this task are different from those used in `FastSpeech2` because the transcriptions are not needed, we use more datasets, refer to [ECAPA-TDNN](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/voxceleb/sv0). +2. Synthesizer: We use the trained speaker encoder to generate speaker embedding for each sentence in AISHELL-3. This embedding is an extra input of `FastSpeech2` which will be concated with encoder outputs. +3. Vocoder: We use [Parallel Wave GAN](http://arxiv.org/abs/1910.11480) as the neural Vocoder, refer to [voc1](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1). + +## Dataset +### Download and Extract +Download AISHELL-3 from it's [Official Website](http://www.aishelltech.com/aishell_3) and extract it to `~/datasets`. Then the dataset is in the directory `~/datasets/data_aishell3`. + +### Get MFA Result and Extract +We use [MFA2.x](https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner) to get durations for aishell3_fastspeech2. +You can download from here [aishell3_alignment_tone.tar.gz](https://paddlespeech.bj.bcebos.com/MFA/AISHELL-3/with_tone/aishell3_alignment_tone.tar.gz), or train your MFA model reference to [mfa example](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/other/mfa) (use MFA1.x now) of our repo. + +## Get Started +Assume the path to the dataset is `~/datasets/data_aishell3`. +Assume the path to the MFA result of AISHELL-3 is `./aishell3_alignment_tone`. + +Run the command below to +1. **source path**. +2. preprocess the dataset. +3. train the model. +4. synthesize waveform from `metadata.jsonl`. +5. start a voice cloning inference. +```bash +./run.sh +``` +You can choose a range of stages you want to run, or set `stage` equal to `stop-stage` to use only one stage, for example, running the following command will only preprocess the dataset. +```bash +./run.sh --stage 0 --stop-stage 0 +``` +### Data Preprocessing +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${conf_path} +``` +When it is done. A `dump` folder is created in the current directory. The structure of the dump folder is listed below. +```text +dump +├── dev +│ ├── norm +│ └── raw +├── embed +│ ├── SSB0005 +│ ├── SSB0009 +│ ├── ... +│ └── ... +├── phone_id_map.txt +├── speaker_id_map.txt +├── test +│ ├── norm +│ └── raw +└── train + ├── energy_stats.npy + ├── norm + ├── pitch_stats.npy + ├── raw + └── speech_stats.npy +``` +The `embed` contains the generated speaker embedding for each sentence in AISHELL-3, which has the same file structure with wav files and the format is `.npy`. + +The computing time of utterance embedding can be x hours. + +The dataset is split into 3 parts, namely `train`, `dev`, and` test`, each of which contains a `norm` and `raw` subfolder. The raw folder contains speech、pitch and energy features of each utterance, while the norm folder contains normalized ones. The statistics used to normalize features are computed from the training set, which is located in `dump/train/*_stats.npy`. + +Also, there is a `metadata.jsonl` in each subfolder. It is a table-like file that contains phones, text_lengths, speech_lengths, durations, the path of speech features, the path of pitch features, the path of energy features, speaker, and id of each utterance. + +The preprocessing step is very similar to that one of [tts3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3), but there is one more `ECAPA-TDNN/inference` step here. + +### Model Training +`./local/train.sh` calls `${BIN_DIR}/train.py`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} +``` +The training step is very similar to that one of [tts3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3), but we should set `--voice-cloning=True` when calling `${BIN_DIR}/train.py`. + +### Synthesizing +We use [parallel wavegan](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/voc1) as the neural vocoder. +Download pretrained parallel wavegan model from [pwg_aishell3_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/pwgan/pwg_aishell3_ckpt_0.5.zip) and unzip it. +```bash +unzip pwg_aishell3_ckpt_0.5.zip +``` +Parallel WaveGAN checkpoint contains files listed below. +```text +pwg_aishell3_ckpt_0.5 +├── default.yaml # default config used to train parallel wavegan +├── feats_stats.npy # statistics used to normalize spectrogram when training parallel wavegan +└── snapshot_iter_1000000.pdz # generator parameters of parallel wavegan +``` +`./local/synthesize.sh` calls `${BIN_DIR}/../synthesize.py`, which can synthesize waveform from `metadata.jsonl`. +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} +``` +The synthesizing step is very similar to that one of [tts3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3), but we should set `--voice-cloning=True` when calling `${BIN_DIR}/../synthesize.py`. + +### Voice Cloning +Assume there are some reference audios in `./ref_audio` (the format must be wav here) +```text +ref_audio +├── 001238.wav +├── LJ015-0254.wav +└── audio_self_test.wav +``` +`./local/voice_cloning.sh` calls `${BIN_DIR}/../voice_cloning.py` + +```bash +CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ref_audio_dir} +``` +## Pretrained Model +- [fastspeech2_aishell3_ckpt_vc2_1.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_aishell3_ckpt_vc2_1.2.0.zip) + +Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/pitch_loss| eval/energy_loss +:-------------:| :------------:| :-----: | :-----: | :--------: |:--------:|:---------: +default|2(gpu) x 96400|0.991855|0.599517|0.052142|0.094877| 0.245318| + +FastSpeech2 checkpoint contains files listed below. +(There is no need for `speaker_id_map.txt` here ) + +```text +fastspeech2_aishell3_ckpt_vc2_1.2.0 +├── default.yaml # default config used to train fastspeech2 +├── energy_stats.npy # statistics used to normalize energy when training fastspeech2 +├── phone_id_map.txt # phone vocabulary file when training fastspeech2 +├── pitch_stats.npy # statistics used to normalize pitch when training fastspeech2 +├── snapshot_iter_96400.pdz # model parameters and optimizer states +└── speech_stats.npy # statistics used to normalize spectrogram when training fastspeech2 +``` diff --git a/examples/aishell3/vc2/conf/default.yaml b/examples/aishell3/vc2/conf/default.yaml new file mode 100644 index 00000000..5ef37f81 --- /dev/null +++ b/examples/aishell3/vc2/conf/default.yaml @@ -0,0 +1,104 @@ +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### + +fs: 24000 # sr +n_fft: 2048 # FFT size (samples). +n_shift: 300 # Hop size (samples). 12.5ms +win_length: 1200 # Window length (samples). 50ms + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. + +# Only used for feats_type != raw + +fmin: 80 # Minimum frequency of Mel basis. +fmax: 7600 # Maximum frequency of Mel basis. +n_mels: 80 # The number of mel basis. + +# Only used for the model using pitch features (e.g. FastSpeech2) +f0min: 80 # Minimum f0 for pitch extraction. +f0max: 400 # Maximum f0 for pitch extraction. + + +########################################################### +# DATA SETTING # +########################################################### +batch_size: 64 +num_workers: 2 + + +########################################################### +# MODEL SETTING # +########################################################### +model: + adim: 384 # attention dimension + aheads: 2 # number of attention heads + elayers: 4 # number of encoder layers + eunits: 1536 # number of encoder ff units + dlayers: 4 # number of decoder layers + dunits: 1536 # number of decoder ff units + positionwise_layer_type: conv1d # type of position-wise layer + positionwise_conv_kernel_size: 3 # kernel size of position wise conv layer + duration_predictor_layers: 2 # number of layers of duration predictor + duration_predictor_chans: 256 # number of channels of duration predictor + duration_predictor_kernel_size: 3 # filter size of duration predictor + postnet_layers: 5 # number of layers of postnset + postnet_filts: 5 # filter size of conv layers in postnet + postnet_chans: 256 # number of channels of conv layers in postnet + use_scaled_pos_enc: True # whether to use scaled positional encoding + encoder_normalize_before: True # whether to perform layer normalization before the input + decoder_normalize_before: True # whether to perform layer normalization before the input + reduction_factor: 1 # reduction factor + init_type: xavier_uniform # initialization type + init_enc_alpha: 1.0 # initial value of alpha of encoder scaled position encoding + init_dec_alpha: 1.0 # initial value of alpha of decoder scaled position encoding + transformer_enc_dropout_rate: 0.2 # dropout rate for transformer encoder layer + transformer_enc_positional_dropout_rate: 0.2 # dropout rate for transformer encoder positional encoding + transformer_enc_attn_dropout_rate: 0.2 # dropout rate for transformer encoder attention layer + transformer_dec_dropout_rate: 0.2 # dropout rate for transformer decoder layer + transformer_dec_positional_dropout_rate: 0.2 # dropout rate for transformer decoder positional encoding + transformer_dec_attn_dropout_rate: 0.2 # dropout rate for transformer decoder attention layer + pitch_predictor_layers: 5 # number of conv layers in pitch predictor + pitch_predictor_chans: 256 # number of channels of conv layers in pitch predictor + pitch_predictor_kernel_size: 5 # kernel size of conv leyers in pitch predictor + pitch_predictor_dropout: 0.5 # dropout rate in pitch predictor + pitch_embed_kernel_size: 1 # kernel size of conv embedding layer for pitch + pitch_embed_dropout: 0.0 # dropout rate after conv embedding layer for pitch + stop_gradient_from_pitch_predictor: True # whether to stop the gradient from pitch predictor to encoder + energy_predictor_layers: 2 # number of conv layers in energy predictor + energy_predictor_chans: 256 # number of channels of conv layers in energy predictor + energy_predictor_kernel_size: 3 # kernel size of conv leyers in energy predictor + energy_predictor_dropout: 0.5 # dropout rate in energy predictor + energy_embed_kernel_size: 1 # kernel size of conv embedding layer for energy + energy_embed_dropout: 0.0 # dropout rate after conv embedding layer for energy + stop_gradient_from_energy_predictor: False # whether to stop the gradient from energy predictor to encoder + spk_embed_dim: 192 # speaker embedding dimension + spk_embed_integration_type: concat # speaker embedding integration type + + + +########################################################### +# UPDATER SETTING # +########################################################### +updater: + use_masking: True # whether to apply masking for padded part in loss calculation + + +########################################################### +# OPTIMIZER SETTING # +########################################################### +optimizer: + optim: adam # optimizer type + learning_rate: 0.001 # learning rate + +########################################################### +# TRAINING SETTING # +########################################################### +max_epoch: 200 +num_snapshots: 5 + + +########################################################### +# OTHER SETTING # +########################################################### +seed: 10086 diff --git a/examples/aishell3/vc2/local/preprocess.sh b/examples/aishell3/vc2/local/preprocess.sh new file mode 100755 index 00000000..f5262a26 --- /dev/null +++ b/examples/aishell3/vc2/local/preprocess.sh @@ -0,0 +1,85 @@ +#!/bin/bash + +stage=0 +stop_stage=100 + +config_path=$1 + +# gen speaker embedding +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/vc2_infer.py \ + --input=~/datasets/data_aishell3/train/wav/ \ + --output=dump/embed \ + --num-cpu=20 +fi + +# copy from tts3/preprocess +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # get durations from MFA's result + echo "Generate durations.txt from MFA results ..." + python3 ${MAIN_ROOT}/utils/gen_duration_from_textgrid.py \ + --inputdir=./aishell3_alignment_tone \ + --output durations.txt \ + --config=${config_path} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # extract features + echo "Extract features ..." + python3 ${BIN_DIR}/preprocess.py \ + --dataset=aishell3 \ + --rootdir=~/datasets/data_aishell3/ \ + --dumpdir=dump \ + --dur-file=durations.txt \ + --config=${config_path} \ + --num-cpu=20 \ + --cut-sil=True \ + --spk_emb_dir=dump/embed +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # get features' stats(mean and std) + echo "Get features' stats ..." + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="speech" + + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="pitch" + + python3 ${MAIN_ROOT}/utils/compute_statistics.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --field-name="energy" +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + # normalize and covert phone/speaker to id, dev and test should use train's stats + echo "Normalize ..." + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/train/raw/metadata.jsonl \ + --dumpdir=dump/train/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/dev/raw/metadata.jsonl \ + --dumpdir=dump/dev/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt + + python3 ${BIN_DIR}/normalize.py \ + --metadata=dump/test/raw/metadata.jsonl \ + --dumpdir=dump/test/norm \ + --speech-stats=dump/train/speech_stats.npy \ + --pitch-stats=dump/train/pitch_stats.npy \ + --energy-stats=dump/train/energy_stats.npy \ + --phones-dict=dump/phone_id_map.txt \ + --speaker-dict=dump/speaker_id_map.txt +fi diff --git a/examples/aishell3/vc2/local/synthesize.sh b/examples/aishell3/vc2/local/synthesize.sh new file mode 100755 index 00000000..8c61e3f3 --- /dev/null +++ b/examples/aishell3/vc2/local/synthesize.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/../synthesize.py \ + --am=fastspeech2_aishell3 \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_aishell3 \ + --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/test \ + --phones_dict=dump/phone_id_map.txt \ + --speaker_dict=dump/speaker_id_map.txt \ + --voice-cloning=True diff --git a/examples/aishell3/vc2/local/train.sh b/examples/aishell3/vc2/local/train.sh new file mode 100755 index 00000000..c775fcad --- /dev/null +++ b/examples/aishell3/vc2/local/train.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 + +python3 ${BIN_DIR}/train.py \ + --train-metadata=dump/train/norm/metadata.jsonl \ + --dev-metadata=dump/dev/norm/metadata.jsonl \ + --config=${config_path} \ + --output-dir=${train_output_path} \ + --ngpu=2 \ + --phones-dict=dump/phone_id_map.txt \ + --voice-cloning=True \ No newline at end of file diff --git a/examples/aishell3/vc2/local/voice_cloning.sh b/examples/aishell3/vc2/local/voice_cloning.sh new file mode 100755 index 00000000..09c5e436 --- /dev/null +++ b/examples/aishell3/vc2/local/voice_cloning.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +config_path=$1 +train_output_path=$2 +ckpt_name=$3 +ref_audio_dir=$4 + +FLAGS_allocator_strategy=naive_best_fit \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python3 ${BIN_DIR}/../voice_cloning.py \ + --am=fastspeech2_aishell3 \ + --am_config=${config_path} \ + --am_ckpt=${train_output_path}/checkpoints/${ckpt_name} \ + --am_stat=dump/train/speech_stats.npy \ + --voc=pwgan_aishell3 \ + --voc_config=pwg_aishell3_ckpt_0.5/default.yaml \ + --voc_ckpt=pwg_aishell3_ckpt_0.5/snapshot_iter_1000000.pdz \ + --voc_stat=pwg_aishell3_ckpt_0.5/feats_stats.npy \ + --text="凯莫瑞安联合体的经济崩溃迫在眉睫。" \ + --input-dir=${ref_audio_dir} \ + --output-dir=${train_output_path}/vc_syn \ + --phones-dict=dump/phone_id_map.txt \ + --use_ecapa=True diff --git a/examples/aishell3/vc2/path.sh b/examples/aishell3/vc2/path.sh new file mode 100755 index 00000000..fb7e8411 --- /dev/null +++ b/examples/aishell3/vc2/path.sh @@ -0,0 +1,13 @@ +#!/bin/bash +export MAIN_ROOT=`realpath ${PWD}/../../../` + +export PATH=${MAIN_ROOT}:${MAIN_ROOT}/utils:${PATH} +export LC_ALL=C + +export PYTHONDONTWRITEBYTECODE=1 +# Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C +export PYTHONIOENCODING=UTF-8 +export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} + +MODEL=fastspeech2 +export BIN_DIR=${MAIN_ROOT}/paddlespeech/t2s/exps/${MODEL} diff --git a/examples/aishell3/vc2/run.sh b/examples/aishell3/vc2/run.sh new file mode 100755 index 00000000..06d56298 --- /dev/null +++ b/examples/aishell3/vc2/run.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +set -e +source path.sh + +gpus=0,1 +stage=0 +stop_stage=100 + +conf_path=conf/default.yaml +train_output_path=exp/default +ckpt_name=snapshot_iter_96400.pdz +ref_audio_dir=ref_audio + + +# with the following command, you can choose the stage range you want to run +# such as `./run.sh --stage 0 --stop-stage 0` +# this can not be mixed use with `$1`, `$2` ... +source ${MAIN_ROOT}/utils/parse_options.sh || exit 1 + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + # prepare data + CUDA_VISIBLE_DEVICES=${gpus} ./local/preprocess.sh ${conf_path} || exit -1 +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + # train model, all `ckpt` under `train_output_path/checkpoints/` dir + CUDA_VISIBLE_DEVICES=${gpus} ./local/train.sh ${conf_path} ${train_output_path} || exit -1 +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/synthesize.sh ${conf_path} ${train_output_path} ${ckpt_name} || exit -1 +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + # synthesize, vocoder is pwgan + CUDA_VISIBLE_DEVICES=${gpus} ./local/voice_cloning.sh ${conf_path} ${train_output_path} ${ckpt_name} ${ref_audio_dir} || exit -1 +fi diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index f9b4439e..7296776f 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -99,8 +99,9 @@ class ASRExecutor(BaseExecutor): '-y', action="store_true", default=False, - help='No additional parameters required. Once set this parameter, it means accepting the request of the program by default, which includes transforming the audio sample rate' - ) + help='No additional parameters required. \ + Once set this parameter, it means accepting the request of the program by default, \ + which includes transforming the audio sample rate') self.parser.add_argument( '--rtf', action="store_true", @@ -340,7 +341,7 @@ class ASRExecutor(BaseExecutor): audio = np.round(audio).astype("int16") return audio - def _check(self, audio_file: str, sample_rate: int, force_yes: bool): + def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False): self.sample_rate = sample_rate if self.sample_rate != 16000 and self.sample_rate != 8000: logger.error( @@ -434,8 +435,17 @@ class ASRExecutor(BaseExecutor): for id_, input_ in task_source.items(): try: - res = self(input_, model, lang, sample_rate, config, ckpt_path, - decode_method, force_yes, rtf, device) + res = self( + audio_file=input_, + model=model, + lang=lang, + sample_rate=sample_rate, + config=config, + ckpt_path=ckpt_path, + decode_method=decode_method, + force_yes=force_yes, + rtf=rtf, + device=device) task_results[id_] = res except Exception as e: has_exceptions = True diff --git a/paddlespeech/cli/vector/infer.py b/paddlespeech/cli/vector/infer.py index 48ca1f98..11198724 100644 --- a/paddlespeech/cli/vector/infer.py +++ b/paddlespeech/cli/vector/infer.py @@ -70,6 +70,14 @@ class VectorExecutor(BaseExecutor): type=str, default=None, help="Checkpoint file of model.") + self.parser.add_argument( + '--yes', + '-y', + action="store_true", + default=False, + help='No additional parameters required. \ + Once set this parameter, it means accepting the request of the program by default, \ + which includes transforming the audio sample rate') self.parser.add_argument( '--config', type=str, @@ -109,6 +117,7 @@ class VectorExecutor(BaseExecutor): sample_rate = parser_args.sample_rate config = parser_args.config ckpt_path = parser_args.ckpt_path + force_yes = parser_args.yes device = parser_args.device # stage 1: configurate the verbose flag @@ -128,8 +137,14 @@ class VectorExecutor(BaseExecutor): # extract the speaker audio embedding if parser_args.task == "spk": logger.debug("do vector spk task") - res = self(input_, model, sample_rate, config, ckpt_path, - device) + res = self( + audio_file=input_, + model=model, + sample_rate=sample_rate, + config=config, + ckpt_path=ckpt_path, + force_yes=force_yes, + device=device) task_result[id_] = res elif parser_args.task == "score": logger.debug("do vector score task") @@ -145,10 +160,22 @@ class VectorExecutor(BaseExecutor): logger.debug( f"score task, enroll audio: {enroll_audio}, test audio: {test_audio}" ) - enroll_embedding = self(enroll_audio, model, sample_rate, - config, ckpt_path, device) - test_embedding = self(test_audio, model, sample_rate, - config, ckpt_path, device) + enroll_embedding = self( + audio_file=enroll_audio, + model=model, + sample_rate=sample_rate, + config=config, + ckpt_path=ckpt_path, + force_yes=force_yes, + device=device) + test_embedding = self( + audio_file=test_audio, + model=model, + sample_rate=sample_rate, + config=config, + ckpt_path=ckpt_path, + force_yes=force_yes, + device=device) # get the score res = self.get_embeddings_score(enroll_embedding, @@ -222,6 +249,7 @@ class VectorExecutor(BaseExecutor): sample_rate: int=16000, config: os.PathLike=None, ckpt_path: os.PathLike=None, + force_yes: bool=False, device=paddle.get_device()): """Extract the audio embedding @@ -240,7 +268,7 @@ class VectorExecutor(BaseExecutor): """ # stage 0: check the audio format audio_file = os.path.abspath(audio_file) - if not self._check(audio_file, sample_rate): + if not self._check(audio_file, sample_rate, force_yes): sys.exit(-1) # stage 1: set the paddle runtime host device @@ -418,7 +446,7 @@ class VectorExecutor(BaseExecutor): logger.debug("audio extract the feat success") - def _check(self, audio_file: str, sample_rate: int): + def _check(self, audio_file: str, sample_rate: int, force_yes: bool=False): """Check if the model sample match the audio sample rate Args: @@ -462,13 +490,34 @@ class VectorExecutor(BaseExecutor): logger.debug(f"The sample rate is {audio_sample_rate}") if audio_sample_rate != self.sample_rate: - logger.error("The sample rate of the input file is not {}.\n \ + logger.debug("The sample rate of the input file is not {}.\n \ The program will resample the wav file to {}.\n \ If the result does not meet your expectations,\n \ Please input the 16k 16 bit 1 channel wav file. \ ".format(self.sample_rate, self.sample_rate)) - sys.exit(-1) + if force_yes is False: + while (True): + logger.debug( + "Whether to change the sample rate and the channel. Y: change the sample. N: exit the prgream." + ) + content = input("Input(Y/N):") + if content.strip() == "Y" or content.strip( + ) == "y" or content.strip() == "yes" or content.strip( + ) == "Yes": + logger.debug( + "change the sampele rate, channel to 16k and 1 channel" + ) + break + elif content.strip() == "N" or content.strip( + ) == "n" or content.strip() == "no" or content.strip( + ) == "No": + logger.debug("Exit the program") + return False + else: + logger.warning("Not regular input, please input again") + self.change_format = True else: logger.debug("The audio file format is right") + self.change_format = False return True diff --git a/paddlespeech/server/engine/vector/python/vector_engine.py b/paddlespeech/server/engine/vector/python/vector_engine.py index f7d60648..7b8f667d 100644 --- a/paddlespeech/server/engine/vector/python/vector_engine.py +++ b/paddlespeech/server/engine/vector/python/vector_engine.py @@ -105,7 +105,8 @@ class PaddleVectorConnectionHandler: # we can not reuse the cache io.BytesIO(audio) data, # because the soundfile will change the io.BytesIO(audio) to the end # thus we should convert the base64 string to io.BytesIO when we need the audio data - if not self.executor._check(io.BytesIO(audio), sample_rate): + if not self.executor._check( + io.BytesIO(audio), sample_rate, force_yes=True): logger.debug("check the audio sample rate occurs error") return np.array([0.0]) diff --git a/paddlespeech/t2s/exps/fastspeech2/vc2_infer.py b/paddlespeech/t2s/exps/fastspeech2/vc2_infer.py new file mode 100644 index 00000000..3d0a8366 --- /dev/null +++ b/paddlespeech/t2s/exps/fastspeech2/vc2_infer.py @@ -0,0 +1,70 @@ +import argparse +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import numpy as np +import tqdm + +from paddlespeech.cli.vector import VectorExecutor + + +def _process_utterance(ifpath: Path, + input_dir: Path, + output_dir: Path, + vec_executor): + rel_path = ifpath.relative_to(input_dir) + ofpath = (output_dir / rel_path).with_suffix(".npy") + ofpath.parent.mkdir(parents=True, exist_ok=True) + embed = vec_executor(audio_file=ifpath, force_yes=True) + np.save(ofpath, embed) + return ofpath + + +def main(args): + # input output preparation + input_dir = Path(args.input).expanduser() + ifpaths = list(input_dir.rglob(args.pattern)) + print(f"{len(ifpaths)} utterances in total") + output_dir = Path(args.output).expanduser() + output_dir.mkdir(parents=True, exist_ok=True) + vec_executor = VectorExecutor() + nprocs = args.num_cpu + + # warm up + vec_executor(audio_file=ifpaths[0], force_yes=True) + + if nprocs == 1: + results = [] + for ifpath in tqdm.tqdm(ifpaths, total=len(ifpaths)): + _process_utterance( + ifpath=ifpath, + input_dir=input_dir, + output_dir=output_dir, + vec_executor=vec_executor) + else: + with ThreadPoolExecutor(nprocs) as pool: + with tqdm.tqdm(total=len(ifpaths)) as progress: + for ifpath in ifpaths: + future = pool.submit(_process_utterance, ifpath, input_dir, + output_dir, vec_executor) + future.add_done_callback(lambda p: progress.update()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="compute utterance embed.") + parser.add_argument( + "--input", type=str, help="path of the audio_file folder.") + parser.add_argument( + "--pattern", + type=str, + default="*.wav", + help="pattern to filter audio files.") + parser.add_argument( + "--output", + metavar="OUTPUT_DIR", + help="path to save spk embedding results.") + parser.add_argument( + "--num-cpu", type=int, default=1, help="number of process.") + args = parser.parse_args() + + main(args) diff --git a/paddlespeech/t2s/exps/voice_cloning.py b/paddlespeech/t2s/exps/voice_cloning.py index b51a4d7b..80cfea4a 100644 --- a/paddlespeech/t2s/exps/voice_cloning.py +++ b/paddlespeech/t2s/exps/voice_cloning.py @@ -21,13 +21,28 @@ import soundfile as sf import yaml from yacs.config import CfgNode +from paddlespeech.cli.vector import VectorExecutor from paddlespeech.t2s.exps.syn_utils import get_am_inference from paddlespeech.t2s.exps.syn_utils import get_voc_inference from paddlespeech.t2s.frontend.zh_frontend import Frontend +from paddlespeech.t2s.utils import str2bool from paddlespeech.vector.exps.ge2e.audio_processor import SpeakerVerificationPreprocessor from paddlespeech.vector.models.lstm_speaker_encoder import LSTMSpeakerEncoder +def gen_random_embed(use_ecapa: bool=False): + if use_ecapa: + # Randomly generate numbers of -25 ~ 25, 192 is the dim of spk_emb + random_spk_emb = (-1 + 2 * np.random.rand(192)) * 25 + + # GE2E + else: + # Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb + random_spk_emb = np.random.rand(256) * 0.2 + random_spk_emb = paddle.to_tensor(random_spk_emb, dtype='float32') + return random_spk_emb + + def voice_cloning(args): # Init body. with open(args.am_config) as f: @@ -41,30 +56,47 @@ def voice_cloning(args): print(am_config) print(voc_config) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + input_dir = Path(args.input_dir) + # speaker encoder - p = SpeakerVerificationPreprocessor( - sampling_rate=16000, - audio_norm_target_dBFS=-30, - vad_window_length=30, - vad_moving_average_width=8, - vad_max_silence_length=6, - mel_window_length=25, - mel_window_step=10, - n_mels=40, - partial_n_frames=160, - min_pad_coverage=0.75, - partial_overlap_ratio=0.5) - print("Audio Processor Done!") - - speaker_encoder = LSTMSpeakerEncoder( - n_mels=40, num_layers=3, hidden_size=256, output_size=256) - speaker_encoder.set_state_dict(paddle.load(args.ge2e_params_path)) - speaker_encoder.eval() - print("GE2E Done!") + if args.use_ecapa: + vec_executor = VectorExecutor() + # warm up + vec_executor( + audio_file=input_dir / os.listdir(input_dir)[0], force_yes=True) + print("ECAPA-TDNN Done!") + # use GE2E + else: + p = SpeakerVerificationPreprocessor( + sampling_rate=16000, + audio_norm_target_dBFS=-30, + vad_window_length=30, + vad_moving_average_width=8, + vad_max_silence_length=6, + mel_window_length=25, + mel_window_step=10, + n_mels=40, + partial_n_frames=160, + min_pad_coverage=0.75, + partial_overlap_ratio=0.5) + print("Audio Processor Done!") + + speaker_encoder = LSTMSpeakerEncoder( + n_mels=40, num_layers=3, hidden_size=256, output_size=256) + speaker_encoder.set_state_dict(paddle.load(args.ge2e_params_path)) + speaker_encoder.eval() + print("GE2E Done!") frontend = Frontend(phone_vocab_path=args.phones_dict) print("frontend done!") + sentence = args.text + input_ids = frontend.get_input_ids(sentence, merge_sentences=True) + phone_ids = input_ids["phone_ids"][0] + # acoustic model am_inference = get_am_inference( am=args.am, @@ -80,26 +112,19 @@ def voice_cloning(args): voc_ckpt=args.voc_ckpt, voc_stat=args.voc_stat) - output_dir = Path(args.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - input_dir = Path(args.input_dir) - - sentence = args.text - - input_ids = frontend.get_input_ids(sentence, merge_sentences=True) - phone_ids = input_ids["phone_ids"][0] - for name in os.listdir(input_dir): utt_id = name.split(".")[0] ref_audio_path = input_dir / name - mel_sequences = p.extract_mel_partials(p.preprocess_wav(ref_audio_path)) - # print("mel_sequences: ", mel_sequences.shape) - with paddle.no_grad(): - spk_emb = speaker_encoder.embed_utterance( - paddle.to_tensor(mel_sequences)) - # print("spk_emb shape: ", spk_emb.shape) - + if args.use_ecapa: + spk_emb = vec_executor(audio_file=ref_audio_path, force_yes=True) + spk_emb = paddle.to_tensor(spk_emb) + # GE2E + else: + mel_sequences = p.extract_mel_partials( + p.preprocess_wav(ref_audio_path)) + with paddle.no_grad(): + spk_emb = speaker_encoder.embed_utterance( + paddle.to_tensor(mel_sequences)) with paddle.no_grad(): wav = voc_inference(am_inference(phone_ids, spk_emb=spk_emb)) @@ -108,16 +133,17 @@ def voice_cloning(args): wav.numpy(), samplerate=am_config.fs) print(f"{utt_id} done!") - # Randomly generate numbers of 0 ~ 0.2, 256 is the dim of spk_emb - random_spk_emb = np.random.rand(256) * 0.2 - random_spk_emb = paddle.to_tensor(random_spk_emb, dtype='float32') - utt_id = "random_spk_emb" - with paddle.no_grad(): - wav = voc_inference(am_inference(phone_ids, spk_emb=random_spk_emb)) - sf.write( - str(output_dir / (utt_id + ".wav")), - wav.numpy(), - samplerate=am_config.fs) + + # generate 5 random_spk_emb + for i in range(5): + random_spk_emb = gen_random_embed(args.use_ecapa) + utt_id = "random_spk_emb" + with paddle.no_grad(): + wav = voc_inference(am_inference(phone_ids, spk_emb=random_spk_emb)) + sf.write( + str(output_dir / (utt_id + "_" + str(i) + ".wav")), + wav.numpy(), + samplerate=am_config.fs) print(f"{utt_id} done!") @@ -171,13 +197,15 @@ def parse_args(): type=str, default="每当你觉得,想要批评什么人的时候,你切要记着,这个世界上的人,并非都具备你禀有的条件。", help="text to synthesize, a line") - parser.add_argument( "--ge2e_params_path", type=str, help="ge2e params path.") - + parser.add_argument( + "--use_ecapa", + type=str2bool, + default=False, + help="whether to use ECAPA-TDNN as speaker encoder.") parser.add_argument( "--ngpu", type=int, default=1, help="if ngpu=0, use cpu.") - parser.add_argument( "--input-dir", type=str,