Fold variadic Add/Mul in JAX and PyTorch backends#2238
Open
cetagostini wants to merge 1 commit into
Open
Conversation
Variadic Add/Mul (3+ inputs) lowered as a reduce over a stacked array via the nfunc_variadic reducer (jnp.sum / torch.sum). Stacking + sum upcasts bool/int dtypes (bool -> int, int8 -> int32/int64), diverging from the declared output dtype and the Python/C reference. Fold the binary op instead: it broadcasts and preserves dtype, matching the C and Numba backends (a + b + c) and the MLX fix in pymc-devs#2235. Follow-up to pymc-devs#2235. Co-authored-by: Cursor <cursoragent@cursor.com>
ricardoV94
reviewed
Jun 18, 2026
| # Some Scalar Ops accept multiple number of inputs, behaving as a variadic function, | ||
| # even though the base Op from `func_name` is specified as a binary Op. | ||
| # This happens with `Add`, which can work as a `Sum` for multiple scalars. | ||
| jax_variadic_func = getattr(jnp, op.nfunc_variadic, None) |
Member
There was a problem hiding this comment.
I think this op.nfunc_variadic was used only for these dispatchers, in which case we can remove it from the Ops
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Follow-up to #2235 (the MLX variadic fix). In that PR's discussion, @ricardoV94 noted the stacked
jnp.sumstrategy "is [not] what you would want (in either jax or mlx)". This applies the same fold to the JAX and PyTorch backends.Problem
Both backends lower a variadic elementwise
Add/Mul(3+ inputs) by stacking the operands and reducing with thenfunc_variadicreducer (sum/prod):They broadcast first, so unlike MLX they don't crash — but reducing a stacked array upcasts the dtype, diverging from the declared output dtype and the Python/C reference:
py)stack+sumboolbool[T, T, T](OR)int32/int64[2, 2, 3]bool[T, T, T]int8int8int32/int64int8float32float32float32float32(Measured on jax 0.4.23 and torch 2.9.0.)
Fix
Fold the binary op (
jnp.add/jnp.multiply,torch.add/torch.multiply) left-to-right withfunctools.reduce. The binary op broadcasts and preserves dtype, so the result matches the reference exactly — including bool OR semantics. This also matches the C/Numba backends (which emita + b + c) and the MLX fix in #2235, and removes the rank-raisingstackallocation. The fold reuses the same already-tested binary op the 2-input case dispatches to, so it's behaviour-preserving for the common numeric case.Tests
Added to
tests/link/jax/test_scalar.pyandtests/link/pytorch/test_elemwise.py(reusingcompare_jax_and_py/compare_pytorch_and_py):test_*_variadic_broadcast[add|mul]— 3-input op with broadcast shapes(3,4)/(1,4)/(3,1).test_*_variadic_add_dtype[bool|int8]— locks dtype preservation with an explicit dtype-checkingassert_fn.Verified RED -> GREEN: the dtype tests fail on
main(upcast) and pass with the fold; the broadcast tests pass both ways (these backends already broadcast). The jax-scalar and pytorch-elemwise suites are otherwise green (3 pre-existingbetaincinv/gammaincinv/gammainccinvNotImplementedErrors are unrelated to this change).Disclosure: I used AI assistance (under my direction) while preparing this change; I reviewed and verified every line, test, and result.
Made with Cursor