[TTS]remove pad op in static model by replace F.pad with nn.Pad1D and nn.Pad2D (#3002)

* remove pad op in static model by replace F.pad with nn.Pad1D and nn.Pad2D

* fix variable names

* add note
pull/3008/head
TianYuan 1 year ago committed by GitHub
parent 864f9a1949
commit 528ae58a67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1 @@
../../tts3/local/paddle2onnx.sh

@ -39,3 +39,34 @@ fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
CUDA_VISIBLE_DEVICES=${gpus} ./local/inference.sh ${train_output_path} ${add_blank}|| exit -1
fi
# # not ready yet for operator missing in Paddle2ONNX
# # paddle2onnx, please make sure the static models are in ${train_output_path}/inference first
# # we have only tested the following models so far
# if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# # install paddle2onnx
# version=$(echo `pip list |grep "paddle2onnx"` |awk -F" " '{print $2}')
# if [[ -z "$version" || ${version} != '1.0.0' ]]; then
# pip install paddle2onnx==1.0.0
# fi
# ./local/paddle2onnx.sh ${train_output_path} inference inference_onnx vits_csmsc
# fi
# # inference with onnxruntime
# if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# ./local/ort_predict.sh ${train_output_path}
# fi
# # not ready yet for operator missing in Paddle-Lite
# # must run after stage 3 (which stage generated static models)
# if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
# # NOTE by yuantian 2022.11.21: please compile develop version of Paddle-Lite to export and run TTS models,
# # cause TTS models are supported by https://github.com/PaddlePaddle/Paddle-Lite/pull/9587
# # and https://github.com/PaddlePaddle/Paddle-Lite/pull/9706
# ./local/export2lite.sh ${train_output_path} inference pdlite vits_csmsc x86
# fi
# if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then
# CUDA_VISIBLE_DEVICES=${gpus} ./local/lite_predict.sh ${train_output_path} || exit -1
# fi

@ -279,6 +279,10 @@ class VITSGenerator(nn.Layer):
from paddlespeech.t2s.models.vits.monotonic_align import maximum_path
self.maximum_path = maximum_path
self.pad1d = nn.Pad1D(
padding=[1, 0],
mode='constant',
data_format='NLC', )
def forward(
self,
@ -685,5 +689,6 @@ class VITSGenerator(nn.Layer):
'''
path = paddle.cast(path, dtype='float32')
path = path - F.pad(path, [0, 0, 1, 0, 0, 0])[:, :-1]
pad_tmp = self.pad1d(path)[:, :-1]
path = path - pad_tmp
return path.unsqueeze(1).transpose([0, 1, 3, 2]) * mask

@ -18,6 +18,7 @@ This code is based on https://github.com/bayesiains/nflows.
"""
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
from paddlespeech.t2s.modules.nets_utils import paddle_gather
@ -87,9 +88,9 @@ def unconstrained_rational_quadratic_spline(
outputs = paddle.zeros(inputs.shape)
logabsdet = paddle.zeros(inputs.shape)
if tails == "linear":
unnormalized_derivatives = F.pad(
unnormalized_derivatives,
pad=[0] * (len(unnormalized_derivatives.shape) - 1) * 2 + [1, 1])
# 注意 padding 的参数顺序
pad2d = nn.Pad2D(padding=[1, 1, 0, 0], mode='constant')
unnormalized_derivatives = pad2d(unnormalized_derivatives)
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
@ -142,6 +143,10 @@ def rational_quadratic_spline(
# for dygraph to static
# if paddle.min(inputs) < left or paddle.max(inputs) > right:
# raise ValueError("Input to a transform is not within its domain")
pad1d = nn.Pad1D(
padding=[1, 0],
mode='constant',
data_format='NCL', )
num_bins = unnormalized_widths.shape[-1]
# for dygraph to static
@ -153,11 +158,8 @@ def rational_quadratic_spline(
widths = F.softmax(unnormalized_widths, axis=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
cumwidths = paddle.cumsum(widths, axis=-1)
cumwidths = F.pad(
cumwidths,
pad=[0] * (len(cumwidths.shape) - 1) * 2 + [1, 0],
mode="constant",
value=0.0)
cumwidths = pad1d(cumwidths.unsqueeze(0)).squeeze()
cumwidths = (right - left) * cumwidths + left
cumwidths[..., 0] = left
cumwidths[..., -1] = right
@ -168,11 +170,7 @@ def rational_quadratic_spline(
heights = F.softmax(unnormalized_heights, axis=-1)
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
cumheights = paddle.cumsum(heights, axis=-1)
cumheights = F.pad(
cumheights,
pad=[0] * (len(cumheights.shape) - 1) * 2 + [1, 0],
mode="constant",
value=0.0)
cumheights = pad1d(cumheights.unsqueeze(0)).squeeze()
cumheights = (top - bottom) * cumheights + bottom
cumheights[..., 0] = bottom
cumheights[..., -1] = top

Loading…
Cancel
Save