Add reinterpreted_batch_ndims to Independent#75
Conversation
| - `reinterpreted_batch_ndims`: Number of batch dimensions to reinterpret | ||
| as event dimensions. Defaults to 0, which preserves standard broadcasting | ||
| behavior for natively batched distributions (e.g., `Normal`). | ||
| **Note:** If you are passing a distribution that does not natively broadcast |
There was a problem hiding this comment.
Does this note render correctly? I think we can make nice note interface. Here is an AI summary of the interfaces
Solid block (!!!) — always visible:
!!! note
Body must be indented 4 spaces and separated from the
`!!!` line by a blank line.
Collapsible block (???) — collapsed by default; user clicks to open:
??? note
Same indentation rules.
Open-by-default collapsible (???+):
???+ warning
Starts expanded but is still collapsible.
Custom title — quoted string after the type:
!!! warning "Bijectors are applied in reverse order"
Given a sequence `[f, g]`, the `Chain` bijector computes `f(g(x))`...
| samples, log_prob = self.distribution.sample_and_log_prob(key) | ||
| log_prob = _reduce_helper(log_prob) | ||
| return samples, log_prob | ||
| if self.reinterpreted_batch_ndims == 0: |
There was a problem hiding this comment.
I'm not sure we should have a specific flag that goes to the reduce helper, these means it should reduce over all dimensions basically, should it be instead a required argument (that then becomes all axis rather than just 0)? Just to simplify code?
| total_batches = math.prod(bshape) | ||
| keys = jax.random.split(key, total_batches).reshape(*bshape) | ||
|
|
||
| def _single_sample_and_log_prob(d, k): |
There was a problem hiding this comment.
this patterns comes up several times, I think if we make a vmap helper it could reduce LoC? e.g.
def _vmap_method(self, fn,):
for _ in range(self.reinterpreted_batch_ndims):
fn = eqx.filter_vmap(fn)
return fn(self.distribution,)
def _vmap_and_sum(self, fn,):
out = self._vmap_method(fn,)
return jnp.sum(out, axis=tuple(range(self.reinterpreted_batch_ndims)))
then each of these can just call vmap and sum
| # on the raw, un-wrapped vmapped inner distributions. | ||
| d1_rndims = dist1.reinterpreted_batch_ndims | ||
| d2_rndims = dist2.reinterpreted_batch_ndims | ||
| p_base_shape = dist1.event_shape[d1_rndims:] # fmt: skip |
| # Safely extract the base event shapes without triggering unsafe traces | ||
| # on the raw, un-wrapped vmapped inner distributions. | ||
| d1_rndims = dist1.reinterpreted_batch_ndims | ||
| d2_rndims = dist2.reinterpreted_batch_ndims |
There was a problem hiding this comment.
if these don't match this has an asymmetric issue? should we enforce this eto be the same
| ) | ||
| return batch_mvn, locs, scales_tri | ||
|
|
||
| # --- 1. Basic & Legacy Tests --- |
| self.assertIsInstance(model, Independent) | ||
|
|
||
| def test_legacy_broadcasting_behavior(self): | ||
| """Tests the reinterpreted_batch_ndims=0 fallback for standard distributions.""" |
There was a problem hiding this comment.
personally, I don't think we need to have much backwards compatible, we can break things if the new API is better.
| def assertion_fn(self, rtol=1e-5): | ||
| return lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol) | ||
|
|
||
| def _create_vmapped_mvn(self, M=20, N=10, D=3): |
There was a problem hiding this comment.
we should check across multiple distributions, e.g. bernoulli/normal/etc to make sure nothing unexpected happens
Adds
reinterpreted_batch_ndimstoIndependent, to support distributions that have been batched via eqx.filter_vmap (e.g.,MultivariateNormalTri) that don't support handle batched arrays natively.Although distreqx has expressly dropped the concept of batch dimensions throughout the codebase (which I strongly agree was the right choice), I do believe
Independentstill requires areinterpreted_batch_ndimsoptions, with motivation in the example below.Without this PR, passing a vmapped distribution that doesn't natively support broadcasing to
Independentcrashes on methods likelog_prob, because the inner computations weren't being appropriately mapped. This change allows users to explicitly define the number of mapped axes, lettingIndependentcorrectly unroll the vmap layers and reduce the results to a single scalar. The previous behaviour is left unaffected with the default ofreinterpreted_batch_ndims = 0.