Skip to content

Fix MLX Convolve1d crash when Blockwise broadcasts the kernel#2231

Open
gaoflow wants to merge 1 commit into
pymc-devs:mainfrom
gaoflow:fix-2092-mlx-convolve1d-broadcast-kernel
Open

Fix MLX Convolve1d crash when Blockwise broadcasts the kernel#2231
gaoflow wants to merge 1 commit into
pymc-devs:mainfrom
gaoflow:fix-2092-mlx-convolve1d-broadcast-kernel

Conversation

@gaoflow

@gaoflow gaoflow commented Jun 15, 2026

Copy link
Copy Markdown

Fixes #2092.

Problem

mlx_funcify_Convolve1d's conv1d thunk calls mx.convolve(data, kernel, ...), which requires both inputs to be 1-D. When convolve1d is batched through Blockwise with a single kernel shared across the batch, PyTensor broadcasts the kernel to a leading-1 dim. The MLX Blockwise dispatcher vmaps over the real batch dim of the data (so data arrives correctly as (N,)) but treats the kernel's leading-1 dim as a broadcast (in_axes=None), so the core thunk receives a (1, K) kernel and mx.convolve raises ValueError: [convolve] Inputs must be 1D.

Empirically, instrumenting the thunk for the reproducer shows data.shape == (32,) and kernel.shape == (1, 5) — only the kernel is malformed.

Fix

Flatten the kernel back to its core (K,) shape inside conv1d when it arrives with extra leading dims. Under the op's core signature (n0),(k0),()->(o0) the kernel is conceptually 1-D per core call, so any leading dim is a broadcast artifact and reshape(-1) is safe.

Test

Added tests/link/mlx/test_signal_conv.py with the batched-kernel-broadcast case (and a plain-vector case as a control), both parametrized over full/valid and checked against the Python linker via the existing compare_mlx_and_py helper. Verified RED→GREEN: without the source change the batched-kernel tests fail with [convolve] Inputs must be 1D while the vector tests pass; with it all pass. The full tests/link/mlx/ suite is green (149 passed, 4 pre-existing xfailed).


Disclosure: I use AI assistance (under my direction) to help with my contributions; I review and verify every change before submitting.

When convolve1d is batched via Blockwise, a kernel shared across the batch
is broadcast to a leading-1 dim (e.g. (1, K)) rather than being vmapped away,
so the core thunk received a 2-D kernel and mx.convolve raised
'[convolve] Inputs must be 1D'. Flatten the kernel back to its core (K,)
shape before convolving.

Closes pymc-devs#2092
@jessegrabowski

Copy link
Copy Markdown
Member

This is a bug, but the real fix is to make a real vectorize for the mlx Blockwise dispatch. jax has a jnp.vectorize syntactic sugar that handles dimension alignment and chained vmap. We need the same thing for mlx. I opened #2233 which is a more general fix, can you clone that branch and confirm it fixes your problem?

@ricardoV94

Copy link
Copy Markdown
Member

IIRC this was special cased when mlx have vmap for convolve1d

@gaoflow

gaoflow commented Jun 17, 2026

Copy link
Copy Markdown
Author

Confirmed: #2233 fixes the case this PR was patching. I checked out mlx-blockwise-upgrade and ran the batched-kernel-broadcast repro from here (a (4, 32) signal matrix convolved with a single shared (5,) kernel):

mode output shape matches numpy
full (4, 36) yes
valid (4, 28) yes

So the full vectorize is the right fix and makes the narrow Convolve1d workaround here unnecessary. Happy to close this in favor of #2233.

One thing worth carrying over: the regression test from this PR, test_convolve1d_batched_kernel_broadcast (vector kernel shared across a batch of signals, both full and valid), exercises exactly the broadcast path #2233 generalizes. If it's useful I can open a small follow-up adding it to the convolve1d tests once #2233 lands, so the specific shape that triggered #2092 stays covered.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: MLX Convolve1d dispatch crashes when Blockwise broadcasts the kernel

3 participants