pull/2883/head
TianYuan 3 years ago
parent d46ca0866a
commit 90e135fa7a

@ -71,11 +71,6 @@ def evaluate(args):
vits_inference = VITSInference(vits) vits_inference = VITSInference(vits)
# whether dygraph to static # whether dygraph to static
if args.inference_dir: if args.inference_dir:
# acoustic model
# vits = jit.to_static(
# vits, input_spec=[InputSpec([-1], dtype=paddle.int64)])
# jit.save(vits, os.path.join(inference_dir, args.am))
# vits = jit.load(os.path.join(inference_dir, args.am))
vits_inference = am_to_static( vits_inference = am_to_static(
am_inference=vits_inference, am_inference=vits_inference,
am=args.am, am=args.am,
@ -108,8 +103,8 @@ def evaluate(args):
for i in range(len(phone_ids)): for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i] part_phone_ids = phone_ids[i]
spk_id = None spk_id = None
if am_dataset in {"aishell3", "vctk" if am_dataset in {"aishell3",
} and spk_num is not None: "vctk"} and spk_num is not None:
spk_id = paddle.to_tensor(args.spk_id) spk_id = paddle.to_tensor(args.spk_id)
wav = vits_inference(part_phone_ids, spk_id) wav = vits_inference(part_phone_ids, spk_id)
else: else:

@ -81,9 +81,9 @@ def unconstrained_rational_quadratic_spline(
min_derivative=1e-3, ): min_derivative=1e-3, ):
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
outside_interval_mask = ~inside_interval_mask outside_interval_mask = ~inside_interval_mask
# for dygraph to static
# outputs = paddle.zeros(paddle.shape(inputs)) # 这里用 paddle.shape(x) 然后调用 zeros 会得到一个全 -1 shape 的 var
# logabsdet = paddle.zeros(paddle.shape(inputs)) # 如果用 x.shape 的话可以保留确定的维度
outputs = paddle.zeros(inputs.shape) outputs = paddle.zeros(inputs.shape)
logabsdet = paddle.zeros(inputs.shape) logabsdet = paddle.zeros(inputs.shape)
if tails == "linear": if tails == "linear":
@ -93,12 +93,9 @@ def unconstrained_rational_quadratic_spline(
constant = np.log(np.exp(1 - min_derivative) - 1) constant = np.log(np.exp(1 - min_derivative) - 1)
unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., 0] = constant
unnormalized_derivatives[..., -1] = constant unnormalized_derivatives[..., -1] = constant
# import pdb # for dygraph to static
# pdb.set_trace() tmp = inputs[outside_interval_mask]
# print("inputs:",inputs) outputs[outside_interval_mask] = tmp
# print("outside_interval_mask:",outside_interval_mask)
a = inputs[outside_interval_mask]
outputs[outside_interval_mask] = a
logabsdet[outside_interval_mask] = 0 logabsdet[outside_interval_mask] = 0
else: else:
raise RuntimeError("{} tails are not implemented.".format(tails)) raise RuntimeError("{} tails are not implemented.".format(tails))
@ -142,12 +139,12 @@ def rational_quadratic_spline(
min_bin_width=1e-3, min_bin_width=1e-3,
min_bin_height=1e-3, min_bin_height=1e-3,
min_derivative=1e-3, ): min_derivative=1e-3, ):
# for dygraph to static
# if paddle.min(inputs) < left or paddle.max(inputs) > right: # if paddle.min(inputs) < left or paddle.max(inputs) > right:
# raise ValueError("Input to a transform is not within its domain") # raise ValueError("Input to a transform is not within its domain")
num_bins = unnormalized_widths.shape[-1] num_bins = unnormalized_widths.shape[-1]
# for dygraph to static
# if min_bin_width * num_bins > 1.0: # if min_bin_width * num_bins > 1.0:
# raise ValueError("Minimal bin width too large for the number of bins") # raise ValueError("Minimal bin width too large for the number of bins")
# if min_bin_height * num_bins > 1.0: # if min_bin_height * num_bins > 1.0:

Loading…
Cancel
Save