Skip to content

Add RationalQuadraticSpline#71

Open
gvcallen wants to merge 1 commit into
lockwo:mainfrom
gvcallen:rqs
Open

Add RationalQuadraticSpline#71
gvcallen wants to merge 1 commit into
lockwo:mainfrom
gvcallen:rqs

Conversation

@gvcallen
Copy link
Copy Markdown
Contributor

@gvcallen gvcallen commented Apr 4, 2026

Adds the rational quadratic spline bijector, very useful for complex normalizing flows. Reference distrax code is here

@gvcallen gvcallen changed the title Add RationalQuadraticSpline bijector Add RationalQuadraticSpline Apr 4, 2026
bin_slope = bin_height / bin_width

z = (x - x_pos_bin[0]) / bin_width
z = jnp.clip(z, 0.0, 1.0)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function has some good comments in the distrax codebase, we should copy them over

"""A rational-quadratic spline bijector.

Implements the spline bijector introduced by:
> Durkan et al., Neural Spline Flows, https://arxiv.org/abs/1906.04032, 2019.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do the special bib notation here for mkdocs (e.g. ??? cite "References")

):
"""Initializes a RationalQuadraticSpline bijector.

Args:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should match docstring style (e.g. **Arguments**)

unnormalized_bin_heights = params[
..., self.num_bins : 2 * self.num_bins # noqa: E203
]
unnormalized_knot_slopes = params[..., 2 * self.num_bins :] # noqa: E203
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I fixed the black/flake interaction for these, does it error if you remove the # noqa: E203 ?

from distreqx.bijectors import RationalQuadraticSpline


class RationalQuadraticSplineTest(TestCase):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few more low hanging fruit from the distrax tests, e.g. https://github.com/google-deepmind/distrax/blob/main/distrax/_src/bijectors/rational_quadratic_spline_test.py#L153 that are worth porting over

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants