Skip to content

Accumulate duplicate indices in MLX AdvancedIncSubtensor#2242

Draft
cetagostini wants to merge 3 commits into
pymc-devs:mainfrom
cetagostini:fix-mlx-advinc-duplicate-indices
Draft

Accumulate duplicate indices in MLX AdvancedIncSubtensor#2242
cetagostini wants to merge 3 commits into
pymc-devs:mainfrom
cetagostini:fix-mlx-advinc-duplicate-indices

Conversation

@cetagostini

@cetagostini cetagostini commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

⚠️ Stacking / merge order

This PR is stacked on #2240 and must be merged after it. It branches off fix-mlx-incsubtensor-slice-grad and relies on #2240's refactor that routes AdvancedIncSubtensor1 onto mlx_funcify_AdvancedIncSubtensor. Until #2240 merges, this PR's diff also shows #2240's commits; GitHub will auto-retarget to a clean diff once #2240 lands. The two PRs are independent in purpose (slice-bound coercion vs. duplicate accumulation) but touch the same dispatcher, hence the ordering.

Summary

On the MLX backend, an increment with duplicate integer indices dropped repeated contributions instead of summing them. x[indices] += y desugars to a gather-add-scatter, so each destination is written once and repeated indices overwrite rather than accumulate. The canonical op explicitly avoids this (np.add.at), and every other backend uses a real scatter-add.

This is silent wrong numbers, not a crash — and it's not a corner case: the gradient of advanced indexing / gather is an inc_subtensor with duplicate indices whenever an index repeats (the norm for embedding lookups / repeated gathers). So gradients through embedding lookups on MLX silently computed wrong values.

Reproducer

import numpy as np, pytensor, pytensor.tensor as pt
from pytensor.tensor import subtensor as st

x = pt.vector("x", dtype="float32")
y = pt.vector("y", dtype="float32")
out = st.advanced_inc_subtensor1(x, y, np.array([0, 0, 0, 1]))   # duplicate index 0

x0 = np.zeros(3, "float32"); y0 = np.ones(4, "float32")
print(pytensor.function([x, y], out, mode="FAST_COMPILE")(x0, y0))  # [3. 1. 0.]  ✅ np.add.at
print(pytensor.function([x, y], out, mode="MLX")(x0, y0))           # [1. 1. 0.]  ❌ before this PR

Divergence (x = zeros(3), y = ones(4), idx = [0, 0, 0, 1])

Backend Result Correct?
Python perform (np.add.at, the spec) [3, 1, 0] ✅ reference
Numba (accumulating scatter codegen) [3, 1, 0]
JAX (x.at[idx].add(y)) [3, 1, 0]
PyTorch (index_put_(..., accumulate=True)) [3, 1, 0]
MLX (before) [1, 1, 0] ❌ duplicates dropped
MLX (this PR) [3, 1, 0]

Root cause

In pytensor/link/mlx/dispatch/subtensor.py, the inc closure did x[indices] += y. With fancy (integer-array) indexing this is gather → add → scatter, which writes each destination once. MLX was the outlier; the canonical perform documents exactly this trap:

if self.set_instead_of_inc:
    x[idx] = y
else:
    # In Numpy, `x[idx] += y` doesn't work if the same index is present
    # many times: it does it only once.
    np.add.at(x, idx, y)

Fix

Use MLX's functional scatter-add x.at[indices].add(y) for the inc path — the direct analogue of JAX's x.at[idx].add(y) — so duplicate indices accumulate. The change is confined to mlx_funcify_AdvancedIncSubtensor (which, after #2240, serves both AdvancedIncSubtensor and AdvancedIncSubtensor1). Basic IncSubtensor is untouched (basic indexing has no duplicate semantics).

ignore_duplicates is honored explicitly (three-way branch mirroring the reference perform and PyTorch/Numba):

  • set_instead_of_incx[indices] = y
  • ignore_duplicates=True → write-once x[indices] += y (numpy r[[0,1,0]] += 5 → [5,5,0] semantics; not accumulated, matching the reference/Numba/PyTorch)
  • otherwise → x.at[indices].add(y) (accumulate, np.add.at)

AdvancedIncSubtensor1 has no ignore_duplicates flag, so it always takes the accumulate path — matching its unconditional np.add.at perform. The getattr(op, "ignore_duplicates", False) guard keeps it safe under the shared registration.

cetagostini and others added 3 commits June 18, 2026 14:50
`inc_subtensor`/`set_subtensor` with a slice index failed on the MLX
backend with `ValueError: Slice indices must be integers or None.`. This
broke the gradient of any basic slicing (`x[a:b]`, `x[::2]`, ...), since
the gradient of "read a slice" is an `IncSubtensor` that increments that
slice of a zero tensor, and the slice bounds arrived as `mx.array`
scalars which MLX slices reject.

The forward `mlx_funcify_Subtensor` already coerces index inputs with
`[int(element) for element in ilists]`; the `IncSubtensor` path did not.
Coerce integer index inputs to Python ints in `mlx_funcify_IncSubtensor`,
mirroring the forward op (and matching C/Numba/JAX/PyTorch semantics).

`AdvancedIncSubtensor1` (a vector index that must not be int-coerced) is
moved off the shared `IncSubtensor` dispatcher onto
`mlx_funcify_AdvancedIncSubtensor`, mirroring PyTorch's canonical
basic-vs-advanced grouping.

Disclosure: implemented with AI assistance.
Co-authored-by: Cursor <cursoragent@cursor.com>
Link the documented negative-strided slice gradient limitation to the
upstream report ml-explore/mlx#3716.

Co-authored-by: Cursor <cursoragent@cursor.com>
MLX increment with duplicate integer indices dropped repeated
contributions: `x[indices] += y` desugars to gather-add-scatter and
writes each destination once. Gradients of advanced indexing (e.g.
embedding lookups with repeated token ids) are inc with duplicate
indices, so MLX silently computed wrong values.

Use MLX's functional scatter-add `x.at[indices].add(y)` for the inc
path, mirroring JAX and matching the reference `np.add.at`-based
`perform`. The `ignore_duplicates=True` mode keeps numpy write-once
`x[idx] += y` semantics (matching PyTorch/Numba/reference), and the set
path is unchanged.

Co-authored-by: Cursor <cursoragent@cursor.com>

@ricardoV94 ricardoV94 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking it as draft so we don't accidentally merge it before the base PRs are in

@ricardoV94 ricardoV94 marked this pull request as draft June 18, 2026 14:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants