Skip to content

Add reinterpreted_batch_ndims to Independent#75

Open
gvcallen wants to merge 2 commits into
lockwo:mainfrom
gvcallen:multivariate_independent
Open

Add reinterpreted_batch_ndims to Independent#75
gvcallen wants to merge 2 commits into
lockwo:mainfrom
gvcallen:multivariate_independent

Conversation

@gvcallen
Copy link
Copy Markdown
Contributor

@gvcallen gvcallen commented Apr 6, 2026

Adds reinterpreted_batch_ndims to Independent, to support distributions that have been batched via eqx.filter_vmap (e.g., MultivariateNormalTri) that don't support handle batched arrays natively.

Although distreqx has expressly dropped the concept of batch dimensions throughout the codebase (which I strongly agree was the right choice), I do believe Independent still requires a reinterpreted_batch_ndims options, with motivation in the example below.

Without this PR, passing a vmapped distribution that doesn't natively support broadcasing to Independent crashes on methods like log_prob, because the inner computations weren't being appropriately mapped. This change allows users to explicitly define the number of mapped axes, letting Independent correctly unroll the vmap layers and reduce the results to a single scalar. The previous behaviour is left unaffected with the default of reinterpreted_batch_ndims = 0.

locs = jnp.zeros((20, 10, 3))
scales_tri = jnp.stack([jnp.tri(3)]*20*10, axis=0).reshape(20, 10, 3, 3)
xs = jnp.ones((20, 10, 3))

# Create a batch of mvn's
mvns = eqx.filter_vmap(eqx.filter_vmap(dist.MultivariateNormalTri))(locs, scales_tri)

# Calling log_prob (expectedly) does not work because MultivariateNormalTri assumes k x k
# log_prob = mvn.log_prob(locs) # error

# We can manually vmap the computation, but then we get a vector of log-probs.
# We could add a sum, but it is better to encapsulate this in a single distribution
def simple_log_prob(d, x): return d.log_prob(x)
log_probs = eqx.filter_vmap(eqx.filter_vmap(simple_log_prob))(mvns, xs)
assert log_probs.shape == (20, 10,)

# If we attempt to use "Independent" to reinterpret the batch dims, we get a crash,
# because the computation dimensions are not being reinterpreted
reinterp_mvn = dist.Independent(mvns)
# reinterp_mvn.log_prob(xs) # error

# With the propose changes, we can successfully encapsulate the independent MVNs.
# We could also pass e.g. 1 here and manually vmap the last dim if desired.
reinterp_mvn = dist.Independent(mvns, reinterpreted_batch_ndims=2)
log_prob = reinterp_mvn.log_prob(xs)
assert jnp.isscalar(log_prob)

- `reinterpreted_batch_ndims`: Number of batch dimensions to reinterpret
as event dimensions. Defaults to 0, which preserves standard broadcasting
behavior for natively batched distributions (e.g., `Normal`).
**Note:** If you are passing a distribution that does not natively broadcast
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this note render correctly? I think we can make nice note interface. Here is an AI summary of the interfaces

Solid block (!!!) — always visible:                                                                                                
  !!! note        
                                                                                                                                     
      Body must be indented 4 spaces and separated from the
      `!!!` line by a blank line.                                                                                                    
                                 
  Collapsible block (???) — collapsed by default; user clicks to open:                                                               
  ??? note                                                                                                                           
          
      Same indentation rules.                                                                                                        
                                                                                                                                     
  Open-by-default collapsible (???+):
  ???+ warning                                                                                                                       
                  
      Starts expanded but is still collapsible.                                                                                      
                                               
  Custom title — quoted string after the type:                                                                                       
  !!! warning "Bijectors are applied in reverse order"                                                                               
                                                      
      Given a sequence `[f, g]`, the `Chain` bijector computes `f(g(x))`... 

samples, log_prob = self.distribution.sample_and_log_prob(key)
log_prob = _reduce_helper(log_prob)
return samples, log_prob
if self.reinterpreted_batch_ndims == 0:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should have a specific flag that goes to the reduce helper, these means it should reduce over all dimensions basically, should it be instead a required argument (that then becomes all axis rather than just 0)? Just to simplify code?

total_batches = math.prod(bshape)
keys = jax.random.split(key, total_batches).reshape(*bshape)

def _single_sample_and_log_prob(d, k):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this patterns comes up several times, I think if we make a vmap helper it could reduce LoC? e.g.

  def _vmap_method(self, fn,):                                                                                                 
      for _ in range(self.reinterpreted_batch_ndims):                                                                              
          fn = eqx.filter_vmap(fn)                                                                                                   
      return fn(self.distribution,)


  def _vmap_and_sum(self, fn,):                                                                                                
      out = self._vmap_method(fn,)                                                                                           
      return jnp.sum(out, axis=tuple(range(self.reinterpreted_batch_ndims)))     

then each of these can just call vmap and sum

# on the raw, un-wrapped vmapped inner distributions.
d1_rndims = dist1.reinterpreted_batch_ndims
d2_rndims = dist2.reinterpreted_batch_ndims
p_base_shape = dist1.event_shape[d1_rndims:] # fmt: skip
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is fmt skip for here?

# Safely extract the base event shapes without triggering unsafe traces
# on the raw, un-wrapped vmapped inner distributions.
d1_rndims = dist1.reinterpreted_batch_ndims
d2_rndims = dist2.reinterpreted_batch_ndims
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if these don't match this has an asymmetric issue? should we enforce this eto be the same

Comment thread tests/independent_test.py
)
return batch_mvn, locs, scales_tri

# --- 1. Basic & Legacy Tests ---
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

un-needed comments

Comment thread tests/independent_test.py
self.assertIsInstance(model, Independent)

def test_legacy_broadcasting_behavior(self):
"""Tests the reinterpreted_batch_ndims=0 fallback for standard distributions."""
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

personally, I don't think we need to have much backwards compatible, we can break things if the new API is better.

Comment thread tests/independent_test.py
def assertion_fn(self, rtol=1e-5):
return lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol)

def _create_vmapped_mvn(self, M=20, N=10, D=3):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check across multiple distributions, e.g. bernoulli/normal/etc to make sure nothing unexpected happens

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants