Fix MLX IncSubtensor with slice index (gradient of basic slicing)#2240
Fix MLX IncSubtensor with slice index (gradient of basic slicing)#2240cetagostini wants to merge 2 commits into
Conversation
`inc_subtensor`/`set_subtensor` with a slice index failed on the MLX backend with `ValueError: Slice indices must be integers or None.`. This broke the gradient of any basic slicing (`x[a:b]`, `x[::2]`, ...), since the gradient of "read a slice" is an `IncSubtensor` that increments that slice of a zero tensor, and the slice bounds arrived as `mx.array` scalars which MLX slices reject. The forward `mlx_funcify_Subtensor` already coerces index inputs with `[int(element) for element in ilists]`; the `IncSubtensor` path did not. Coerce integer index inputs to Python ints in `mlx_funcify_IncSubtensor`, mirroring the forward op (and matching C/Numba/JAX/PyTorch semantics). `AdvancedIncSubtensor1` (a vector index that must not be int-coerced) is moved off the shared `IncSubtensor` dispatcher onto `mlx_funcify_AdvancedIncSubtensor`, mirroring PyTorch's canonical basic-vs-advanced grouping. Disclosure: implemented with AI assistance. Co-authored-by: Cursor <cursoragent@cursor.com>
Follow-up: negative-step slice gradient is a separate, pre-existing
|
Link the documented negative-strided slice gradient limitation to the upstream report ml-explore/mlx#3716. Co-authored-by: Cursor <cursoragent@cursor.com>
|
Filed the upstream MLX bug: ml-explore/mlx#3716 — "mx.compile: assigning an elementwise expression to a negative-strided slice writes only one element". The |
| @@ -69,8 +70,9 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): | |||
|
|
|||
|
|
|||
| @mlx_funcify.register(AdvancedIncSubtensor) | |||
There was a problem hiding this comment.
out of scope if you want to merge this PR already but the dispatch for the general AdvancedIncSubtensor is wrong, as it can have slices as well. Maybe also worth checking AdvancedSubtensor
Summary
inc_subtensor/set_subtensorwith a slice index fails on the MLX backend:This breaks the gradient of any basic slicing (
x[a:b],x[::2], …), because the gradient of "read a slice" is "increment that slice of a zero tensor" — anIncSubtensorwith a slice. The forwardSubtensorworks; only the gradient'sIncSubtensorfails.Reproducer
The interleaved / rotate-half RoPE pattern (
x[..., 0::2],x[..., d//2:]) is the common transformer case that hits this.Root cause
In
pytensor/link/mlx/dispatch/subtensor.py, the forwardmlx_funcify_Subtensorcoerces its index inputs to Python ints:but
mlx_funcify_IncSubtensorpassed the index inputs through unchanged. For anIncSubtensorwhoseidx_listcontains a slice, the slice bounds therefore arrive asmx.arrayscalars, and MLX'sx[slice] += yrejects array-typed slice bounds.Fix
Coerce integer index inputs to Python ints in
mlx_funcify_IncSubtensor, exactly as the forwardmlx_funcify_Subtensoralready does. This matches the canonical semantics defined by the C/Numbaimpland mirrors JAX/PyTorch.AdvancedIncSubtensor1(whose index is an integer vector that must not beint()-coerced) is moved off the sharedIncSubtensordispatcher ontomlx_funcify_AdvancedIncSubtensor, mirroring PyTorch's canonical basic-vs-advanced grouping. This is semantically equivalent (idx_listis always(0,), sox[(idx,)] == x[idx]).Tests
test_mlx_IncSubtensor_slice_gradexercises the gradient of a contiguous slice (x[0:3]) and a strided RoPE-style slice (x[0::2]), asserting the node is anIncSubtensorand comparing against the Python backend. Confirmed RED onmain, GREEN with this PR.xfailstest_mlx_subtensor_with_variables(set_subtensor(x[0, :2], y)with variable inputs) — same root cause, previously mis-attributed to "MLX indexing with tuples not yet supported".mx.compile) is documented with astrictxfailand tracked upstream; details + a pure-MLX reproducer in a follow-up comment.