Support dy2st for VITS

pull/2883/head
0x45f 3 years ago
parent 71eabceedd
commit 4c30cd6eb1

@ -64,7 +64,6 @@ def evaluate(args):
vits = VITS(idim=vocab_size, odim=odim, **config["model"]) vits = VITS(idim=vocab_size, odim=odim, **config["model"])
vits.set_state_dict(paddle.load(args.ckpt)["main_params"]) vits.set_state_dict(paddle.load(args.ckpt)["main_params"])
vits.eval() vits.eval()
VITSInference
vits_inference = VITSInference(vits) vits_inference = VITSInference(vits)
# whether dygraph to static # whether dygraph to static
@ -108,7 +107,8 @@ def evaluate(args):
spk_id = None spk_id = None
if spk_num is not None: if spk_num is not None:
spk_id = paddle.to_tensor(args.spk_id) spk_id = paddle.to_tensor(args.spk_id)
wav = vits_inference(text=part_phone_ids, sids=spk_id) # wav = vits_inference(text=part_phone_ids, sids=spk_id)
wav = vits_inference(part_phone_ids)
if flags == 0: if flags == 0:
wav_all = wav wav_all = wav
flags = 1 flags = 1

@ -82,8 +82,10 @@ def unconstrained_rational_quadratic_spline(
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask outside_interval_mask = ~inside_interval_mask
outputs = paddle.zeros(paddle.shape(inputs)) # outputs = paddle.zeros(paddle.shape(inputs))
logabsdet = paddle.zeros(paddle.shape(inputs)) # logabsdet = paddle.zeros(paddle.shape(inputs))
outputs = paddle.zeros(inputs.shape)
logabsdet = paddle.zeros(inputs.shape)
if tails == "linear": if tails == "linear":
unnormalized_derivatives = F.pad( unnormalized_derivatives = F.pad(
unnormalized_derivatives, unnormalized_derivatives,
@ -140,15 +142,16 @@ def rational_quadratic_spline(
min_bin_width=1e-3, min_bin_width=1e-3,
min_bin_height=1e-3, min_bin_height=1e-3,
min_derivative=1e-3, ): min_derivative=1e-3, ):
if paddle.min(inputs) < left or paddle.max(inputs) > right:
raise ValueError("Input to a transform is not within its domain") # if paddle.min(inputs) < left or paddle.max(inputs) > right:
# raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1] num_bins = unnormalized_widths.shape[-1]
if min_bin_width * num_bins > 1.0: # if min_bin_width * num_bins > 1.0:
raise ValueError("Minimal bin width too large for the number of bins") # raise ValueError("Minimal bin width too large for the number of bins")
if min_bin_height * num_bins > 1.0: # if min_bin_height * num_bins > 1.0:
raise ValueError("Minimal bin height too large for the number of bins") # raise ValueError("Minimal bin height too large for the number of bins")
widths = F.softmax(unnormalized_widths, axis=-1) widths = F.softmax(unnormalized_widths, axis=-1)
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths widths = min_bin_width + (1 - min_bin_width * num_bins) * widths

Loading…
Cancel
Save