Accumulate duplicate indices in MLX AdvancedIncSubtensor#2242
Draft
cetagostini wants to merge 3 commits into
Draft
Accumulate duplicate indices in MLX AdvancedIncSubtensor#2242cetagostini wants to merge 3 commits into
cetagostini wants to merge 3 commits into
Conversation
`inc_subtensor`/`set_subtensor` with a slice index failed on the MLX backend with `ValueError: Slice indices must be integers or None.`. This broke the gradient of any basic slicing (`x[a:b]`, `x[::2]`, ...), since the gradient of "read a slice" is an `IncSubtensor` that increments that slice of a zero tensor, and the slice bounds arrived as `mx.array` scalars which MLX slices reject. The forward `mlx_funcify_Subtensor` already coerces index inputs with `[int(element) for element in ilists]`; the `IncSubtensor` path did not. Coerce integer index inputs to Python ints in `mlx_funcify_IncSubtensor`, mirroring the forward op (and matching C/Numba/JAX/PyTorch semantics). `AdvancedIncSubtensor1` (a vector index that must not be int-coerced) is moved off the shared `IncSubtensor` dispatcher onto `mlx_funcify_AdvancedIncSubtensor`, mirroring PyTorch's canonical basic-vs-advanced grouping. Disclosure: implemented with AI assistance. Co-authored-by: Cursor <cursoragent@cursor.com>
Link the documented negative-strided slice gradient limitation to the upstream report ml-explore/mlx#3716. Co-authored-by: Cursor <cursoragent@cursor.com>
MLX increment with duplicate integer indices dropped repeated contributions: `x[indices] += y` desugars to gather-add-scatter and writes each destination once. Gradients of advanced indexing (e.g. embedding lookups with repeated token ids) are inc with duplicate indices, so MLX silently computed wrong values. Use MLX's functional scatter-add `x.at[indices].add(y)` for the inc path, mirroring JAX and matching the reference `np.add.at`-based `perform`. The `ignore_duplicates=True` mode keeps numpy write-once `x[idx] += y` semantics (matching PyTorch/Numba/reference), and the set path is unchanged. Co-authored-by: Cursor <cursoragent@cursor.com>
ricardoV94
approved these changes
Jun 18, 2026
ricardoV94
left a comment
Member
There was a problem hiding this comment.
Marking it as draft so we don't accidentally merge it before the base PRs are in
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR is stacked on #2240 and must be merged after it. It branches off
fix-mlx-incsubtensor-slice-gradand relies on #2240's refactor that routesAdvancedIncSubtensor1ontomlx_funcify_AdvancedIncSubtensor. Until #2240 merges, this PR's diff also shows #2240's commits; GitHub will auto-retarget to a clean diff once #2240 lands. The two PRs are independent in purpose (slice-bound coercion vs. duplicate accumulation) but touch the same dispatcher, hence the ordering.Summary
On the MLX backend, an increment with duplicate integer indices dropped repeated contributions instead of summing them.
x[indices] += ydesugars to a gather-add-scatter, so each destination is written once and repeated indices overwrite rather than accumulate. The canonical op explicitly avoids this (np.add.at), and every other backend uses a real scatter-add.This is silent wrong numbers, not a crash — and it's not a corner case: the gradient of advanced indexing /
gatheris aninc_subtensorwith duplicate indices whenever an index repeats (the norm for embedding lookups / repeated gathers). So gradients through embedding lookups on MLX silently computed wrong values.Reproducer
Divergence (x = zeros(3), y = ones(4), idx = [0, 0, 0, 1])
perform(np.add.at, the spec)[3, 1, 0][3, 1, 0]x.at[idx].add(y))[3, 1, 0]index_put_(..., accumulate=True))[3, 1, 0][1, 1, 0][3, 1, 0]Root cause
In
pytensor/link/mlx/dispatch/subtensor.py, the inc closure didx[indices] += y. With fancy (integer-array) indexing this is gather → add → scatter, which writes each destination once. MLX was the outlier; the canonicalperformdocuments exactly this trap:Fix
Use MLX's functional scatter-add
x.at[indices].add(y)for the inc path — the direct analogue of JAX'sx.at[idx].add(y)— so duplicate indices accumulate. The change is confined tomlx_funcify_AdvancedIncSubtensor(which, after #2240, serves bothAdvancedIncSubtensorandAdvancedIncSubtensor1). BasicIncSubtensoris untouched (basic indexing has no duplicate semantics).ignore_duplicatesis honored explicitly (three-way branch mirroring the referenceperformand PyTorch/Numba):set_instead_of_inc→x[indices] = yignore_duplicates=True→ write-oncex[indices] += y(numpyr[[0,1,0]] += 5 → [5,5,0]semantics; not accumulated, matching the reference/Numba/PyTorch)x.at[indices].add(y)(accumulate,np.add.at)AdvancedIncSubtensor1has noignore_duplicatesflag, so it always takes the accumulate path — matching its unconditionalnp.add.atperform. Thegetattr(op, "ignore_duplicates", False)guard keeps it safe under the shared registration.