Support dy2st for VITS

pull/2883/head
0x45f 3 years ago
parent 71eabceedd
commit 4c30cd6eb1

@ -64,7 +64,6 @@ def evaluate(args):
vits = VITS(idim=vocab_size, odim=odim, **config["model"])
vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
vits.eval()
VITSInference
vits_inference = VITSInference(vits)
# whether dygraph to static
@ -108,7 +107,8 @@ def evaluate(args):
spk_id = None
if spk_num is not None:
spk_id = paddle.to_tensor(args.spk_id)
wav = vits_inference(text=part_phone_ids, sids=spk_id)
# wav = vits_inference(text=part_phone_ids, sids=spk_id)
wav = vits_inference(part_phone_ids)
if flags == 0:
wav_all = wav
flags = 1

@ -82,8 +82,10 @@ def unconstrained_rational_quadratic_spline(
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
outputs = paddle.zeros(paddle.shape(inputs))
logabsdet = paddle.zeros(paddle.shape(inputs))
# outputs = paddle.zeros(paddle.shape(inputs))
# logabsdet = paddle.zeros(paddle.shape(inputs))
outputs = paddle.zeros(inputs.shape)
logabsdet = paddle.zeros(inputs.shape)
if tails == "linear":
unnormalized_derivatives = F.pad(
unnormalized_derivatives,
@ -140,15 +142,16 @@ def rational_quadratic_spline(
min_bin_width=1e-3,
min_bin_height=1e-3,
min_derivative=1e-3, ):
if paddle.min(inputs) < left or paddle.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain")
# if paddle.min(inputs) < left or paddle.max(inputs) > right:
# raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins")
# if min_bin_width * num_bins > 1.0:
# raise ValueError("Minimal bin width too large for the number of bins")
# if min_bin_height * num_bins > 1.0:
# raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, axis=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths

Loading…
Cancel
Save