Add RationalQuadraticSpline#71
Open
gvcallen wants to merge 1 commit into
Open
Conversation
lockwo
reviewed
May 6, 2026
| bin_slope = bin_height / bin_width | ||
|
|
||
| z = (x - x_pos_bin[0]) / bin_width | ||
| z = jnp.clip(z, 0.0, 1.0) |
Owner
There was a problem hiding this comment.
this function has some good comments in the distrax codebase, we should copy them over
| """A rational-quadratic spline bijector. | ||
|
|
||
| Implements the spline bijector introduced by: | ||
| > Durkan et al., Neural Spline Flows, https://arxiv.org/abs/1906.04032, 2019. |
Owner
There was a problem hiding this comment.
we can do the special bib notation here for mkdocs (e.g. ??? cite "References")
| ): | ||
| """Initializes a RationalQuadraticSpline bijector. | ||
|
|
||
| Args: |
Owner
There was a problem hiding this comment.
should match docstring style (e.g. **Arguments**)
| unnormalized_bin_heights = params[ | ||
| ..., self.num_bins : 2 * self.num_bins # noqa: E203 | ||
| ] | ||
| unnormalized_knot_slopes = params[..., 2 * self.num_bins :] # noqa: E203 |
Owner
There was a problem hiding this comment.
I thought I fixed the black/flake interaction for these, does it error if you remove the # noqa: E203 ?
| from distreqx.bijectors import RationalQuadraticSpline | ||
|
|
||
|
|
||
| class RationalQuadraticSplineTest(TestCase): |
Owner
There was a problem hiding this comment.
There's a few more low hanging fruit from the distrax tests, e.g. https://github.com/google-deepmind/distrax/blob/main/distrax/_src/bijectors/rational_quadratic_spline_test.py#L153 that are worth porting over
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds the rational quadratic spline bijector, very useful for complex normalizing flows. Reference distrax code is here