Skip to content

Fix MLX Metal codegen crash on NaN scalar constants#2239

Open
cetagostini wants to merge 1 commit into
pymc-devs:mainfrom
cetagostini:fix-mlx-nan-constant-codegen
Open

Fix MLX Metal codegen crash on NaN scalar constants#2239
cetagostini wants to merge 1 commit into
pymc-devs:mainfrom
cetagostini:fix-mlx-nan-constant-codegen

Conversation

@cetagostini

Copy link
Copy Markdown
Contributor

Problem

Compiling the gradient of a normalization expression (RMSNorm/LayerNorm) on the MLX backend aborts the process with a Metal compiler error:

mlx/backend/metal/kernels/ternary_ops.h: error: use of undeclared identifier 'nan'
  auto tmp_J = static_cast<float>(nan);

This blocks training any normalized network on MLX, since the input-gradient is on the main backprop path of every transformer.

Reproducer (runs on main)

import numpy as np
import pytensor
import pytensor.tensor as pt

pytensor.config.floatX = "float32"

x = pt.matrix("x")
w = pt.vector("w")
ms = (x * x).mean(axis=-1, keepdims=True)
out = x / pt.sqrt(ms + np.float32(1e-6)) * w   # RMSNorm
gx = pt.grad((out * out).sum(), x)             # gradient through the INPUT

pytensor.function([x, w], gx, mode=None)        # default backend: compiles fine
pytensor.function([x, w], gx, mode="MLX")       # MLX: Metal build FAILS on `nan`

The raw graph contains no NaN. The MLX optimizer introduces it: the local_sqrt_sqr rewrite turns sqr(sqrt(z)) into switch(z >= 0, z, nan) (a defensive branch; z = mean(x**2) + eps > 0 here), which inserts a scalar nan constant into a Switch (mx.where).

Root cause

PyTensor's MLX linker wraps the generated function in mx.compile, which inlines size-1 constants directly into the generated Metal source. Metal has no nan literal (it does accept inf), so the kernel fails to build. Multi-element constants are passed as buffers (fine), and runtime inputs are passed as arguments (fine) — only size-1 constants hit the bad path. The same graph compiles on the C backend, which emits a valid NAN, so this is MLX-specific.

Handle

The fix lives in mlx_typify (constant conversion), mirroring the existing numba_typify precedent of massaging constants to satisfy a backend compiler (non-contiguous constants, #2063). When a graph constant is a size-1 float containing NaN, it is built from an op (0 / 0) so MLX emits a valid expression instead of the bare nan token. It is gated on the variable kwarg, which fgraph_to_python passes only for graph constants/shared values, so runtime inputs keep the cheap buffer path with no per-call overhead. Larger NaN constants are passed as buffers and are left untouched.

Why this is sustainable and clean

It fixes the mechanism (codegen of the constant), not the symptom: the shared local_sqrt_sqr rewrite is left intact, preserving cross-backend consistency rather than special-casing MLX in a global rewrite. The change is localized to the one layer PyTensor controls when handing values to MLX, and is scoped precisely to the constants MLX actually inlines (size == 1), so the common buffer path is never rewritten. This also generalizes beyond the reported case — it covers NaN constants wherever they appear (including inside nested OpFromGraph/Composite subgraphs), since all constants flow through mlx_typify.

Tests

tests/link/mlx/test_scalar.py:

  • test_nan_constant[float32/float16]: a size-1 NaN constant fed to a Switch (the failing class). RED -> GREEN: aborts the process on unpatched code, passes with the fix.
  • test_nan_array_constant: a multi-element NaN constant still compiles via the buffer path (no materialization).

Full MLX suite: 150 passed, 4 xfailed; ruff/pre-commit clean.


Disclosure: implemented with AI assistance.

Made with Cursor

`mx.compile` inlines size-1 constants directly into the generated Metal
source, but Metal has no `nan` literal (it does accept `inf`), so a graph
containing a scalar NaN constant aborts kernel compilation. This blocks any
normalization (RMSNorm/LayerNorm) input-gradient on MLX, since the
`local_sqrt_sqr` rewrite turns `sqr(sqrt(z))` into `switch(z >= 0, z, nan)`.

Materialize size-1 NaN constants through an op (`0 / 0`) in `mlx_typify` so
MLX emits a valid expression instead of the bare `nan` token. Gated on graph
constants only (the `variable` kwarg), so runtime inputs keep the buffer path;
larger NaN constants are already passed as buffers and are left untouched.

Co-authored-by: Cursor <cursoragent@cursor.com>
@cetagostini cetagostini self-assigned this Jun 18, 2026
@cetagostini cetagostini requested a review from ricardoV94 June 18, 2026 09:46
def mlx_typify_tensor(data, dtype=None, variable=None, **kwargs):
arr = mx.array(data, dtype=dtype)
if variable is not None:
arr = _nan_safe_constant(arr, data)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just return _nan_safe_constant(data) (or np.asarray(data for the others), and the helper does the arr = mx.array(data) at the end? Also what is the deal with if variable is not None?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants