Skip to content

Fold variadic Add/Mul in JAX and PyTorch backends#2238

Open
cetagostini wants to merge 1 commit into
pymc-devs:mainfrom
cetagostini:fix-jax-pytorch-variadic-fold
Open

Fold variadic Add/Mul in JAX and PyTorch backends#2238
cetagostini wants to merge 1 commit into
pymc-devs:mainfrom
cetagostini:fix-jax-pytorch-variadic-fold

Conversation

@cetagostini

Copy link
Copy Markdown
Contributor

Follow-up to #2235 (the MLX variadic fix). In that PR's discussion, @ricardoV94 noted the stacked jnp.sum strategy "is [not] what you would want (in either jax or mlx)". This applies the same fold to the JAX and PyTorch backends.

Problem

Both backends lower a variadic elementwise Add/Mul (3+ inputs) by stacking the operands and reducing with the nfunc_variadic reducer (sum/prod):

# jax/dispatch/scalar.py (pytorch is analogous)
def jax_func(*args):
    return jax_variadic_func(jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0)

They broadcast first, so unlike MLX they don't crash — but reducing a stacked array upcasts the dtype, diverging from the declared output dtype and the Python/C reference:

dtype PyTensor reference (py) current stack+sum this PR (fold)
bool bool [T, T, T] (OR) int32/int64 [2, 2, 3] bool [T, T, T]
int8 int8 int32/int64 int8
float32 float32 float32 float32

(Measured on jax 0.4.23 and torch 2.9.0.)

Fix

Fold the binary op (jnp.add/jnp.multiply, torch.add/torch.multiply) left-to-right with functools.reduce. The binary op broadcasts and preserves dtype, so the result matches the reference exactly — including bool OR semantics. This also matches the C/Numba backends (which emit a + b + c) and the MLX fix in #2235, and removes the rank-raising stack allocation. The fold reuses the same already-tested binary op the 2-input case dispatches to, so it's behaviour-preserving for the common numeric case.

Tests

Added to tests/link/jax/test_scalar.py and tests/link/pytorch/test_elemwise.py (reusing compare_jax_and_py / compare_pytorch_and_py):

  • test_*_variadic_broadcast[add|mul] — 3-input op with broadcast shapes (3,4)/(1,4)/(3,1).
  • test_*_variadic_add_dtype[bool|int8] — locks dtype preservation with an explicit dtype-checking assert_fn.

Verified RED -> GREEN: the dtype tests fail on main (upcast) and pass with the fold; the broadcast tests pass both ways (these backends already broadcast). The jax-scalar and pytorch-elemwise suites are otherwise green (3 pre-existing betaincinv/gammaincinv/gammainccinv NotImplementedErrors are unrelated to this change).


Disclosure: I used AI assistance (under my direction) while preparing this change; I reviewed and verified every line, test, and result.

Made with Cursor

Variadic Add/Mul (3+ inputs) lowered as a reduce over a stacked array via
the nfunc_variadic reducer (jnp.sum / torch.sum). Stacking + sum upcasts
bool/int dtypes (bool -> int, int8 -> int32/int64), diverging from the
declared output dtype and the Python/C reference.

Fold the binary op instead: it broadcasts and preserves dtype, matching
the C and Numba backends (a + b + c) and the MLX fix in pymc-devs#2235.

Follow-up to pymc-devs#2235.

Co-authored-by: Cursor <cursoragent@cursor.com>
@cetagostini cetagostini requested a review from ricardoV94 June 17, 2026 20:41
@cetagostini cetagostini self-assigned this Jun 17, 2026
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
# even though the base Op from `func_name` is specified as a binary Op.
# This happens with `Add`, which can work as a `Sum` for multiple scalars.
jax_variadic_func = getattr(jnp, op.nfunc_variadic, None)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this op.nfunc_variadic was used only for these dispatchers, in which case we can remove it from the Ops

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants