Merge pull request #1693 from yt605155624/fix_ss_NHWC

[TTS]change NLC to NCL in speedyspeech, test=tts
pull/1698/head
TianYuan 2 years ago committed by GitHub
commit 98f67870ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -37,7 +37,7 @@ Model Type | Dataset| Example Link | Pretrained Models|Static Models|Size (stati
Tacotron2|LJSpeech|[tacotron2-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts0)|[tacotron2_ljspeech_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_ljspeech_ckpt_0.2.0.zip)|||
Tacotron2|CSMSC|[tacotron2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts0)|[tacotron2_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_ckpt_0.2.0.zip)|[tacotron2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/tacotron2/tacotron2_csmsc_static_0.2.0.zip)|103MB|
TransformerTTS| LJSpeech| [transformer-ljspeech](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/ljspeech/tts1)|[transformer_tts_ljspeech_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/transformer_tts/transformer_tts_ljspeech_ckpt_0.4.zip)|||
SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2) |[speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip)|[speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)|12MB|
SpeedySpeech| CSMSC | [speedyspeech-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts2)|[speedyspeech_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip)|[speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)|12MB|
FastSpeech2| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_nosil_baker_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_baker_ckpt_0.4.zip)|[fastspeech2_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_csmsc_static_0.2.0.zip)|157MB|
FastSpeech2-Conformer| CSMSC |[fastspeech2-csmsc](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/csmsc/tts3)|[fastspeech2_conformer_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_conformer_baker_ckpt_0.5.zip)|||
FastSpeech2| AISHELL-3 |[fastspeech2-aishell3](https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/examples/aishell3/tts3)|[fastspeech2_nosil_aishell3_ckpt_0.4.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/fastspeech2/fastspeech2_nosil_aishell3_ckpt_0.4.zip)|||

@ -223,22 +223,28 @@ CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path}
## Pretrained Model
Pretrained SpeedySpeech model with no silence in the edge of audios:
- [speedyspeech_nosil_baker_ckpt_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip)
- [speedyspeech_csmsc_ckpt_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip)
The static model can be downloaded here:
- [speedyspeech_nosil_baker_static_0.5.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_static_0.5.zip)
- [speedyspeech_csmsc_static_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_static_0.2.0.zip)
The ONNX model can be downloaded here:
- [speedyspeech_csmsc_onnx_0.2.0.zip](https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_onnx_0.2.0.zip)
Model | Step | eval/loss | eval/l1_loss | eval/duration_loss | eval/ssim_loss
:-------------:| :------------:| :-----: | :-----: | :--------:|:--------:
default| 1(gpu) x 11400|0.83655|0.42324|0.03211| 0.38119
default| 1(gpu) x 11400|0.79532|0.400246|0.030259| 0.36482
SpeedySpeech checkpoint contains files listed below.
```text
speedyspeech_nosil_baker_ckpt_0.5
speedyspeech_csmsc_ckpt_0.2.0
├── default.yaml # default config used to train speedyspeech
├── feats_stats.npy # statistics used to normalize spectrogram when training speedyspeech
├── phone_id_map.txt # phone vocabulary file when training speedyspeech
├── snapshot_iter_11400.pdz # model parameters and optimizer states
├── snapshot_iter_30600.pdz # model parameters and optimizer states
└── tone_id_map.txt # tone vocabulary file when training speedyspeech
```
You can use the following scripts to synthesize for `${BIN_DIR}/../sentences.txt` using pretrained speedyspeech and parallel wavegan models.
@ -249,9 +255,9 @@ FLAGS_allocator_strategy=naive_best_fit \
FLAGS_fraction_of_gpu_memory_to_use=0.01 \
python3 ${BIN_DIR}/../synthesize_e2e.py \
--am=speedyspeech_csmsc \
--am_config=speedyspeech_nosil_baker_ckpt_0.5/default.yaml \
--am_ckpt=speedyspeech_nosil_baker_ckpt_0.5/snapshot_iter_11400.pdz \
--am_stat=speedyspeech_nosil_baker_ckpt_0.5/feats_stats.npy \
--am_config=speedyspeech_csmsc_ckpt_0.2.0/default.yaml \
--am_ckpt=speedyspeech_csmsc_ckpt_0.2.0/snapshot_iter_30600.pdz \
--am_stat=speedyspeech_csmsc_ckpt_0.2.0/feats_stats.npy \
--voc=pwgan_csmsc \
--voc_config=pwg_baker_ckpt_0.4/pwg_default.yaml \
--voc_ckpt=pwg_baker_ckpt_0.4/pwg_snapshot_iter_400000.pdz \
@ -260,6 +266,6 @@ python3 ${BIN_DIR}/../synthesize_e2e.py \
--text=${BIN_DIR}/../sentences.txt \
--output_dir=exp/default/test_e2e \
--inference_dir=exp/default/inference \
--phones_dict=speedyspeech_nosil_baker_ckpt_0.5/phone_id_map.txt \
--tones_dict=speedyspeech_nosil_baker_ckpt_0.5/tone_id_map.txt
--phones_dict=speedyspeech_csmsc_ckpt_0.2.0/phone_id_map.txt \
--tones_dict=speedyspeech_csmsc_ckpt_0.2.0/tone_id_map.txt
```

@ -0,0 +1,32 @@
train_output_path=$1
stage=0
stop_stage=0
# only support default_fastspeech2/speedyspeech + hifigan/mb_melgan now!
# synthesize from metadata
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
python3 ${BIN_DIR}/../ort_predict.py \
--inference_dir=${train_output_path}/inference_onnx \
--am=speedyspeech_csmsc \
--voc=hifigan_csmsc \
--test_metadata=dump/test/norm/metadata.jsonl \
--output_dir=${train_output_path}/onnx_infer_out \
--device=cpu \
--cpu_threads=2
fi
# e2e, synthesize from text
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
python3 ${BIN_DIR}/../ort_predict_e2e.py \
--inference_dir=${train_output_path}/inference_onnx \
--am=speedyspeech_csmsc \
--voc=hifigan_csmsc \
--output_dir=${train_output_path}/onnx_infer_out_e2e \
--text=${BIN_DIR}/../csmsc_test.txt \
--phones_dict=dump/phone_id_map.txt \
--tones_dict=dump/tone_id_map.txt \
--device=cpu \
--cpu_threads=2
fi

@ -0,0 +1 @@
../../tts3/local/paddle2onnx.sh

@ -40,3 +40,25 @@ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# inference with static model
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} || exit -1
fi
# paddle2onnx, please make sure the static models are in ${train_output_path}/inference first
# we have only tested the following models so far
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# install paddle2onnx
version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}')
if [[ -z "$version" || ${version} != '0.9.4' ]]; then
pip install paddle2onnx==0.9.4
fi
./local/paddle2onnx.sh ${train_output_path} inference inference_onnx speedyspeech_csmsc
./local/paddle2onnx.sh ${train_output_path} inference inference_onnx hifigan_csmsc
fi
# inference with onnxruntime, use fastspeech2 + hifigan by default
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# install onnxruntime
version=$(echo `pip list |grep "onnxruntime"` |awk -F" " '{print $2}')
if [[ -z "$version" || ${version} != '1.10.0' ]]; then
pip install onnxruntime==1.10.0
fi
./local/ort_predict.sh ${train_output_path}
fi

@ -3,7 +3,7 @@ train_output_path=$1
stage=0
stop_stage=0
# only support default_fastspeech2 + hifigan/mb_melgan now!
# only support default_fastspeech2/speedyspeech + hifigan/mb_melgan now!
# synthesize from metadata
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then

@ -19,4 +19,5 @@ paddle2onnx \
--model_filename ${model}.pdmodel \
--params_filename ${model}.pdiparams \
--save_file ${train_output_path}/${output_dir}/${model}.onnx \
--opset_version 11 \
--enable_dev_version ${enable_dev_version}

@ -341,7 +341,7 @@ def stft(x: np.ndarray,
hop_length (Optional[int], optional): Number of steps to advance between adjacent windows. Defaults to None.
win_length (Optional[int], optional): The size of window. Defaults to None.
window (str, optional): A string of window specification. Defaults to "hann".
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
dtype (type, optional): Data type of STFT results. Defaults to np.complex64.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
@ -509,7 +509,7 @@ def melspectrogram(x: np.ndarray,
fmin (float, optional): Minimum frequency in Hz. Defaults to 50.0.
fmax (Optional[float], optional): Maximum frequency in Hz. Defaults to None.
window (str, optional): A string of window specification. Defaults to "hann".
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0.
to_db (bool, optional): Enable db scale. Defaults to True.
@ -564,7 +564,7 @@ def spectrogram(x: np.ndarray,
window_size (int, optional): Size of FFT and window length. Defaults to 512.
hop_length (int, optional): Number of steps to advance between adjacent windows. Defaults to 320.
window (str, optional): A string of window specification. Defaults to "hann".
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to "reflect".
power (float, optional): Exponent for the magnitude melspectrogram. Defaults to 2.0.

@ -42,7 +42,7 @@ class Spectrogram(nn.Layer):
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
dtype (str, optional): Data type of input and window. Defaults to 'float32'.
"""
@ -99,7 +99,7 @@ class MelSpectrogram(nn.Layer):
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0.
@ -176,7 +176,7 @@ class LogMelSpectrogram(nn.Layer):
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0.
@ -257,7 +257,7 @@ class MFCC(nn.Layer):
win_length (Optional[int], optional): The window length of the short time FFT. If `None`, it is set to same as `n_fft`. Defaults to None.
window (str, optional): The window function applied to the signal before the Fourier transform. Supported window functions: 'hamming', 'hann', 'kaiser', 'gaussian', 'exponential', 'triang', 'bohman', 'blackman', 'cosine', 'tukey', 'taylor'. Defaults to 'hann'.
power (float, optional): Exponent for the magnitude spectrogram. Defaults to 2.0.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\_length` at the center of `t`-th frame. Defaults to True.
center (bool, optional): Whether to pad `x` to make that the :math:`t \times hop\\_length` at the center of `t`-th frame. Defaults to True.
pad_mode (str, optional): Choose padding pattern when `center` is `True`. Defaults to 'reflect'.
n_mels (int, optional): Number of mel bins. Defaults to 64.
f_min (float, optional): Minimum frequency in Hz. Defaults to 50.0.

@ -43,13 +43,13 @@ pretrained_models = {
# speedyspeech
"speedyspeech_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_nosil_baker_ckpt_0.5.zip',
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/speedyspeech/speedyspeech_csmsc_ckpt_0.2.0.zip',
'md5':
'9edce23b1a87f31b814d9477bf52afbc',
'6f6fa967b408454b6662c8c00c0027cb',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_11400.pdz',
'snapshot_iter_30600.pdz',
'speech_stats':
'feats_stats.npy',
'phones_dict':

@ -38,9 +38,7 @@ def get_predictor(args, filed='am'):
config.enable_use_gpu(100, 0)
elif args.device == "cpu":
config.disable_gpu()
# This line must be commented for fastspeech2, if not, it will OOM
if model_name != 'fastspeech2':
config.enable_memory_optim()
config.enable_memory_optim()
predictor = inference.create_predictor(config)
return predictor

@ -70,8 +70,15 @@ def ort_predict(args):
# am warmup
for T in [27, 38, 54]:
data = np.random.randint(1, 266, size=(T, ))
am_sess.run(None, {"text": data})
am_input_feed = {}
if am_name == 'fastspeech2':
phone_ids = np.random.randint(1, 266, size=(T, ))
am_input_feed.update({'text': phone_ids})
elif am_name == 'speedyspeech':
phone_ids = np.random.randint(1, 92, size=(T, ))
tone_ids = np.random.randint(1, 5, size=(T, ))
am_input_feed.update({'phones': phone_ids, 'tones': tone_ids})
am_sess.run(None, input_feed=am_input_feed)
# voc warmup
for T in [227, 308, 544]:
@ -81,14 +88,20 @@ def ort_predict(args):
N = 0
T = 0
am_input_feed = {}
for example in test_dataset:
utt_id = example['utt_id']
phone_ids = example["text"]
if am_name == 'fastspeech2':
phone_ids = example["text"]
am_input_feed.update({'text': phone_ids})
elif am_name == 'speedyspeech':
phone_ids = example["phones"]
tone_ids = example["tones"]
am_input_feed.update({'phones': phone_ids, 'tones': tone_ids})
with timer() as t:
mel = am_sess.run(output_names=None, input_feed={'text': phone_ids})
mel = am_sess.run(output_names=None, input_feed=am_input_feed)
mel = mel[0]
wav = voc_sess.run(output_names=None, input_feed={'logmel': mel})
N += len(wav[0])
T += t.elapse
speed = len(wav[0]) / t.elapse
@ -110,9 +123,7 @@ def parse_args():
'--am',
type=str,
default='fastspeech2_csmsc',
choices=[
'fastspeech2_csmsc',
],
choices=['fastspeech2_csmsc', 'speedyspeech_csmsc'],
help='Choose acoustic model type of tts task.')
# voc

@ -68,39 +68,58 @@ def ort_predict(args):
# vocoder
voc_sess = get_sess(args, filed='voc')
# frontend warmup
# Loading model cost 0.5+ seconds
if args.lang == 'zh':
frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True)
else:
print("lang should in be 'zh' here!")
# am warmup
for T in [27, 38, 54]:
data = np.random.randint(1, 266, size=(T, ))
am_sess.run(None, {"text": data})
am_input_feed = {}
if am_name == 'fastspeech2':
phone_ids = np.random.randint(1, 266, size=(T, ))
am_input_feed.update({'text': phone_ids})
elif am_name == 'speedyspeech':
phone_ids = np.random.randint(1, 92, size=(T, ))
tone_ids = np.random.randint(1, 5, size=(T, ))
am_input_feed.update({'phones': phone_ids, 'tones': tone_ids})
am_sess.run(None, input_feed=am_input_feed)
# voc warmup
for T in [227, 308, 544]:
data = np.random.rand(T, 80).astype("float32")
voc_sess.run(None, {"logmel": data})
voc_sess.run(None, input_feed={"logmel": data})
print("warm up done!")
# frontend warmup
# Loading model cost 0.5+ seconds
if args.lang == 'zh':
frontend.get_input_ids("你好,欢迎使用飞桨框架进行深度学习研究!", merge_sentences=True)
else:
print("lang should in be 'zh' here!")
N = 0
T = 0
merge_sentences = True
get_tone_ids = False
am_input_feed = {}
if am_name == 'speedyspeech':
get_tone_ids = True
for utt_id, sentence in sentences:
with timer() as t:
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
else:
print("lang should in be 'zh' here!")
# merge_sentences=True here, so we only use the first item of phone_ids
phone_ids = phone_ids[0].numpy()
mel = am_sess.run(output_names=None, input_feed={'text': phone_ids})
if am_name == 'fastspeech2':
am_input_feed.update({'text': phone_ids})
elif am_name == 'speedyspeech':
tone_ids = tone_ids[0].numpy()
am_input_feed.update({'phones': phone_ids, 'tones': tone_ids})
mel = am_sess.run(output_names=None, input_feed=am_input_feed)
mel = mel[0]
wav = voc_sess.run(output_names=None, input_feed={'logmel': mel})
@ -125,9 +144,7 @@ def parse_args():
'--am',
type=str,
default='fastspeech2_csmsc',
choices=[
'fastspeech2_csmsc',
],
choices=['fastspeech2_csmsc', 'speedyspeech_csmsc'],
help='Choose acoustic model type of tts task.')
parser.add_argument(
"--phones_dict", type=str, default=None, help="phone vocabulary file.")

@ -68,13 +68,15 @@ def evaluate(args):
# but still not stopping in the end (NOTE by yuantian01 Feb 9 2022)
if am_name == 'tacotron2':
merge_sentences = True
get_tone_ids = False
if am_name == 'speedyspeech':
get_tone_ids = True
N = 0
T = 0
for utt_id, sentence in sentences:
with timer() as t:
get_tone_ids = False
if am_name == 'speedyspeech':
get_tone_ids = True
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence,

@ -667,8 +667,8 @@ class FastSpeech2(nn.Layer):
use_teacher_forcing(bool, optional): Whether to use teacher forcing.
If true, groundtruth of duration, pitch and energy will be used.
spk_emb(Tensor, optional, optional): peaker embedding vector (spk_embed_dim,). (Default value = None)
spk_id(Tensor, optional(int64), optional): Batch of padded spk ids (1,). (Default value = None)
tone_id(Tensor, optional(int64), optional): Batch of padded tone ids (T,). (Default value = None)
spk_id(Tensor, optional(int64), optional): spk ids (1,). (Default value = None)
tone_id(Tensor, optional(int64), optional): tone ids (T,). (Default value = None)
Returns:
@ -751,7 +751,6 @@ class FastSpeech2(nn.Layer):
Returns:
"""
if self.tone_embed_integration_type == "add":
# apply projection and then add to hidden states

@ -11,17 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import paddle
from paddle import nn
from paddlespeech.t2s.modules.nets_utils import initialize
from paddlespeech.t2s.modules.positional_encoding import sinusoid_position_encoding
from paddlespeech.t2s.modules.predictor.length_regulator import LengthRegulator
from paddlespeech.t2s.modules.transformer.embedding import ScaledPositionalEncoding
class ResidualBlock(nn.Layer):
def __init__(self, channels, kernel_size, dilation, n=2):
def __init__(self,
channels: int=128,
kernel_size: int=3,
dilation: int=3,
n: int=2):
"""SpeedySpeech encoder module.
Args:
channels (int, optional): Feature size of the resiaudl output(and also the input).
kernel_size (int, optional): Kernel size of the 1D convolution.
dilation (int, optional): Dilation of the 1D convolution.
n (int): Number of blocks.
"""
super().__init__()
total_pad = (dilation * (kernel_size - 1))
begin = total_pad // 2
end = total_pad - begin
# remove padding='same' here, cause onnx don't support dilation + 'same' padding
blocks = [
nn.Sequential(
nn.Conv1D(
@ -29,14 +47,20 @@ class ResidualBlock(nn.Layer):
channels,
kernel_size,
dilation=dilation,
padding="same",
data_format="NLC"),
# make sure output T == input T
padding=((0, 0), (0, 0), (begin, end))),
nn.ReLU(),
nn.BatchNorm1D(channels, data_format="NLC"), ) for _ in range(n)
nn.BatchNorm1D(channels), ) for _ in range(n)
]
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
def forward(self, x: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor): Batch of input sequences (B, hidden_size, Tmax).
Returns:
Tensor: The residual output (B, hidden_size, Tmax).
"""
return x + self.blocks(x)
@ -62,7 +86,15 @@ class TextEmbedding(nn.Layer):
tone_vocab_size, tone_embedding_size, tone_padding_idx)
self.concat = concat
def forward(self, text, tone=None):
def forward(self, text: paddle.Tensor, tone: paddle.Tensor=None):
"""Calculate forward propagation.
Args:
text(Tensor(int64)): Batch of padded token ids (B, Tmax).
tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax).
Returns:
Tensor: The residual output (B, Tmax, embedding_size).
"""
text_embed = self.text_embedding(text)
if tone is None:
return text_embed
@ -75,13 +107,24 @@ class TextEmbedding(nn.Layer):
class SpeedySpeechEncoder(nn.Layer):
"""SpeedySpeech encoder module.
Args:
vocab_size (int): Dimension of the inputs.
tone_size (Optional[int]): Number of tones.
hidden_size (int): Number of encoder hidden units.
kernel_size (int): Kernel size of encoder.
dilations (List[int]): Dilations of encoder.
spk_num (Optional[int]): Number of speakers.
"""
def __init__(self,
vocab_size,
tone_size,
hidden_size,
kernel_size,
dilations,
vocab_size: int,
tone_size: int,
hidden_size: int=128,
kernel_size: int=3,
dilations: List[int]=[1, 3, 9, 27, 1, 3, 9, 27, 1, 1],
spk_num=None):
super().__init__()
self.embedding = TextEmbedding(
vocab_size,
@ -109,34 +152,71 @@ class SpeedySpeechEncoder(nn.Layer):
self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size))
self.postnet2 = nn.Sequential(
nn.ReLU(),
nn.BatchNorm1D(hidden_size, data_format="NLC"),
nn.Linear(hidden_size, hidden_size), )
def forward(self, text, tones, spk_id=None):
nn.BatchNorm1D(hidden_size), )
self.linear = nn.Linear(hidden_size, hidden_size)
def forward(self,
text: paddle.Tensor,
tones: paddle.Tensor,
spk_id: paddle.Tensor=None):
"""Encoder input sequence.
Args:
text(Tensor(int64)): Batch of padded token ids (B, Tmax).
tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax).
spk_id(Tnesor, optional(int64)): Batch of speaker ids (B,)
Returns:
Tensor: Output tensor (B, Tmax, hidden_size).
"""
embedding = self.embedding(text, tones)
if self.spk_emb:
embedding += self.spk_emb(spk_id).unsqueeze(1)
embedding = self.prenet(embedding)
x = self.res_blocks(embedding)
x = self.res_blocks(embedding.transpose([0, 2, 1])).transpose([0, 2, 1])
# (B, T, dim)
x = embedding + self.postnet1(x)
x = self.postnet2(x)
x = self.postnet2(x.transpose([0, 2, 1])).transpose([0, 2, 1])
x = self.linear(x)
return x
class DurationPredictor(nn.Layer):
def __init__(self, hidden_size):
def __init__(self, hidden_size: int=128):
super().__init__()
self.layers = nn.Sequential(
ResidualBlock(hidden_size, 4, 1, n=1),
ResidualBlock(hidden_size, 3, 1, n=1),
ResidualBlock(hidden_size, 1, 1, n=1), nn.Linear(hidden_size, 1))
ResidualBlock(hidden_size, 1, 1, n=1), )
self.linear = nn.Linear(hidden_size, 1)
def forward(self, x):
return paddle.squeeze(self.layers(x), -1)
def forward(self, x: paddle.Tensor):
"""Calculate forward propagation.
Args:
x(Tensor): Batch of input sequences (B, Tmax, hidden_size).
Returns:
Tensor: Batch of predicted durations in log domain (B, Tmax).
"""
x = self.layers(x.transpose([0, 2, 1])).transpose([0, 2, 1])
x = self.linear(x)
return paddle.squeeze(x, -1)
class SpeedySpeechDecoder(nn.Layer):
def __init__(self, hidden_size, output_size, kernel_size, dilations):
def __init__(self,
hidden_size: int=128,
output_size: int=80,
kernel_size: int=3,
dilations: List[int]=[
1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1
]):
"""SpeedySpeech decoder module.
Args:
hidden_size (int): Number of decoder hidden units.
kernel_size (int): Kernel size of decoder.
output_size (int): Dimension of the outputs.
dilations (List[int]): Dilations of decoder.
"""
super().__init__()
res_blocks = [
ResidualBlock(hidden_size, kernel_size, d, n=2) for d in dilations
@ -144,14 +224,21 @@ class SpeedySpeechDecoder(nn.Layer):
self.res_blocks = nn.Sequential(*res_blocks)
self.postnet1 = nn.Sequential(nn.Linear(hidden_size, hidden_size))
self.postnet2 = nn.Sequential(
ResidualBlock(hidden_size, kernel_size, 1, n=2),
nn.Linear(hidden_size, output_size))
self.postnet2 = ResidualBlock(hidden_size, kernel_size, 1, n=2)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, x):
xx = self.res_blocks(x)
"""Decoder input sequence.
Args:
x(Tensor): Input tensor (B, time, hidden_size).
Returns:
Tensor: Output tensor (B, time, output_size).
"""
xx = self.res_blocks(x.transpose([0, 2, 1])).transpose([0, 2, 1])
x = x + self.postnet1(xx)
x = self.postnet2(x)
x = self.postnet2(x.transpose([0, 2, 1])).transpose([0, 2, 1])
x = self.linear(x)
return x
@ -159,17 +246,35 @@ class SpeedySpeech(nn.Layer):
def __init__(
self,
vocab_size,
encoder_hidden_size,
encoder_kernel_size,
encoder_dilations,
duration_predictor_hidden_size,
decoder_hidden_size,
decoder_output_size,
decoder_kernel_size,
decoder_dilations,
tone_size=None,
spk_num=None,
init_type: str="xavier_uniform", ):
encoder_hidden_size: int=128,
encoder_kernel_size: int=3,
encoder_dilations: List[int]=[1, 3, 9, 27, 1, 3, 9, 27, 1, 1],
duration_predictor_hidden_size: int=128,
decoder_hidden_size: int=128,
decoder_output_size: int=80,
decoder_kernel_size: int=3,
decoder_dilations: List[
int]=[1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 3, 9, 27, 1, 1],
tone_size: int=None,
spk_num: int=None,
init_type: str="xavier_uniform",
positional_dropout_rate: int=0.1):
"""Initialize SpeedySpeech module.
Args:
vocab_size (int): Dimension of the inputs.
encoder_hidden_size (int): Number of encoder hidden units.
encoder_kernel_size (int): Kernel size of encoder.
encoder_dilations (List[int]): Dilations of encoder.
duration_predictor_hidden_size (int): Number of duration predictor hidden units.
decoder_hidden_size (int): Number of decoder hidden units.
decoder_kernel_size (int): Kernel size of decoder.
decoder_dilations (List[int]): Dilations of decoder.
decoder_output_size (int): Dimension of the outputs.
tone_size (Optional[int]): Number of tones.
spk_num (Optional[int]): Number of speakers.
init_type (str): How to initialize transformer parameters.
"""
super().__init__()
# initialize parameters
@ -181,6 +286,8 @@ class SpeedySpeech(nn.Layer):
duration_predictor = DurationPredictor(duration_predictor_hidden_size)
decoder = SpeedySpeechDecoder(decoder_hidden_size, decoder_output_size,
decoder_kernel_size, decoder_dilations)
self.position_enc = ScaledPositionalEncoding(encoder_hidden_size,
positional_dropout_rate)
self.encoder = encoder
self.duration_predictor = duration_predictor
@ -190,7 +297,22 @@ class SpeedySpeech(nn.Layer):
nn.initializer.set_global_initializer(None)
def forward(self, text, tones, durations, spk_id: paddle.Tensor=None):
def forward(self,
text: paddle.Tensor,
tones: paddle.Tensor,
durations: paddle.Tensor,
spk_id: paddle.Tensor=None):
"""Calculate forward propagation.
Args:
text(Tensor(int64)): Batch of padded token ids (B, Tmax).
durations(Tensor(int64)): Batch of padded durations (B, Tmax).
tones(Tensor, optional(int64)): Batch of padded tone ids (B, Tmax).
spk_id(Tnesor, optional(int64)): Batch of speaker ids (B,)
Returns:
Tensor: Output tensor (B, T_frames, decoder_output_size).
Tensor: Predicted durations (B, Tmax).
"""
# input of embedding must be int64
text = paddle.cast(text, 'int64')
tones = paddle.cast(tones, 'int64')
@ -198,23 +320,30 @@ class SpeedySpeech(nn.Layer):
spk_id = paddle.cast(spk_id, 'int64')
durations = paddle.cast(durations, 'int64')
encodings = self.encoder(text, tones, spk_id)
pred_durations = self.duration_predictor(encodings.detach())
# expand encodings
durations_to_expand = durations
encodings = self.length_regulator(encodings, durations_to_expand)
encodings = self.position_enc(encodings)
# decode
# remove positional encoding here
_, t_dec, feature_size = encodings.shape
encodings += sinusoid_position_encoding(t_dec, feature_size)
decoded = self.decoder(encodings)
return decoded, pred_durations
def inference(self, text, tones=None, durations=None, spk_id=None):
# text: [T]
# tones: [T]
def inference(self,
text: paddle.Tensor,
tones: paddle.Tensor=None,
durations: paddle.Tensor=None,
spk_id: paddle.Tensor=None):
"""Generate the sequence of features given the sequences of characters.
Args:
text(Tensor(int64)): Input sequence of characters (T,).
tones(Tensor, optional(int64)): Batch of padded tone ids (T, ).
durations(Tensor, optional (int64)): Groundtruth of duration (T,).
spk_id(Tensor, optional(int64), optional): spk ids (1,). (Default value = None)
Returns:
Tensor: logmel (T, decoder_output_size).
"""
# input of embedding must be int64
text = paddle.cast(text, 'int64')
text = text.unsqueeze(0)
@ -233,10 +362,7 @@ class SpeedySpeech(nn.Layer):
durations_to_expand = durations
encodings = self.length_regulator(
encodings, durations_to_expand, is_inference=True)
shape = paddle.shape(encodings)
t_dec, feature_size = shape[1], shape[2]
encodings += sinusoid_position_encoding(t_dec, feature_size)
encodings = self.position_enc(encodings)
decoded = self.decoder(encodings)
return decoded[0]

@ -86,7 +86,7 @@ class LengthRegulator(nn.Layer):
M[:, i] = m - init
init = m
M = paddle.reshape(M, shape=[t_dec_1, batch_size, t_enc])
M = M[1:, :, :]
M = M[1:t_dec_1, :, :]
M = paddle.transpose(M, (1, 0, 2))
encodings = paddle.matmul(M, encodings)
return encodings

@ -347,7 +347,7 @@ class TransformerEncoder(BaseEncoder):
encoder_type="transformer")
def forward(self, xs, masks):
"""Encode input sequence.
"""Encoder input sequence.
Args:
xs(Tensor): Input tensor (#batch, time, idim).
@ -355,7 +355,7 @@ class TransformerEncoder(BaseEncoder):
Returns:
Tensor: Output tensor (#batch, time, attention_dim).
Tensor:Mask tensor (#batch, 1, time).
Tensor: Mask tensor (#batch, 1, time).
"""
xs = self.embed(xs)
xs, masks = self.encoders(xs, masks)

Loading…
Cancel
Save