Fix MLX Cast dispatcher crash on Python scalar input#2230
Conversation
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("value", [5, 2.0]) |
There was a problem hiding this comment.
This is not a valid value. We have a strongly typed dtype system in PyTensor so we should be able to exploit/trust it and not have every op doubt its inputs.
Your original issue may come from a real MLX limitation that we should be aware of, or a bug/bad behavior in an upstream dispatch that we should fix, not patch elsewhere.
Shape_i's MLX dispatch returned ``x.shape[op.i]``, which is a plain Python ``int`` because ``mx.array.shape`` is a tuple of Python ints. After the rewriter folds a shape-derived value into a ``Composite``, the inner ``cast(int_value)`` reaches the ``Cast`` dispatcher, whose ``x.astype(...)`` raises ``AttributeError: 'int' object has no attribute 'astype'``. Wrap the ``Shape_i`` result in ``mx.array(..., dtype=int64)`` so downstream ops receive a properly typed array, mirroring ``Shape`` which already wraps its result. Adds an MLX-mode regression test exercising a shape value that flows into a ``Cast``. Fixes pymc-devs#2096.
a3860bd to
d90fb27
Compare
|
Thanks, that is a fair point and you are right. I traced where the bare Python int actually originates: I have reworked the PR to fix it there instead of patching The test is now an end-to-end MLX-mode regression (a shape value flowing into a Minimal reproducer for reference: a = pt.vector("a", dtype="float32")
out = pt.cast(a.shape[0], "float32") + pt.cast(2, "float32")
pytensor.function([a], out, mode="MLX")(np.arange(6, dtype="float32")) |
| fn = function([a], out, mode="MLX") | ||
| result = fn(np.arange(6, dtype="float32")) | ||
|
|
||
| np.testing.assert_allclose(np.asarray(result), 8.0) |
There was a problem hiding this comment.
can'w we use compare_mlx_and_py helper?
| # what ``mx.array.shape[i]`` yields). Downstream ops such as ``Cast`` | ||
| # rely on receiving an array; a Python scalar makes them crash with | ||
| # ``AttributeError: 'int' object has no attribute 'astype'`` (#2096). | ||
| # This mirrors ``Shape``, which already wraps its result in ``mx.array``. |
There was a problem hiding this comment.
Nit: I don't think we need all this comment here
Shorten the Shape_i comment per review and rewrite the shape-cast test to use the compare_mlx_and_py helper (with the full MLX mode, where the issue was reported), dropping the now-unused function import.
|
Thanks for the review. Addressed the two nits in 0173cd6:
The One thing I want to be upfront about, though. While re-verifying I could not reproduce the #2096 crash on current So this change now reads more as a consistency fix (mirroring |
Fixes #2096.
Problem
mlx_funcify_Cast'scastclosure assumes its input always has an.astypemethod and only guards againstValueError. After constant folding, aCastnested inside aCompositecan receive a plain Pythonint/float(e.g. a shape-derived value that the rewriter folded into the graph). Those objects have no.astype, so the closure raised an uncaughtAttributeError: 'int' object has no attribute 'astype'.Fix
Promote inputs that lack
.astypeto anmx.arraybefore casting. This is the minimal change: the existing GPU-unsupported-dtypeValueErrorfallback is left intact, so the dtype-downcast warning behaviour (covered by the existingtest_mlx_float64_*tests) is unchanged.Test
Added
test_cast_python_scalar_inputintests/link/mlx/test_scalar.py, parametrized over a Pythonintandfloat, asserting the dispatcher returns anmx.arrayof the target dtype. Verified RED→GREEN: without the source change the new test fails withAttributeError: '...' object has no attribute 'astype'; with it, it passes. The fulltests/link/mlx/suite is green (145 passed, 4 pre-existing xfailed).Disclosure: I use AI assistance (under my direction) to help with my contributions; I review and verify every change before submitting.