Skip to content

Fix MLX IncSubtensor with slice index (gradient of basic slicing)#2240

Open
cetagostini wants to merge 2 commits into
pymc-devs:mainfrom
cetagostini:fix-mlx-incsubtensor-slice-grad
Open

Fix MLX IncSubtensor with slice index (gradient of basic slicing)#2240
cetagostini wants to merge 2 commits into
pymc-devs:mainfrom
cetagostini:fix-mlx-incsubtensor-slice-grad

Conversation

@cetagostini

@cetagostini cetagostini commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Summary

inc_subtensor/set_subtensor with a slice index fails on the MLX backend:

ValueError: Slice indices must be integers or None.

This breaks the gradient of any basic slicing (x[a:b], x[::2], …), because the gradient of "read a slice" is "increment that slice of a zero tensor" — an IncSubtensor with a slice. The forward Subtensor works; only the gradient's IncSubtensor fails.

Reproducer

import numpy as np, pytensor, pytensor.tensor as pt
pytensor.config.floatX = "float32"

x = pt.vector("x")

# forward slice works on MLX:
f = pytensor.function([x], x[0:3], mode="MLX")
print(f(np.arange(6, dtype="float32")))            # OK -> [0. 1. 2.]

# gradient of a slice -> IncSubtensor(zeros[0:3], ...) -> FAILS on MLX (before this PR):
g = pt.grad((x[0:3] ** 2).sum(), x)
pytensor.function([x], g, mode="MLX")              # ValueError: Slice indices must be integers or None.

# same graph on the default backend is fine:
print(pytensor.function([x], g, mode=None)(np.arange(6, dtype="float32")))   # [0. 2. 4. 0. 0. 0.]

The interleaved / rotate-half RoPE pattern (x[..., 0::2], x[..., d//2:]) is the common transformer case that hits this.

Root cause

In pytensor/link/mlx/dispatch/subtensor.py, the forward mlx_funcify_Subtensor coerces its index inputs to Python ints:

indices = indices_from_subtensor([int(element) for element in ilists], op.idx_list)

but mlx_funcify_IncSubtensor passed the index inputs through unchanged. For an IncSubtensor whose idx_list contains a slice, the slice bounds therefore arrive as mx.array scalars, and MLX's x[slice] += y rejects array-typed slice bounds.

Fix

Coerce integer index inputs to Python ints in mlx_funcify_IncSubtensor, exactly as the forward mlx_funcify_Subtensor already does. This matches the canonical semantics defined by the C/Numba impl and mirrors JAX/PyTorch.

AdvancedIncSubtensor1 (whose index is an integer vector that must not be int()-coerced) is moved off the shared IncSubtensor dispatcher onto mlx_funcify_AdvancedIncSubtensor, mirroring PyTorch's canonical basic-vs-advanced grouping. This is semantically equivalent (idx_list is always (0,), so x[(idx,)] == x[idx]).

Tests

  • New test_mlx_IncSubtensor_slice_grad exercises the gradient of a contiguous slice (x[0:3]) and a strided RoPE-style slice (x[0::2]), asserting the node is an IncSubtensor and comparing against the Python backend. Confirmed RED on main, GREEN with this PR.
  • The fix also un-xfails test_mlx_subtensor_with_variables (set_subtensor(x[0, :2], y) with variable inputs) — same root cause, previously mis-attributed to "MLX indexing with tuples not yet supported".
  • A separate, pre-existing limitation surfaced (negative-step slice gradient under mx.compile) is documented with a strict xfail and tracked upstream; details + a pure-MLX reproducer in a follow-up comment.

`inc_subtensor`/`set_subtensor` with a slice index failed on the MLX
backend with `ValueError: Slice indices must be integers or None.`. This
broke the gradient of any basic slicing (`x[a:b]`, `x[::2]`, ...), since
the gradient of "read a slice" is an `IncSubtensor` that increments that
slice of a zero tensor, and the slice bounds arrived as `mx.array`
scalars which MLX slices reject.

The forward `mlx_funcify_Subtensor` already coerces index inputs with
`[int(element) for element in ilists]`; the `IncSubtensor` path did not.
Coerce integer index inputs to Python ints in `mlx_funcify_IncSubtensor`,
mirroring the forward op (and matching C/Numba/JAX/PyTorch semantics).

`AdvancedIncSubtensor1` (a vector index that must not be int-coerced) is
moved off the shared `IncSubtensor` dispatcher onto
`mlx_funcify_AdvancedIncSubtensor`, mirroring PyTorch's canonical
basic-vs-advanced grouping.

Disclosure: implemented with AI assistance.
Co-authored-by: Cursor <cursoragent@cursor.com>
@cetagostini cetagostini requested a review from ricardoV94 June 18, 2026 11:54
@cetagostini cetagostini self-assigned this Jun 18, 2026
@cetagostini

Copy link
Copy Markdown
Contributor Author

Follow-up: negative-step slice gradient is a separate, pre-existing mx.compile bug (not addressed here)

While validating this fix I found that the gradient of a negative-step slice (e.g. x[::-1]) now runs (before this PR it crashed with the same ValueError), but returns wrong values under the default use_compile=True mode — while being correct with use_compile=False / eager.

I traced it all the way down and reproduced it in pure MLX, with no PyTensor involved:

import mlx.core as mx

def f(x):
    base = mx.zeros_like(x)
    base[::-1] += 2.0 * x[::-1]   # negative-strided in-place scatter; update is an elementwise expr
    return base

x = mx.arange(6, dtype=mx.float32)
print("eager   :", f(x).tolist())             # [0, 2, 4, 6, 8, 10]   correct
print("compiled:", mx.compile(f)(x).tolist()) # [0, 0, 0, 0, 0, 10]   WRONG

Characterization (mlx 0.31.2):

  • A negative-strided in-place scatter base[::-1] += update under mx.compile writes only one element when update is a non-trivial elementwise expression (2.0 * x[::-1], x[::-1] + x[::-1], …).
  • It is correct when update is a bare view (base[::-1] += x[::-1]), when the scatter is positive-strided (base[::2] += 2.0 * x[::2]), and always in eager mode.

This is independent of the slice-bound coercion fixed in this PR (the eager/use_compile=False path is correct), and it was previously masked by the crash. Per "one root cause per PR", it is not fixed here.

It is documented and tracked with a strict xfail so it auto-alerts (xpass) when MLX fixes it upstream:

@pytest.mark.xfail(
    reason="Upstream mx.compile bug: a negative-strided in-place scatter whose "
    "update derives from a negative-strided view returns wrong values "
    "(correct when eager / use_compile=False).",
    strict=True,
)
def test_mlx_IncSubtensor_negative_step_slice_grad():
    x_pt = pt.vector("x", dtype="float32")
    x_np = np.arange(6, dtype=np.float32)
    g = pt.grad((x_pt[::-1] ** 2).sum(), x_pt)
    assert isinstance(g.owner.op, pt_subtensor.IncSubtensor)
    compare_mlx_and_py([x_pt], [g], [x_np])

I'll file an upstream MLX issue and link it here.

Link the documented negative-strided slice gradient limitation to the
upstream report ml-explore/mlx#3716.

Co-authored-by: Cursor <cursoragent@cursor.com>
@cetagostini

Copy link
Copy Markdown
Contributor Author

Filed the upstream MLX bug: ml-explore/mlx#3716"mx.compile: assigning an elementwise expression to a negative-strided slice writes only one element".

The xfail for test_mlx_IncSubtensor_negative_step_slice_grad now references it inline, so it will turn into an xpass (and fail strict) once MLX ships a fix, prompting us to drop the marker.

@cetagostini

Copy link
Copy Markdown
Contributor Author

Follow-up #2242 (duplicate-index accumulation in MLX AdvancedIncSubtensor) is stacked on this PR and depends on the AdvancedIncSubtensor1 dispatch refactor here. Merge order: this PR (#2240) first, then #2242.

@@ -69,8 +70,9 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):


@mlx_funcify.register(AdvancedIncSubtensor)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of scope if you want to merge this PR already but the dispatch for the general AdvancedIncSubtensor is wrong, as it can have slices as well. Maybe also worth checking AdvancedSubtensor

@ricardoV94 ricardoV94 added bug Something isn't working indexing mlx labels Jun 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working indexing mlx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants