Numba: Fuse Elemwise + Reduction#2227
Draft
ricardoV94 wants to merge 5 commits into
Draft
Conversation
ricardoV94
commented
Jun 15, 2026
|
|
||
| if not idx_groups: | ||
| # Find reductions to fuse: an Elemwise output whose sole client is an | ||
| # eligible CAReduce. Outputs that are write targets are excluded (an |
Member
Author
There was a problem hiding this comment.
Why can't they be duplicated the same?
ricardoV94
commented
Jun 15, 2026
| default_mode = get_default_mode() | ||
| # Exclude the Numba reduction fusion so the outer reduction op stays | ||
| # visible in the toposort for the structural assertions below. | ||
| default_mode = get_default_mode().excluding("fuse_indexed_into_elemwise") |
Member
Author
There was a problem hiding this comment.
rename tag to fused_elemwise
The fused loop writes the elementwise result to the write buffer, never to the inplaced input, but the inner fgraph kept the inplace Elemwise: the Python-mode fallback (OpFromGraph.perform) would destroy that input without the outer destroy map declaring it, losing the ordering constraint for other readers of the destroyed buffer. The JIT path was unaffected (write buffers shadow the inplace pattern in make_outputs). Write-and-direct duplication now runs before the strip and preserves the inplace pattern, so an inplace on an output that stays materialized (the write consuming a duplicate) still survives the fusion.
An output consumed by several eligible CAReduces previously disqualified itself entirely (the detection required exactly one reduce client). Peel one extra reduction per rewrite pass onto a duplicate output until each reduction has its own copy, so e.g. [sum(f), max(f), prod(f), f] becomes a single FusedElemwise with three fused reductions.
sum(x[idx]) had no Elemwise for FuseElemwise to anchor on, so the gather materialized and the reduction stayed external. A new pre-rewrite wraps such reductions in an identity Elemwise (covering the bare AdvancedSubtensor1, axis-swap DimShuffle and flattened-ND-index Reshape forms), letting gather, identity and reduction collapse into one fused loop.
777ed4c to
605789e
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Extends #2015 to also fuse CAReduce with Elemwise or IndexedELemwise (now renamed FusedElemwise).
This sometimes actually causes regressions without numba/llvmlite#895 as it prevents auto-vectorization numba was doing when the Ops weren't fused.
So I did the only logical thing and patched llvmlite from within PyTensor.
Results on the same benchmark we focused on for IndexedElemwise: https://ricardov94.github.io/pymc-model-catalogue/experiments.html#base=fuse_reduction_curated_base&compare=reduction_fusion_vs_nonfused
This is one of the ways in which JAX sometimes beats us on the CPU