Enforce AdvancedIncSubtensor1 runtime-broadcast check on MLX backend#2241
Open
cetagostini wants to merge 1 commit into
Open
Enforce AdvancedIncSubtensor1 runtime-broadcast check on MLX backend#2241cetagostini wants to merge 1 commit into
cetagostini wants to merge 1 commit into
Conversation
`AdvancedIncSubtensor1` (`x[int_vector] += y` / `= y`) requires `y` to already match the indexed shape; runtime broadcasting is forbidden. The canonical op enforces this in `perform`/`c_code` via `_check_runtime_broadcasting`, and the JAX and PyTorch dispatchers call it under an `isinstance(op, AdvancedIncSubtensor1)` guard. The MLX dispatcher did not, so MLX silently broadcast `y` and returned a result where every other backend raises `ValueError: Runtime broadcasting not allowed`. Call `op._check_runtime_broadcasting(node, x, y, indices)` for `AdvancedIncSubtensor1` in `mlx_funcify_IncSubtensor`, mirroring JAX line-for-line. This is a cross-backend-consistency fix (MLX was too permissive); `AdvancedIncSubtensor` legitimately allows broadcasting and is correctly left unguarded. Disclosure: implemented with AI assistance. Co-authored-by: Cursor <cursoragent@cursor.com>
ricardoV94
reviewed
Jun 18, 2026
| a statically non-broadcastable dimension that is length 1 at runtime is an | ||
| error, not a silent broadcast. | ||
| """ | ||
| from pytensor import function |
ricardoV94
reviewed
Jun 18, 2026
| out = func(x, y, idxs) | ||
| assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor1) | ||
|
|
||
| f = function([y], out, mode="MLX") |
Member
There was a problem hiding this comment.
mode should me Mode(linker="mlx", optimizer=None), otherwise your advanced_inc_subtensor path will end up testing the same advancde_inc_subtensor1 op
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
AdvancedIncSubtensor1(x[int_vector] += y/x[int_vector] = y) carries a hard contract:ymust already match the indexed shape — runtime broadcasting is not allowed (a statically non-broadcastable dimension that happens to be length 1 at runtime is an error, not a silent broadcast). The canonical op enforces this in bothperformandc_codeviaAdvancedIncSubtensor1._check_runtime_broadcasting, and the JAX and PyTorch dispatchers call it under anisinstance(op, AdvancedIncSubtensor1)guard.The MLX dispatcher never called it, so MLX silently broadcasts
yand returns a result where every other backend (Python/C/Numba/JAX/PyTorch) raisesValueError: Runtime broadcasting not allowed.Reproducer
Root cause
In
pytensor/link/mlx/dispatch/subtensor.py,mlx_funcify_IncSubtensor(which also servesAdvancedIncSubtensor1) computed the indices and called the increment/set function without ever invokingop._check_runtime_broadcasting(...). JAX does exactly this check; MLX was the outlier.Fix
Call
op._check_runtime_broadcasting(node, x, y, indices)forAdvancedIncSubtensor1, mirroring the JAX dispatcher line-for-line:AdvancedIncSubtensor1only.AdvancedIncSubtensorlegitimately allows broadcasting and has no such method (calling it unconditionally wouldAttributeError); this matches JAX/PyTorch.setandinc(the check runs before the increment/set function in both branches).mx.arrayexposes.ndimand a tuple.shape, so_check_runtime_broadcastingruns unmodified.This is purely a cross-backend-consistency fix — MLX was more permissive than the spec.
Tests
New
test_mlx_AdvancedIncSubtensor1_runtime_broadcast, parametrized overadvanced_inc_subtensor1/advanced_set_subtensor1(mirrorstest_jax_AdvancedIncSubtensor1_runtime_broadcast): a correctly sizedyruns fine, while a runtime broadcast along the index dim(1, 5)or the buffer dim(20, 1)raisesValueError: Runtime broadcasting not allowed.