pull/2883/head
TianYuan 3 years ago
parent d46ca0866a
commit 90e135fa7a

@ -71,11 +71,6 @@ def evaluate(args):
vits_inference = VITSInference(vits)
# whether dygraph to static
if args.inference_dir:
# acoustic model
# vits = jit.to_static(
# vits, input_spec=[InputSpec([-1], dtype=paddle.int64)])
# jit.save(vits, os.path.join(inference_dir, args.am))
# vits = jit.load(os.path.join(inference_dir, args.am))
vits_inference = am_to_static(
am_inference=vits_inference,
am=args.am,
@ -108,8 +103,8 @@ def evaluate(args):
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i]
spk_id = None
if am_dataset in {"aishell3", "vctk"
} and spk_num is not None:
if am_dataset in {"aishell3",
"vctk"} and spk_num is not None:
spk_id = paddle.to_tensor(args.spk_id)
wav = vits_inference(part_phone_ids, spk_id)
else:

@ -81,9 +81,9 @@ def unconstrained_rational_quadratic_spline(
min_derivative=1e-3, ):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask
# outputs = paddle.zeros(paddle.shape(inputs))
# logabsdet = paddle.zeros(paddle.shape(inputs))
# for dygraph to static
# 这里用 paddle.shape(x) 然后调用 zeros 会得到一个全 -1 shape 的 var
# 如果用 x.shape 的话可以保留确定的维度
outputs = paddle.zeros(inputs.shape)
logabsdet = paddle.zeros(inputs.shape)
if tails == "linear":
@ -93,12 +93,9 @@ def unconstrained_rational_quadratic_spline(
constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant
# import pdb
# pdb.set_trace()
# print("inputs:",inputs)
# print("outside_interval_mask:",outside_interval_mask)
a = inputs[outside_interval_mask]
outputs[outside_interval_mask] = a
# for dygraph to static
tmp = inputs[outside_interval_mask]
outputs[outside_interval_mask] = tmp
logabsdet[outside_interval_mask] = 0
else:
raise RuntimeError("{} tails are not implemented.".format(tails))
@ -142,12 +139,12 @@ def rational_quadratic_spline(
min_bin_width=1e-3,
min_bin_height=1e-3,
min_derivative=1e-3, ):
# for dygraph to static
# if paddle.min(inputs) < left or paddle.max(inputs) > right:
# raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1]
# for dygraph to static
# if min_bin_width * num_bins > 1.0:
# raise ValueError("Minimal bin width too large for the number of bins")
# if min_bin_height * num_bins > 1.0:

Loading…
Cancel
Save