|
|
|
@ -82,8 +82,10 @@ def unconstrained_rational_quadratic_spline(
|
|
|
|
|
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
|
|
|
|
outside_interval_mask = ~inside_interval_mask
|
|
|
|
|
|
|
|
|
|
outputs = paddle.zeros(paddle.shape(inputs))
|
|
|
|
|
logabsdet = paddle.zeros(paddle.shape(inputs))
|
|
|
|
|
# outputs = paddle.zeros(paddle.shape(inputs))
|
|
|
|
|
# logabsdet = paddle.zeros(paddle.shape(inputs))
|
|
|
|
|
outputs = paddle.zeros(inputs.shape)
|
|
|
|
|
logabsdet = paddle.zeros(inputs.shape)
|
|
|
|
|
if tails == "linear":
|
|
|
|
|
unnormalized_derivatives = F.pad(
|
|
|
|
|
unnormalized_derivatives,
|
|
|
|
@ -140,15 +142,16 @@ def rational_quadratic_spline(
|
|
|
|
|
min_bin_width=1e-3,
|
|
|
|
|
min_bin_height=1e-3,
|
|
|
|
|
min_derivative=1e-3, ):
|
|
|
|
|
if paddle.min(inputs) < left or paddle.max(inputs) > right:
|
|
|
|
|
raise ValueError("Input to a transform is not within its domain")
|
|
|
|
|
|
|
|
|
|
# if paddle.min(inputs) < left or paddle.max(inputs) > right:
|
|
|
|
|
# raise ValueError("Input to a transform is not within its domain")
|
|
|
|
|
|
|
|
|
|
num_bins = unnormalized_widths.shape[-1]
|
|
|
|
|
|
|
|
|
|
if min_bin_width * num_bins > 1.0:
|
|
|
|
|
raise ValueError("Minimal bin width too large for the number of bins")
|
|
|
|
|
if min_bin_height * num_bins > 1.0:
|
|
|
|
|
raise ValueError("Minimal bin height too large for the number of bins")
|
|
|
|
|
# if min_bin_width * num_bins > 1.0:
|
|
|
|
|
# raise ValueError("Minimal bin width too large for the number of bins")
|
|
|
|
|
# if min_bin_height * num_bins > 1.0:
|
|
|
|
|
# raise ValueError("Minimal bin height too large for the number of bins")
|
|
|
|
|
|
|
|
|
|
widths = F.softmax(unnormalized_widths, axis=-1)
|
|
|
|
|
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
|
|
|
|