diff --git a/docs/source/released_model.md b/docs/source/released_model.md index 1cbe3989..a7d00f24 100644 --- a/docs/source/released_model.md +++ b/docs/source/released_model.md @@ -37,7 +37,7 @@ Model Type | Dataset| Example Link | Pretrained Models|Static Models|Size (stati Tacotron2|LJSpeech|[tacotron2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip)||| Tacotron2|CSMSC|[tacotron2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts0)|[tacotron2_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip)|[tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip)|103MB| TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)||| -SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2) |[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip)|[speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)|12MB| +SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2)|[speedyspeech_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip)|[speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)|12MB| FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip)|157MB| FastSpeech2-Conformer| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip)||| FastSpeech2| AISHELL-3 |[fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3)|[fastspeech2_nosil_aishell3_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip)||| diff --git a/examples/csmsc/tts2/README.md b/examples/csmsc/tts2/README.md index 4fbe34cb..081d8584 100644 --- a/examples/csmsc/tts2/README.md +++ b/examples/csmsc/tts2/README.md @@ -223,22 +223,28 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} ## Pretrained Model Pretrained SpeedySpeech model with no silence in the edge of audios: - [speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip) +- [speedyspeech_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip) The static model can be downloaded here: - [speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip) - [speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip) +The ONNX model can be downloaded here: +- [speedyspeech_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_onnx_0.2.0.zip) + + Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/ssim_loss :-------------:| :------------:| :-----: | :-----: | :--------:|:--------: -default| 1(gpu) x 11400|0.83655|0.42324|0.03211| 0.38119 +default| 1(gpu) x 11400|0.79532|0.400246|0.030259| 0.36482 SpeedySpeech checkpoint contains files listed below. + ```text -speedyspeech_nosil_baker_ckpt_0.5 +speedyspeech_csmsc_ckpt_0.2.0 ├── default.yaml # default config used to train speedyspeech ├── feats_stats.npy # statistics used to normalize spectrogram when training speedyspeech ├── phone_id_map.txt # phone vocabulary file when training speedyspeech -├── snapshot_iter_11400.pdz # model parameters and optimizer states +├── snapshot_iter_30600.pdz # model parameters and optimizer states └── tone_id_map.txt # tone vocabulary file when training speedyspeech ``` You can use the following scripts to synthesize for `${BIN_DIR}/../sentences.txt` using pretrained speedyspeech and parallel wavegan models. @@ -249,9 +255,9 @@ FLAGS_allocator_strategy=naive_best_fit \ FLAGS_fraction_of_gpu_memory_to_use=0.01 \ python3 ${BIN_DIR}/../synthesize_e2e.py \ --am=speedyspeech_csmsc \ - --am_config=speedyspeech_nosil_baker_ckpt_0.5/default.yaml \ - --am_ckpt=speedyspeech_nosil_baker_ckpt_0.5/snapshot_iter_11400.pdz \ - --am_stat=speedyspeech_nosil_baker_ckpt_0.5/feats_stats.npy \ + --am_config=speedyspeech_csmsc_ckpt_0.2.0/default.yaml \ + --am_ckpt=speedyspeech_csmsc_ckpt_0.2.0/snapshot_iter_30600.pdz \ + --am_stat=speedyspeech_csmsc_ckpt_0.2.0/feats_stats.npy \ --voc=pwgan_csmsc \ --voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \ --voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \ @@ -260,6 +266,6 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \ --text=${BIN_DIR}/../sentences.txt \ --output_dir=exp/default/test_e2e \ --inference_dir=exp/default/inference \ - --phones_dict=speedyspeech_nosil_baker_ckpt_0.5/phone_id_map.txt \ - --tones_dict=speedyspeech_nosil_baker_ckpt_0.5/tone_id_map.txt + --phones_dict=speedyspeech_csmsc_ckpt_0.2.0/phone_id_map.txt \ + --tones_dict=speedyspeech_csmsc_ckpt_0.2.0/tone_id_map.txt ``` diff --git a/examples/csmsc/tts2/local/ort_predict.sh b/examples/csmsc/tts2/local/ort_predict.sh new file mode 100755 index 00000000..46b0409b --- /dev/null +++ b/examples/csmsc/tts2/local/ort_predict.sh @@ -0,0 +1,32 @@ +train_output_path=$1 + +stage=0 +stop_stage=0 + +# only support default_fastspeech2/speedyspeech + hifigan/mb_melgan now! + +# synthesize from metadata +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + python3 ${BIN_DIR}/../ort_predict.py \ + --inference_dir=${train_output_path}/inference_onnx \ + --am=speedyspeech_csmsc \ + --voc=hifigan_csmsc \ + --test_metadata=dump/test/norm/metadata.jsonl \ + --output_dir=${train_output_path}/onnx_infer_out \ + --device=cpu \ + --cpu_threads=2 +fi + +# e2e, synthesize from text +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + python3 ${BIN_DIR}/../ort_predict_e2e.py \ + --inference_dir=${train_output_path}/inference_onnx \ + --am=speedyspeech_csmsc \ + --voc=hifigan_csmsc \ + --output_dir=${train_output_path}/onnx_infer_out_e2e \ + --text=${BIN_DIR}/../csmsc_test.txt \ + --phones_dict=dump/phone_id_map.txt \ + --tones_dict=dump/tone_id_map.txt \ + --device=cpu \ + --cpu_threads=2 +fi diff --git a/examples/csmsc/tts2/local/paddle2onnx.sh b/examples/csmsc/tts2/local/paddle2onnx.sh new file mode 120000 index 00000000..87c46634 --- /dev/null +++ b/examples/csmsc/tts2/local/paddle2onnx.sh @@ -0,0 +1 @@ +../../tts3/local/paddle2onnx.sh \ No newline at end of file diff --git a/examples/csmsc/tts2/run.sh b/examples/csmsc/tts2/run.sh index 8b8f53bd..766aa882 100755 --- a/examples/csmsc/tts2/run.sh +++ b/examples/csmsc/tts2/run.sh @@ -40,3 +40,25 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then # inference with static model CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1 fi + +# paddle2onnx, please make sure the static models are in ${train_output_path}/inference first +# we have only tested the following models so far +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + # install paddle2onnx + version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '0.9.4' ]]; then + pip install paddle2onnx==0.9.4 + fi + ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx speedyspeech_csmsc + ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_csmsc +fi + +# inference with onnxruntime, use fastspeech2 + hifigan by default +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + # install onnxruntime + version=$(echo `pip list |grep "onnxruntime"` |awk -F" " '{print $2}') + if [[ -z "$version" || ${version} != '1.10.0' ]]; then + pip install onnxruntime==1.10.0 + fi + ./local/ort_predict.sh ${train_output_path} +fi diff --git a/examples/csmsc/tts3/local/ort_predict.sh b/examples/csmsc/tts3/local/ort_predict.sh index 3154f6e5..96350c06 100755 --- a/examples/csmsc/tts3/local/ort_predict.sh +++ b/examples/csmsc/tts3/local/ort_predict.sh @@ -3,7 +3,7 @@ train_output_path=$1 stage=0 stop_stage=0 -# only support default_fastspeech2 + hifigan/mb_melgan now! +# only support default_fastspeech2/speedyspeech + hifigan/mb_melgan now! # synthesize from metadata if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then diff --git a/examples/csmsc/tts3/local/paddle2onnx.sh b/examples/csmsc/tts3/local/paddle2onnx.sh index 505f3b66..0b05a6d6 100755 --- a/examples/csmsc/tts3/local/paddle2onnx.sh +++ b/examples/csmsc/tts3/local/paddle2onnx.sh @@ -19,4 +19,5 @@ paddle2onnx \ --model_filename ${model}.pdmodel \ --params_filename ${model}.pdiparams \ --save_file ${train_output_path}/${output_dir}/${model}.onnx \ + --opset_version 11 \ --enable_dev_version ${enable_dev_version} \ No newline at end of file diff --git a/paddleaudio/paddleaudio/compliance/librosa.py b/paddleaudio/paddleaudio/compliance/librosa.py index 740584ca..168632d7 100644 --- a/paddleaudio/paddleaudio/compliance/librosa.py +++ b/paddleaudio/paddleaudio/compliance/librosa.py @@ -341,7 +341,7 @@ def stft(x: np.ndarray, hop_length (Optional[int], optional): Number of steps to advance between adjacent windows. Defaults to None. win_length (Optional[int], optional): The size of window. Defaults to None. window (str, optional): A string of window specification. Defaults to "hann". - center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. dtype (type, optional): Data type of STFT results. Defaults to np.complex64. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". @@ -509,7 +509,7 @@ def melspectrogram(x: np.ndarray, fmin (float, optional): Minimum frequency in Hz. Defaults to 50.0. fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None. window (str, optional): A string of window specification. Defaults to "hann". - center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0. to_db (bool, optional): Enable db scale. Defaults to True. @@ -564,7 +564,7 @@ def spectrogram(x: np.ndarray, window_size (int, optional): Size of FFT and window length. Defaults to 512. hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320. window (str, optional): A string of window specification. Defaults to "hann". - center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect". power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0. diff --git a/paddleaudio/paddleaudio/features/layers.py b/paddleaudio/paddleaudio/features/layers.py index 09037255..292363e6 100644 --- a/paddleaudio/paddleaudio/features/layers.py +++ b/paddleaudio/paddleaudio/features/layers.py @@ -42,7 +42,7 @@ class Spectrogram(nn.Layer): win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. - center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. dtype (str, optional): Data type of input and window. Defaults to 'float32'. """ @@ -99,7 +99,7 @@ class MelSpectrogram(nn.Layer): win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. - center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. n_mels (int, optional): Number of mel bins. Defaults to 64. f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. @@ -176,7 +176,7 @@ class LogMelSpectrogram(nn.Layer): win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. - center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. n_mels (int, optional): Number of mel bins. Defaults to 64. f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. @@ -257,7 +257,7 @@ class MFCC(nn.Layer): win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None. window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'. power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0. - center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True. + center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True. pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'. n_mels (int, optional): Number of mel bins. Defaults to 64. f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0. diff --git a/paddlespeech/cli/tts/infer.py b/paddlespeech/cli/tts/infer.py index c7a1edc9..1c3fb29f 100644 --- a/paddlespeech/cli/tts/infer.py +++ b/paddlespeech/cli/tts/infer.py @@ -43,13 +43,13 @@ pretrained_models = { # speedyspeech "speedyspeech_csmsc-zh": { 'url': - 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip', + 'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip', 'md5': - '9edce23b1a87f31b814d9477bf52afbc', + '6f6fa967b408454b6662c8c00c0027cb', 'config': 'default.yaml', 'ckpt': - 'snapshot_iter_11400.pdz', + 'snapshot_iter_30600.pdz', 'speech_stats': 'feats_stats.npy', 'phones_dict': diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 62602a01..c5b64ac7 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -38,9 +38,7 @@ def get_predictor(args, filed='am'): config.enable_use_gpu(100, 0) elif args.device == "cpu": config.disable_gpu() - # This line must be commented for fastspeech2, if not, it will OOM - if model_name != 'fastspeech2': - config.enable_memory_optim() + config.enable_memory_optim() predictor = inference.create_predictor(config) return predictor diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py index e8d4d61c..2ca9b5be 100644 --- a/paddlespeech/t2s/exps/ort_predict.py +++ b/paddlespeech/t2s/exps/ort_predict.py @@ -70,8 +70,15 @@ def ort_predict(args): # am warmup for T in [27, 38, 54]: - data = np.random.randint(1, 266, size=(T, )) - am_sess.run(None, {"text": data}) + am_input_feed = {} + if am_name == 'fastspeech2': + phone_ids = np.random.randint(1, 266, size=(T, )) + am_input_feed.update({'text': phone_ids}) + elif am_name == 'speedyspeech': + phone_ids = np.random.randint(1, 92, size=(T, )) + tone_ids = np.random.randint(1, 5, size=(T, )) + am_input_feed.update({'phones': phone_ids, 'tones': tone_ids}) + am_sess.run(None, input_feed=am_input_feed) # voc warmup for T in [227, 308, 544]: @@ -81,14 +88,20 @@ def ort_predict(args): N = 0 T = 0 + am_input_feed = {} for example in test_dataset: utt_id = example['utt_id'] - phone_ids = example["text"] + if am_name == 'fastspeech2': + phone_ids = example["text"] + am_input_feed.update({'text': phone_ids}) + elif am_name == 'speedyspeech': + phone_ids = example["phones"] + tone_ids = example["tones"] + am_input_feed.update({'phones': phone_ids, 'tones': tone_ids}) with timer() as t: - mel = am_sess.run(output_names=None, input_feed={'text': phone_ids}) + mel = am_sess.run(output_names=None, input_feed=am_input_feed) mel = mel[0] wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) - N += len(wav[0]) T += t.elapse speed = len(wav[0]) / t.elapse @@ -110,9 +123,7 @@ def parse_args(): '--am', type=str, default='fastspeech2_csmsc', - choices=[ - 'fastspeech2_csmsc', - ], + choices=['fastspeech2_csmsc', 'speedyspeech_csmsc'], help='Choose acoustic model type of tts task.') # voc diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py index 8aa04cbc..c62b7ecd 100644 --- a/paddlespeech/t2s/exps/ort_predict_e2e.py +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -68,39 +68,58 @@ def ort_predict(args): # vocoder voc_sess = get_sess(args, filed='voc') + # frontend warmup + # Loading model cost 0.5+ seconds + if args.lang == 'zh': + frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True) + else: + print("lang should in be 'zh' here!") + # am warmup for T in [27, 38, 54]: - data = np.random.randint(1, 266, size=(T, )) - am_sess.run(None, {"text": data}) + am_input_feed = {} + if am_name == 'fastspeech2': + phone_ids = np.random.randint(1, 266, size=(T, )) + am_input_feed.update({'text': phone_ids}) + elif am_name == 'speedyspeech': + phone_ids = np.random.randint(1, 92, size=(T, )) + tone_ids = np.random.randint(1, 5, size=(T, )) + am_input_feed.update({'phones': phone_ids, 'tones': tone_ids}) + am_sess.run(None, input_feed=am_input_feed) # voc warmup for T in [227, 308, 544]: data = np.random.rand(T, 80).astype("float32") - voc_sess.run(None, {"logmel": data}) + voc_sess.run(None, input_feed={"logmel": data}) print("warm up done!") - # frontend warmup - # Loading model cost 0.5+ seconds - if args.lang == 'zh': - frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True) - else: - print("lang should in be 'zh' here!") - N = 0 T = 0 merge_sentences = True + get_tone_ids = False + am_input_feed = {} + if am_name == 'speedyspeech': + get_tone_ids = True for utt_id, sentence in sentences: with timer() as t: if args.lang == 'zh': input_ids = frontend.get_input_ids( - sentence, merge_sentences=merge_sentences) - + sentence, + merge_sentences=merge_sentences, + get_tone_ids=get_tone_ids) phone_ids = input_ids["phone_ids"] + if get_tone_ids: + tone_ids = input_ids["tone_ids"] else: print("lang should in be 'zh' here!") # merge_sentences=True here, so we only use the first item of phone_ids phone_ids = phone_ids[0].numpy() - mel = am_sess.run(output_names=None, input_feed={'text': phone_ids}) + if am_name == 'fastspeech2': + am_input_feed.update({'text': phone_ids}) + elif am_name == 'speedyspeech': + tone_ids = tone_ids[0].numpy() + am_input_feed.update({'phones': phone_ids, 'tones': tone_ids}) + mel = am_sess.run(output_names=None, input_feed=am_input_feed) mel = mel[0] wav = voc_sess.run(output_names=None, input_feed={'logmel': mel}) @@ -125,9 +144,7 @@ def parse_args(): '--am', type=str, default='fastspeech2_csmsc', - choices=[ - 'fastspeech2_csmsc', - ], + choices=['fastspeech2_csmsc', 'speedyspeech_csmsc'], help='Choose acoustic model type of tts task.') parser.add_argument( "--phones_dict", type=str, default=None, help="phone vocabulary file.") diff --git a/paddlespeech/t2s/exps/synthesize_e2e.py b/paddlespeech/t2s/exps/synthesize_e2e.py index 10b33c60..6c28dc48 100644 --- a/paddlespeech/t2s/exps/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/synthesize_e2e.py @@ -68,13 +68,15 @@ def evaluate(args): # but still not stopping in the end (NOTE by yuantian01 Feb 9 2022) if am_name == 'tacotron2': merge_sentences = True + + get_tone_ids = False + if am_name == 'speedyspeech': + get_tone_ids = True + N = 0 T = 0 for utt_id, sentence in sentences: with timer() as t: - get_tone_ids = False - if am_name == 'speedyspeech': - get_tone_ids = True if args.lang == 'zh': input_ids = frontend.get_input_ids( sentence, diff --git a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py index c2f1e218..8e52f916 100644 --- a/paddlespeech/t2s/models/fastspeech2/fastspeech2.py +++ b/paddlespeech/t2s/models/fastspeech2/fastspeech2.py @@ -667,8 +667,8 @@ class FastSpeech2(nn.Layer): use_teacher_forcing(bool, optional): Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. spk_emb(Tensor, optional, optional): peaker embedding vector (spk_embed_dim,). (Default value = None) - spk_id(Tensor, optional(int64), optional): Batch of padded spk ids (1,). (Default value = None) - tone_id(Tensor, optional(int64), optional): Batch of padded tone ids (T,). (Default value = None) + spk_id(Tensor, optional(int64), optional): spk ids (1,). (Default value = None) + tone_id(Tensor, optional(int64), optional): tone ids (T,). (Default value = None) Returns: @@ -751,7 +751,6 @@ class FastSpeech2(nn.Layer): Returns: - """ if self.tone_embed_integration_type == "add": # apply projection and then add to hidden states diff --git a/paddlespeech/t2s/models/speedyspeech/speedyspeech.py b/paddlespeech/t2s/models/speedyspeech/speedyspeech.py index 44ccfc60..86c84320 100644 --- a/paddlespeech/t2s/models/speedyspeech/speedyspeech.py +++ b/paddlespeech/t2s/models/speedyspeech/speedyspeech.py @@ -11,17 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + import paddle from paddle import nn from paddlespeech.t2s.modules.nets_utils import initialize -from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator +from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding class ResidualBlock(nn.Layer): - def __init__(self, channels, kernel_size, dilation, n=2): + def __init__(self, + channels: int=128, + kernel_size: int=3, + dilation: int=3, + n: int=2): + """SpeedySpeech encoder module. + Args: + channels (int, optional): Feature size of the resiaudl output(and also the input). + kernel_size (int, optional): Kernel size of the 1D convolution. + dilation (int, optional): Dilation of the 1D convolution. + n (int): Number of blocks. + """ + super().__init__() + total_pad = (dilation * (kernel_size - 1)) + begin = total_pad // 2 + end = total_pad - begin + # remove padding='same' here, cause onnx don't support dilation + 'same' padding blocks = [ nn.Sequential( nn.Conv1D( @@ -29,14 +47,20 @@ class ResidualBlock(nn.Layer): channels, kernel_size, dilation=dilation, - padding="same", - data_format="NLC"), + # make sure output T == input T + padding=((0, 0), (0, 0), (begin, end))), nn.ReLU(), - nn.BatchNorm1D(channels, data_format="NLC"), ) for _ in range(n) + nn.BatchNorm1D(channels), ) for _ in range(n) ] self.blocks = nn.Sequential(*blocks) - def forward(self, x): + def forward(self, x: paddle.Tensor): + """Calculate forward propagation. + Args: + x(Tensor): Batch of input sequences (B, hidden_size, Tmax). + Returns: + Tensor: The residual output (B, hidden_size, Tmax). + """ return x + self.blocks(x) @@ -62,7 +86,15 @@ class TextEmbedding(nn.Layer): tone_vocab_size, tone_embedding_size, tone_padding_idx) self.concat = concat - def forward(self, text, tone=None): + def forward(self, text: paddle.Tensor, tone: paddle.Tensor=None): + """Calculate forward propagation. + Args: + text(Tensor(int64)): Batch of padded token ids (B, Tmax). + tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax). + Returns: + Tensor: The residual output (B, Tmax, embedding_size). + """ + text_embed = self.text_embedding(text) if tone is None: return text_embed @@ -75,13 +107,24 @@ class TextEmbedding(nn.Layer): class SpeedySpeechEncoder(nn.Layer): + """SpeedySpeech encoder module. + Args: + vocab_size (int): Dimension of the inputs. + tone_size (Optional[int]): Number of tones. + hidden_size (int): Number of encoder hidden units. + kernel_size (int): Kernel size of encoder. + dilations (List[int]): Dilations of encoder. + spk_num (Optional[int]): Number of speakers. + """ + def __init__(self, - vocab_size, - tone_size, - hidden_size, - kernel_size, - dilations, + vocab_size: int, + tone_size: int, + hidden_size: int=128, + kernel_size: int=3, + dilations: List[int]=[1, 3, 9, 27, 1, 3, 9, 27, 1, 1], spk_num=None): + super().__init__() self.embedding = TextEmbedding( vocab_size, @@ -109,34 +152,71 @@ class SpeedySpeechEncoder(nn.Layer): self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size)) self.postnet2 = nn.Sequential( nn.ReLU(), - nn.BatchNorm1D(hidden_size, data_format="NLC"), - nn.Linear(hidden_size, hidden_size), ) - - def forward(self, text, tones, spk_id=None): + nn.BatchNorm1D(hidden_size), ) + self.linear = nn.Linear(hidden_size, hidden_size) + + def forward(self, + text: paddle.Tensor, + tones: paddle.Tensor, + spk_id: paddle.Tensor=None): + """Encoder input sequence. + Args: + text(Tensor(int64)): Batch of padded token ids (B, Tmax). + tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax). + spk_id(Tnesor, optional(int64)): Batch of speaker ids (B,) + + Returns: + Tensor: Output tensor (B, Tmax, hidden_size). + """ embedding = self.embedding(text, tones) if self.spk_emb: embedding += self.spk_emb(spk_id).unsqueeze(1) embedding = self.prenet(embedding) - x = self.res_blocks(embedding) + x = self.res_blocks(embedding.transpose([0, 2, 1])).transpose([0, 2, 1]) + # (B, T, dim) x = embedding + self.postnet1(x) - x = self.postnet2(x) + x = self.postnet2(x.transpose([0, 2, 1])).transpose([0, 2, 1]) + x = self.linear(x) return x class DurationPredictor(nn.Layer): - def __init__(self, hidden_size): + def __init__(self, hidden_size: int=128): super().__init__() self.layers = nn.Sequential( ResidualBlock(hidden_size, 4, 1, n=1), ResidualBlock(hidden_size, 3, 1, n=1), - ResidualBlock(hidden_size, 1, 1, n=1), nn.Linear(hidden_size, 1)) + ResidualBlock(hidden_size, 1, 1, n=1), ) + self.linear = nn.Linear(hidden_size, 1) - def forward(self, x): - return paddle.squeeze(self.layers(x), -1) + def forward(self, x: paddle.Tensor): + """Calculate forward propagation. + Args: + x(Tensor): Batch of input sequences (B, Tmax, hidden_size). + + Returns: + Tensor: Batch of predicted durations in log domain (B, Tmax). + """ + x = self.layers(x.transpose([0, 2, 1])).transpose([0, 2, 1]) + x = self.linear(x) + return paddle.squeeze(x, -1) class SpeedySpeechDecoder(nn.Layer): - def __init__(self, hidden_size, output_size, kernel_size, dilations): + def __init__(self, + hidden_size: int=128, + output_size: int=80, + kernel_size: int=3, + dilations: List[int]=[ + 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1 + ]): + """SpeedySpeech decoder module. + Args: + hidden_size (int): Number of decoder hidden units. + kernel_size (int): Kernel size of decoder. + output_size (int): Dimension of the outputs. + dilations (List[int]): Dilations of decoder. + """ super().__init__() res_blocks = [ ResidualBlock(hidden_size, kernel_size, d, n=2) for d in dilations @@ -144,14 +224,21 @@ class SpeedySpeechDecoder(nn.Layer): self.res_blocks = nn.Sequential(*res_blocks) self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size)) - self.postnet2 = nn.Sequential( - ResidualBlock(hidden_size, kernel_size, 1, n=2), - nn.Linear(hidden_size, output_size)) + self.postnet2 = ResidualBlock(hidden_size, kernel_size, 1, n=2) + self.linear = nn.Linear(hidden_size, output_size) def forward(self, x): - xx = self.res_blocks(x) + """Decoder input sequence. + Args: + x(Tensor): Input tensor (B, time, hidden_size). + + Returns: + Tensor: Output tensor (B, time, output_size). + """ + xx = self.res_blocks(x.transpose([0, 2, 1])).transpose([0, 2, 1]) x = x + self.postnet1(xx) - x = self.postnet2(x) + x = self.postnet2(x.transpose([0, 2, 1])).transpose([0, 2, 1]) + x = self.linear(x) return x @@ -159,17 +246,35 @@ class SpeedySpeech(nn.Layer): def __init__( self, vocab_size, - encoder_hidden_size, - encoder_kernel_size, - encoder_dilations, - duration_predictor_hidden_size, - decoder_hidden_size, - decoder_output_size, - decoder_kernel_size, - decoder_dilations, - tone_size=None, - spk_num=None, - init_type: str="xavier_uniform", ): + encoder_hidden_size: int=128, + encoder_kernel_size: int=3, + encoder_dilations: List[int]=[1, 3, 9, 27, 1, 3, 9, 27, 1, 1], + duration_predictor_hidden_size: int=128, + decoder_hidden_size: int=128, + decoder_output_size: int=80, + decoder_kernel_size: int=3, + decoder_dilations: List[ + int]=[1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1], + tone_size: int=None, + spk_num: int=None, + init_type: str="xavier_uniform", + positional_dropout_rate: int=0.1): + """Initialize SpeedySpeech module. + Args: + vocab_size (int): Dimension of the inputs. + encoder_hidden_size (int): Number of encoder hidden units. + encoder_kernel_size (int): Kernel size of encoder. + encoder_dilations (List[int]): Dilations of encoder. + duration_predictor_hidden_size (int): Number of duration predictor hidden units. + decoder_hidden_size (int): Number of decoder hidden units. + decoder_kernel_size (int): Kernel size of decoder. + decoder_dilations (List[int]): Dilations of decoder. + decoder_output_size (int): Dimension of the outputs. + tone_size (Optional[int]): Number of tones. + spk_num (Optional[int]): Number of speakers. + init_type (str): How to initialize transformer parameters. + + """ super().__init__() # initialize parameters @@ -181,6 +286,8 @@ class SpeedySpeech(nn.Layer): duration_predictor = DurationPredictor(duration_predictor_hidden_size) decoder = SpeedySpeechDecoder(decoder_hidden_size, decoder_output_size, decoder_kernel_size, decoder_dilations) + self.position_enc = ScaledPositionalEncoding(encoder_hidden_size, + positional_dropout_rate) self.encoder = encoder self.duration_predictor = duration_predictor @@ -190,7 +297,22 @@ class SpeedySpeech(nn.Layer): nn.initializer.set_global_initializer(None) - def forward(self, text, tones, durations, spk_id: paddle.Tensor=None): + def forward(self, + text: paddle.Tensor, + tones: paddle.Tensor, + durations: paddle.Tensor, + spk_id: paddle.Tensor=None): + """Calculate forward propagation. + Args: + text(Tensor(int64)): Batch of padded token ids (B, Tmax). + durations(Tensor(int64)): Batch of padded durations (B, Tmax). + tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax). + spk_id(Tnesor, optional(int64)): Batch of speaker ids (B,) + + Returns: + Tensor: Output tensor (B, T_frames, decoder_output_size). + Tensor: Predicted durations (B, Tmax). + """ # input of embedding must be int64 text = paddle.cast(text, 'int64') tones = paddle.cast(tones, 'int64') @@ -198,23 +320,30 @@ class SpeedySpeech(nn.Layer): spk_id = paddle.cast(spk_id, 'int64') durations = paddle.cast(durations, 'int64') encodings = self.encoder(text, tones, spk_id) - pred_durations = self.duration_predictor(encodings.detach()) - # expand encodings durations_to_expand = durations encodings = self.length_regulator(encodings, durations_to_expand) - + encodings = self.position_enc(encodings) # decode - # remove positional encoding here - _, t_dec, feature_size = encodings.shape - encodings += sinusoid_position_encoding(t_dec, feature_size) decoded = self.decoder(encodings) return decoded, pred_durations - def inference(self, text, tones=None, durations=None, spk_id=None): - # text: [T] - # tones: [T] + def inference(self, + text: paddle.Tensor, + tones: paddle.Tensor=None, + durations: paddle.Tensor=None, + spk_id: paddle.Tensor=None): + """Generate the sequence of features given the sequences of characters. + Args: + text(Tensor(int64)): Input sequence of characters (T,). + tones(Tensor, optional(int64)): Batch of padded tone ids (T, ). + durations(Tensor, optional (int64)): Groundtruth of duration (T,). + spk_id(Tensor, optional(int64), optional): spk ids (1,). (Default value = None) + + Returns: + Tensor: logmel (T, decoder_output_size). + """ # input of embedding must be int64 text = paddle.cast(text, 'int64') text = text.unsqueeze(0) @@ -233,10 +362,7 @@ class SpeedySpeech(nn.Layer): durations_to_expand = durations encodings = self.length_regulator( encodings, durations_to_expand, is_inference=True) - - shape = paddle.shape(encodings) - t_dec, feature_size = shape[1], shape[2] - encodings += sinusoid_position_encoding(t_dec, feature_size) + encodings = self.position_enc(encodings) decoded = self.decoder(encodings) return decoded[0] diff --git a/paddlespeech/t2s/modules/predictor/length_regulator.py b/paddlespeech/t2s/modules/predictor/length_regulator.py index be788e6e..b64aa44a 100644 --- a/paddlespeech/t2s/modules/predictor/length_regulator.py +++ b/paddlespeech/t2s/modules/predictor/length_regulator.py @@ -86,7 +86,7 @@ class LengthRegulator(nn.Layer): M[:, i] = m - init init = m M = paddle.reshape(M, shape=[t_dec_1, batch_size, t_enc]) - M = M[1:, :, :] + M = M[1:t_dec_1, :, :] M = paddle.transpose(M, (1, 0, 2)) encodings = paddle.matmul(M, encodings) return encodings diff --git a/paddlespeech/t2s/modules/transformer/encoder.py b/paddlespeech/t2s/modules/transformer/encoder.py index f6420282..d05516c2 100644 --- a/paddlespeech/t2s/modules/transformer/encoder.py +++ b/paddlespeech/t2s/modules/transformer/encoder.py @@ -347,7 +347,7 @@ class TransformerEncoder(BaseEncoder): encoder_type="transformer") def forward(self, xs, masks): - """Encode input sequence. + """Encoder input sequence. Args: xs(Tensor): Input tensor (#batch, time, idim). @@ -355,7 +355,7 @@ class TransformerEncoder(BaseEncoder): Returns: Tensor: Output tensor (#batch, time, attention_dim). - Tensor:Mask tensor (#batch, 1, time). + Tensor: Mask tensor (#batch, 1, time). """ xs = self.embed(xs) xs, masks = self.encoders(xs, masks)