Skip to content

Enforce AdvancedIncSubtensor1 runtime-broadcast check on MLX backend#2241

Open
cetagostini wants to merge 1 commit into
pymc-devs:mainfrom
cetagostini:fix-mlx-advinc1-runtime-broadcast
Open

Enforce AdvancedIncSubtensor1 runtime-broadcast check on MLX backend#2241
cetagostini wants to merge 1 commit into
pymc-devs:mainfrom
cetagostini:fix-mlx-advinc1-runtime-broadcast

Conversation

@cetagostini

@cetagostini cetagostini commented Jun 18, 2026

Copy link
Copy Markdown
Contributor

Summary

AdvancedIncSubtensor1 (x[int_vector] += y / x[int_vector] = y) carries a hard contract: y must already match the indexed shape — runtime broadcasting is not allowed (a statically non-broadcastable dimension that happens to be length 1 at runtime is an error, not a silent broadcast). The canonical op enforces this in both perform and c_code via AdvancedIncSubtensor1._check_runtime_broadcasting, and the JAX and PyTorch dispatchers call it under an isinstance(op, AdvancedIncSubtensor1) guard.

The MLX dispatcher never called it, so MLX silently broadcasts y and returns a result where every other backend (Python/C/Numba/JAX/PyTorch) raises ValueError: Runtime broadcasting not allowed.

Reproducer

import numpy as np, pytensor, pytensor.tensor as pt
from pytensor.tensor import subtensor as pt_subtensor
pytensor.config.floatX = "float32"

x = pt.matrix("x", dtype="float32")
y = pt.matrix("y", dtype="float32")          # static shape (None, None): first dim NOT broadcastable
out = pt_subtensor.advanced_set_subtensor1(x, y, [0, 2])   # AdvancedIncSubtensor1

x_np = np.zeros((4, 3), dtype="float32")
y_np = np.ones((1, 3), dtype="float32")      # runtime length-1 first dim vs 2 indices

# Default backend correctly rejects the runtime broadcast:
pytensor.function([x, y], out, mode="FAST_COMPILE")(x_np, y_np)
# -> ValueError: Runtime broadcasting not allowed. ...

# MLX (before this PR) silently broadcasts instead of raising:
print(pytensor.function([x, y], out, mode="MLX")(x_np, y_np))
# -> array([[1,1,1],[0,0,0],[1,1,1],[0,0,0]])   (no error)

Root cause

In pytensor/link/mlx/dispatch/subtensor.py, mlx_funcify_IncSubtensor (which also serves AdvancedIncSubtensor1) computed the indices and called the increment/set function without ever invoking op._check_runtime_broadcasting(...). JAX does exactly this check; MLX was the outlier.

Fix

Call op._check_runtime_broadcasting(node, x, y, indices) for AdvancedIncSubtensor1, mirroring the JAX dispatcher line-for-line:

def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):
    indices = indices_from_subtensor(ilist, idx_list)
    if len(indices) == 1:
        indices = indices[0]

    if isinstance(op, AdvancedIncSubtensor1):
        op._check_runtime_broadcasting(node, x, y, indices)

    return mlx_fn(x, indices, y)
  • Scoped to AdvancedIncSubtensor1 only. AdvancedIncSubtensor legitimately allows broadcasting and has no such method (calling it unconditionally would AttributeError); this matches JAX/PyTorch.
  • Covers both set and inc (the check runs before the increment/set function in both branches).
  • mx.array exposes .ndim and a tuple .shape, so _check_runtime_broadcasting runs unmodified.

This is purely a cross-backend-consistency fix — MLX was more permissive than the spec.

Tests

New test_mlx_AdvancedIncSubtensor1_runtime_broadcast, parametrized over advanced_inc_subtensor1 / advanced_set_subtensor1 (mirrors test_jax_AdvancedIncSubtensor1_runtime_broadcast): a correctly sized y runs fine, while a runtime broadcast along the index dim (1, 5) or the buffer dim (20, 1) raises ValueError: Runtime broadcasting not allowed.

`AdvancedIncSubtensor1` (`x[int_vector] += y` / `= y`) requires `y` to
already match the indexed shape; runtime broadcasting is forbidden. The
canonical op enforces this in `perform`/`c_code` via
`_check_runtime_broadcasting`, and the JAX and PyTorch dispatchers call it
under an `isinstance(op, AdvancedIncSubtensor1)` guard. The MLX dispatcher
did not, so MLX silently broadcast `y` and returned a result where every
other backend raises `ValueError: Runtime broadcasting not allowed`.

Call `op._check_runtime_broadcasting(node, x, y, indices)` for
`AdvancedIncSubtensor1` in `mlx_funcify_IncSubtensor`, mirroring JAX
line-for-line. This is a cross-backend-consistency fix (MLX was too
permissive); `AdvancedIncSubtensor` legitimately allows broadcasting and
is correctly left unguarded.

Disclosure: implemented with AI assistance.
Co-authored-by: Cursor <cursoragent@cursor.com>
@cetagostini cetagostini requested a review from ricardoV94 June 18, 2026 12:43
@cetagostini cetagostini self-assigned this Jun 18, 2026
a statically non-broadcastable dimension that is length 1 at runtime is an
error, not a silent broadcast.
"""
from pytensor import function

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make import global

out = func(x, y, idxs)
assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor1)

f = function([y], out, mode="MLX")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mode should me Mode(linker="mlx", optimizer=None), otherwise your advanced_inc_subtensor path will end up testing the same advancde_inc_subtensor1 op

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants