diff --git a/paddlespeech/t2s/exps/vits/synthesize_e2e.py b/paddlespeech/t2s/exps/vits/synthesize_e2e.py index dac459a5b..9768a16ef 100644 --- a/paddlespeech/t2s/exps/vits/synthesize_e2e.py +++ b/paddlespeech/t2s/exps/vits/synthesize_e2e.py @@ -71,11 +71,6 @@ def evaluate(args): vits_inference = VITSInference(vits) # whether dygraph to static if args.inference_dir: - # acoustic model - # vits = jit.to_static( - # vits, input_spec=[InputSpec([-1], dtype=paddle.int64)]) - # jit.save(vits, os.path.join(inference_dir, args.am)) - # vits = jit.load(os.path.join(inference_dir, args.am)) vits_inference = am_to_static( am_inference=vits_inference, am=args.am, @@ -108,8 +103,8 @@ def evaluate(args): for i in range(len(phone_ids)): part_phone_ids = phone_ids[i] spk_id = None - if am_dataset in {"aishell3", "vctk" - } and spk_num is not None: + if am_dataset in {"aishell3", + "vctk"} and spk_num is not None: spk_id = paddle.to_tensor(args.spk_id) wav = vits_inference(part_phone_ids, spk_id) else: diff --git a/paddlespeech/t2s/models/vits/transform.py b/paddlespeech/t2s/models/vits/transform.py index 360453cca..ea333dcff 100644 --- a/paddlespeech/t2s/models/vits/transform.py +++ b/paddlespeech/t2s/models/vits/transform.py @@ -81,9 +81,9 @@ def unconstrained_rational_quadratic_spline( min_derivative=1e-3, ): inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) outside_interval_mask = ~inside_interval_mask - - # outputs = paddle.zeros(paddle.shape(inputs)) - # logabsdet = paddle.zeros(paddle.shape(inputs)) + # for dygraph to static + # 这里用 paddle.shape(x) 然后调用 zeros 会得到一个全 -1 shape 的 var + # 如果用 x.shape 的话可以保留确定的维度 outputs = paddle.zeros(inputs.shape) logabsdet = paddle.zeros(inputs.shape) if tails == "linear": @@ -93,12 +93,9 @@ def unconstrained_rational_quadratic_spline( constant = np.log(np.exp(1 - min_derivative) - 1) unnormalized_derivatives[..., 0] = constant unnormalized_derivatives[..., -1] = constant - # import pdb - # pdb.set_trace() - # print("inputs:",inputs) - # print("outside_interval_mask:",outside_interval_mask) - a = inputs[outside_interval_mask] - outputs[outside_interval_mask] = a + # for dygraph to static + tmp = inputs[outside_interval_mask] + outputs[outside_interval_mask] = tmp logabsdet[outside_interval_mask] = 0 else: raise RuntimeError("{} tails are not implemented.".format(tails)) @@ -142,12 +139,12 @@ def rational_quadratic_spline( min_bin_width=1e-3, min_bin_height=1e-3, min_derivative=1e-3, ): - + # for dygraph to static # if paddle.min(inputs) < left or paddle.max(inputs) > right: # raise ValueError("Input to a transform is not within its domain") num_bins = unnormalized_widths.shape[-1] - + # for dygraph to static # if min_bin_width * num_bins > 1.0: # raise ValueError("Minimal bin width too large for the number of bins") # if min_bin_height * num_bins > 1.0: