Skip to content

Numba: Fuse Elemwise + Reduction#2227

Draft
ricardoV94 wants to merge 5 commits into
pymc-devs:mainfrom
ricardoV94:reduction_fusion_with_scope_markers
Draft

Numba: Fuse Elemwise + Reduction#2227
ricardoV94 wants to merge 5 commits into
pymc-devs:mainfrom
ricardoV94:reduction_fusion_with_scope_markers

Conversation

@ricardoV94

Copy link
Copy Markdown
Member

Extends #2015 to also fuse CAReduce with Elemwise or IndexedELemwise (now renamed FusedElemwise).

This sometimes actually causes regressions without numba/llvmlite#895 as it prevents auto-vectorization numba was doing when the Ops weren't fused.

So I did the only logical thing and patched llvmlite from within PyTensor.

Results on the same benchmark we focused on for IndexedElemwise: https://ricardov94.github.io/pymc-model-catalogue/experiments.html#base=fuse_reduction_curated_base&compare=reduction_fusion_vs_nonfused

image

This is one of the ways in which JAX sometimes beats us on the CPU


if not idx_groups:
# Find reductions to fuse: an Elemwise output whose sole client is an
# eligible CAReduce. Outputs that are write targets are excluded (an

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't they be duplicated the same?

default_mode = get_default_mode()
# Exclude the Numba reduction fusion so the outer reduction op stays
# visible in the toposort for the structural assertions below.
default_mode = get_default_mode().excluding("fuse_indexed_into_elemwise")

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename tag to fused_elemwise

The fused loop writes the elementwise result to the write buffer, never to
the inplaced input, but the inner fgraph kept the inplace Elemwise: the
Python-mode fallback (OpFromGraph.perform) would destroy that input without
the outer destroy map declaring it, losing the ordering constraint for other
readers of the destroyed buffer. The JIT path was unaffected (write buffers
shadow the inplace pattern in make_outputs).

Write-and-direct duplication now runs before the strip and preserves the
inplace pattern, so an inplace on an output that stays materialized (the
write consuming a duplicate) still survives the fusion.
An output consumed by several eligible CAReduces previously disqualified
itself entirely (the detection required exactly one reduce client). Peel one
extra reduction per rewrite pass onto a duplicate output until each reduction
has its own copy, so e.g. [sum(f), max(f), prod(f), f] becomes a single
FusedElemwise with three fused reductions.
sum(x[idx]) had no Elemwise for FuseElemwise to anchor on, so the gather
materialized and the reduction stayed external. A new pre-rewrite wraps such
reductions in an identity Elemwise (covering the bare AdvancedSubtensor1,
axis-swap DimShuffle and flattened-ND-index Reshape forms), letting gather,
identity and reduction collapse into one fused loop.
@ricardoV94 ricardoV94 force-pushed the reduction_fusion_with_scope_markers branch from 777ed4c to 605789e Compare June 16, 2026 11:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant