Skip to content

Add full mlx blockwise support#2233

Open
jessegrabowski wants to merge 1 commit into
pymc-devs:mainfrom
jessegrabowski:mlx-blockwise-upgrade
Open

Add full mlx blockwise support#2233
jessegrabowski wants to merge 1 commit into
pymc-devs:mainfrom
jessegrabowski:mlx-blockwise-upgrade

Conversation

@jessegrabowski

Copy link
Copy Markdown
Member

MLX blockwise operations fail in non-trivial cases because our implementation isn't a full vectorize, it's just a single naive application of mlx.vmap. We need to align and broadcast all input shapes, then apply vmap once per non-core dimension. This is what jnp.vectorize does under the hood. I just went ahead and re-implemented jnp.vectorize in MLX, giving us full vectorize support.

Closes #2092 , and also address additional unreported cases (e.g. when data has batch dim).



# Equivalent blockwise to matmul but with dumb signature
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This had a purpose, it's the default signature of a a fallback Blockwise

rng = np.random.default_rng(42)

# Create a blockwise matmul with no batch dimensions (core operation only)
x = pt.matrix("x")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at it I don't think we have a single test with non static shapes?

for arg, batch_shape in zip(args, batch_shapes):
padded = (1,) * (batch_ndim - len(batch_shape)) + batch_shape
rev_filled.append(padded[::-1])
squeeze_axes = tuple(i for i, s in enumerate(batch_shape) if s == 1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for consistency Blockwise broadcasting should depend on static shapes, or it will be allowed in some backends but not others and this is already one big source of confusion in PyTensor.

That means you know ahead of time what are None or 0 vmapped axis and the dispatch should only verify them, not reinfer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: MLX Convolve1d dispatch crashes when Blockwise broadcasts the kernel

2 participants