[TTS]vits dygraph to static (#2883)

Co-authored-by: 0x45f <wangzhen45@baidu.com>
pull/2964/head
TianYuan 1 year ago committed by GitHub
parent 11bc392617
commit 84f751f529
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,5 +18,6 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
--phones_dict=dump/phone_id_map.txt \
--output_dir=${train_output_path}/test_e2e \
--text=${BIN_DIR}/../sentences.txt \
--add-blank=${add_blank}
--add-blank=${add_blank} #\
# --inference_dir=${train_output_path}/inference
fi

@ -452,9 +452,19 @@ def am_to_static(am_inference,
elif am_name == 'tacotron2':
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
paddle.jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = paddle.jit.load(os.path.join(inference_dir, am))
elif am_name == 'vits':
if am_dataset in {"aishell3", "vctk"} and speaker_dict is not None:
am_inference = jit.to_static(
am_inference,
input_spec=[
InputSpec([-1], dtype=paddle.int64),
InputSpec([1], dtype=paddle.int64),
])
else:
am_inference = jit.to_static(
am_inference, input_spec=[InputSpec([-1], dtype=paddle.int64)])
jit.save(am_inference, os.path.join(inference_dir, am))
am_inference = jit.load(os.path.join(inference_dir, am))
return am_inference
@ -465,8 +475,8 @@ def voc_to_static(voc_inference,
voc_inference, input_spec=[
InputSpec([-1, 80], dtype=paddle.float32),
])
paddle.jit.save(voc_inference, os.path.join(inference_dir, voc))
voc_inference = paddle.jit.load(os.path.join(inference_dir, voc))
jit.save(voc_inference, os.path.join(inference_dir, voc))
voc_inference = jit.load(os.path.join(inference_dir, voc))
return voc_inference

@ -20,14 +20,15 @@ import yaml
from timer import timer
from yacs.config import CfgNode
from paddlespeech.t2s.exps.syn_utils import am_to_static
from paddlespeech.t2s.exps.syn_utils import get_frontend
from paddlespeech.t2s.exps.syn_utils import get_sentences
from paddlespeech.t2s.models.vits import VITS
from paddlespeech.t2s.models.vits import VITSInference
from paddlespeech.t2s.utils import str2bool
def evaluate(args):
# Init body.
with open(args.config) as f:
config = CfgNode(yaml.safe_load(f))
@ -41,6 +42,9 @@ def evaluate(args):
# frontend
frontend = get_frontend(lang=args.lang, phones_dict=args.phones_dict)
# acoustic model
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
spk_num = None
if args.speaker_dict is not None:
@ -64,6 +68,15 @@ def evaluate(args):
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
vits.eval()
vits_inference = VITSInference(vits)
# whether dygraph to static
if args.inference_dir:
vits_inference = am_to_static(
am_inference=vits_inference,
am=args.am,
inference_dir=args.inference_dir,
speaker_dict=args.speaker_dict)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
merge_sentences = False
@ -90,10 +103,12 @@ def evaluate(args):
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i]
spk_id = None
if spk_num is not None:
if am_dataset in {"aishell3",
"vctk"} and spk_num is not None:
spk_id = paddle.to_tensor(args.spk_id)
out = vits.inference(text=part_phone_ids, sids=spk_id)
wav = out["wav"]
wav = vits_inference(part_phone_ids, spk_id)
else:
wav = vits_inference(part_phone_ids)
if flags == 0:
wav_all = wav
flags = 1
@ -155,6 +170,11 @@ def parse_args():
type=str2bool,
default=True,
help="whether to add blank between phones")
parser.add_argument(
'--am',
type=str,
default='vits_csmsc',
help='Choose acoustic model type of tts task.')
args = parser.parse_args()
return args

@ -35,9 +35,10 @@ def piecewise_rational_quadratic_transform(
inverse=False,
tails=None,
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
# for dygraph-to-static
min_bin_width=1e-3,
min_bin_height=1e-3,
min_derivative=1e-3, ):
if tails is None:
spline_fn = rational_quadratic_spline
spline_kwargs = {}
@ -74,14 +75,17 @@ def unconstrained_rational_quadratic_spline(
inverse=False,
tails="linear",
tail_bound=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
# for dygraph-to-static
min_bin_width=1e-3,
min_bin_height=1e-3,
min_derivative=1e-3, ):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = paddle.zeros(paddle.shape(inputs))
logabsdet = paddle.zeros(paddle.shape(inputs))
# for dygraph to static
# 这里用 paddle.shape(x) 然后调用 zeros 会得到一个全 -1 shape 的 var
# 如果用 x.shape 的话可以保留确定的维度
outputs = paddle.zeros(inputs.shape)
logabsdet = paddle.zeros(inputs.shape)
if tails == "linear":
unnormalized_derivatives = F.pad(
unnormalized_derivatives,
@ -89,8 +93,9 @@ def unconstrained_rational_quadratic_spline(
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
outputs[outside_interval_mask] = inputs[outside_interval_mask]
# for dygraph to static
tmp = inputs[outside_interval_mask]
outputs[outside_interval_mask] = tmp
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
@ -130,18 +135,20 @@ def rational_quadratic_spline(
right=1.0,
bottom=0.0,
top=1.0,
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
min_derivative=DEFAULT_MIN_DERIVATIVE, ):
if paddle.min(inputs) < left or paddle.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
# for dygraph-to-static
min_bin_width=1e-3,
min_bin_height=1e-3,
min_derivative=1e-3, ):
# for dygraph to static
# if paddle.min(inputs) < left or paddle.max(inputs) > right:
# raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
# for dygraph to static
# if min_bin_width * num_bins > 1.0:
# raise ValueError("Minimal bin width too large for the number of bins")
# if min_bin_height * num_bins > 1.0:
# raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, axis=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths

@ -532,3 +532,14 @@ class VITS(nn.Layer):
module.weight[module._padding_idx] = 0
self.apply(_reset_parameters)
class VITSInference(nn.Layer):
def __init__(self, model):
super().__init__()
self.acoustic_model = model
def forward(self, text, sids=None):
out = self.acoustic_model.inference(
text, sids=sids)
wav = out['wav']
return wav

Loading…
Cancel
Save