Skip to content

Fix MLX Cast dispatcher crash on Python scalar input#2230

Open
gaoflow wants to merge 2 commits into
pymc-devs:mainfrom
gaoflow:fix-2096-mlx-cast-python-scalar
Open

Fix MLX Cast dispatcher crash on Python scalar input#2230
gaoflow wants to merge 2 commits into
pymc-devs:mainfrom
gaoflow:fix-2096-mlx-cast-python-scalar

Conversation

@gaoflow

@gaoflow gaoflow commented Jun 15, 2026

Copy link
Copy Markdown

Fixes #2096.

Problem

mlx_funcify_Cast's cast closure assumes its input always has an .astype method and only guards against ValueError. After constant folding, a Cast nested inside a Composite can receive a plain Python int/float (e.g. a shape-derived value that the rewriter folded into the graph). Those objects have no .astype, so the closure raised an uncaught AttributeError: 'int' object has no attribute 'astype'.

Fix

Promote inputs that lack .astype to an mx.array before casting. This is the minimal change: the existing GPU-unsupported-dtype ValueError fallback is left intact, so the dtype-downcast warning behaviour (covered by the existing test_mlx_float64_* tests) is unchanged.

Test

Added test_cast_python_scalar_input in tests/link/mlx/test_scalar.py, parametrized over a Python int and float, asserting the dispatcher returns an mx.array of the target dtype. Verified RED→GREEN: without the source change the new test fails with AttributeError: '...' object has no attribute 'astype'; with it, it passes. The full tests/link/mlx/ suite is green (145 passed, 4 pre-existing xfailed).


Disclosure: I use AI assistance (under my direction) to help with my contributions; I review and verify every change before submitting.

Comment thread tests/link/mlx/test_scalar.py Outdated
)


@pytest.mark.parametrize("value", [5, 2.0])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a valid value. We have a strongly typed dtype system in PyTensor so we should be able to exploit/trust it and not have every op doubt its inputs.

Your original issue may come from a real MLX limitation that we should be aware of, or a bug/bad behavior in an upstream dispatch that we should fix, not patch elsewhere.

Shape_i's MLX dispatch returned ``x.shape[op.i]``, which is a plain Python
``int`` because ``mx.array.shape`` is a tuple of Python ints. After the
rewriter folds a shape-derived value into a ``Composite``, the inner
``cast(int_value)`` reaches the ``Cast`` dispatcher, whose ``x.astype(...)``
raises ``AttributeError: 'int' object has no attribute 'astype'``.

Wrap the ``Shape_i`` result in ``mx.array(..., dtype=int64)`` so downstream
ops receive a properly typed array, mirroring ``Shape`` which already wraps
its result. Adds an MLX-mode regression test exercising a shape value that
flows into a ``Cast``.

Fixes pymc-devs#2096.
@gaoflow gaoflow force-pushed the fix-2096-mlx-cast-python-scalar branch from a3860bd to d90fb27 Compare June 16, 2026 10:33
@gaoflow

gaoflow commented Jun 16, 2026

Copy link
Copy Markdown
Author

Thanks, that is a fair point and you are right. I traced where the bare Python int actually originates: Shape_i's MLX dispatch returns x.shape[op.i], which is a plain Python int (an mx.array.shape is a tuple of Python ints). Once the rewriter folds that shape value into a Composite, the inner cast(...) receives the bare int and crashes. Its sibling Shape already wraps its result in mx.array, so Shape_i was simply inconsistent.

I have reworked the PR to fix it there instead of patching Cast: Shape_i now returns mx.array(x.shape[op.i], dtype=int64), so Cast (and any other downstream op) can trust its input as you suggested. The Cast change is reverted.

The test is now an end-to-end MLX-mode regression (a shape value flowing into a Cast) rather than calling the Cast closure with a hand-made scalar. It fails on main with the original AttributeError and passes with the fix; the full MLX suite stays green (144 passed, 4 xfailed).

Minimal reproducer for reference:

a = pt.vector("a", dtype="float32")
out = pt.cast(a.shape[0], "float32") + pt.cast(2, "float32")
pytensor.function([a], out, mode="MLX")(np.arange(6, dtype="float32"))

Comment thread tests/link/mlx/test_shape.py Outdated
fn = function([a], out, mode="MLX")
result = fn(np.arange(6, dtype="float32"))

np.testing.assert_allclose(np.asarray(result), 8.0)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can'w we use compare_mlx_and_py helper?

Comment thread pytensor/link/mlx/dispatch/shape.py Outdated
# what ``mx.array.shape[i]`` yields). Downstream ops such as ``Cast``
# rely on receiving an array; a Python scalar makes them crash with
# ``AttributeError: 'int' object has no attribute 'astype'`` (#2096).
# This mirrors ``Shape``, which already wraps its result in ``mx.array``.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I don't think we need all this comment here

Shorten the Shape_i comment per review and rewrite the shape-cast test
to use the compare_mlx_and_py helper (with the full MLX mode, where the
issue was reported), dropping the now-unused function import.
@gaoflow

gaoflow commented Jun 17, 2026

Copy link
Copy Markdown
Author

Thanks for the review. Addressed the two nits in 0173cd6:

  • Trimmed the Shape_i comment down to a single sentence.
  • Switched the test to compare_mlx_and_py. I pass mlx_mode="MLX" so it still exercises the full rewrite pipeline (the issue's configuration) rather than the trimmed test mode, and dropped the now-unused function import.

The test_scalar.py case you flagged is already gone: when this PR moved from patching Cast to wrapping Shape_i at the source, that scalar test was removed, so the branch now only touches shape.py and test_shape.py. That also matches your point that Cast shouldn't have to doubt a typed input.

One thing I want to be upfront about, though. While re-verifying I could not reproduce the #2096 crash on current main with the current mlx, with or without this Shape_i wrap. Both a minimal cast(a.shape[0], "float32") graph and an accumulator-cast gradient (((alpha-1)*log(val) - alpha*val).sum() then grad) run cleanly under mode="MLX" and return the right value. The Cast dispatcher still calls x.astype with only a ValueError guard, so the bare scalar simply isn't reaching it in these cases — it looks mlx-version sensitive.

So this change now reads more as a consistency fix (mirroring Shape, which already wraps in mx.array) than as something I can demonstrate fixing a live crash. Could you confirm whether #2096 still reproduces for you, ideally with the exact graph that triggers it (the issue mentions a gammaln/Composite chain)? If you can hand me a repro that still crashes on main, I'll trace which dispatch actually emits the bare Python scalar and fix that producer with a faithful red→green test. If it's no longer reproducible, I'm happy to either keep this as the consistency tidy-up or close it, whichever you prefer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: MLX Cast dispatcher crashes when input is a Python int/float

2 participants