Fix MLX Convolve1d crash when Blockwise broadcasts the kernel#2231
Fix MLX Convolve1d crash when Blockwise broadcasts the kernel#2231gaoflow wants to merge 1 commit into
Conversation
When convolve1d is batched via Blockwise, a kernel shared across the batch is broadcast to a leading-1 dim (e.g. (1, K)) rather than being vmapped away, so the core thunk received a 2-D kernel and mx.convolve raised '[convolve] Inputs must be 1D'. Flatten the kernel back to its core (K,) shape before convolving. Closes pymc-devs#2092
|
This is a bug, but the real fix is to make a real vectorize for the mlx Blockwise dispatch. jax has a |
|
IIRC this was special cased when mlx have vmap for convolve1d |
|
Confirmed: #2233 fixes the case this PR was patching. I checked out
So the full vectorize is the right fix and makes the narrow Convolve1d workaround here unnecessary. Happy to close this in favor of #2233. One thing worth carrying over: the regression test from this PR, |
Fixes #2092.
Problem
mlx_funcify_Convolve1d'sconv1dthunk callsmx.convolve(data, kernel, ...), which requires both inputs to be 1-D. Whenconvolve1dis batched throughBlockwisewith a single kernel shared across the batch, PyTensor broadcasts the kernel to a leading-1 dim. The MLXBlockwisedispatchervmaps over the real batch dim of the data (sodataarrives correctly as(N,)) but treats the kernel's leading-1 dim as a broadcast (in_axes=None), so the core thunk receives a(1, K)kernel andmx.convolveraisesValueError: [convolve] Inputs must be 1D.Empirically, instrumenting the thunk for the reproducer shows
data.shape == (32,)andkernel.shape == (1, 5)— only the kernel is malformed.Fix
Flatten the kernel back to its core
(K,)shape insideconv1dwhen it arrives with extra leading dims. Under the op's core signature(n0),(k0),()->(o0)the kernel is conceptually 1-D per core call, so any leading dim is a broadcast artifact andreshape(-1)is safe.Test
Added
tests/link/mlx/test_signal_conv.pywith the batched-kernel-broadcast case (and a plain-vector case as a control), both parametrized overfull/validand checked against the Python linker via the existingcompare_mlx_and_pyhelper. Verified RED→GREEN: without the source change the batched-kernel tests fail with[convolve] Inputs must be 1Dwhile the vector tests pass; with it all pass. The fulltests/link/mlx/suite is green (149 passed, 4 pre-existing xfailed).Disclosure: I use AI assistance (under my direction) to help with my contributions; I review and verify every change before submitting.