Fix MLX Metal codegen crash on NaN scalar constants#2239
Open
cetagostini wants to merge 1 commit into
Open
Conversation
`mx.compile` inlines size-1 constants directly into the generated Metal source, but Metal has no `nan` literal (it does accept `inf`), so a graph containing a scalar NaN constant aborts kernel compilation. This blocks any normalization (RMSNorm/LayerNorm) input-gradient on MLX, since the `local_sqrt_sqr` rewrite turns `sqr(sqrt(z))` into `switch(z >= 0, z, nan)`. Materialize size-1 NaN constants through an op (`0 / 0`) in `mlx_typify` so MLX emits a valid expression instead of the bare `nan` token. Gated on graph constants only (the `variable` kwarg), so runtime inputs keep the buffer path; larger NaN constants are already passed as buffers and are left untouched. Co-authored-by: Cursor <cursoragent@cursor.com>
ricardoV94
reviewed
Jun 18, 2026
| def mlx_typify_tensor(data, dtype=None, variable=None, **kwargs): | ||
| arr = mx.array(data, dtype=dtype) | ||
| if variable is not None: | ||
| arr = _nan_safe_constant(arr, data) |
Member
There was a problem hiding this comment.
Why not just return _nan_safe_constant(data) (or np.asarray(data for the others), and the helper does the arr = mx.array(data) at the end? Also what is the deal with if variable is not None?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Compiling the gradient of a normalization expression (RMSNorm/LayerNorm) on the MLX backend aborts the process with a Metal compiler error:
This blocks training any normalized network on MLX, since the input-gradient is on the main backprop path of every transformer.
Reproducer (runs on
main)The raw graph contains no NaN. The MLX optimizer introduces it: the
local_sqrt_sqrrewrite turnssqr(sqrt(z))intoswitch(z >= 0, z, nan)(a defensive branch;z = mean(x**2) + eps > 0here), which inserts a scalarnanconstant into aSwitch(mx.where).Root cause
PyTensor's MLX linker wraps the generated function in
mx.compile, which inlines size-1 constants directly into the generated Metal source. Metal has nonanliteral (it does acceptinf), so the kernel fails to build. Multi-element constants are passed as buffers (fine), and runtime inputs are passed as arguments (fine) — only size-1 constants hit the bad path. The same graph compiles on the C backend, which emits a validNAN, so this is MLX-specific.Handle
The fix lives in
mlx_typify(constant conversion), mirroring the existingnumba_typifyprecedent of massaging constants to satisfy a backend compiler (non-contiguous constants, #2063). When a graph constant is a size-1 float containing NaN, it is built from an op (0 / 0) so MLX emits a valid expression instead of the barenantoken. It is gated on thevariablekwarg, whichfgraph_to_pythonpasses only for graph constants/shared values, so runtime inputs keep the cheap buffer path with no per-call overhead. Larger NaN constants are passed as buffers and are left untouched.Why this is sustainable and clean
It fixes the mechanism (codegen of the constant), not the symptom: the shared
local_sqrt_sqrrewrite is left intact, preserving cross-backend consistency rather than special-casing MLX in a global rewrite. The change is localized to the one layer PyTensor controls when handing values to MLX, and is scoped precisely to the constants MLX actually inlines (size == 1), so the common buffer path is never rewritten. This also generalizes beyond the reported case — it covers NaN constants wherever they appear (including inside nestedOpFromGraph/Compositesubgraphs), since all constants flow throughmlx_typify.Tests
tests/link/mlx/test_scalar.py:test_nan_constant[float32/float16]: a size-1 NaN constant fed to aSwitch(the failing class). RED -> GREEN: aborts the process on unpatched code, passes with the fix.test_nan_array_constant: a multi-element NaN constant still compiles via the buffer path (no materialization).Full MLX suite:
150 passed, 4 xfailed;ruff/pre-commitclean.Disclosure: implemented with AI assistance.
Made with Cursor