diff --git a/examples/csmsc/vits/local/paddle2onnx.sh b/examples/csmsc/vits/local/paddle2onnx.sh new file mode 120000 index 00000000..87c46634 --- /dev/null +++ b/examples/csmsc/vits/local/paddle2onnx.sh @@ -0,0 +1 @@ +../../tts3/local/paddle2onnx.sh \ No newline at end of file diff --git a/examples/csmsc/vits/run.sh b/examples/csmsc/vits/run.sh index ac190bfa..f2c5d452 100755 --- a/examples/csmsc/vits/run.sh +++ b/examples/csmsc/vits/run.sh @@ -39,3 +39,34 @@ fi if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} ${add_blank}|| exit -1 fi + +# # not ready yet for operator missing in Paddle2ONNX +# # paddle2onnx, please make sure the static models are in ${train_output_path}/inference first +# # we have only tested the following models so far +# if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then +# # install paddle2onnx +# version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}') +# if [[ -z "$version" || ${version} != '1.0.0' ]]; then +# pip install paddle2onnx==1.0.0 +# fi +# ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx vits_csmsc +# fi + +# # inference with onnxruntime +# if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then +# ./local/ort_predict.sh ${train_output_path} +# fi + +# # not ready yet for operator missing in Paddle-Lite +# # must run after stage 3 (which stage generated static models) +# if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then +# # NOTE by yuantian 2022.11.21: please compile develop version of Paddle-Lite to export and run TTS models, +# # cause TTS models are supported by https://github.com/PaddlePaddle/Paddle-Lite/pull/9587 +# # and https://github.com/PaddlePaddle/Paddle-Lite/pull/9706 +# ./local/export2lite.sh ${train_output_path} inference pdlite vits_csmsc x86 +# fi + +# if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then +# CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1 +# fi + diff --git a/paddlespeech/t2s/models/vits/generator.py b/paddlespeech/t2s/models/vits/generator.py index 7ecc5161..fbd2d665 100644 --- a/paddlespeech/t2s/models/vits/generator.py +++ b/paddlespeech/t2s/models/vits/generator.py @@ -279,6 +279,10 @@ class VITSGenerator(nn.Layer): from paddlespeech.t2s.models.vits.monotonic_align import maximum_path self.maximum_path = maximum_path + self.pad1d = nn.Pad1D( + padding=[1, 0], + mode='constant', + data_format='NLC', ) def forward( self, @@ -685,5 +689,6 @@ class VITSGenerator(nn.Layer): ''' path = paddle.cast(path, dtype='float32') - path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1] + pad_tmp = self.pad1d(path)[:, :-1] + path = path - pad_tmp return path.unsqueeze(1).transpose([0, 1, 3, 2]) * mask diff --git a/paddlespeech/t2s/models/vits/transform.py b/paddlespeech/t2s/models/vits/transform.py index ea333dcf..61bd5ee2 100644 --- a/paddlespeech/t2s/models/vits/transform.py +++ b/paddlespeech/t2s/models/vits/transform.py @@ -18,6 +18,7 @@ This code is based on https://github.com/bayesiains/nflows. """ import numpy as np import paddle +from paddle import nn from paddle.nn import functional as F from paddlespeech.t2s.modules.nets_utils import paddle_gather @@ -87,9 +88,9 @@ def unconstrained_rational_quadratic_spline( outputs = paddle.zeros(inputs.shape) logabsdet = paddle.zeros(inputs.shape) if tails == "linear": - unnormalized_derivatives = F.pad( - unnormalized_derivatives, - pad=[0] * (len(unnormalized_derivatives.shape) - 1) * 2 + [1, 1]) + # 注意 padding 的参数顺序 + pad2d = nn.Pad2D(padding=[1, 1, 0, 0], mode='constant') + unnormalized_derivatives = pad2d(unnormalized_derivatives) constant = np.log(np.exp(1 - min_derivative) - 1) unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., -1] = constant @@ -142,6 +143,10 @@ def rational_quadratic_spline( # for dygraph to static # if paddle.min(inputs) < left or paddle.max(inputs) > right: # raise ValueError("Input to a transform is not within its domain") + pad1d = nn.Pad1D( + padding=[1, 0], + mode='constant', + data_format='NCL', ) num_bins = unnormalized_widths.shape[-1] # for dygraph to static @@ -153,11 +158,8 @@ def rational_quadratic_spline( widths = F.softmax(unnormalized_widths, axis=-1) widths = min_bin_width + (1 - min_bin_width * num_bins) * widths cumwidths = paddle.cumsum(widths, axis=-1) - cumwidths = F.pad( - cumwidths, - pad=[0] * (len(cumwidths.shape) - 1) * 2 + [1, 0], - mode="constant", - value=0.0) + + cumwidths = pad1d(cumwidths.unsqueeze(0)).squeeze() cumwidths = (right - left) * cumwidths + left cumwidths[..., 0] = left cumwidths[..., -1] = right @@ -168,11 +170,7 @@ def rational_quadratic_spline( heights = F.softmax(unnormalized_heights, axis=-1) heights = min_bin_height + (1 - min_bin_height * num_bins) * heights cumheights = paddle.cumsum(heights, axis=-1) - cumheights = F.pad( - cumheights, - pad=[0] * (len(cumheights.shape) - 1) * 2 + [1, 0], - mode="constant", - value=0.0) + cumheights = pad1d(cumheights.unsqueeze(0)).squeeze() cumheights = (top - bottom) * cumheights + bottom cumheights[..., 0] = bottom cumheights[..., -1] = top