diff --git a/compier_optimization_plan.md b/compier_optimization_plan.md new file mode 100644 index 0000000000..1c0fe27255 --- /dev/null +++ b/compier_optimization_plan.md @@ -0,0 +1,667 @@ +# Faster-Python Iteration 1: Temporary OSS Compilation Optimization Plan + +Benchmark nomenclature used below: + +- `light`: `test_profile_eiso_stress_like_schedule_infos` +- `heavy`: `test_profile_etti_velocity_then_stress_like_schedule_infos` +- `heavy_IO`: + `test_profile_etti_velocity_then_stress_like_bitcomp_serial_schedule_infos` + (the `heavy` operator plus bitcomp compression and serialization) + +## April 28, 2026 update: impact of the latest two OSS commits + +The latest two OSS commits moved the paired PRO/OSS branch beyond the previous +April 21/22 `~26 s` heavy-compile checkpoint: + +- `fc755479d` (`compiler: Split into EqBlock and Cluster`) + - split the structural Cluster payload into cached `EqBlock` objects while + preserving Cluster identity semantics + - made `IREq` hashing/equality include IR metadata via `_hashable_content`, + with the reusable `as_hashable` helper, so EqBlock caching does not merge + equations that only differ in `ispace`, `conditionals`, `implicit_dims`, or + `operation` + - measured impact after correctness fixes: + - stress-only: about `11.0 s` + - heavy `velocity+stress`: about `24.4 s` + - heavy `optimize_kernels`: about `11.3-11.4 s` + - validation at this point included the targeted EqBlock repros, nearby + equation/visitor/CSE tests, and a full OSS suite: + `3549 passed, 5 skipped, 4 xfailed, 1 xpassed` + +- `771f807ab` (`compiler: Stash hash were essential for compilation performance`) + - added `cached_hash`, which stashes immutable-object `__hash__` results in + `_mhash` + - applied it to the hottest hash sites from the profiling investigation: + support-space objects (`Interval`, `IntervalGroup`, `IterationInterval`, + `IterationSpace`, `DataSpace`, and `IterationDirection`) plus Cluster queue + keys (`Prefix`) and `ClusterGroup` + - removed the generic `Space.__hash__` path and made subclasses hash their + concrete payloads directly, avoiding shared base/subclass hash-cache + ambiguity + - measured impact after this commit: + - stress-only: `10.45-10.57 s`, with `optimize_kernels 3.30-3.43 s` + - heavy `velocity+stress`: `22.82-22.89 s`, with + `optimize_kernels 10.43-10.50 s` + - validation at this point included `test_lower_clusters.py + test_ir.py`, + the four EqBlock/equation repros, and the two benchmark probes above; a + full OSS suite has not yet been rerun after this second commit + +Current practical checkpoint for the heavy benchmark is therefore now about +`22.8-22.9 s`, not the older `25.8-26.0 s` plateau. Relative to the previous +April 22 plateau, the latest two commits are worth roughly `3 s` on the heavy +compile, with the larger second-step gain coming from cached support-space and +Cluster queue hashing. + +## Historical April 21 status after caching/reuse, fusion, and derivative-topofuse work + +More recent paired PRO/OSS reruns on April 21, 2026 with PRO +`faster-python-1` at `b770aaee`, OSS `/home/fl1612/devito-faster-python-1` at +`30715c026`, `devitopro-cuda:latest`, `--taskset 0-15`, `--deviceid 3`, and +the two `schedule_infos` probes in +`devitopro/tests/test_kernelopt_nogil_tmp.py` reproduced the current practical +checkpoint as: + +- `test_profile_eiso_stress_like_schedule_infos`: `5.81-5.83 s` +- `test_profile_etti_velocity_then_stress_like_schedule_infos`: `25.81-26.21 s` + +At that point, the paired `velocity+stress` checkpoint for this branch family +was still about `26 s`; the older `32.00 s` value below is only a historical +April 10 milestone, and the April 28 section above supersedes both numbers for +the current branch. + +## April 22, 2026 plateau note on the paired PRO/OSS branches + +Fresh clean-`HEAD` reruns on April 22, 2026 with PRO `faster-python-1` at +`b770aaee`, OSS `/home/fl1612/devito-faster-python-1` at `c44c30339`, +`devitopro-cuda:latest`, `--taskset 0-15`, `--deviceid 3`, and the same +temporary PRO harness confirm that the practical checkpoint for the current +pair is still: + +- light probe: still about `5.8 s` +- heavy `test_profile_etti_velocity_then_stress_like_schedule_infos`: + still about `25.8-26.0 s` (`25.80 s`, `25.99 s` in fresh reruns) + +Latest findings from the PRO-side deep profiling on this same pair: + +- isolated `search_rotation` and `search_shm` postponement tweaks were both + flat/noisy when rerun independently and were reverted +- more invasive scheduler and node-caching refactors were also dropped after + either code-complexity growth or measurable regressions +- explicit per-kernel accounting on the heavy benchmark shows the dominant + remaining cost is the initial stress kernel lineage, not the velocity kernel: + - helper lineage: `0.402522 s` + - velocity lineage: `0.585283 s` + - stress lineage: `11.393894 s` +- splitting the first `optimize()` pass by initial kernel yields: + - helper (`r0..r5`): `0.315156 s` + - velocity (`v_x,v_y,v_z`): `0.426755 s` + - stress (`tau_*`): `10.524947 s` +- the stress-kernel `optimize()` time is currently dominated by: + - `action_apply`: `5.570871 s` + - `minimize_barrier_likelihood`: `2.337799 s` + - `reschedule`: `2.169853 s` + - with `68` action steps, `69` schedule calls, and an average per-schedule + state of about `229.7` IDG nodes and `30.4` queued actions +- current IET-side coarse buckets on the heavy probe are still: + - `make_parallel ~= 1.99 s` + - `place_definitions ~= 0.87 s` + - `_place_transfers ~= 0.86 s` + - `linearization ~= 0.37 s` + +Current interpretation: +the current pair looks close to a local plateau on this benchmark family. The +next step should therefore be to move up in benchmark complexity rather than +keep forcing increasingly marginal scheduler micro-optimizations on the current +`velocity+stress` case. + +First result after moving up in benchmark complexity on the PRO scratch harness: + +- new `heavy_IO` benchmark: + `test_profile_etti_velocity_then_stress_like_bitcomp_serial_schedule_infos` +- construction: + extend the current `heavy` `velocity+stress` operator with one compressed, + serialized saved `TimeFunction` per `tau_*` component and an extra + `Eq(tau_*save, tau_*.forward)` per component +- first pinned compile result: + - total compile: `30.44 s` + - `lowering.Clusters`: `18.74 s` + - `optimize_kernels`: `12.55 s` + - `lowering.IET`: `9.39 s` + - `specializing.IET`: `8.49 s` + - main new IET-side buckets: + `lower_async_objs ~= 1.29 s`, + `place_definitions ~= 1.27 s`, + `linearization ~= 1.12 s`, + `_place_transfers ~= 0.92 s` + +Interpretation: +the first heavier benchmark does become meaningfully slower, but the extra cost +lands primarily in IET / serialization-lowering rather than in the already +stress-dominated `kernelopt` slice. + +Older measured compile times on April 10, 2026 with PRO `faster-python-1`, +OSS `faster-python-1`, `devitopro-cuda:latest`, `--taskset 0-15`, and +`--deviceid 0`: + +- `test_profile_etti_stress_like_schedule_infos`: `13.63 s` +- `test_profile_etti_velocity_then_stress_like_schedule_infos`: `32.00 s` + +Compared with the March 30, 2026 no-cache baseline on the same probe family: + +- stress-like: `29.99 s -> 13.63 s` (`-16.36 s`, about `54.5%`) +- velocity+stress: `98.49 s -> 32.00 s` (`-66.49 s`, about `67.5%`) + +Compared with the later retrieve-accesses-only replay on the same branch family: + +- stress-like: `29.32 s -> 13.63 s` (`-15.69 s`) +- velocity+stress: `93.16 s -> 32.00 s` (`-61.16 s`) + +Compared with the pre-`TimedAccess` landed branch state: + +- stress-like: `24.65 s -> 13.63 s` (`-11.02 s`) +- velocity+stress: `79.68 s -> 32.00 s` (`-47.68 s`) + +Compared with the pre-space/rebuild branch (`3f3017e46`, +`compiler: Augment caching and memoization`): + +- stress-like: `27.45 s -> 13.63 s` (`-13.82 s`) +- velocity+stress: `84.56 s -> 32.00 s` (`-52.56 s`) + +The current landed branch now covers: + +- narrow helper memoization and finite-difference evaluation caching + (`3f3017e46`) +- `Scope`/access-inventory caching and lazy function-view reuse (`3f3017e46`) +- conservative space-object caching and no-op `Cluster.rebuild()` reuse + (`e16f222e1`) +- cached `TimedAccess` construction and reuse of its per-instance distance cache + across repeated `Scope` builds (`fd850927a`) +- synthetic `Scope.from_scopes(...)` construction from cached access summaries, + plus fusion-hazard analysis over those synthetic scopes rather than fresh + `Scope(exprs0 + exprs1)` rescans +- bounded derivative-driven topofusion in `lower_index_derivatives`, using the + maximum nested `IndexDerivative.depth` as an upper bound on the number of + productive `toposort='nofuse'` rounds before the final plain `fuse(False)` + +Current profiled bottlenecks on the landed branch: + +- stress-like: + `lowering.Clusters ~= 8.45 s`, `lowering.IET ~= 3.64 s`, + `optimize_kernels ~= 6.12 s`, `fuse ~= 0.53 s` +- velocity+stress: + `lowering.Clusters ~= 23.82 s`, `lowering.IET ~= 5.64 s`, + `specializing.Clusters ~= 19.24 s`, `fuse ~= 2.71 s` +- hottest Cluster-side buckets on velocity+stress: + `optimize_kernels ~= 18.17 s`, with the remaining clean OSS cluster-side + work still dominated by fusion/topofusion rather than the derivative + lowering wrapper itself +- hottest IET-side buckets on velocity+stress: + `make_parallel ~= 1.59 s`, `place_definitions ~= 1.33 s`, + `_place_transfers ~= 0.86 s`, `linearization ~= 0.37 s`, + `_generate_macros ~= 0.24 s`, `minimize_symbols ~= 0.25 s`, + `optimize_halospots ~= 0.22 s` + +Validation status of the latest derivative-topofuse heuristic: + +- targeted OSS sensitivity checks around `test_unexpansion.py::{test_v3,test_v4, + test_v5}` passed +- the previously failing PRO CUDA regression in compressed layered MPI + serialization turned out to be unrelated to compilation changes; it was a + `NVIDIA_VISIBLE_DEVICES`/implicit `deviceid` correctness bug and is now fixed +- current PRO `tests/test_gpu_lang.py::TestKernelOptDefault:: + test_flip_for_canonical_ordering` is failing on the `faster-python-1` + PRO/OSS pair, but the failure hits the baseline `op0.apply(...)` path with an + undefined `npthreads0` symbol in generated CUDA, so it currently looks like + an OSS-side issue unrelated to the derivative-topofuse / `dsequences()` + changes +- a full fresh OSS + PRO sweep has not yet been rerun after the current + derivative-topofuse heuristic + +This temporary note captures the main OSS-side compilation optimizations explored in +iteration 0. The list below is intentionally in ascending order of complexity: +smaller and safer caching/micro-optimization ideas come first, while broader +algorithmic and threading changes come later. + +1. ~~Cache tiny pure helper results and other stable scalar metadata first.~~ + + Completed in a narrow form in `3f3017e46` + (`compiler: Augment caching and memoization`) via cached + `IndexDerivative.pivot`, memoized `Derivative._eval_fd`, and shared + numeric-weight reuse. + + Performance: + this helper bucket was not isolated cleanly from point 5 in the squashed + branch, but it is part of the landed `29.32 s -> 27.45 s` and + `93.16 s -> 84.56 s` move. + + Rationale: + these changes are local, easy to reason about, and usually do not alter the + structure of the compiler pipeline. + +2. ~~Preserve identity on no-op symbolic and visitor rewrites.~~ + + Completed in a narrow compiler-local form in `e16f222e1` + (`compiler: Augment caching and tweak memoization heuristics`) via + `Cluster.rebuild()` returning `self` when all effective rebuild inputs are + already identical objects. + + Performance: + not isolated cleanly from point 9 below; the combined landed diff moved the + probes from `27.45 s -> 24.65 s` and `84.56 s -> 79.68 s`. + + Rationale: + this is still fairly contained work, but it starts touching generic traversal + machinery that is used in many places. + +3. Specialize traversal-heavy symbol discovery before changing higher-level + algorithms. + + Relevant iteration-0 commits: + `e450f0546` (`ir: trim findsymbols stack overhead`), + `ceeb42689` (`ir: specialize findsymbols traversal`), + `ff8d9efcc` (`symbolics: trim IET traversal overhead`). + + Rationale: + these changes stay within existing traversal semantics, but they attack some + of the hottest generic walks in lowering. + +4. Reuse already computed inventories in IET cleanup and callable deduplication. + + Relevant iteration-0 commits: + `f91d9256f` (`CODEX: ITER 6`, better caller tracking and cheaper param drops), + `b05dd2084` (`iet: reuse symbol inventory in parameter updates`), + `c26dbc3e6` (`WIP`, shared DataManager inventory collection and `reuse_efuncs` + caches), + `96ff77a94` (`iet: Prune reuse_efuncs by name family`). + + Current replay status: + still WIP and intentionally not landed. + + April 7, 2026 replay findings on the current iteration-1 branch: + rebuilding the non-WIP subset (`b05dd2084` + `f91d9256f` + `96ff77a94`) + was correct on targeted `test_iet.py` / DSE checks, but the payoff was small + relative to the extra engine/utils complexity. + + Performance: + `b05dd2084` alone was flat-to-worse on the probes (`23.16 s -> 23.16 s` and + `72.34 s -> 73.66 s`). + Adding the two non-WIP `engine.py` follow-ups improved that to + `23.16 s -> 22.01 s` and `72.34 s -> 72.03 s`. + The subset was therefore dropped rather than landed: the light probe moved + nicely, but the heavy probe improved by only about `0.31 s`. + + Rationale: + this is the first bucket that spans multiple IET passes and shared helper + caches, so it is more invasive than the previous purely local fast paths. + +5. ~~Cheapen `Scope` construction and pairwise dependence pre-checks used by + fusion/topofusion.~~ + + Completed in `3f3017e46` via memoized `retrieve_accesses`, lazy cached + `IREq` read/write inventories, and reuse of cached function views in + `Scope`, `Cluster.traffic`, `Expression`, and `Operator`. + + Performance: + the landed cache/memoization batch moved the probes from + `29.32 s -> 27.45 s` and `93.16 s -> 84.56 s`. A narrower mid-iteration + replay of the `Scope`/access portion alone had already reached roughly + `27.65 s` and `86.21 s`. A later `TimedAccess` follow-up in `fd850927a` + moved the landed branch further from `24.65 s -> 23.16 s` and + `79.68 s -> 72.34 s`, for a total point-5-aligned move of roughly + `29.32 s -> 23.16 s` and `93.16 s -> 72.34 s`. + + Rationale: + these changes keep the same broad fusion algorithm, but they start replacing + repeated rescans with cached summaries and synthetic scopes. + +6. ~~Replace repeated generic fusion-hazard walks with focused hazard summaries, + and tighten derivative-driven rescans.~~ + + Relevant iteration-0 commits: + `8c2e76a99` (`CODEX: ITER 5`, `fusion_hazards` summary), + `024de93a2` (`clusters: Cheapen derivative topofusion hazards`), + `0abbe2cb9` (`clusters: Restrict derivative nofuse rescans`). + + Completed on the current branch in a simpler form than the original iteration-0 + patches: fusion hazard analysis now reuses the already-cached per-ClusterGroup + `Scope` inventories and synthesizes cross-scope dependences via + `Scope.from_scopes(...)`, instead of repeatedly constructing fresh + `Scope(exprs0 + exprs1)` objects from raw expressions. The derivative side is + also now bounded: `lower_index_derivatives` runs at most `max_depth` + `toposort='nofuse'` rounds, where `max_depth` is the maximum nested + `IndexDerivative.depth` across the input clusters, and then finishes with the + usual plain `fuse(False)`. + + Performance: + compared with the pre-fusion landed state, the probes moved from + `23.16 s -> 13.63 s` and `72.34 s -> 32.00 s`. + + Rationale: + this turned out to be the dominant remaining algorithmic win after the earlier + caching groundwork was in place. The essential gain is sparing repeated + expression rescans during fusion/topofusion legality checks. + + Deferred April 17, 2026 follow-up: + while profiling the PRO heavy `velocity_then_stress` compile on the paired + OSS/PRO `faster-python-1` worktrees, `minimize_barrier_likelihood` + consistently spent about `2.5-2.6 s` inside `fuse(toposort=True)`, with + `_build_dag` and `_fusion_hazards` dominating that cost. A trial OSS patch + in `Fusion._build_dag` skipped `_fusion_hazards` for unfenced ClusterGroup + pairs whose scopes cannot possibly interact + (`cg0.scope.writes.keys().isdisjoint(cg1.scope.functions)` and vice versa). + + Measured effect: + `_fusion_hazards` calls dropped from about `47k` to about `5.7k`, and the + barrier-minimization slice improved by about `0.12-0.17 s`, but the end-to-end + heavy compile-time win was noisy and marginal. Focused OSS topofusion/barrier + tests passed, but the change was still deferred rather than landed. + + Why deferred: + this is exactly the kind of fast path that is easy to justify locally but + hard to value globally. The measured win is real but small, and fusion/toposort + is regression-prone enough that carrying extra control-flow in this area + should require a clearer compile-time payoff. + + If revisited later: + keep the prefilter in `_build_dag`, not inside `_fusion_hazards`. + Moving it into `_fusion_hazards` would still pay the function-call and + memoization overhead that the experiment was specifically avoiding, while + `fenced` is a `_build_dag` scheduling concern rather than a property of the + pairwise hazard relation itself. + +7. Add concurrency inside expression lowering only after the single-threaded + fast paths are understood. + + Relevant iteration-0 commits: + `cd8bbec49` (`equations: Thread per-expression lowering`), + `0f8d775c3` (`operator: Thread expression evaluation`). + + Rationale: + threading can move the needle, but it also introduces option plumbing, + scheduling questions, and failure modes that are harder to debug than the + earlier single-threaded wins. + +8. Add concurrency inside fusion/toposort last. + + Relevant iteration-0 commits: + `e94ee8b52` (`CODEX: ITER 7`, `fuse-workers` and threaded DAG row building). + + Rationale: + this depends on the earlier `Scope` and hazard-summary work, and it sits in a + particularly regression-prone area of the compiler. + +9. ~~Treat aggressive object/space caching as a late experiment, not an + initial iteration-1 target.~~ + + Completed in a conservative form in `e16f222e1` via cached + `Interval`/`IterationSpace`-family objects, immutable/hashable `Properties`, + `Prefix._preprocess_args`, and the no-op `Cluster.rebuild()` fast path + above. + + Performance: + compared with the pre-space/rebuild branch, the landed diff moved the probes + from `27.45 s -> 24.65 s` and `84.56 s -> 79.68 s`. + + Rationale: + iteration 0 showed that this class of optimization can improve compile-time + behavior, but it also showed that the semantic risk is high enough that it + should not be part of the first iteration-1 subset. + + April 28, 2026 landed follow-up: + the current branch now extends this bucket with the EqBlock/Cluster split in + `fc755479d` and cached immutable-object hashes in `771f807ab`. This is the + first object-caching follow-up in a while that clearly moved the main heavy + benchmark rather than only shaving noise: the heavy `velocity+stress` probe + moved from the previous `25.8-26.0 s` plateau through about `24.4 s` after + EqBlock caching, then to `22.8-22.9 s` after cached support-space and + Cluster queue hashes. The `cached_hash` result also confirms that repeated + hashing was a real compile-time cost, not just profiling noise. + +Regression-fix commits such as `cc6ee524a`, `6bc7ea1fd`, `9014e0ad0`, and +`d8981b0de` are intentionally not part of the ordered list above. They matter +for keeping iteration 0 green, but they are correctness follow-ups rather than +the primary optimization ideas to replay in iteration 1. + +April 22, 2026 IET / bitcomp+serialization (`heavy_IO`) follow-up: + +- New `heavy_IO` PRO scratch benchmark: + start from the current `heavy` `velocity_then_stress` case and add one + bitcomp+serialized saved `TimeFunction` per `tau_*` component. + +- Paired clean baseline: + about `30.3-31.2 s` total compile, with `optimize_kernels` still around + `12.5-12.7 s` and the extra cost landing primarily in `lowering.IET` + (`~9.4-9.7 s`). + +- Profiling conclusions: + `lower_async_objs` scanning is not the dominant new cost; the more relevant + IET-side work is in `update_args` and in the second `place_definitions` + pass triggered after `pthreadify`. + +- Reverted experiment 1: + simplify `engine.py:update_args` by collapsing the separate + `FindSymbols('basics')` / `FindSymbols('symbolics')` scans and computing + `drop_params` directly by index. + + Result: + the compile-time probe looked mildly positive/noisy, but narrow + compressed-layer runtime tests failed with the same + `nbytes_avail_mapper` / `deviceid=-1` breakage, so this is not safe as-is. + +- Reverted experiment 2: + after `pthreadify`, rerun `place_definitions` only on callables touched by + async lowering rather than across the whole graph. + + Result: + this was the strongest local compile-time signal in the new `heavy_IO` + benchmark: + the second-epoch `place_definitions` visits dropped from `31` to `5`, and + the heavy compile moved into roughly the `30.2-30.4 s` band. + However, the same compressed-layer runtime tests failed, so the idea was + reverted as well. + +- Current recommendation: + treat the async/definitions area as the right place to look for the new + heavier benchmark, but do not carry either of the above optimizations + without a stronger correctness story. The paired OSS worktree should stay at + clean `HEAD`. + +- April 22 late follow-up: + a narrower post-`pthreadify` rerun of `place_definitions` does look viable + after all. Instead of revisiting the whole graph, the current worktree now + reruns the pass only on async-owned callables: the transformed + `ThreadCallable`s, the helper callables (`activate*`, `init_sdata*`, + `shutdown*`), and callers that reference those helpers. This is implemented + by allowing `Graph.apply(..., targets=...)` and passing the selected names + from `pthreadify`. + + Validation: + the CPU layered async cases + `tests/test_layered_funcs.py::TestSerialization::test_diskhost[...]` with + `buf-async-degree=1` still pass on the paired worktrees. + The higher-degree `buf-async-degree=4` variant remains baseline-red because + of the pre-existing `npthreads0` codegen issue, so it is not a useful gate. + + Performance: + on the `heavy_IO` bitcomp+serialization benchmark, the second-epoch + `place_definitions` visits shrink from `31` down to `5`, and the local IET + bucket improves from about `1.80 s` to about `1.57-1.66 s`. + End-to-end compile time is a small but repeatable win on the latest paired + reruns, moving from roughly `30.54 s` to about `30.24-30.49 s`. + +April 30, 2026 IET memoization / no-op rebuild follow-up: + +- Baseline before this IET-focused patch series: + - `heavy`: `21.99 s`, `21.86 s`, `22.26 s`; average `22.04 s`. + `lowering.IET`: `5.55 s`, `5.49 s`, `5.89 s`; average `5.64 s`. + - `heavy_IO`: `26.38 s`, `26.42 s`, `26.51 s`; average `26.44 s`. + `lowering.IET`: `8.98 s`, `8.96 s`, `9.54 s`; average `9.16 s`. + +- Current simplified patch: + - memoize public `create_call_graph`, with callers passing + `as_hashable(self.efuncs)` / `as_hashable(efuncs)` rather than using a + private cached helper; + - memoize public `abstract_efunc`; + - memoize public `abstract_objects` directly. `rg` across OSS and PRO shows + no caller passes an explicit `sregistry`, so the old optional parameter was + removed and the function now always uses its local `SymbolRegistry`; + - simplify IET `reuse_if_unchanged` by using `Node._same_arg` instead of a + duplicate local kwarg comparison helper. + +- Dropped follow-up: + a generic `memoized_func` key-path optimization was tested but left out of the + patch. It appeared mildly positive in one set of runs, but was not necessary + for the main IET win and is too broad for this focused change. + +- Current measured performance with the simplified patch and unchanged + `memoized_func`: + - `heavy`: `21.63 s`, `21.53 s`, `21.57 s`; average `21.58 s`. + - `heavy_IO`: `25.32 s`, `25.40 s`, `25.35 s`; average `25.36 s`. + - net improvement versus the pre-patch reference is about `0.46 s` on + `heavy` and about `1.08 s` on `heavy_IO`. + +- Validation: + targeted OSS IET/tool tests passed: + `/app/devitopro/submodules/devito/tests/test_iet.py`, + `/app/devitopro/submodules/devito/tests/test_visitors.py`, + `/app/devitopro/submodules/devito/tests/test_tools.py` (`72 passed`). + +- Interpretation: + the durable win is in the IET callable-deduplication/reuse path, especially + repeated call-graph creation and repeated abstraction of structurally stable + callables. Dropping `abstract_objects` caching regressed `heavy_IO` back to + roughly `25.5 s`, so that cache is worth keeping now that the unused + `sregistry` parameter has been removed. + +May 4, 2026 benchmark refresh after the no-op IET transform and visitor-cache +follow-ups: + +- Setup: + PRO `faster-python-1` worktree with paired OSS `faster-python-1`, CUDA docker + image `devitopro-cuda:latest`, GPU device `3`, launcher pinned with + `taskset 0-15`. The three schedule-info probes were run in one pytest-docker + invocation. + +- `stress-only` (`test_profile_etti_stress_like_schedule_infos`): + - total compile: `10.06 s`; + - `lowering.Clusters`: `5.52 s`; + - `specializing.Clusters`: `4.18 s`; + - `optimize_kernels`: `3.39 s`; + - `lowering.IET`: `3.23 s`; + - `specializing.IET`: `3.00 s`; + - IET notable buckets: `make_parallel 1.59 s`, + `_place_transfers 0.70 s`, `place_definitions 0.29 s`; + - kernelopt `fuse`: `0.60 s`. + +- `heavy` velocity+stress + (`test_profile_etti_velocity_then_stress_like_schedule_infos`): + - total compile: `21.63 s`; + - `lowering.Clusters`: `14.69 s`; + - `specializing.Clusters`: `11.17 s`; + - `optimize_kernels`: `10.08 s`; + - `lowering.IET`: `5.02 s`; + - `specializing.IET`: `4.43 s`; + - IET notable buckets: `make_parallel 1.47 s`, + `_place_transfers 1.37 s`, `place_definitions 0.65 s`, + `linearization 0.28 s`; + - kernelopt `fuse`: `1.82 s`. + +- `heavy_IO` velocity+stress plus bitcomp+serialization + (`test_profile_etti_velocity_then_stress_like_bitcomp_serial_schedule_infos`): + - total compile: `25.26 s`; + - `lowering.Clusters`: `15.58 s`; + - `specializing.Clusters`: `11.63 s`; + - `optimize_kernels`: `10.53 s`; + - `lowering.IET`: `7.59 s`; + - `specializing.IET`: `6.85 s`; + - IET notable buckets: `make_parallel 1.59 s`, + `place_definitions 1.52 s`, `lower_async_objs 1.16 s`, + `process 0.73 s`, `_place_transfers 0.54 s`, + `linearization 0.47 s`; + - kernelopt `fuse`: `1.95 s`. + +- Interpretation: + the three current probes are still in the expected post-IET-cache band: + about `10.1 s` for stress-only, `21.5-21.7 s` for `heavy`, and + `25.0-25.3 s` for `heavy_IO`. The `FindNodes` visitor cache reduced direct + repeated visitor cost in profiling, but it remains a small/noise-level + end-to-end compile-time effect. The dominant open costs are still + `optimize_kernels`/cluster specialization and, for `heavy_IO`, the IET + async/definitions path. + +May 4, 2026 IET `reuse_efuncs` drill-down: + +- The expensive IET buckets in `heavy_IO` (`make_parallel`, + `place_definitions`, `_place_transfers`, `lower_async_objs`, and `process`) + are mostly paying common `Graph.apply` post-processing cost rather than pass + body cost. A temporary graph-phase profile showed: + - `Graph.apply` total: about `7.33 s` across `25` calls; + - `reuse_efuncs`: about `3.93 s` across `5` calls; + - pass bodies: about `2.17 s`; + - `update_args`: about `0.85 s`. + +- Inside `reuse_efuncs`, the hot path is abstraction/signature generation: + - before the new signature cache: `reuse_efuncs ~3.93 s`, + `abstract_efunc ~1.91 s`, `_signature ~1.75 s`; + - with IET `Node._signature()` memoized per node: `reuse_efuncs` drops to + about `3.62-3.69 s`, and `_signature` drops to about `1.41-1.44 s`. + +- The tested signature-cache patch was deliberately narrow: + IET `Node` overrode `_signature()` with `@memoized_meth` and delegated to + `Signer._signature()`, caching the SHA1 signature on the immutable-ish IET + node instance without caching the full CIR string. + +- Direct multiplicity check on `heavy_IO` showed why the patch is not a + meaningful end-to-end win: + - `_signature()` calls: `180`; + - unique IET nodes: `150`; + - repeated calls on the same node: only `30`; + - call histogram: `121` nodes called once, `28` nodes called twice, `1` node + called three times. + +- The remaining `abstract_efunc` body cost is still substantial. A temporary + body-level profile of `heavy_IO` showed about `150` misses and `30` hits + across the five `reuse_efuncs` calls. Miss cost split roughly as: + - `Uxreplace`: `0.63 s`; + - `abstract_objects`: `0.63 s`; + - `FindSymbols('basics|symbolics|dimensions')`: `0.23 s`. + +- Dropped variants: + - IET `Node._signature()` memoization was dropped after the multiplicity + check. There are not enough repeated calls on the same node to justify even + this small cache as a production change; + - filtering identity mappings out of `abstract_objects` was slower in + practice; `abstract_objects` increased from about `0.63 s` to about + `1.62 s` in the instrumented run, because rebuilding the mapper dominated; + - returning raw CIR from IET `Node._signature()` instead of the SHA1 digest + was also rejected. It retains large strings and made the instrumented + profile noisier/worse, without a clear wall-time win. + +- Validation and benchmark signal from the rejected signature-cache patch: + - targeted OSS IET/visitor tests still pass: + `/app/devitopro/submodules/devito/tests/test_iet.py` and + `/app/devitopro/submodules/devito/tests/test_visitors.py` + (`42 passed`); + - the earlier `heavy 22.25 s` combined-run sample was confirmed noisy and + should be ignored. + +- May 4 rerun, three combined invocations before and after the signature-cache + patch, same setup (`devitopro-cuda:latest`, GPU `3`, `taskset 0-15`): + - without signature cache: + `stress-only 10.02/10.03/10.00 s` (avg `10.02 s`), + `heavy 21.29/21.27/21.27 s` (avg `21.28 s`), + `heavy_IO 24.80/24.66/24.63 s` (avg `24.70 s`); + - with signature cache: + `stress-only 10.02/10.03/9.98 s` (avg `10.01 s`), + `heavy 21.36/21.40/21.29 s` (avg `21.35 s`), + `heavy_IO 24.55/24.49/24.37 s` (avg `24.47 s`). + +- Interpretation: + memoizing IET node signatures is not worth keeping. The end-to-end signal is + neutral for `stress-only`, neutral/slightly negative for `heavy`, and only + mildly positive for `heavy_IO` (`~0.23 s`). The direct multiplicity check + shows the cache surface is tiny: only `30/180` calls are repeated on the same + node. The next meaningful IET win is unlikely to come from the individual + pass bodies. It would need to reduce repeated `abstract_efunc` misses, likely + by making `reuse_efuncs` more incremental/cache-aware across successive + `Graph.apply` calls. diff --git a/conftest.py b/conftest.py index b2d49697fb..e1de26f7aa 100644 --- a/conftest.py +++ b/conftest.py @@ -183,6 +183,7 @@ def parallel(item, m): raise ValueError(f"Can't run test: unexpected mode `{m}`") env_vars = {'DEVITO_MPI': scheme} + timeout = item.get_closest_marker("parallel").kwargs.get('timeout', 300) pyversion = sys.executable testname = get_testname(item) @@ -197,7 +198,7 @@ def parallel(item, m): # OpenMPI requires an explicit flag for oversubscription. We need it as some # of the MPI tests will spawn lots of processes if mpi_distro == 'OpenMPI': - call = [mpi_exec, '--oversubscribe', '--timeout', '300'] + args + call = [mpi_exec, '--oversubscribe', '--timeout', str(timeout)] + args else: call = [mpi_exec] + args @@ -228,7 +229,7 @@ def pytest_configure(config): """Register an additional marker.""" config.addinivalue_line( "markers", - "parallel(mode): mark test to run in parallel" + "parallel(mode, timeout=300): mark test to run in parallel" ) config.addinivalue_line( "markers", diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index 1a17ba6da1..7ad821c1ce 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -496,6 +496,7 @@ def parse_product_arch(): def get_visible_devices(): device_vars = ( 'CUDA_VISIBLE_DEVICES', + 'NVIDIA_VISIBLE_DEVICES', 'ROCR_VISIBLE_DEVICES', 'HIP_VISIBLE_DEVICES' ) diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 22984af6fb..cf514ee8d6 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -6,7 +6,9 @@ import sympy -from devito.tools import Pickable, as_mapper, as_tuple, frozendict, is_integer +from devito.tools import ( + Pickable, as_mapper, as_tuple, frozendict, is_integer, memoized_func +) from devito.types.dimension import Dimension from devito.types.utils import DimensionTuple from devito.warnings import warn @@ -546,6 +548,7 @@ def _evaluate(self, **kwargs): def _eval_deriv(self): return self._eval_fd(self.expr) + @memoized_func(scope='build') def _eval_fd(self, expr, **kwargs): """ Evaluate the finite-difference approximation of the Derivative. diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 503393b089..63c3a700a3 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -976,6 +976,10 @@ def compare(self, other): def base(self): return self.expr.func(*[a for a in self.expr.args if a is not self.weights]) + @cached_property + def pivot(self): + return self.base.subs({d: 0 for d in self.dimensions}) + @property def weights(self): return self._weights diff --git a/devito/finite_differences/finite_difference.py b/devito/finite_differences/finite_difference.py index 30199fb3d8..bdf3199b0d 100644 --- a/devito/finite_differences/finite_difference.py +++ b/devito/finite_differences/finite_difference.py @@ -170,14 +170,15 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici # `coefficients` method (`taylor` or `symbolic`) if weights is None: weights = fd_weights_registry[coefficients](expr, deriv_order, indices, x0) - if isinstance(weights, Iterable) and len(weights) != len(indices): + _, wdim, _ = process_weights(weights, expr, dim) + elif isinstance(weights, Iterable) and len(weights) != len(indices): warning(f"Number of weights ({len(weights)}) does not match " f"number of indices ({len(indices)}), reverting to Taylor") scale = False + wdim = None weights = fd_weights_registry['taylor'](expr, deriv_order, indices, x0) # Did fd_weights_registry return a new Function/Expression instead of a values? - _, wdim, _ = process_weights(weights, expr, dim) if wdim is not None: weights = [weights._subs(wdim, i) for i in range(len(indices))] diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index 69e66ce4e6..ee04cdfb7d 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -228,10 +228,14 @@ def make_stencil_dimension(expr, _min, _max): @cacheit -def numeric_weights(function, deriv_order, indices, x0): +def _numeric_weights(deriv_order, indices, x0): return finite_diff_weights(deriv_order, indices, x0)[-1][-1] +def numeric_weights(function, deriv_order, indices, x0): + return _numeric_weights(deriv_order, indices, x0) + + fd_weights_registry = {'taylor': numeric_weights, 'standard': numeric_weights, 'symbolic': numeric_weights} # Backward compat for 'symbolic' coeff_priority = {'taylor': 1, 'standard': 1} diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index 3f87496c62..e899367173 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -21,7 +21,10 @@ from devito.tools import ctypes_to_cstr, ctypes_vector_mapper, dtype_to_ctype from devito.types.basic import AbstractFunction -__all__ = ['BasePrinter', 'ccode'] +__all__ = ['BasePrinter', 'ccode', 'get_printer'] + +_preset_dtypes = (np.float32, np.float64, np.complex64, np.complex128) +_printer_registry = {} class BasePrinter(CodePrinter): @@ -449,15 +452,33 @@ def _print_Fallback(self, expr): sympy.printing.str.StrPrinter._print_Add = BasePrinter._print_Add -def ccode(expr, printer=None, **settings): +def get_printer(printer, dtype=None): + try: + registry = _printer_registry[printer] + except KeyError: + default = printer() + registry = {None: default, default.dtype: default} + for i in _preset_dtypes: + registry.setdefault(i, printer(settings={'dtype': i})) + _printer_registry[printer] = registry + + try: + return registry[dtype] + except KeyError: + handle = printer(settings={'dtype': dtype}) + registry[dtype] = handle + return handle + + +def ccode(expr, printer=None, dtype=None): """Generate C++ code from an expression. Parameters ---------- expr : expr-like The expression to be printed. - settings : dict - Options for code printing. + dtype : data-type, optional + Data type used by the printer. Returns ------- @@ -468,4 +489,4 @@ def ccode(expr, printer=None, **settings): if printer is None: from devito.passes.iet.languages.C import CPrinter printer = CPrinter - return printer(settings=settings).doprint(expr, None) + return get_printer(printer, dtype).doprint(expr, None) diff --git a/devito/ir/clusters/analysis.py b/devito/ir/clusters/analysis.py index 5ebae71b0f..f78f1ee456 100644 --- a/devito/ir/clusters/analysis.py +++ b/devito/ir/clusters/analysis.py @@ -101,7 +101,7 @@ def _callback(self, clusters, dim, prefix): is_parallel_atomic = False scope = Scope(flatten(c.exprs for c in clusters)) - for dep in scope.d_all_gen(): + for dep in scope.d_all_gen(writes=scope.writes_tensor): test00 = dep.is_indep(dim) and not dep.is_storage_related(dim) test01 = all(dep.is_reduce_atmost(i) for i in prev) if test00 and test01: @@ -112,10 +112,6 @@ def _callback(self, clusters, dim, prefix): is_parallel_indep &= (dep.distance_mapper.get(dim.root) == 0) continue - if dep.function in scope.initialized: - # False alarm, the dependence is over a locally-defined symbol - continue - if dep.is_reduction: is_parallel_atomic = True continue diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index cbf206b3ff..717b508b92 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -8,13 +8,15 @@ from devito.ir.support import ( PARALLEL, PARALLEL_IF_PVT, BaseGuardBoundNext, DataSpace, Forward, Guards, Interval, IntervalGroup, IterationSpace, PrefetchUpdate, Properties, Scope, WaitLock, WithLock, - detect_accesses, detect_io, maximum, minimum, normalize_properties, normalize_syncs, - null_ispace, tailor_properties, update_properties + detect_accesses, maximum, minimum, normalize_properties, normalize_syncs, null_ispace, + tailor_properties, update_properties ) from devito.mpi.halo_scheme import HaloScheme, HaloTouch from devito.mpi.reduction_scheme import DistReduce from devito.symbolics import estimate_cost -from devito.tools import as_tuple, filter_ordered, flatten, infer_dtype +from devito.tools import ( + CacheInstances, as_tuple, cached_hash, filter_ordered, flatten, infer_dtype +) from devito.types import ( CriticalRegion, Fence, Indexed, PhaseMarker, TensorMove, ThreadArrive, ThreadCommit, ThreadPoolSync, ThreadWait, WeakFence @@ -23,110 +25,45 @@ __all__ = ["Cluster", "ClusterGroup"] -class Cluster: +class EqBlock(CacheInstances): """ - A Cluster is an ordered sequence of expressions in an IterationSpace. - - Parameters - ---------- - exprs : expr-like or list of expr-like - An ordered sequence of expressions computing a tensor. - ispace : IterationSpace, optional - The Cluster iteration space. - guards : dict, optional - Mapper from Dimensions to expr-like, representing the conditions under - which the Cluster should be computed. - properties : dict, optional - Mapper from Dimensions to Property, describing the Cluster properties - such as its parallel Dimensions. - syncs : dict, optional - Mapper from Dimensions to lists of SyncOps, that is ordered sequences of - synchronization operations that must be performed in order to compute the - Cluster asynchronously. - halo_scheme : HaloScheme, optional - The halo exchanges required by the Cluster. + A sequence of equations with associated metadata. """ + @classmethod + def _preprocess_args(cls, exprs, ispace=null_ispace, guards=None, + properties=None, syncs=None, halo_scheme=None): + exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs)) + guards = Guards(guards or {}) + properties = Properties(properties or {}) + syncs = normalize_syncs(syncs or {}) + + return (exprs, ispace, guards, properties, syncs, halo_scheme), {} + def __init__(self, exprs, ispace=null_ispace, guards=None, properties=None, syncs=None, halo_scheme=None): - self._exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs)) + self._exprs = exprs self._ispace = ispace - self._guards = Guards(guards or {}) - self._syncs = normalize_syncs(syncs or {}) - - properties = Properties(properties or {}) - properties = tailor_properties(properties, ispace) - self._properties = update_properties(properties, self.exprs) - + self._guards = guards + self._syncs = syncs self._halo_scheme = halo_scheme - def __repr__(self): - return "Cluster([{}])".format(('\n' + ' '*9).join(f'{i}' for i in self.exprs)) - - @classmethod - def from_clusters(cls, *clusters): - """ - Build a new Cluster from a sequence of pre-existing Clusters with - compatible IterationSpace. - """ - assert len(clusters) > 0 - root = clusters[0] - - if len(clusters) == 1: - return root - - if not all(root.ispace.is_compatible(c.ispace) for c in clusters): - raise ValueError("Cannot build a Cluster from Clusters with " - "incompatible IterationSpace") - if not all(root.guards == c.guards for c in clusters): - raise ValueError("Cannot build a Cluster from Clusters with " - "non-homogeneous guards") - - writes = set().union(*[c.scope.writes for c in clusters]) - reads = set().union(*[c.scope.reads for c in clusters]) - if any(f._mem_shared for f in writes & reads): - raise ValueError("Cannot build a Cluster from Clusters with " - "read-write conflicts on shared-memory Functions") - - exprs = chain(*[c.exprs for c in clusters]) - ispace = IterationSpace.union(*[c.ispace for c in clusters]) - - guards = root.guards - - properties = reduce_properties(clusters) - - try: - syncs = normalize_syncs(*[c.syncs for c in clusters]) - except ValueError as e: - raise ValueError( - "Cannot build a Cluster from Clusters with " - "non-compatible synchronization operations" - ) from e - - halo_scheme = HaloScheme.union([c.halo_scheme for c in clusters]) - - return Cluster(exprs, ispace, guards, properties, syncs, halo_scheme) + properties = tailor_properties(properties, ispace) + self._properties = update_properties(properties, self._exprs) - def rebuild(self, *args, **kwargs): - """ - Build a new Cluster from the attributes given as keywords. All other - attributes are taken from ``self``. - """ - # Shortcut for backwards compatibility - if args: - if len(args) != 1: - raise ValueError("rebuild takes at most one positional argument (exprs)") - if kwargs.get('exprs'): - raise ValueError("`exprs` provided both as arg and kwarg") - kwargs['exprs'] = args[0] + def __eq__(self, other): + return (type(self) is type(other) and + self.exprs == other.exprs and + self.ispace == other.ispace and + self.guards == other.guards and + self.properties == other.properties and + self.syncs == other.syncs and + self.halo_scheme == other.halo_scheme) - return self.__class__(exprs=kwargs.get('exprs', self.exprs), - ispace=kwargs.get('ispace', self.ispace), - guards=kwargs.get('guards', self.guards), - properties=kwargs.get('properties', self.properties), - syncs=kwargs.get('syncs', self.syncs), - halo_scheme=kwargs.get('halo_scheme', self.halo_scheme)) + def __hash__(self): + return hash((self.exprs, self.ispace, self.guards, self.properties, + self.syncs, self.halo_scheme)) @property def exprs(self): @@ -382,8 +319,8 @@ def dtype(self): performing integer arithmetic are ignored, assuming that they are only carrying out array index calculations. - If two expressions perform calculations with different precision, the - data type with highest precision is returned. + If two expressions perform calculations with different precision, + the data type with highest precision is returned. """ dtypes = set() for i in self.exprs: @@ -399,8 +336,8 @@ def dtype(self): @cached_property def dspace(self): """ - Derive the DataSpace of the Cluster from its expressions, IterationSpace, - and Guards. + Derive the DataSpace of the Cluster from its expressions, + IterationSpace, and Guards. """ accesses = detect_accesses(self.exprs) @@ -491,7 +428,8 @@ def traffic(self): ----- If a Function is both read and written, then it is counted twice. """ - reads, writes = detect_io(self.exprs, relax=True) + reads = flatten(i.read_functions_relaxed for i in self.exprs) + writes = flatten(i.write_functions_relaxed for i in self.exprs) accesses = [(i, 'r') for i in reads] + [(i, 'w') for i in writes] # Ordering isn't important at this point, so returning an unordered @@ -525,6 +463,156 @@ def traffic(self): return ret +class Cluster: + + """ + A context-sensitive sequence of equations. + + The structural payload (equations, IterationSpace, ...) lives in the + underlying EqBlock. A Cluster, unlike EqBlock, deliberately keeps identity + semantics because its position in a sequence of Clusters does matter. It + follows that two Cluster instances may share the same EqBlock, but they + remain distinct: Clusters intentionally use object identity for equality + and hashing, so only references to the same Cluster object compare equal. + + Parameters + ---------- + exprs : expr-like or list of expr-like + An ordered sequence of expressions computing a tensor. + ispace : IterationSpace, optional + The Cluster iteration space. + guards : dict, optional + Mapper from Dimensions to expr-like, representing the conditions under + which the Cluster should be computed. + properties : dict, optional + Mapper from Dimensions to Property, describing the Cluster properties + such as its parallel Dimensions. + syncs : dict, optional + Mapper from Dimensions to lists of SyncOps, that is ordered sequences of + synchronization operations that must be performed in order to compute the + Cluster asynchronously. + halo_scheme : HaloScheme, optional + The halo exchanges required by the Cluster. + """ + + def __init__(self, exprs, ispace=null_ispace, guards=None, properties=None, + syncs=None, halo_scheme=None): + self._block = EqBlock(exprs, ispace, guards, properties, syncs, halo_scheme) + + def __repr__(self): + return "Cluster([{}])".format(('\n' + ' '*9).join(f'{i}' for i in self.exprs)) + + def __getattr__(self, name): + try: + block = object.__getattribute__(self, '_block') + except AttributeError: + raise AttributeError(name) from None + return getattr(block, name) + + @property + def exprs(self): + return self._block.exprs + + @property + def ispace(self): + return self._block.ispace + + @property + def guards(self): + return self._block.guards + + @property + def properties(self): + return self._block.properties + + @property + def syncs(self): + return self._block.syncs + + @property + def halo_scheme(self): + return self._block.halo_scheme + + @classmethod + def from_clusters(cls, *clusters): + """ + Build a new Cluster from a sequence of pre-existing Clusters with + compatible IterationSpace. + """ + assert len(clusters) > 0 + root = clusters[0] + + if len(clusters) == 1: + return root + + if not all(root.ispace.is_compatible(c.ispace) for c in clusters): + raise ValueError("Cannot build a Cluster from Clusters with " + "incompatible IterationSpace") + if not all(root.guards == c.guards for c in clusters): + raise ValueError("Cannot build a Cluster from Clusters with " + "non-homogeneous guards") + + writes = set().union(*[c.scope.writes for c in clusters]) + reads = set().union(*[c.scope.reads for c in clusters]) + if any(f._mem_shared for f in writes & reads): + raise ValueError("Cannot build a Cluster from Clusters with " + "read-write conflicts on shared-memory Functions") + + exprs = chain(*[c.exprs for c in clusters]) + ispace = IterationSpace.union(*[c.ispace for c in clusters]) + + guards = root.guards + + properties = reduce_properties(clusters) + + try: + syncs = normalize_syncs(*[c.syncs for c in clusters]) + except ValueError as e: + raise ValueError( + "Cannot build a Cluster from Clusters with " + "non-compatible synchronization operations" + ) from e + + halo_scheme = HaloScheme.union([c.halo_scheme for c in clusters]) + + return Cluster(exprs, ispace, guards, properties, syncs, halo_scheme) + + def rebuild(self, *args, **kwargs): + """ + Build a new Cluster from the attributes given as keywords. All other + attributes are taken from ``self``. + """ + # Shortcut for backwards compatibility + if args: + if len(args) != 1: + raise ValueError("rebuild takes at most one positional argument (exprs)") + if kwargs.get('exprs'): + raise ValueError("`exprs` provided both as arg and kwarg") + kwargs['exprs'] = args[0] + + exprs = kwargs.get('exprs', self.exprs) + ispace = kwargs.get('ispace', self.ispace) + guards = kwargs.get('guards', self.guards) + properties = kwargs.get('properties', self.properties) + syncs = kwargs.get('syncs', self.syncs) + halo_scheme = kwargs.get('halo_scheme', self.halo_scheme) + + if exprs is self.exprs and \ + ispace is self.ispace and \ + guards is self.guards and \ + properties is self.properties and \ + syncs is self.syncs and \ + halo_scheme is self.halo_scheme: + return self + + return self.__class__(exprs=exprs, + ispace=ispace, + guards=guards, + properties=properties, + syncs=syncs, + halo_scheme=halo_scheme) + + class ClusterGroup(tuple): """ @@ -552,6 +640,18 @@ def __new__(cls, clusters, ispace=None): return obj + def __eq__(self, other): + return (isinstance(other, ClusterGroup) and + super().__eq__(other) and + self._ispace == other._ispace) + + def __ne__(self, other): + return not self == other + + @cached_hash + def __hash__(self): + return hash((tuple(self), self._ispace)) + @classmethod def concatenate(cls, *cgroups): return list(chain(*cgroups)) diff --git a/devito/ir/clusters/visitors.py b/devito/ir/clusters/visitors.py index 11bcad5365..da0a62344a 100644 --- a/devito/ir/clusters/visitors.py +++ b/devito/ir/clusters/visitors.py @@ -2,7 +2,7 @@ from itertools import groupby from devito.ir.support import IterationSpace, null_ispace -from devito.tools import flatten, timed_pass +from devito.tools import cached_hash, flatten, timed_pass __all__ = ['Queue', 'cluster_pass'] @@ -113,6 +113,10 @@ def _process_fatd(self, clusters, level, prefix=None, **kwargs): class Prefix(IterationSpace): + @classmethod + def _preprocess_args(cls, ispace, guards, properties, syncs): + return (ispace, guards, properties, syncs), {} + def __init__(self, ispace, guards, properties, syncs): super().__init__(ispace.intervals, ispace.sub_iterators, ispace.directions) @@ -127,6 +131,7 @@ def __eq__(self, other): self.properties == other.properties and self.syncs == other.syncs) + @cached_hash def __hash__(self): return hash((self.intervals, self.sub_iterators, self.directions, self.guards, self.properties, self.syncs)) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 8d72704b79..52ac36473f 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -1,3 +1,4 @@ +from contextlib import suppress from functools import cached_property import numpy as np @@ -6,11 +7,12 @@ from devito.finite_differences.differentiable import diff2sympy from devito.ir.equations.algorithms import dimension_sort, lower_exprs from devito.ir.support import ( - GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses, - detect_io + GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses +) +from devito.symbolics import IntDiv, limits_mapper, retrieve_accesses, uxreplace +from devito.tools import ( + Pickable, Tag, as_hashable, filter_sorted, frozendict, reuse_if_unchanged ) -from devito.symbolics import IntDiv, limits_mapper, uxreplace -from devito.tools import Pickable, Tag, frozendict from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min __all__ = [ @@ -72,6 +74,10 @@ def state(self): def operation(self): return self._operation + def _hashable_content(self): + return (super()._hashable_content() + + tuple(as_hashable(getattr(self, i)) for i in self.__rkwargs__)) + @property def is_Reduction(self): return self.operation in (OpInc, OpMin, OpMax, OpMinMax) @@ -80,6 +86,82 @@ def is_Reduction(self): def is_Increment(self): return self.operation is OpInc + @cached_property + def _writes(self): + from devito.symbolics.queries import q_routine + + terminals = set(retrieve_accesses(self.lhs)) + if q_routine(self.rhs): + with suppress(AttributeError): + # Everything except: foreign routines, such as `cos` or `sin` etc. + terminals.update(self.rhs.writes) + + return tuple(terminals) + + @property + def writes(self): + return self._writes + + @cached_property + def reads_explicit(self): + terminals = set(retrieve_accesses(self.rhs, deep=True)) + with suppress(AttributeError): + terminals.update(retrieve_accesses(self.lhs.indices)) + + return tuple(terminals) + + @cached_property + def reads_conditional(self): + accesses = [] + for v in self.conditionals.values(): + accesses.extend(retrieve_accesses(v)) + + return tuple(accesses) + + @cached_property + def _reads(self): + return tuple(set(self.reads_explicit) | set(self.reads_conditional)) + + @property + def reads(self): + return self._reads + + @cached_property + def _read_functions(self): + found = [] + for i in self.reads: + with suppress(AttributeError): + i = i.function + found.append(i) + return tuple(filter_sorted(found)) + + @cached_property + def _write_functions(self): + found = [] + for i in self.writes: + with suppress(AttributeError): + i = i.function + found.append(i) + return tuple(filter_sorted(found)) + + @cached_property + def read_functions(self): + return tuple(i for i in self._read_functions if i.is_Input) + + @cached_property + def write_functions(self): + return tuple(i for i in self._write_functions if i.is_Input) + + @cached_property + def read_functions_relaxed(self): + return tuple(i for i in self._read_functions + if i.is_Input or i.is_AbstractFunction) + + @cached_property + def write_functions_relaxed(self): + return tuple(i for i in self._write_functions + if i.is_Input or i.is_AbstractFunction) + def apply(self, func): """ Apply a callable to `self` and each expr-like attribute carried by `self`, @@ -175,7 +257,7 @@ class LoweredEq(IREq): `LoweredEq.__rkwargs__` must appear in `kwargs`. """ - __rkwargs__ = IREq.__rkwargs__ + ('reads', 'writes') + __rkwargs__ = IREq.__rkwargs__ def __new__(cls, *args, **kwargs): if len(args) == 1 and isinstance(args[0], LoweredEq): @@ -250,20 +332,11 @@ def __new__(cls, *args, **kwargs): expr._ispace = ispace expr._conditionals = conditionals - expr._reads, expr._writes = detect_io(expr) expr._implicit_dims = input_expr.implicit_dims expr._operation = Operation.detect(input_expr) return expr - @property - def reads(self): - return self._reads - - @property - def writes(self): - return self._writes - def xreplace(self, rules): return LoweredEq(self.lhs.xreplace(rules), self.rhs.xreplace(rules), **self.state) @@ -292,6 +365,7 @@ class ClusterizedEq(IREq): These two properties make a ClusterizedEq suitable for use in a Cluster. """ + @reuse_if_unchanged('__rkwargs__') def __new__(cls, *args, **kwargs): if len(args) == 1: # origin: ClusterizedEq(expr, **kwargs) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index d106b5e811..626f2bef16 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -6,7 +6,7 @@ from collections import OrderedDict, namedtuple from collections.abc import Iterable from contextlib import suppress -from functools import cached_property +from functools import cached_property, lru_cache import cgen as c from sympy import IndexedBase, sympify @@ -16,7 +16,7 @@ from devito.ir.equations import DummyEq, OpInc, OpMax, OpMin, OpMinMax from devito.ir.support import ( AFFINE, INBOUND, PARALLEL, PARALLEL_IF_ATOMIC, PARALLEL_IF_PVT, SEQUENTIAL, - VECTORIZED, Forward, PrefetchUpdate, Property, WithLock, detect_io + VECTORIZED, Forward, PrefetchUpdate, Property, WithLock ) from devito.symbolics import CallFromPointer, ListInitializer from devito.tools import ( @@ -102,27 +102,36 @@ class Node(Signer): def __new__(cls, *args, **kwargs): obj = super().__new__(cls) - argnames, _, _, defaultvalues, _, _, _ = inspect.getfullargspec(cls.__init__) - try: - defaults = dict( - zip(argnames[-len(defaultvalues):], defaultvalues, strict=True) - ) - except TypeError: - # No default kwarg values - defaults = {} - obj._args = {k: v for k, v in zip(argnames[1:], args, strict=False)} + argnames, defaults = _constructor_args(cls) + obj._args = {k: v for k, v in zip(argnames, args, strict=False)} obj._args.update(kwargs.items()) - obj._args.update({k: defaults.get(k) for k in argnames[1:] if k not in obj._args}) + obj._args.update({k: defaults.get(k) for k in argnames if k not in obj._args}) return obj def _rebuild(self, *args, **kwargs): """Reconstruct ``self``.""" handle = self._args.copy() # Original constructor arguments argnames = [i for i in self._traversable if i not in kwargs] - handle.update(OrderedDict([(k, v) for k, v in zip(argnames, args, strict=False)])) - handle.update(kwargs) + updates = OrderedDict([(k, v) for k, v in zip(argnames, args, strict=False)]) + updates.update(kwargs) + + if updates and all(self._same_arg(k, v) for k, v in updates.items()): + return self + + handle.update(updates) return type(self)(**handle) + def _same_arg(self, key, value): + with suppress(AttributeError): + if _same_as_before(getattr(self, key), value): + return True + + with suppress(KeyError): + if _same_as_before(self._args[key], value): + return True + + return False + @cached_property def ccode(self): """ @@ -452,7 +461,7 @@ def rhs(self): @cached_property def reads(self): """The Functions read by the Expression.""" - return detect_io(self.expr, relax=True)[0] + return self.expr.read_functions_relaxed @cached_property def write(self): @@ -1558,9 +1567,6 @@ def DummyExpr(*args, init=False): return Expression(DummyEq(*args), init=init) -BlankLine = CBlankLine() - - # Nodes required for distributed-memory halo exchange @@ -1635,3 +1641,54 @@ def functions(self): Iteration/Expression tree. ``local`` is a boolean indicating whether the definition of the callable is known or not. """ + + +# *** Utils + + +@lru_cache(maxsize=None) +def _constructor_args(cls): + """ + Return cached constructor argument names and default values for an IET type. + + IET node construction records the original constructor arguments in + ``_args``. This helper avoids repeating ``inspect.getfullargspec`` for every + node instance of the same class. + """ + argnames, _, _, defaultvalues, _, _, _ = inspect.getfullargspec(cls.__init__) + if defaultvalues is None: + defaults = {} + else: + defaults = dict(zip(argnames[-len(defaultvalues):], defaultvalues, strict=True)) + + return tuple(argnames[1:]), defaults + + +def _same_as_before(old, new): + """ + Return True if ``new`` preserves the object identity structure of ``old``. + + This intentionally does not use equality for arbitrary objects. It only + recurses through common containers and otherwise requires object identity, + which keeps no-op rebuild detection compatible with IET mapper semantics. + """ + if old is new: + return True + + if isinstance(old, (tuple, list)) and isinstance(new, (tuple, list)): + return len(old) == len(new) and all( + _same_as_before(i, j) for i, j in zip(old, new, strict=True) + ) + + if type(old) is not type(new): + return False + + if isinstance(old, dict): + return old.keys() == new.keys() and all( + _same_as_before(v, new[k]) for k, v in old.items() + ) + + return False + + +BlankLine = CBlankLine() diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 2744da24d6..caa7b137af 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -7,6 +7,7 @@ import ctypes from collections import OrderedDict from collections.abc import Callable, Generator, Iterable, Iterator, Sequence +from contextlib import suppress from itertools import chain, groupby from typing import Any, Generic, TypeVar @@ -15,9 +16,10 @@ from sympy.core.function import Application from devito.exceptions import CompilationError +from devito.ir.cgen.printer import get_printer from devito.ir.iet.nodes import ( BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node, - Section + Section, _same_as_before as same_as_before ) from devito.ir.support.space import Backward from devito.symbolics import ( @@ -26,7 +28,7 @@ from devito.symbolics.extended_dtypes import NoDeclStruct from devito.tools import ( GenericVisitor, as_tuple, c_restrict_void_p, filter_ordered, filter_sorted, flatten, - is_external_ctype, sorted_priority + is_external_ctype, memoized_weak_meth, sorted_priority ) from devito.types import ( ArrayObject, CompositeObject, DeviceMap, Dimension, IndexedData, Pointer @@ -255,8 +257,8 @@ def __init__(self, *args, printer=None, **kwargs): printer = CPrinter self.printer = printer - def ccode(self, expr, **kwargs): - return self.printer(settings=kwargs).doprint(expr, None) + def ccode(self, expr, dtype=None): + return get_printer(self.printer, dtype).doprint(expr, None) @property def _qualifiers_mapper(self): @@ -1113,12 +1115,17 @@ def _defines_aliases(n): def __init__(self, mode: str = 'symbolics') -> None: super().__init__() + self.mode = mode modes = mode.split('|') if len(modes) == 1: self.rule = self.rules[mode] else: self.rule = lambda n: chain(*[self.rules[mode](n) for mode in modes]) + @memoized_weak_meth(key=lambda i: i.mode, freeze=tuple, thaw=list) + def visit(self, o, *args, **kwargs): + return super().visit(o, *args, **kwargs) + def _post_visit(self, ret): return sorted(filter_ordered(ret, key=id), key=str) @@ -1165,8 +1172,13 @@ class FindNodes(LazyVisitor[Node, list[Node], None]): def __init__(self, match: type, mode: str = 'type') -> None: super().__init__() self.match = match + self.mode = mode self.rule = self.rules[mode] + @memoized_weak_meth(key=lambda i: (i.match, i.mode), freeze=tuple, thaw=list) + def visit(self, o, *args, **kwargs): + return super().visit(o, *args, **kwargs) + def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]: if self.rule(self.match, o): yield o @@ -1187,6 +1199,11 @@ def __init__(self, match: type, start: Node, stop: Node | None = None) -> None: self.start = start self.stop = stop + def visit(self, o, *args, **kwargs): + # `start` and `stop` are part of this visitor's state. + return GenericVisitor.visit(self, o, *args, **kwargs) + + def visit_object(self, o: object, flag: bool = False) -> LazyVisit[Node, bool]: yield from () return flag # noqa: B901 @@ -1234,8 +1251,13 @@ class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType], None]) def __init__(self, cls: type[ApplicationType] = Application): super().__init__() + self.cls = cls self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic) + @memoized_weak_meth(key=lambda i: i.cls, freeze=frozenset, thaw=set) + def visit(self, o, *args, **kwargs): + return super().visit(o, *args, **kwargs) + def _post_visit(self, ret): return set(ret) @@ -1319,6 +1341,13 @@ def __init__(self, mapper, nested=False): self.mapper = mapper self.nested = nested + def visit(self, o, *args, **kwargs): + # Subclasses may implement mapper-independent transformations. + if type(self) is Transformer and not self.mapper: + return o + + return super().visit(o, *args, **kwargs) + def transform(self, o, handle, **kwargs): if handle is None: # None -> drop `o` @@ -1332,12 +1361,12 @@ def transform(self, o, handle, **kwargs): else: children = o.children children = (tuple(handle) + children[0],) + tuple(children[1:]) - return o._rebuild(*children, **o.args_frozen) + return reuse_if_unchanged(o, *children, **o.args_frozen) else: # Replace `o` with `handle` if self.nested: children = [self._visit(i, **kwargs) for i in handle.children] - return handle._rebuild(*children, **handle.args_frozen) + return reuse_if_unchanged(handle, *children, **handle.args_frozen) else: return handle @@ -1346,7 +1375,12 @@ def visit_object(self, o, **kwargs): def visit_tuple(self, o, **kwargs): visited = tuple(self._visit(i, **kwargs) for i in o) - return tuple(i for i in visited if i is not None) + processed = tuple(i for i in visited if i is not None) + + if same_as_before(o, processed): + return o + + return processed visit_list = visit_tuple @@ -1357,7 +1391,7 @@ def visit_Node(self, o, **kwargs): children = [self._visit(i, **kwargs) for i in o.children] if o._traversable and not any(children) and any(o.children): return None - return o._rebuild(*children, **o.args_frozen) + return reuse_if_unchanged(o, *children, **o.args_frozen) def visit_Operator(self, o, **kwargs): raise ValueError("Cannot apply a Transformer visitor to an Operator directly") @@ -1374,8 +1408,14 @@ class Uxreplace(Transformer): The substitution rules. """ + def visit(self, o, *args, **kwargs): + if not self.mapper: + return o + + return super().visit(o, *args, **kwargs) + def visit_Expression(self, o): - return o._rebuild(expr=uxreplace(o.expr, self.mapper)) + return reuse_if_unchanged(o, expr=uxreplace(o.expr, self.mapper)) def _visit_Iteration_common(self, o): nodes = self._visit(o.nodes) @@ -1392,8 +1432,8 @@ def visit_Iteration(self, o): nodes, dimension, limits, pragmas, uindices = \ self._visit_Iteration_common(o) - return o._rebuild(nodes=nodes, dimension=dimension, limits=limits, - pragmas=pragmas, uindices=uindices) + return reuse_if_unchanged(o, nodes=nodes, dimension=dimension, limits=limits, + pragmas=pragmas, uindices=uindices) def visit_PragmaIteration(self, o): nodes, dimension, limits, pragmas, uindices = \ @@ -1420,7 +1460,7 @@ def visit_Return(self, o): def visit_Callable(self, o): body = self._visit(o.body) parameters = [self.mapper.get(i, i) for i in o.parameters] - return o._rebuild(body=body, parameters=parameters) + return reuse_if_unchanged(o, body=body, parameters=parameters) def visit_Call(self, o): arguments = [] @@ -1431,47 +1471,47 @@ def visit_Call(self, o): arguments.append(uxreplace(i, self.mapper)) if o.retobj is not None: retobj = uxreplace(o.retobj, self.mapper) - return o._rebuild(arguments=arguments, retobj=retobj) + return reuse_if_unchanged(o, arguments=arguments, retobj=retobj) else: - return o._rebuild(arguments=arguments) + return reuse_if_unchanged(o, arguments=arguments) def visit_Lambda(self, o): body = self._visit(o.body) parameters = [self.mapper.get(i, i) for i in o.parameters] - return o._rebuild(body=body, parameters=parameters) + return reuse_if_unchanged(o, body=body, parameters=parameters) def visit_Conditional(self, o): condition = uxreplace(o.condition, self.mapper) then_body = self._visit(o.then_body) else_body = self._visit(o.else_body) - return o._rebuild(condition=condition, then_body=then_body, - else_body=else_body) + return reuse_if_unchanged(o, condition=condition, then_body=then_body, + else_body=else_body) def visit_Switch(self, o): condition = uxreplace(o.condition, self.mapper) nodes = self._visit(o.nodes) default = self._visit(o.default) - return o._rebuild(condition=condition, nodes=nodes, default=default) + return reuse_if_unchanged(o, condition=condition, nodes=nodes, default=default) def visit_PointerCast(self, o): function = self.mapper.get(o.function, o.function) obj = self.mapper.get(o.obj, o.obj) - return o._rebuild(function=function, obj=obj) + return reuse_if_unchanged(o, function=function, obj=obj) def visit_Dereference(self, o): pointee = self.mapper.get(o.pointee, o.pointee) pointer = self.mapper.get(o.pointer, o.pointer) - return o._rebuild(pointee=pointee, pointer=pointer) + return reuse_if_unchanged(o, pointee=pointee, pointer=pointer) def visit_Pragma(self, o): arguments = [uxreplace(i, self.mapper) for i in o.arguments] - return o._rebuild(arguments=arguments) + return reuse_if_unchanged(o, arguments=arguments) def visit_PragmaTransfer(self, o): function = uxreplace(o.function, self.mapper) arguments = [uxreplace(i, self.mapper) for i in o.arguments] if o.imask is None: - return o._rebuild(function=function, arguments=arguments) + return reuse_if_unchanged(o, function=function, arguments=arguments) # An `imask` may be None, a list of symbols/numbers, or a list of # 2-tuples representing ranges @@ -1483,25 +1523,26 @@ def visit_PragmaTransfer(self, o): uxreplace(j, self.mapper))) except TypeError: imask.append(uxreplace(v, self.mapper)) - return o._rebuild(function=function, imask=imask, arguments=arguments) + return reuse_if_unchanged(o, function=function, imask=imask, + arguments=arguments) def visit_ParallelTree(self, o): prefix = self._visit(o.prefix) body = self._visit(o.body) nthreads = self.mapper.get(o.nthreads, o.nthreads) - return o._rebuild(prefix=prefix, body=body, nthreads=nthreads) + return reuse_if_unchanged(o, prefix=prefix, body=body, nthreads=nthreads) def visit_HaloSpot(self, o): hs = o.halo_scheme fmapper = {self.mapper.get(k, k): v for k, v in hs.fmapper.items()} halo_scheme = hs.build(fmapper, hs.honored) body = self._visit(o.body) - return o._rebuild(halo_scheme=halo_scheme, body=body) + return reuse_if_unchanged(o, halo_scheme=halo_scheme, body=body) def visit_While(self, o, **kwargs): condition = uxreplace(o.condition, self.mapper) body = self._visit(o.body) - return o._rebuild(condition=condition, body=body) + return reuse_if_unchanged(o, condition=condition, body=body) visit_ThreadedProdder = visit_Call @@ -1510,8 +1551,8 @@ def visit_KernelLaunch(self, o): grid = self.mapper.get(o.grid, o.grid) block = self.mapper.get(o.block, o.block) stream = self.mapper.get(o.stream, o.stream) - return o._rebuild(grid=grid, block=block, stream=stream, - arguments=arguments) + return reuse_if_unchanged(o, grid=grid, block=block, stream=stream, + arguments=arguments) # Utils @@ -1519,6 +1560,20 @@ def visit_KernelLaunch(self, o): blankline = c.Line("") +def reuse_if_unchanged(o, *children, **kwargs): + if children: + same_children = all( + same_as_before(i, j) for i, j in zip(o.children, children, strict=True) + ) + if not same_children: + return o._rebuild(*children, **kwargs) + + if kwargs and not all(o._same_arg(k, v) for k, v in kwargs.items()): + return o._rebuild(*children, **kwargs) + + return o + + def printAST(node, verbose=True): return PrintAST(verbose=verbose)._visit(node) diff --git a/devito/ir/stree/algorithms.py b/devito/ir/stree/algorithms.py index d4a761dfc8..68fc697d3e 100644 --- a/devito/ir/stree/algorithms.py +++ b/devito/ir/stree/algorithms.py @@ -111,7 +111,7 @@ def stree_build(clusters, profiler=None, **kwargs): else: parent = tip - NodeExprs(exprs, c.ispace, c.dspace, c.ops, c.traffic, parent) + NodeExprs(exprs, c.ispace, c.ops, c.traffic, parent) # Nest within a NodeSection if possible if profiler is None or \ diff --git a/devito/ir/stree/tree.py b/devito/ir/stree/tree.py index e033c9fd15..96e498396d 100644 --- a/devito/ir/stree/tree.py +++ b/devito/ir/stree/tree.py @@ -115,11 +115,10 @@ class NodeExprs(ScheduleTree): is_Exprs = True - def __init__(self, exprs, ispace, dspace, ops, traffic, parent=None): + def __init__(self, exprs, ispace, ops, traffic, parent=None): super().__init__(parent) self.exprs = exprs self.ispace = ispace - self.dspace = dspace self.ops = ops self.traffic = traffic diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 7939ee8fe8..d9cebed3fe 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Iterable from contextlib import suppress -from functools import cached_property +from functools import cached_property, wraps from itertools import chain, product import sympy @@ -10,12 +10,12 @@ from devito.ir.support.utils import AccessMode, extrema from devito.ir.support.vector import LabeledVector, Vector from devito.symbolics import ( - compare_ops, q_affine, q_comp_acc, q_constant, q_routine, retrieve_indexed, - retrieve_terminals, search, uxreplace + compare_ops, q_affine, q_comp_acc, q_constant, retrieve_accesses, + retrieve_indexed ) from devito.tools import ( - CacheInstances, Tag, as_mapper, as_tuple, filter_sorted, flatten, is_integer, - memoized_generator, memoized_meth, smart_gt, smart_lt + CacheInstances, Tag, as_mapper, as_tuple, cached_hash, filter_sorted, flatten, + is_integer, memoized_func, memoized_generator, memoized_meth, smart_gt, smart_lt ) from devito.types import ( ComponentAccess, CriticalRegion, Dimension, DimensionTuple, Fence, Function, Symbol, @@ -200,7 +200,7 @@ def is_scalar(self): return self.rank == 0 -class TimedAccess(IterationInstance, AccessMode): +class TimedAccess(IterationInstance, AccessMode, CacheInstances): """ A TimedAccess ties together an IterationInstance and an AccessMode. @@ -218,6 +218,12 @@ class TimedAccess(IterationInstance, AccessMode): on the values of the index functions and the access mode (read, write). """ + @classmethod + def _preprocess_args(cls, access, mode, timestamp, ispace=None): + if ispace is None: + ispace = null_ispace + return (access, mode, timestamp, ispace), {} + def __new__(cls, access, mode, timestamp, ispace=None): obj = super().__new__(cls, access) AccessMode.__init__(obj, mode=mode) @@ -247,6 +253,7 @@ def __eq__(self, other): self.access == other.access and self.ispace == other.ispace) + @cached_hash def __hash__(self): return hash((self.access, self.mode, self.timestamp, self.ispace)) @@ -320,6 +327,21 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp + def rebuild(self, **kwargs): + access = kwargs.get('access', self.access) + mode = kwargs.get('mode', self.mode) + timestamp = kwargs.get('timestamp', self.timestamp) + ispace = kwargs.get('ispace', self.ispace) + + if access is self.access and \ + mode is self.mode and \ + timestamp is self.timestamp and \ + ispace is self.ispace: + return self + + return TimedAccess(access, mode, timestamp, ispace) + + @memoized_meth def distance(self, other, logical=False): """ Compute the distance from ``self`` to ``other``. @@ -853,18 +875,91 @@ class Scope(CacheInstances): # Describes a rule for dependencies Rule = Callable[[TimedAccess, TimedAccess], bool] + def normalize_input(func): + + @wraps(func) + def wrapper(self, *args, writes=None, **kwargs): + mapped = {} + for k in as_tuple(writes or self.writes): + v = self.getwrites(k) + if v: + mapped[k] = v + return func(self, *args, writes=mapped, **kwargs) + + return wrapper + + @classmethod + @memoized_func(scope='build') + def from_scopes(cls, scope0, scope1): + """ + Build a synthetic Scope out of two existing Scopes by reusing their + cached reads and writes rather than rediscovering accesses from the + underlying expressions. + + This is used to analyze cross-scope dependences cheaply, for example in + loop-fusion hazard checks. Return None if the two Scopes cannot induce + any cross-scope dependences. + """ + offset = len(scope0.exprs) + + targets = ( + set(scope0.writes) & scope1.functions + ) | ( + set(scope1.writes) & scope0.functions + ) + if not targets: + return None + + def is_cross(source, sink): + t0 = source.timestamp + t1 = sink.timestamp + return t0 < offset <= t1 or t1 < offset <= t0 + + reads = {} + writes = {} + + for f in targets: + shifted = tuple( + i.rebuild(timestamp=i.timestamp + offset) + for i in scope1.getreads(f) + ) + accesses = scope0.getreads(f) + if shifted: + accesses = accesses + shifted if accesses else shifted + if accesses: + reads[f] = accesses + + shifted = tuple( + i.rebuild(timestamp=i.timestamp + offset) + for i in scope1.getwrites(f) + ) + accesses = scope0.getwrites(f) + if shifted: + accesses = accesses + shifted if accesses else shifted + if accesses: + writes[f] = accesses + + return cls((), rules=is_cross, reads=reads.items(), writes=writes.items()) + @classmethod def _preprocess_args(cls, exprs: Expr | Iterable[Expr], **kwargs) -> tuple[tuple, dict]: + for i in ('reads', 'writes'): + with suppress(KeyError): + kwargs[i] = tuple(kwargs[i]) + return (as_tuple(exprs),), kwargs def __init__(self, exprs: tuple[Expr], - rules: Rule | tuple[Rule] | None = None) -> None: + rules: Rule | tuple[Rule] | None = None, + reads=None, writes=None) -> None: """ A Scope enables data dependence analysis on a totally ordered sequence of expressions. """ self.exprs = exprs + self._reads = dict(reads) if reads is not None else None + self._writes = dict(writes) if writes is not None else None # A set of rules to drive the collection of dependencies self.rules: tuple[Scope.Rule] = as_tuple(rules) # type: ignore[assignment] @@ -876,13 +971,7 @@ def writes_gen(self): Generate all write accesses. """ for i, e in enumerate(self.exprs): - terminals = retrieve_accesses(e.lhs) - if q_routine(e.rhs): - with suppress(AttributeError): - # Everything except: foreign routines, such as `cos` or `sin` etc. - terminals.update(e.rhs.writes) - - for j in terminals: + for j in e.writes: mode = 'WR' if e.is_Reduction else 'W' yield TimedAccess(j, mode, i, e.ispace) @@ -909,8 +998,17 @@ def writes(self): """ Create a mapper from functions to write accesses. """ + if self._writes is not None: + return self._writes + return as_mapper(self.writes_gen(), key=lambda i: i.function) + @cached_property + def writes_tensor(self): + initialized = frozenset(e.lhs.function for e in self.exprs + if not e.is_Reduction and e.is_scalar) + return frozenset(self.writes) - initialized + @memoized_generator def reads_explicit_gen(self): """ @@ -919,11 +1017,7 @@ def reads_explicit_gen(self): expressions. """ for i, e in enumerate(self.exprs): - # Reads - terminals = retrieve_accesses(e.rhs, deep=True) - with suppress(AttributeError): - terminals.update(retrieve_accesses(e.lhs.indices)) - for j in terminals: + for j in e.reads_explicit: mode = 'RR' if j.function is e.lhs.function and e.is_Reduction else 'R' yield TimedAccess(j, mode, i, e.ispace) @@ -932,9 +1026,8 @@ def reads_explicit_gen(self): yield TimedAccess(e.lhs, 'RR', i, e.ispace) # Look up ConditionalDimensions - for v in e.conditionals.values(): - for j in retrieve_accesses(v): - yield TimedAccess(j, 'R', -1, e.ispace) + for j in e.reads_conditional: + yield TimedAccess(j, 'R', -1, e.ispace) @memoized_generator def reads_implicit_gen(self): @@ -1008,21 +1101,22 @@ def reads_smart_gen(self, f): the iteration symbols. """ if isinstance(f, (Function, Temp, TempArray, TBArray)): - for i in chain(self.reads_explicit_gen(), self.reads_synchro_gen()): - if f is i.function: - for j in extrema(i.access): - yield TimedAccess(j, i.mode, i.timestamp, i.ispace) + for i in self.getreads(f): + for j in extrema(i.access): + yield TimedAccess(j, i.mode, i.timestamp, i.ispace) else: - for i in self.reads_gen(): - if f is i.function: - yield i + for i in self.getreads(f): + yield i @cached_property def reads(self): """ Create a mapper from functions to read accesses. """ + if self._reads is not None: + return self._reads + return as_mapper(self.reads_gen(), key=lambda i: i.function) @cached_property @@ -1033,9 +1127,9 @@ def read_only(self): return set(self.reads) - set(self.writes) @cached_property - def initialized(self): - return frozenset(e.lhs.function for e in self.exprs - if not e.is_Reduction and e.is_scalar) + def has_barrier(self): + """True if the Scope contains a fence-like control-flow object.""" + return any(isinstance(e.rhs, (Fence, CriticalRegion)) for e in self.exprs) def getreads(self, function): return as_tuple(self.reads.get(function)) @@ -1095,11 +1189,17 @@ def a_query(self, timestamps=None, modes=None): if a.timestamp in timestamps and a.mode in modes) @memoized_generator - def d_flow_gen(self): - """Generate the flow (or "read-after-write") dependences.""" - for k, v in self.writes.items(): + @normalize_input + def d_flow_gen(self, writes=None): + """ + Generate the flow (or "read-after-write") dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + for k, v in writes.items(): + reads = tuple(self.reads_smart_gen(k)) for w in v: - for r in self.reads_smart_gen(k): + for r in reads: if any(not rule(w, r) for rule in self.rules): continue @@ -1126,11 +1226,17 @@ def d_flow(self): return DependenceGroup(self.d_flow_gen()) @memoized_generator - def d_anti_gen(self, depcls=Dependence): - """Generate the anti (or "write-after-read") dependences.""" - for k, v in self.writes.items(): + @normalize_input + def d_anti_gen(self, depcls=Dependence, writes=None): + """ + Generate the anti (or "write-after-read") dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + for k, v in writes.items(): + reads = tuple(self.reads_smart_gen(k)) for w in v: - for r in self.reads_smart_gen(k): + for r in reads: if any(not rule(r, w) for rule in self.rules): continue @@ -1165,11 +1271,16 @@ def d_anti_logical(self): return DependenceGroup(self.d_anti_gen(depcls=LogicalDependence)) @memoized_generator - def d_output_gen(self): - """Generate the output (or "write-after-write") dependences.""" - for k, v in self.writes.items(): + @normalize_input + def d_output_gen(self, writes=None): + """ + Generate the output (or "write-after-write") dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + for k, v in writes.items(): for w1 in v: - for w2 in self.writes.get(k, []): + for w2 in v: if any(not rule(w2, w1) for rule in self.rules): continue @@ -1193,9 +1304,15 @@ def d_output(self): """Output (or "write-after-write") dependences.""" return DependenceGroup(self.d_output_gen()) - def d_all_gen(self): - """Generate all flow, anti and output dependences.""" - return chain(self.d_flow_gen(), self.d_anti_gen(), self.d_output_gen()) + def d_all_gen(self, writes=None): + """ + Generate all flow, anti and output dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + return chain(self.d_flow_gen(writes=writes), + self.d_anti_gen(writes=writes), + self.d_output_gen(writes=writes)) @cached_property def d_all(self): @@ -1380,24 +1497,6 @@ def is_regular(self): def vinf(entries): return Vector(*(entries + [S.Infinity])) - -def retrieve_accesses(exprs, **kwargs): - """ - Like retrieve_terminals, but ensure that if a ComponentAccess is found, - the ComponentAccess itself is returned, while the wrapped Indexed is discarded. - """ - kwargs['mode'] = 'unique' - - compaccs = search(exprs, ComponentAccess) - if not compaccs: - return retrieve_terminals(exprs, **kwargs) - - subs = {i: Symbol(f'dummy{n}') for n, i in enumerate(compaccs)} - exprs1 = uxreplace(exprs, subs) - - return compaccs | retrieve_terminals(exprs1, **kwargs) - set(subs.values()) - - def disjoint_test(e0, e1, d, it): """ A rudimentary test to check if two accesses `e0` and `e1` along `d` within diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index b8a335b1f4..cd9a5292b2 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -272,31 +272,34 @@ class Guards(frozendict): def get(self, d, v=true): return super().get(d, v) + def _rebuild(self, mapper): + return self if mapper == self else Guards(mapper) + def andg(self, d, guard): m = dict(self) if guard == true: - return Guards(m) + return self try: m[d] = simplify_and(m[d], guard) except KeyError: m[d] = guard - return Guards(m) + return self._rebuild(m) def xandg(self, d, guard): m = dict(self) if guard == true: - return Guards(m) + return self try: m[d] = And(m[d], guard) except KeyError: m[d] = guard - return Guards(m) + return self._rebuild(m) def pairwise_or(self, d, *guards): m = dict(self) @@ -311,17 +314,17 @@ def pairwise_or(self, d, *guards): else: m[d] = g - return Guards(m) + return self._rebuild(m) def impose(self, d, guard): m = dict(self) if guard == true: - return Guards(m) + return self m[d] = guard - return Guards(m) + return self._rebuild(m) def popany(self, dims): m = dict(self) @@ -329,12 +332,12 @@ def popany(self, dims): for d in as_tuple(dims): m.pop(d, None) - return Guards(m) + return self._rebuild(m) def filter(self, key): m = {d: v for d, v in self.items() if key(d)} - return Guards(m) + return self._rebuild(m) def as_map(self, d, cls): if cls not in (Le, Lt, Ge, Gt): diff --git a/devito/ir/support/properties.py b/devito/ir/support/properties.py index 9e787a8b9e..a835bb3f07 100644 --- a/devito/ir/support/properties.py +++ b/devito/ir/support/properties.py @@ -199,19 +199,27 @@ class Properties(frozendict): A mapper {Dimension -> {properties}}. """ + def __init__(self, *args, **kwargs): + mapper = dict(*args, **kwargs) + mapper = {d: frozenset(as_tuple(v)) for d, v in mapper.items()} + super().__init__(mapper) + @property def dimensions(self): return tuple(self) + def _rebuild(self, mapper): + return self if mapper == self else Properties(mapper) + def add(self, dims, properties=None): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | set(as_tuple(properties)) - return Properties(m) + return self._rebuild(m) def filter(self, key): m = {d: v for d, v in self.items() if key(d)} - return Properties(m) + return self._rebuild(m) def drop(self, dims=None, properties=None): if dims is None: @@ -222,7 +230,7 @@ def drop(self, dims=None, properties=None): m.pop(d, None) else: m[d] = self[d] - set(as_tuple(properties)) - return Properties(m) + return self._rebuild(m) def parallelize(self, dims): m = dict(self) @@ -231,13 +239,13 @@ def parallelize(self, dims): v.difference_update({PARALLEL_IF_PVT, PARALLEL_IF_ATOMIC, SEQUENTIAL}) v.add(PARALLEL) m[d] = v - return Properties(m) + return self._rebuild(m) def affine(self, dims): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | {AFFINE} - return Properties(m) + return self._rebuild(m) def sequentialize(self, dims=None): if dims is None: @@ -245,13 +253,13 @@ def sequentialize(self, dims=None): m = dict(self) for d in as_tuple(dims): m[d] = normalize_properties(set(self.get(d, [])), {SEQUENTIAL}) - return Properties(m) + return self._rebuild(m) def prefetchable(self, dims, v=PREFETCHABLE): m = dict(self) for d in as_tuple(dims): m[d] = self.get(d, set()) | {v} - return Properties(m) + return self._rebuild(m) def block(self, dims, kind='default'): if kind == 'default': @@ -263,7 +271,7 @@ def block(self, dims, kind='default'): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | {p} - return Properties(m) + return self._rebuild(m) def inbound(self, dims): return self.add(dims, INBOUND) diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 1c760128b5..7c9f970108 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -8,8 +8,8 @@ from devito.ir.support.utils import maximum, minimum from devito.ir.support.vector import Vector, vmax, vmin from devito.tools import ( - Ordering, Stamp, as_list, as_set, as_tuple, filter_ordered, flatten, frozendict, - is_integer, toposort + CacheInstances, Ordering, Stamp, as_list, as_set, as_tuple, filter_ordered, + cached_hash, flatten, frozendict, is_integer, toposort ) from devito.types import Dimension, ModuloDimension @@ -53,6 +53,7 @@ def __eq__(self, o): is_compatible = __eq__ + @cached_hash def __hash__(self): return hash(self.dim.name) @@ -88,7 +89,7 @@ def negate(self): translate = negate -class NullInterval(AbstractInterval): +class NullInterval(AbstractInterval, CacheInstances): """ A degenerate iterated closed interval on Z. @@ -96,9 +97,14 @@ class NullInterval(AbstractInterval): is_Null = True + @classmethod + def _preprocess_args(cls, dim, stamp=S0): + return (dim, stamp), {} + def __repr__(self): return f"{self.dim}[Null]{self.stamp}" + @cached_hash def __hash__(self): return hash(self.dim) @@ -120,7 +126,7 @@ def switch(self, d): return NullInterval(d, self.stamp) -class Interval(AbstractInterval): +class Interval(AbstractInterval, CacheInstances): """ Interval(dim, lower, upper) @@ -134,6 +140,18 @@ class Interval(AbstractInterval): is_Defined = True + @classmethod + def _preprocess_args(cls, dim, lower=0, upper=0, stamp=S0): + try: + lower = int(lower) + except TypeError: + assert isinstance(lower, Expr) + try: + upper = int(upper) + except TypeError: + assert isinstance(upper, Expr) + return (dim, lower, upper, stamp), {} + def __init__(self, dim, lower=0, upper=0, stamp=S0): super().__init__(dim, stamp) @@ -151,6 +169,7 @@ def __init__(self, dim, lower=0, upper=0, stamp=S0): def __repr__(self): return f"{self.dim}[{self.lower},{self.upper}]{self.stamp}" + @cached_hash def __hash__(self): return hash((self.dim, self.offsets)) @@ -304,12 +323,18 @@ def expand(self): ) -class IntervalGroup(Ordering): +class IntervalGroup(Ordering, CacheInstances): """ A sequence of Intervals equipped with set-like operations. """ + @classmethod + def _preprocess_args(cls, items=None, relations=None, mode='total'): + items = as_tuple(items) + relations = tuple(tuple(i) for i in as_tuple(relations)) + return (items,), {'relations': relations, 'mode': mode} + @classmethod def reorder(cls, items, relations): if not all(isinstance(i, AbstractInterval) for i in items): @@ -335,13 +360,14 @@ def simplify_relations(cls, relations, items, mode): return super().simplify_relations(relations, items, mode) def __eq__(self, o): - return len(self) == len(o) and all(i == j for i, j in zip(self, o, strict=True)) + return isinstance(o, IntervalGroup) and super().__eq__(o) def __contains__(self, d): return any(i.dim is d for i in self) + @cached_hash def __hash__(self): - return hash(tuple(self)) + return hash((tuple(self), self.relations, self.mode)) def __repr__(self): return "IntervalGroup[{}]".format(', '.join([repr(i) for i in self])) @@ -598,6 +624,7 @@ def __eq__(self, other): def __repr__(self): return self._name + @cached_hash def __hash__(self): return hash(self._name) @@ -618,6 +645,11 @@ class IterationInterval(Interval): An Interval associated with metadata. """ + @classmethod + def _preprocess_args(cls, interval, sub_iterators=(), direction=Forward): + sub_iterators = tuple(filter_ordered(as_tuple(sub_iterators))) + return (interval, sub_iterators, direction), {} + def __init__(self, interval, sub_iterators=(), direction=Forward): super().__init__(interval.dim, *interval.offsets, stamp=interval.stamp) self.sub_iterators = sub_iterators @@ -631,6 +663,7 @@ def __eq__(self, other): return False return self.direction is other.direction and super().__eq__(other) + @cached_hash def __hash__(self): return hash((self.dim, self.offsets, self.direction)) @@ -665,9 +698,6 @@ def __repr__(self): def __eq__(self, other): return isinstance(other, Space) and self.intervals == other.intervals - def __hash__(self): - return hash(self.intervals) - def __len__(self): return len(self.intervals) @@ -731,8 +761,9 @@ def __eq__(self, other): self.intervals == other.intervals and self.parts == other.parts) + @cached_hash def __hash__(self): - return hash((super().__hash__(), self.parts)) + return hash((self.intervals, self.parts)) @classmethod def union(cls, *others): @@ -768,8 +799,7 @@ def reset(self): return DataSpace(intervals, parts) - -class IterationSpace(Space): +class IterationSpace(Space, CacheInstances): """ Represent an iteration space as a Space with additional metadata and operations. @@ -785,23 +815,29 @@ class IterationSpace(Space): A mapper from Dimensions in ``intervals`` to IterationDirections. """ - def __init__(self, intervals, sub_iterators=None, directions=None): - super().__init__(intervals) + @classmethod + def _preprocess_args(cls, intervals, sub_iterators=None, directions=None): + if not isinstance(intervals, IntervalGroup): + intervals = IntervalGroup(as_tuple(intervals)) - # Normalize sub-iterators sub_iterators = sub_iterators or {} sub_iterators = {d: tuple(filter_ordered(as_tuple(v))) - for d, v in sub_iterators.items() if d in self.intervals} - sub_iterators.update({i.dim: () for i in self.intervals + for d, v in sub_iterators.items() if d in intervals} + sub_iterators.update({i.dim: () for i in intervals if i.dim not in sub_iterators}) - self._sub_iterators = frozendict(sub_iterators) - # Normalize directions directions = directions or {} - directions = {d: v for d, v in directions.items() if d in self.intervals} - directions.update({i.dim: Any for i in self.intervals + directions = {d: v for d, v in directions.items() if d in intervals} + directions.update({i.dim: Any for i in intervals if i.dim not in directions}) - self._directions = frozendict(directions) + + return (intervals, frozendict(sub_iterators), frozendict(directions)), {} + + def __init__(self, intervals, sub_iterators=None, directions=None): + super().__init__(intervals) + + self._sub_iterators = sub_iterators + self._directions = directions def __repr__(self): ret = ', '.join([f"{repr(i)}{repr(self.directions[i.dim])}" @@ -822,8 +858,9 @@ def __lt__(self, other): """ return len(self.itintervals) < len(other.itintervals) + @cached_hash def __hash__(self): - return hash((super().__hash__(), self.sub_iterators, self.directions)) + return hash((self.intervals, self.sub_iterators, self.directions)) def __contains__(self, d): try: diff --git a/devito/ir/support/utils.py b/devito/ir/support/utils.py index 644bab5d4c..5f2ee39af7 100644 --- a/devito/ir/support/utils.py +++ b/devito/ir/support/utils.py @@ -3,8 +3,8 @@ from itertools import product from devito.finite_differences import IndexDerivative -from devito.symbolics import CallFromPointer, retrieve_indexed, retrieve_terminals, search -from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, flatten, split +from devito.symbolics import retrieve_indexed, search +from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, split from devito.types import ( Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension, TensorMove ) @@ -14,7 +14,6 @@ 'IMask', 'Stencil', 'detect_accesses', - 'detect_io', 'erange', 'extrema', 'maximum', @@ -217,70 +216,6 @@ def detect_accesses(exprs): return mapper -def detect_io(exprs, relax=False): - """ - ``{exprs} -> ({reads}, {writes})`` - - Parameters - ---------- - exprs : expr-like or list of expr-like - The searched expressions. - relax : bool, optional - If False, as by default, collect all Input objects, such as - Constants and Functions. Otherwise, also collect AbstractFunctions. - """ - exprs = as_tuple(exprs) - if relax is False: - rule = lambda i: i.is_Input - else: - rule = lambda i: i.is_Input or i.is_AbstractFunction - - # Don't forget the nasty case with indirections on the LHS: - # >>> u[t, a[x]] = f[x] -> (reads={a, f}, writes={u}) - - roots = [] - for i in exprs: - try: - roots.append(i.rhs) - roots.extend(list(i.lhs.indices)) - roots.extend(list(i.conditionals.values())) - except AttributeError: - # E.g., CallFromPointer - roots.append(i) - - reads = [] - terminals = flatten(retrieve_terminals(i, deep=True) for i in roots) - for i in terminals: - candidates = set(i.free_symbols) - with suppress(AttributeError): - candidates.update({i.function}) - for j in candidates: - try: - if rule(j): - reads.append(j) - except AttributeError: - pass - - writes = [] - for i in exprs: - try: - f = i.lhs.function - except AttributeError: - continue - try: - if rule(f): - writes.append(f) - except AttributeError: - # We only end up here after complex IET transformations which make - # use of composite types - assert isinstance(i.lhs, CallFromPointer) - f = i.lhs.base.function - if rule(f): - writes.append(f) - - return tuple(filter_sorted(reads)), tuple(filter_sorted(writes)) - - def pull_dims(exprs, flag=True): """ Extract all Dimensions from one or more expressions. If `flag=True` diff --git a/devito/operator/operator.py b/devito/operator/operator.py index c0bb6145a6..f29d1d0369 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -32,14 +32,14 @@ from devito.operator.registry import operator_selector from devito.parameters import configuration from devito.passes import ( - Graph, error_mapper, generate_implicit, generate_macros, is_on_device, lower_dtypes, - lower_index_derivatives, minimize_symbols, optimize_pows, unevaluate + Graph, error_mapper, finalize_args, generate_implicit, generate_macros, is_on_device, + lower_dtypes, lower_index_derivatives, minimize_symbols, optimize_pows, unevaluate ) from devito.symbolics import estimate_cost, subs_op_args from devito.tools import ( DAG, CacheInstances, MemoryEstimate, OrderedSet, ReducerMap, Signer, as_mapper, - as_tuple, contains_val, filter_sorted, flatten, frozendict, is_integer, split, - timed_pass, timed_region + as_tuple, contains_val, filter_sorted, flatten, frozendict, is_integer, + memoized_func, split, timed_pass, timed_region ) from devito.types import Buffer, Evaluable, device_layer, disk_layer, host_layer from devito.types.dimension import Thickness @@ -173,7 +173,11 @@ def __new__(cls, expressions, **kwargs): # Lower to a JIT-compilable object with timed_region('op-compile') as r: - op = cls._build(expressions, **kwargs) + try: + op = cls._build(expressions, **kwargs) + finally: + CacheInstances.clear_caches() + memoized_func.clear_build_caches() op._profiler.py_timers.update(r.timings) # Emit info about how long it took to perform the lowering @@ -249,15 +253,12 @@ def _build(cls, expressions, **kwargs): op._state = cls._initialize_state(**kwargs) # Produced by the various compilation passes - op._reads = filter_sorted(flatten(e.reads for e in irs.expressions)) - op._writes = filter_sorted(flatten(e.writes for e in irs.expressions)) + op._reads = filter_sorted(flatten(e.read_functions for e in irs.expressions)) + op._writes = filter_sorted(flatten(e.write_functions for e in irs.expressions)) op._dimensions = set().union(*[e.dimensions for e in irs.expressions]) op._dtype, op._dspace = irs.clusters.meta op._profiler = profiler - # Clear build-scoped instance caches - CacheInstances.clear_caches() - return op def __init__(self, *args, **kwargs): @@ -507,6 +508,9 @@ def _lower_iet(cls, uiet, **kwargs): # Target-independent optimizations minimize_symbols(graph) + # Finalize helper signatures after all IET transformations have settled. + finalize_args(graph) + return graph.root, graph # Read-only properties exposed to the outside world @@ -1389,11 +1393,13 @@ def _physical_deviceid(self): if isinstance(self.platform, Device): # Get the physical device ID (as CUDA_VISIBLE_DEVICES may be set) logical_deviceid = self.get('deviceid', -1) + visible_device_var, visible_devices = get_visible_devices() if logical_deviceid < 0: rank = self.comm.Get_rank() if self.comm != MPI.COMM_NULL else 0 - logical_deviceid = rank - - visible_device_var, visible_devices = get_visible_devices() + if visible_devices is None: + logical_deviceid = rank + else: + logical_deviceid = rank % len(visible_devices) if visible_devices is None: return logical_deviceid elif len(visible_devices) == 1: diff --git a/devito/passes/clusters/__init__.py b/devito/passes/clusters/__init__.py index c41a628e06..e27d2b755d 100644 --- a/devito/passes/clusters/__init__.py +++ b/devito/passes/clusters/__init__.py @@ -3,6 +3,7 @@ from .cse import * # noqa from .aliases import * # noqa from .factorization import * # noqa +from .fusion import * # noqa from .blocking import * # noqa from .asynchrony import * # noqa from .implicit import * # noqa diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 194d26523d..b8c4108b83 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -5,8 +5,8 @@ from devito.finite_differences import IndexDerivative, Weights from devito.ir import Backward, Forward, Interval, IterationSpace, Queue -from devito.passes.clusters.misc import fuse -from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace +from devito.passes.clusters.fusion import fuse +from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, search, uxreplace from devito.tools import infer_dtype, timed_pass from devito.types import Eq, Inc, Indexed, Symbol @@ -15,13 +15,16 @@ @timed_pass() def lower_index_derivatives(clusters, mode=None, **kwargs): + max_depth = _max_index_derivative_depth(clusters) clusters, weights, mapper = _lower_index_derivatives(clusters, **kwargs) if not weights: return clusters if mode != 'noop': - clusters = fuse(clusters, toposort='maximal') + for _ in range(max_depth): + clusters = fuse(clusters, toposort='nofuse') + clusters = fuse(clusters, toposort=False) # At this point we can detect redundancies induced by inner derivatives that # previously were just not detectable via e.g. plain CSE. For example, if @@ -258,3 +261,16 @@ def callback(self, clusters, prefix, subs0=None, seen=None): seen.update(processed) return processed + + +# *** Utils + + +def _max_index_derivative_depth(clusters): + max_depth = 0 + + for c in clusters: + for i in search(c.exprs, IndexDerivative): + max_depth = max(max_depth, i.depth) + + return max_depth diff --git a/devito/passes/clusters/fusion.py b/devito/passes/clusters/fusion.py new file mode 100644 index 0000000000..c8b6ac2cb1 --- /dev/null +++ b/devito/passes/clusters/fusion.py @@ -0,0 +1,335 @@ +from collections import Counter, defaultdict +from itertools import groupby + +from devito.finite_differences import IndexDerivative +from devito.ir.clusters import Cluster, ClusterGroup, Queue +from devito.ir.support import ( + InitArray, PrefetchUpdate, ReleaseLock, Scope, SyncArray, WaitLock, WithLock +) +from devito.symbolics import search +from devito.tools import ( + DAG, as_tuple, flatten, frozendict, memoized_func, memoized_meth, timed_pass +) + +__all__ = ['fuse'] + + +# No hazard: fusion may proceed. +NO_HAZARD = None +# Ordering hazard: preserve program order and forbid fusion. +EDGE = 'edge' +# Prefix anti-dependence: break the execution flow across the pair. +BREAK = 'break' + + +@memoized_func(scope='build') +def _fusion_hazards(scope0, scope1, prefix): + scope = Scope.from_scopes(scope0, scope1) + if scope is None: + return NO_HAZARD + + anti = False + for i in scope.d_anti_gen(): + if i.cause & prefix: + return BREAK + anti = True + + if anti: + return EDGE + + for i in scope.d_flow_gen(): + if not (i.cause & prefix): + return EDGE + + for _ in scope.d_output_gen(): + return EDGE + + return NO_HAZARD + + +class Fusion(Queue): + + """ + Fuse Clusters with compatible IterationSpace. + """ + + _q_guards_in_key = True + _q_syncs_in_key = True + + def __init__(self, toposort, options=None): + options = options or {} + + self.toposort = toposort + self.fusetasks = options.get('fuse-tasks', False) + + super().__init__() + + def process(self, clusters): + cgroups = [ClusterGroup(c, c.ispace) for c in clusters] + cgroups = self._process_fdta(cgroups, 1) + clusters = ClusterGroup.concatenate(*cgroups) + return clusters + + def callback(self, cgroups, prefix): + # Toposort to maximize fusion + if self.toposort: + clusters = self._toposort(cgroups, prefix) + if self.toposort == 'nofuse': + return [clusters] + else: + clusters = ClusterGroup(cgroups) + + # Fusion + processed = [] + for _, group in groupby(clusters, key=self._key): + g = list(group) + + for maybe_fusible in self._apply_heuristics(g): + try: + # Perform fusion + processed.append(Cluster.from_clusters(*maybe_fusible)) + except ValueError: + # We end up here if, for example, some Clusters have same + # iteration Dimensions but different (partial) orderings + processed.extend(maybe_fusible) + + # Maximize effectiveness of topo-sorting at next stage by only + # grouping together Clusters characterized by data dependencies + if self.toposort and prefix: + dag = self._build_dag(processed, prefix) + mapper = dag.connected_components(enumerated=True) + groups = groupby(processed, key=mapper.get) + return [ClusterGroup(tuple(g), prefix) for _, g in groups] + else: + return [ClusterGroup(processed, prefix)] + + class Key(tuple): + + """ + A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that + two Clusters (ClusterGroups) are topo-fusible if and only if their Key is + identical. + + A Key contains elements that can logically be split into two groups -- the + `strict` and the `weak` components of the Key. Two Clusters (ClusterGroups) + having same `strict` but different `weak` parts are, by definition, not + fusible; however, since at least their `strict` parts match, they can at + least be topologically reordered. + """ + + def __new__(cls, itintervals, guards, syncs, weak): + strict = [itintervals, guards, syncs] + obj = super().__new__(cls, strict + weak) + + obj.itintervals = itintervals + obj.guards = guards + obj.syncs = syncs + + obj.strict = tuple(strict) + obj.weak = tuple(weak) + + return obj + + @memoized_meth + def _key(self, c): + itintervals = frozenset(c.ispace.itintervals) + guards = c.guards if any(c.guards) else None + + # We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and + # WithLocks, but not with any other SyncOps + mapper = defaultdict(set) + for d, v in c.syncs.items(): + for s in v: + if isinstance(s, PrefetchUpdate): + continue + elif isinstance(s, WaitLock) and not self.fusetasks: + # NOTE: A mix of Clusters w/ and w/o WaitLocks can safely + # be fused, as in the worst case scenario the WaitLocks + # get "hoisted" above the first Cluster in the sequence + continue + elif isinstance(s, (InitArray, SyncArray, WaitLock, ReleaseLock)): + mapper[d].add(type(s)) + elif isinstance(s, WithLock) and self.fusetasks: + # NOTE: Different WithLocks aren't fused unless the user + # explicitly asks for it + mapper[d].add(type(s)) + else: + mapper[d].add(s) + if d in mapper: + mapper[d] = frozenset(mapper[d]) + syncs = frozendict(mapper) + + # Clusters representing HaloTouches should get merged, if possible + weak = [c.is_halo_touch] + + # If there are writes to thread-shared object, make it part of the key. + # This will promote fusion of non-adjacent Clusters writing to (some + # form of) shared memory, which in turn will minimize the number of + # necessary barriers. Same story for reads from thread-shared objects + weak.extend([ + any(f._mem_shared for f in c.scope.writes), + any(f._mem_shared for f in c.scope.reads) + ]) + weak.append(c.properties.is_core_init()) + + # Prefetchable Clusters should get merged, if possible + weak.append(c.is_glb_load_to_mem_shared) + + # Promoting adjacency of IndexDerivatives will maximize their reuse + weak.append(any(search(c.exprs, IndexDerivative))) + + # Promote adjacency of Clusters with same guard + weak.append(c.guards) + + key = self.Key(itintervals, guards, syncs, weak) + + return key + + def _apply_heuristics(self, clusters): + # We know at this point that `clusters` are potentially fusible since + # they have same `_key`, but should we actually fuse them? In most cases + # yes, but there are exceptions... + + # 1) Consider the following scenario with three Clusters: + # c0[no syncs] + # c1[WaitLock] + # c2[no syncs] + # Then we return two groups [[c0], [c1, c2]] rather than a single group + # [[c0, c1, c2]] because this way c0 can be computed without having to + # wait on a lock for a longer period + processed = [] + + group = [] + flag = False # True -> need to dump before creating a new group + + def dump(): + processed.append(tuple(group)) + group[:] = [] + + for c in clusters: + if any(isinstance(i, WaitLock) for i in flatten(c.syncs.values())): + if flag: + dump() + flag = False + else: + flag = True + group.append(c) + dump() + + # 2) Don't group HaloTouch's + groups, processed = processed, [] + for group in groups: + for flag, minigroup in groupby(group, key=lambda c: c.is_wild): + if flag: + processed.extend([(c,) for c in minigroup]) + else: + processed.append(tuple(minigroup)) + + return processed + + def _toposort(self, cgroups, prefix): + # Are there any ClusterGroups that could potentially be topologically + # reordered? If not, do not waste time + counter = Counter(self._key(cg).strict for cg in cgroups) + if len(counter.most_common()) == 1 or \ + not any(v > 1 for it, v in counter.most_common()): + return ClusterGroup(cgroups, prefix) + + dag = self._build_dag(cgroups, prefix) + + def choose_element(queue, scheduled): + if not scheduled: + return queue.pop() + + k = self._key(scheduled[-1]) + m = {i: self._key(i) for i in queue} + + # Process the `strict` part of the key + candidates = [i for i in queue if m[i].itintervals == k.itintervals] + + compatible = [i for i in candidates if m[i].guards == k.guards] + candidates = compatible or candidates + + compatible = [i for i in candidates if m[i].syncs == k.syncs] + candidates = compatible or candidates + + # Process the `weak` part of the key + for i in range(len(k.weak), -1, -1): + choosable = [e for e in candidates if m[e].weak[:i] == k.weak[:i]] + try: + # Ensure stability + e = min(choosable, key=lambda i: cgroups.index(i)) + except ValueError: + continue + queue.remove(e) + return e + + # Fallback + e = min(queue, key=lambda i: cgroups.index(i)) + queue.remove(e) + return e + + return ClusterGroup(dag.topological_sort(choose_element), prefix) + + def _build_dag(self, cgroups, prefix): + """ + A DAG representing the data dependences across the ClusterGroups within + a given scope. + """ + prefix = frozenset(i.dim for i in as_tuple(prefix)) + + dag = DAG(nodes=cgroups) + for n, cg0 in enumerate(cgroups): + # Track whether there is any fence between `cg0` and the current `cg1`. + fenced = cg0.scope.has_barrier + + for n1, cg1 in enumerate(cgroups[n+1:], start=n+1): + fenced = fenced or cg1.scope.has_barrier + + hazard = _fusion_hazards(cg0.scope, cg1.scope, prefix) + if not (hazard or fenced): + continue + + # Anti-dependences along `prefix` break the execution flow + # (intuitively, "the loop nests are to be kept separated") + # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` + # * All ClusterGroups after `cg1` cannot precede `cg1` + if hazard == BREAK: + for cg2 in cgroups[n:n1]: + dag.add_edge(cg2, cg1) + for cg2 in cgroups[n1+1:]: + dag.add_edge(cg1, cg2) + break + elif fenced or hazard == EDGE: + # Any anti- and iaw-dependences impose that `cg1` follows `cg0` + # and forbid any sort of fusion. Fences have the same effect + dag.add_edge(cg0, cg1) + + return dag + + +@timed_pass() +def fuse(clusters, toposort=False, options=None): + """ + Clusters fusion. + + If `toposort=True`, then the Clusters are reordered to maximize the likelihood + of fusion; the new ordering is computed such that all data dependencies are + honored. + + If `toposort='maximal'`, then `toposort` is performed, iteratively, multiple + times to actually maximize Clusters fusion. Hence, this is more aggressive than + `toposort=True`. + """ + if toposort != 'maximal': + return Fusion(toposort, options).process(clusters) + + nxt = clusters + while True: + nxt = fuse(clusters, toposort='nofuse', options=options) + if all(c0 is c1 for c0, c1 in zip(clusters, nxt, strict=True)): + break + clusters = nxt + clusters = fuse(clusters, toposort=False, options=options) + + return clusters diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 494ebe7490..0cd255ce97 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -1,18 +1,15 @@ -from collections import Counter, defaultdict from itertools import groupby, product -from devito.finite_differences import IndexDerivative -from devito.ir.clusters import Cluster, ClusterGroup, Queue, cluster_pass +from devito.ir.clusters import Queue, cluster_pass from devito.ir.support import ( - SEPARABLE, SEQUENTIAL, InitArray, PrefetchUpdate, ReleaseLock, Scope, SyncArray, - WaitLock, WithLock + SEPARABLE, SEQUENTIAL, Scope ) from devito.passes.clusters.utils import in_critical_region -from devito.symbolics import pow_to_mul, search -from devito.tools import DAG, Stamp, as_tuple, flatten, frozendict, timed_pass +from devito.symbolics import pow_to_mul +from devito.tools import Stamp, flatten, frozendict, timed_pass from devito.types import Hyperplane -__all__ = ['Lift', 'fission', 'fuse', 'optimize_hyperplanes', 'optimize_pows'] +__all__ = ['Lift', 'fission', 'optimize_hyperplanes', 'optimize_pows'] class Lift(Queue): @@ -107,309 +104,12 @@ def callback(self, clusters, prefix): return lifted + processed -class Fusion(Queue): - - """ - Fuse Clusters with compatible IterationSpace. - """ - - _q_guards_in_key = True - _q_syncs_in_key = True - - def __init__(self, toposort, options=None): - options = options or {} - - self.toposort = toposort - self.fusetasks = options.get('fuse-tasks', False) - - super().__init__() - - def process(self, clusters): - cgroups = [ClusterGroup(c, c.ispace) for c in clusters] - cgroups = self._process_fdta(cgroups, 1) - clusters = ClusterGroup.concatenate(*cgroups) - return clusters - - def callback(self, cgroups, prefix): - # Toposort to maximize fusion - if self.toposort: - clusters = self._toposort(cgroups, prefix) - if self.toposort == 'nofuse': - return [clusters] - else: - clusters = ClusterGroup(cgroups) - - # Fusion - processed = [] - for _, group in groupby(clusters, key=self._key): - g = list(group) - - for maybe_fusible in self._apply_heuristics(g): - try: - # Perform fusion - processed.append(Cluster.from_clusters(*maybe_fusible)) - except ValueError: - # We end up here if, for example, some Clusters have same - # iteration Dimensions but different (partial) orderings - processed.extend(maybe_fusible) - - # Maximize effectiveness of topo-sorting at next stage by only - # grouping together Clusters characterized by data dependencies - if self.toposort and prefix: - dag = self._build_dag(processed, prefix) - mapper = dag.connected_components(enumerated=True) - groups = groupby(processed, key=mapper.get) - return [ClusterGroup(tuple(g), prefix) for _, g in groups] - else: - return [ClusterGroup(processed, prefix)] - - class Key(tuple): - - """ - A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that - two Clusters (ClusterGroups) are topo-fusible if and only if their Key is - identical. - - A Key contains elements that can logically be split into two groups -- the - `strict` and the `weak` components of the Key. Two Clusters (ClusterGroups) - having same `strict` but different `weak` parts are, by definition, not - fusible; however, since at least their `strict` parts match, they can at - least be topologically reordered. - """ - - def __new__(cls, itintervals, guards, syncs, weak): - strict = [itintervals, guards, syncs] - obj = super().__new__(cls, strict + weak) - - obj.itintervals = itintervals - obj.guards = guards - obj.syncs = syncs - - obj.strict = tuple(strict) - obj.weak = tuple(weak) - - return obj - - def _key(self, c): - itintervals = frozenset(c.ispace.itintervals) - guards = c.guards if any(c.guards) else None - - # We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and - # WithLocks, but not with any other SyncOps - mapper = defaultdict(set) - for d, v in c.syncs.items(): - for s in v: - if isinstance(s, PrefetchUpdate): - continue - elif isinstance(s, WaitLock) and not self.fusetasks: - # NOTE: A mix of Clusters w/ and w/o WaitLocks can safely - # be fused, as in the worst case scenario the WaitLocks - # get "hoisted" above the first Cluster in the sequence - continue - elif isinstance(s, (InitArray, SyncArray, WaitLock, ReleaseLock)): - mapper[d].add(type(s)) - elif isinstance(s, WithLock) and self.fusetasks: - # NOTE: Different WithLocks aren't fused unless the user - # explicitly asks for it - mapper[d].add(type(s)) - else: - mapper[d].add(s) - if d in mapper: - mapper[d] = frozenset(mapper[d]) - syncs = frozendict(mapper) - - # Clusters representing HaloTouches should get merged, if possible - weak = [c.is_halo_touch] - - # If there are writes to thread-shared object, make it part of the key. - # This will promote fusion of non-adjacent Clusters writing to (some - # form of) shared memory, which in turn will minimize the number of - # necessary barriers. Same story for reads from thread-shared objects - weak.extend([ - any(f._mem_shared for f in c.scope.writes), - any(f._mem_shared for f in c.scope.reads) - ]) - weak.append(c.properties.is_core_init()) - - # Prefetchable Clusters should get merged, if possible - weak.append(c.is_glb_load_to_mem_shared) - - # Promoting adjacency of IndexDerivatives will maximize their reuse - weak.append(any(search(c.exprs, IndexDerivative))) - - # Promote adjacency of Clusters with same guard - weak.append(c.guards) - - key = self.Key(itintervals, guards, syncs, weak) - - return key - - def _apply_heuristics(self, clusters): - # We know at this point that `clusters` are potentially fusible since - # they have same `_key`, but should we actually fuse them? In most cases - # yes, but there are exceptions... - - # 1) Consider the following scenario with three Clusters: - # c0[no syncs] - # c1[WaitLock] - # c2[no syncs] - # Then we return two groups [[c0], [c1, c2]] rather than a single group - # [[c0, c1, c2]] because this way c0 can be computed without having to - # wait on a lock for a longer period - processed = [] - - group = [] - flag = False # True -> need to dump before creating a new group - - def dump(): - processed.append(tuple(group)) - group[:] = [] - - for c in clusters: - if any(isinstance(i, WaitLock) for i in flatten(c.syncs.values())): - if flag: - dump() - flag = False - else: - flag = True - group.append(c) - dump() - - # 2) Don't group HaloTouch's - - groups, processed = processed, [] - for group in groups: - for flag, minigroup in groupby(group, key=lambda c: c.is_wild): - if flag: - processed.extend([(c,) for c in minigroup]) - else: - processed.append(tuple(minigroup)) - - return processed - - def _toposort(self, cgroups, prefix): - # Are there any ClusterGroups that could potentially be topologically - # reordered? If not, do not waste time - counter = Counter(self._key(cg).strict for cg in cgroups) - if len(counter.most_common()) == 1 or \ - not any(v > 1 for it, v in counter.most_common()): - return ClusterGroup(cgroups, prefix) - - dag = self._build_dag(cgroups, prefix) - - def choose_element(queue, scheduled): - if not scheduled: - return queue.pop() - - k = self._key(scheduled[-1]) - m = {i: self._key(i) for i in queue} - - # Process the `strict` part of the key - candidates = [i for i in queue if m[i].itintervals == k.itintervals] - - compatible = [i for i in candidates if m[i].guards == k.guards] - candidates = compatible or candidates - - compatible = [i for i in candidates if m[i].syncs == k.syncs] - candidates = compatible or candidates - - # Process the `weak` part of the key - for i in range(len(k.weak), -1, -1): - choosable = [e for e in candidates if m[e].weak[:i] == k.weak[:i]] - try: - # Ensure stability - e = min(choosable, key=lambda i: cgroups.index(i)) - except ValueError: - continue - queue.remove(e) - return e - - # Fallback - e = min(queue, key=lambda i: cgroups.index(i)) - queue.remove(e) - return e - - return ClusterGroup(dag.topological_sort(choose_element), prefix) - - def _build_dag(self, cgroups, prefix): - """ - A DAG representing the data dependences across the ClusterGroups within - a given scope. - """ - prefix = {i.dim for i in as_tuple(prefix)} - - dag = DAG(nodes=cgroups) - for n, cg0 in enumerate(cgroups): - - def is_cross(source, sink): - # True if a cross-ClusterGroup dependence, False otherwise - t0 = source.timestamp - t1 = sink.timestamp - v = len(cg0.exprs) # noqa: B023 - return t0 < v <= t1 or t1 < v <= t0 - - for n1, cg1 in enumerate(cgroups[n+1:], start=n+1): - - # A Scope to compute all cross-ClusterGroup anti-dependences - scope = Scope(exprs=cg0.exprs + cg1.exprs, rules=is_cross) - - # Anti-dependences along `prefix` break the execution flow - # (intuitively, "the loop nests are to be kept separated") - # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` - # * All ClusterGroups after `cg1` cannot precede `cg1` - if any(i.cause & prefix for i in scope.d_anti_gen()): - for cg2 in cgroups[n:cgroups.index(cg1)]: - dag.add_edge(cg2, cg1) - for cg2 in cgroups[cgroups.index(cg1)+1:]: - dag.add_edge(cg1, cg2) - break - - # Any anti- and iaw-dependences impose that `cg1` follows `cg0` - # and forbid any sort of fusion. Fences have the same effect - elif ( - any(scope.d_anti_gen()) or - any(i.is_iaw for i in scope.d_output_gen()) or - any(c.is_fence for c in flatten(cgroups[n:n1+1])) - ) or any(not (i.cause and i.cause & prefix) for i in scope.d_flow_gen()) \ - or any(scope.d_output_gen()): - dag.add_edge(cg0, cg1) - - return dag - - -@timed_pass() -def fuse(clusters, toposort=False, options=None): - """ - Clusters fusion. - - If `toposort=True`, then the Clusters are reordered to maximize the likelihood - of fusion; the new ordering is computed such that all data dependencies are - honored. - - If `toposort='maximal'`, then `toposort` is performed, iteratively, multiple - times to actually maximize Clusters fusion. Hence, this is more aggressive than - `toposort=True`. - """ - if toposort != 'maximal': - return Fusion(toposort, options).process(clusters) - - nxt = clusters - while True: - nxt = fuse(clusters, toposort='nofuse', options=options) - if all(c0 is c1 for c0, c1 in zip(clusters, nxt, strict=True)): - break - clusters = nxt - clusters = fuse(clusters, toposort=False, options=options) - - return clusters - - @cluster_pass(mode='all') def optimize_pows(cluster, *args): """ Convert integer powers into Muls, such as ``a**2 => a*a``. """ - return cluster.rebuild(exprs=[pow_to_mul(e) for e in cluster.exprs]) + return cluster.rebuild(exprs=pow_to_mul(cluster.exprs)) class Fission(Queue): diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 2f1cce8f10..9fc20bfcee 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -444,7 +444,7 @@ def _inject_definitions(self, iet, storage): return processed, flatten(efuncs) - @iet_pass + @iet_pass(updates_args=True) def place_definitions(self, iet, globs=None, **kwargs): """ Create a new IET where all symbols have been declared, allocated, and diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 9b936fba76..3655e4a413 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -15,7 +15,10 @@ from devito.mpi.routines import Gather, HaloUpdate, HaloWait, MPIMsg, Scatter from devito.passes import needs_transfer from devito.symbolics import FieldFromComposite, FieldFromPointer, IndexedPointer, search -from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass +from devito.tools import ( + DAG, as_hashable, as_tuple, filter_ordered, memoized_func, sorted_priority, + timed_pass +) from devito.types import ( Array, Auto, Bundle, ComponentAccess, CompositeObject, FunctionMap, IncrDimension, Indirection, ModuloDimension, NPThreads, NThreadsBase, Pointer, SharedData, Symbol, @@ -25,7 +28,7 @@ from devito.types.dense import DiscreteFunction from devito.types.dimension import AbstractIncrDimension, BlockDimension -__all__ = ['Graph', 'iet_pass', 'iet_visit'] +__all__ = ['Graph', 'finalize_args', 'iet_pass', 'iet_visit'] class Byproduct: @@ -102,7 +105,7 @@ def sync_mapper(self): A mapper {Iteration -> SyncSpot} describing the Iterations, if any, living an asynchronous region, across all Callables in the Graph. """ - dag = create_call_graph(self.root.name, self.efuncs) + dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) mapper = MapNodes(SyncSpot, (Iteration, Call)).visit(self.root) @@ -125,16 +128,29 @@ def sync_mapper(self): return found - def apply(self, func, **kwargs): + def apply(self, func, *, updates_args=False, **kwargs): """ - Apply `func` to all nodes in the Graph. This changes the state of the Graph. + Apply ``func`` to all nodes in the Graph. + + Parameters + ---------- + updates_args : bool, optional + If True, reconcile Callable parameters and Call arguments before + the graph walk and after each changed node. This is only needed by + passes whose transformation logic depends on already-updated + signatures while the pass is still running. Otherwise, argument + reconciliation is intentionally deferred to ``finalize_args``. """ - dag = create_call_graph(self.root.name, self.efuncs) + if updates_args: + _update_args(self) + + dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) # Apply `func` efuncs = dict(self.efuncs) for i in dag.topological_sort(): efunc, metadata = func(efuncs[i], **kwargs) + new_efuncs = metadata.get('efuncs', []) self.includes.extend(as_tuple(metadata.get('includes'))) self.headers.extend(as_tuple(metadata.get('headers'))) @@ -151,17 +167,14 @@ def apply(self, func, **kwargs): except KeyError: pass - if efunc is efuncs[i]: + if efunc is efuncs[i] and not new_efuncs: continue - new_efuncs = metadata.get('efuncs', []) - efuncs[i] = efunc efuncs.update(dict([(i.name, i) for i in new_efuncs])) - # Update the parameters / arguments lists since `func` may have - # introduced or removed objects - efuncs = update_args(efunc, efuncs, dag) + if updates_args: + efuncs = _update_args_efunc(efunc, efuncs, dag) # Minimize code size if len(efuncs) > len(self.efuncs): @@ -184,7 +197,7 @@ def visit(self, func, **kwargs): from nodes to info. Unlike `apply`, `visit` does not change the state of the Graph. """ - dag = create_call_graph(self.root.name, self.efuncs) + dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) toposort = dag.topological_sort() mapper = dict([(i, func(self.efuncs[i], **kwargs)) for i in toposort]) @@ -206,13 +219,46 @@ def filter(self, key): ) -def iet_pass(func): +@timed_pass(name='finalize_args') +def finalize_args(graph): + """ + Finalize Callable parameter lists and Call argument lists across ``graph``. + + IET passes may temporarily leave helper signatures stale while introducing + or eliminating symbols. This pass reconciles the whole call graph once, + after lowering has settled. + """ + _update_args(graph) + + +def _update_args(graph): + dag = create_call_graph(graph.root.name, as_hashable(graph.efuncs)) + + efuncs = graph.efuncs + for i in dag.topological_sort(): + efuncs = _update_args_efunc(efuncs[i], efuncs, dag) + + graph.efuncs = efuncs + + +def iet_pass(func=None, *, updates_args=False): + """ + Decorate an IET pass. + + ``updates_args=True`` is an opt-in for passes that must observe up-to-date + Callable/Call signatures before and during their own graph walk. Most + passes should leave it False and rely on ``finalize_args`` at the end of + IET lowering. + """ + if func is None: + return partial(iet_pass, updates_args=updates_args) + if isinstance(func, tuple): assert len(func) == 2 and func[0] is iet_visit call = lambda graph: graph.visit func = func[1] else: - call = lambda graph: graph.apply + call = lambda graph: partial(graph.apply, updates_args=updates_args) @wraps(func) def wrapper(*args, **kwargs): @@ -231,6 +277,7 @@ def wrapper(*args, **kwargs): # Instance method case self, graph = args return maybe_timed(call(graph), func.__name__)(partial(func, self), **kwargs) + return wrapper @@ -238,11 +285,14 @@ def iet_visit(func): return iet_pass((iet_visit, func)) +@memoized_func(scope='build') def create_call_graph(root, efuncs): """ Create a Call graph -- a Direct Acyclic Graph with edges from callees to callers. """ + efuncs = dict(efuncs) + dag = DAG(nodes=[root]) queue = [root] @@ -438,7 +488,7 @@ def reuse_efuncs(root, efuncs, sregistry=None): # assuming that `bar0` and `bar1` are compatible, we first process the # `bar`'s to obtain `[foo0(u(x)): bar0(u), foo1(u(x)): bar0(u)]`, # and finally `foo0(u(x)): bar0(u)` - dag = create_call_graph(root.name, efuncs) + dag = create_call_graph(root.name, as_hashable(efuncs)) mapper = {} for i in dag.topological_sort(): @@ -480,6 +530,7 @@ def reuse_efuncs(root, efuncs, sregistry=None): return retval +@memoized_func(scope='build') def abstract_efunc(efunc): """ Abstract `efunc` applying a set of rules: @@ -492,7 +543,7 @@ def abstract_efunc(efunc): """ functions = FindSymbols('basics|symbolics|dimensions').visit(efunc) - mapper = abstract_objects(functions) + mapper = abstract_objects(tuple(functions)) efunc = Uxreplace(mapper).visit(efunc) efunc = efunc._rebuild(name='foo') @@ -500,7 +551,8 @@ def abstract_efunc(efunc): return efunc -def abstract_objects(objects0, sregistry=None): +@memoized_func(scope='build') +def abstract_objects(objects0): """ Proxy for `abstract_object`. """ @@ -519,7 +571,7 @@ def abstract_objects(objects0, sregistry=None): # Build abstraction mappings mapper = {} - sregistry = sregistry or SymbolRegistry() + sregistry = SymbolRegistry() for i in objects: abstract_object(i, mapper, sregistry) @@ -690,7 +742,7 @@ def _(i, mapper, sregistry): mapper[i] = i._rebuild(name=sregistry.make_name(prefix='nthreads')) -def update_args(root, efuncs, dag): +def _update_args_efunc(root, efuncs, dag): """ Re-derive the parameters of `root` and apply the changes in cascade through the `efuncs`. @@ -780,6 +832,14 @@ def _filter(v, efunc=None): mapper = {c: c._rebuild(arguments=_filter(c.arguments)) for c in FindNodes(Call).visit(efuncs[n]) if c.name == root.name} - efuncs[n] = Transformer(mapper).visit(efuncs[n]) + if not mapper: + continue + + efunc = Transformer(mapper).visit(efuncs[n]) + if efunc is efuncs[n]: + continue + + efuncs[n] = efunc + efuncs = _update_args_efunc(efunc, efuncs, dag) return efuncs diff --git a/devito/passes/iet/linearization.py b/devito/passes/iet/linearization.py index aca2485444..78ec84987c 100644 --- a/devito/passes/iet/linearization.py +++ b/devito/passes/iet/linearization.py @@ -46,7 +46,7 @@ def linearize(graph, **kwargs): linearization(graph, key=key, tracker=tracker, **kwargs) -@iet_pass +@iet_pass(updates_args=True) def linearization(iet, key=None, tracker=None, **kwargs): """ Carry out the actual work of `linearize`. diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 3a3d354905..4e2c9a32ad 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -285,18 +285,17 @@ def _mark_overlappable(iet): # Comp/comm overlaps is legal only if the OWNED regions can grow # arbitrarily, which means all of the dependencies must be carried # along a non-halo Dimension - for dep in scope.d_all_gen(): - if dep.function in hs.functions: - cause = dep.cause & hs.dimensions - if any(dep.distance_mapper[d] is S.Infinity for d in cause): - # E.g., dependencies across PARALLEL iterations - # for x - # for y - # ... = ... f[x, y-1] ... - # for y - # f[x, y] = ... - test = False - break + for dep in scope.d_all_gen(writes=hs.functions): + cause = dep.cause & hs.dimensions + if any(dep.distance_mapper[d] is S.Infinity for d in cause): + # E.g., dependencies across PARALLEL iterations + # for x + # for y + # ... = ... f[x, y-1] ... + # for y + # f[x, y] = ... + test = False + break else: test = True @@ -493,7 +492,7 @@ def rule1(dep, loc_indices): for d, v in loc_indices.items()) for f, v in hsf.fmapper.items(): - for dep in scope.d_flow.project(f): + for dep in scope.d_flow_gen(writes=f): if not rule0(dep) and not rule1(dep, v.loc_indices): return False diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 57d9314e16..85ddfcaeab 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -64,6 +64,9 @@ def uxreplace(expr, rule): Finally, `uxreplace` supports Reconstructable objects, that is, it searches for replacement opportunities inside the Reconstructable's `__rkwargs__`. """ + if not rule: + return expr + return _uxreplace(expr, rule)[0] @@ -129,13 +132,15 @@ def _(iterable, rule): ax, flag = _uxreplace(a, rule) ret.append(ax) changed |= flag - return iterable.__class__(ret), changed + return (iterable.__class__(ret), True) if changed else (iterable, False) @_uxreplace_dispatch.register(EnrichedTuple) def _(iterable, rule): retval, changed = _uxreplace_dispatch(tuple(iterable), rule) - return iterable.__class__(*retval, getters=iterable.getters), changed + if changed: + return iterable.__class__(*retval, getters=iterable.getters), True + return iterable, False @_uxreplace_dispatch.register(dict) @@ -146,7 +151,7 @@ def _(mapper, rule): vx, flag = _uxreplace_dispatch(v, rule) ret[k] = vx changed |= flag - return ret, changed + return (ret, True) if changed else (mapper, False) @singledispatch @@ -282,10 +287,18 @@ def subs_if_composite(expr, subs): Indexed"). Instead, if `subs` consists of just "primitive" expressions, then resort to the much faster `uxreplace`. """ - if all(isinstance(i, (Indexed, IndexDerivative)) for i in subs): + if not subs: + return expr + + if type(expr) is tuple: + return reuse_if_untouched(expr, (subs_if_composite(e, subs) for e in expr)) + elif type(expr) is list: + return reuse_if_untouched(expr, [subs_if_composite(e, subs) for e in expr]) + elif all(isinstance(i, (Indexed, IndexDerivative)) for i in subs): return uxreplace(expr, subs) else: - return expr.subs(subs) + processed = expr.subs(subs) + return expr if processed == expr else processed def xreplace_indices(exprs, mapper, key=None): @@ -304,14 +317,25 @@ def xreplace_indices(exprs, mapper, key=None): callable, apply the replacement to a symbol S if and only if ``key(S)`` gives True. """ - handle = flatten(retrieve_indexed(i) for i in as_tuple(exprs)) + exprs0 = as_tuple(exprs) + + handle = flatten(retrieve_indexed(i) for i in exprs0) if isinstance(key, Iterable): handle = [i for i in handle if i.base.label in key] elif callable(key): handle = [i for i in handle if key(i)] - mapper = dict(zip(handle, [i.xreplace(mapper) for i in handle], strict=True)) - replaced = [uxreplace(i, mapper) for i in as_tuple(exprs)] - return replaced if isinstance(exprs, Iterable) else replaced[0] + mapper = {i: v for i in handle if (v := i.xreplace(mapper)) != i} + if not mapper: + return exprs + + replaced = [uxreplace(i, mapper) for i in exprs0] + + if isinstance(exprs, Iterable): + if len(replaced) == len(exprs0) and all(i is j for i, j in zip(replaced, exprs0)): + return exprs + return replaced + else: + return replaced[0] def _eval_numbers(expr, args): @@ -344,7 +368,9 @@ def flatten_args(args, op, ignore=None): def pow_to_mul(expr): - if q_leaf(expr) or isinstance(expr, Basic): + if type(expr) in (tuple, list): + return reuse_if_untouched(expr, (pow_to_mul(i) for i in expr)) + elif q_leaf(expr) or isinstance(expr, Basic): return expr elif expr.is_Pow: base, exp = expr.as_base_exp() @@ -359,7 +385,7 @@ def pow_to_mul(expr): elif (int(exp) - exp != 0): # Fractional powers also remain untouched, # but at least we traverse the base looking for other Pows - return expr.func(pow_to_mul(base), exp, evaluate=False) + return reuse_if_untouched(expr, (pow_to_mul(base), exp), evaluate=False) elif exp > 0: return UnevalMul(*[pow_to_mul(base)]*int(exp), evaluate=False) elif exp < 0: @@ -383,7 +409,7 @@ def pow_to_mul(expr): except ValueError: pass - return expr.func(*args, evaluate=False) + return reuse_if_untouched(expr, args, evaluate=False) def indexify(expr): @@ -429,10 +455,21 @@ def normalize_args(args): def reuse_if_untouched(expr, args, evaluate=False): """ - Reconstruct `expr` iff any of the provided `args` is different than - the corresponding arg in `expr.args`. + Reconstruct `expr` iff any of the provided `args` is different from + the corresponding arg in `expr.args`, or from the corresponding item + for plain tuples/lists. """ - if all(a is b for a, b in zip(expr.args, args, strict=False)): + args = tuple(args) + + if type(expr) is tuple: + if len(args) == len(expr) and all(a is b for a, b in zip(expr, args)): + return expr + return args + elif type(expr) is list: + if len(args) == len(expr) and all(a is b for a, b in zip(expr, args)): + return expr + return list(args) + elif all(a is b for a, b in zip(expr.args, args, strict=False)): return expr else: return expr.func(*args, evaluate=evaluate) diff --git a/devito/symbolics/search.py b/devito/symbolics/search.py index 9c30948064..55064cbc23 100644 --- a/devito/symbolics/search.py +++ b/devito/symbolics/search.py @@ -8,13 +8,14 @@ from devito.symbolics.queries import ( q_derivative, q_dimension, q_function, q_indexed, q_leaf, q_symbol, q_terminal ) -from devito.tools import as_tuple +from devito.tools import as_tuple, memoized_func __all__ = [ 'retrieve_derivatives', 'retrieve_dimensions', 'retrieve_function_carriers', 'retrieve_functions', + 'retrieve_accesses', 'retrieve_indexed', 'retrieve_symbols', 'retrieve_terminals', @@ -140,10 +141,19 @@ def retrieve_indexed(exprs, mode='all', deep=False): def retrieve_functions(exprs, mode='all', deep=False): """Shorthand to retrieve the DiscreteFunctions in `exprs`.""" - indexeds = search(exprs, q_indexed, mode, 'dfs', deep) + query = lambda i: q_function(i) or q_indexed(i) + found = search(exprs, query, 'all', 'dfs', deep) + + functions = modes[mode]() + indexed_functions = set() + + for i in found: + if q_function(i): + functions.add(i) if mode == 'unique' else functions.append(i) + else: + indexed_functions.add(i.function) - functions = search(exprs, q_function, mode, 'dfs', deep) - functions.update({i.function for i in indexeds}) + functions.update(indexed_functions) return functions @@ -177,6 +187,26 @@ def retrieve_terminals(exprs, mode='all', deep=False): return search(exprs, q_terminal, mode, 'dfs', deep) +@memoized_func(scope='build') +def retrieve_accesses(exprs, deep=False): + """ + Like retrieve_terminals, but ensure that if a ComponentAccess is found, + the ComponentAccess itself is returned, while the wrapped Indexed is discarded. + """ + from devito.symbolics.manipulation import uxreplace + from devito.types import ComponentAccess, Symbol + + compaccs = search(exprs, ComponentAccess) + if not compaccs: + return frozenset(retrieve_terminals(exprs, mode='unique', deep=deep)) + + subs = {i: Symbol(f'dummy{n}') for n, i in enumerate(compaccs)} + exprs1 = uxreplace(exprs, subs) + + return frozenset(compaccs | retrieve_terminals(exprs1, mode='unique', deep=deep) - + set(subs.values())) + + def retrieve_dimensions(exprs, mode='all', deep=False): """Shorthand to retrieve the dimensions in ``exprs``.""" return search(exprs, q_dimension, mode, 'dfs', deep) diff --git a/devito/tools/abc.py b/devito/tools/abc.py index 2e489ac3c0..814eabd7f2 100644 --- a/devito/tools/abc.py +++ b/devito/tools/abc.py @@ -1,5 +1,7 @@ from hashlib import sha1 +from devito.tools.memoization import cached_hash + __all__ = ['Pickable', 'Reconstructable', 'Signer', 'Singleton', 'Stamp', 'Tag'] @@ -34,6 +36,7 @@ def __gt__(self, other): def __ge__(self, other): return self.val >= other.val + @cached_hash def __hash__(self): return hash((self.name, self.val)) diff --git a/devito/tools/data_structures.py b/devito/tools/data_structures.py index d875878d02..2bf901c3f4 100644 --- a/devito/tools/data_structures.py +++ b/devito/tools/data_structures.py @@ -14,6 +14,7 @@ __all__ = [ 'DAG', 'Bunch', + 'DefaultFrozenDict', 'DefaultOrderedDict', 'EnrichedTuple', 'MemoryEstimate', @@ -672,6 +673,42 @@ def __hash__(self): return self._hash +class DefaultFrozenDict(frozendict): + """ + An immutable mapper that returns a configured default value for missing + keys when accessed via ``obj[key]``. + + Unlike :class:`collections.defaultdict`, the mapping remains immutable and + missing-key access does not mutate internal state. The ``get`` method + preserves the standard dictionary semantics, defaulting to ``None`` unless + the caller provides an explicit fallback. + """ + + _sentinel = object() + + def __init__(self, *args, default=_sentinel, **kwargs): + self._default = default + super().__init__(*args, **kwargs) + + def __getitem__(self, key): + try: + return self._dict[key] + except KeyError: + if self._default is self._sentinel: + raise + + if callable(self._default): + return self._default() + else: + return self._default + + def get(self, key, default=None): + return self._dict.get(key, default) + + def copy(self, **add_or_replace): + return self.__class__(self, default=self._default, **add_or_replace) + + class MemoryEstimate(frozendict): """ An immutable mapper for a memory estimate, providing the estimated memory diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index c10f5ea092..cb36870e37 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -1,9 +1,37 @@ from collections.abc import Callable, Hashable -from functools import lru_cache, partial +from functools import lru_cache, partial, wraps from itertools import tee from typing import TypeVar +from weakref import WeakKeyDictionary -__all__ = ['CacheInstances', 'memoized_func', 'memoized_generator', 'memoized_meth'] +__all__ = [ + 'CacheInstances', + 'cached_hash', + 'memoized_func', + 'memoized_generator', + 'memoized_meth', + 'memoized_weak_meth', + 'reuse_if_unchanged' +] + + +def cached_hash(func): + """ + Cache an immutable object's ``__hash__`` return value in ``_mhash``. + + Warning: avoid explicitly calling a superclass' cached ``__hash__`` on a + subclass instance, as that would stash the superclass hash in ``_mhash``. + """ + @wraps(func) + def wrapper(self): + try: + return self._mhash + except AttributeError: + ret = func(self) + self._mhash = ret + return ret + + return wrapper class memoized_func: @@ -19,9 +47,22 @@ class memoized_func: https://wiki.python.org/moin/PythonDecoratorLibrary#Memoize """ - def __init__(self, func): + # Long-lived caches for process-global helpers, such as arch discovery. + _scope_persistent = 'persistent' + # Build-scoped caches that may retain compiler inputs during Operator construction. + _scope_build = 'build' + _scoped_caches = {} + + def __new__(cls, func=None, *, scope=None): + if func is None: + return lambda f: cls(f, scope=scope) + return super().__new__(cls) + + def __init__(self, func, *, scope=None): self.func = func + self.scope = scope or self._scope_persistent self.cache = {} + self._scoped_caches.setdefault(self.scope, set()).add(self) def __call__(self, *args, **kw): if not isinstance(args, Hashable): @@ -44,6 +85,18 @@ def __get__(self, obj, objtype): """Support instance methods.""" return partial(self.__call__, obj) + def clear(self): + self.cache.clear() + + @classmethod + def clear_scoped_caches(cls, scope): + for cache in cls._scoped_caches.get(scope, ()): + cache.clear() + + @classmethod + def clear_build_caches(cls): + cls.clear_scoped_caches(cls._scope_build) + class memoized_meth: """ @@ -86,11 +139,19 @@ def __call__(self, *args, **kw): cache = obj.__cache_meth except AttributeError: cache = obj.__cache_meth = {} - key = (self.func, args[1:], frozenset(kw.items())) + if kw: + key = (self.func, args[1:], frozenset(kw.items())) + else: + key = (self.func, args[1:]) + try: res = cache[key] except KeyError: res = cache[key] = self.func(*args, **kw) + except TypeError: + # Uncacheable, e.g. an unhashable item within ``args``. + return self.func(*args, **kw) + return res @@ -128,6 +189,54 @@ def __call__(self, *args, **kwargs): return result +def memoized_weak_meth(*, key=None, freeze=None, thaw=None): + """ + Cache a method result against its first argument using weak references. + + This is useful for visitors operating on temporary IR roots: the cache can + be shared across short-lived visitor instances without keeping those roots + alive. Only calls without extra arguments are cached; all other calls fall + back to the wrapped method. + + Parameters + ---------- + key : callable, optional + A callable receiving ``self`` and returning a hashable cache partition. + freeze : callable, optional + Convert the method result before storing it in the cache. + thaw : callable, optional + Convert the cached value before returning it to the caller. + """ + def decorator(func): + caches = {} + + @wraps(func) + def wrapper(self, o, *args, **kwargs): + if args or kwargs: + return func(self, o, *args, **kwargs) + + try: + partition = key(self) if key is not None else None + cache = caches.setdefault(partition, WeakKeyDictionary()) + ret = cache[o] + except KeyError: + ret = func(self, o) + if freeze is not None: + ret = freeze(ret) + cache[o] = ret + except TypeError: + return func(self, o) + + if thaw is not None: + return thaw(ret) + + return ret + + return wrapper + + return decorator + + # Describes the type of a subclass of CacheInstances InstanceType = TypeVar('InstanceType', bound='CacheInstances', covariant=True) @@ -154,6 +263,9 @@ def __init__(cls: type[InstanceType], *args) -> None: # type: ignore def __call__(cls: type[InstanceType], # type: ignore *args, **kwargs) -> InstanceType: + if cls._instance_cache_size == 0: + return super().__call__(*args, **kwargs) + args, kwargs = cls._preprocess_args(*args, **kwargs) return cls._instance_cache(*args, **kwargs) @@ -173,7 +285,7 @@ class CacheInstances(metaclass=CacheInstancesMeta): """ _instance_cache: Callable | None = None - _instance_cache_size: int = 128 + _instance_cache_size: int = 8192 @classmethod def _preprocess_args(cls, *args, **kwargs): @@ -189,3 +301,35 @@ def clear_caches() -> None: Clears all IR instance caches. """ CacheInstancesMeta.clear_caches() + + +def reuse_if_unchanged(fields): + """ + Decorator for wrapper-style constructors that should return the original + object when called as ``Cls(existing_obj, **same_metadata)``. + + The wrapped callable is assumed to be a classmethod-like constructor + receiving ``cls`` as first argument. The fast path triggers only when: + + * the constructor is called with exactly one positional argument; + * that argument is already an exact instance of ``cls``; + * any explicitly provided metadata fields are the same objects as the + corresponding attributes on the input object. + """ + def decorator(func): + @wraps(func) + def wrapper(cls, *args, **kwargs): + if len(args) == 1: + input_obj = args[0] + if type(input_obj) is cls: + names = getattr(cls, fields) if isinstance(fields, str) else fields + for name in names: + if name in kwargs and kwargs[name] is not getattr(input_obj, name, None): + break + else: + return input_obj + return func(cls, *args, **kwargs) + + return wrapper + + return decorator diff --git a/devito/tools/utils.py b/devito/tools/utils.py index 91b5bcdbf7..470be7e79e 100644 --- a/devito/tools/utils.py +++ b/devito/tools/utils.py @@ -1,6 +1,6 @@ import types from collections import OrderedDict -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from functools import reduce, wraps from itertools import chain, combinations, groupby, product, zip_longest from operator import attrgetter, mul @@ -10,6 +10,7 @@ __all__ = [ 'all_equal', + 'as_hashable', 'as_list', 'as_mapper', 'as_set', @@ -87,6 +88,28 @@ def as_tuple(item, type=None, length=None): return t +def as_hashable(item): + """ + Convert common containers into a hashable representation. + + Unknown unhashable objects fall back to identity, avoiding false cache hits. + """ + if isinstance(item, Mapping): + items = ((as_hashable(k), as_hashable(v)) for k, v in item.items()) + return tuple(sorted(items, key=repr)) + if isinstance(item, (tuple, list)): + return tuple(as_hashable(i) for i in item) + if isinstance(item, (set, frozenset)): + return tuple(sorted((as_hashable(i) for i in item), key=repr)) + + try: + hash(item) + except TypeError: + return (type(item), id(item)) + else: + return item + + def as_mapper(iterable, key=None, get=None): """ Rearrange an iterable into a dictionary of lists in which keys are diff --git a/devito/types/caching.py b/devito/types/caching.py index 742c0b3d33..9fbbcaa638 100644 --- a/devito/types/caching.py +++ b/devito/types/caching.py @@ -4,7 +4,7 @@ import sympy from sympy.core import cache -from devito.tools import safe_dict_copy +from devito.tools import memoized_func, safe_dict_copy __all__ = ['CacheManager', 'Cached', 'Uncached', '_SymbolCache'] @@ -175,6 +175,10 @@ def clear(cls, force=True): # SymPy 1.14 and later pass + # Drop compiler-scoped Python memoization that may still hold strong + # references to symbolic objects pending collection. + memoized_func.clear_build_caches() + # Take a copy of the dictionary so we can safely iterate over it # even if another thread is making changes cache_copied = safe_dict_copy(_SymbolCache) diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 18d0d3609c..4206f70057 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -10,7 +10,7 @@ from devito import Constant, Eq, Function, Grid, Operator, configuration, exp, log, sin from devito.arch.compiler import CustomCompiler, GNUCompiler from devito.exceptions import InvalidOperator -from devito.ir.cgen.printer import BasePrinter +from devito.ir.cgen.printer import BasePrinter, get_printer from devito.passes.iet.langbase import LangBB from devito.passes.iet.languages.C import CBB, CPrinter from devito.passes.iet.languages.openacc import AccBB, AccPrinter @@ -204,6 +204,19 @@ def test_math_functions(dtype: np.dtype[np.inexact], assert call_str in str(op) +def test_printer_registry() -> None: + default = get_printer(CPrinter) + + assert get_printer(CPrinter) is default + assert get_printer(CPrinter, np.float32) is default + + float64 = get_printer(CPrinter, np.float64) + assert get_printer(CPrinter, np.float64) is float64 + + float16 = get_printer(CPrinter, np.float16) + assert get_printer(CPrinter, np.float16) is float16 + + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) def test_complex_override(dtype: np.dtype[np.complexfloating]) -> None: """ diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index e84d0df5d8..4bde13a7fb 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -82,6 +82,8 @@ class TestDeviceID: @pytest.mark.parametrize('env_variables', [{"CUDA_VISIBLE_DEVICES": "1"}, {"CUDA_VISIBLE_DEVICES": "1,2"}, {"CUDA_VISIBLE_DEVICES": "1,0"}, + {"NVIDIA_VISIBLE_DEVICES": "1"}, + {"NVIDIA_VISIBLE_DEVICES": "1,2"}, {"ROCR_VISIBLE_DEVICES": "1"}, {"HIP_VISIBLE_DEVICES": " 1"}]) def test_visible_devices(self, env_variables): diff --git a/tests/test_iet.py b/tests/test_iet.py index e8e8f8444f..0c3f944a32 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -14,7 +14,8 @@ ElementalFunction, FindSymbols, Iteration, KernelLaunch, Lambda, List, Switch, Transformer, filter_iterations, make_efunc, retrieve_iteration_tree ) -from devito.passes.iet.engine import Graph +from devito.passes.iet import engine as iet_engine +from devito.passes.iet.engine import Graph, iet_pass from devito.passes.iet.languages.C import CDataManager from devito.symbolics import ( FLOAT, Byref, Class, FieldFromComposite, InlineIf, ListInitializer, Macro, SizeOf, @@ -539,6 +540,26 @@ def test_complex_array(): "float _Complex **restrict a_vec __attribute__ ((aligned (64)));" +def test_iet_pass_does_not_update_args(monkeypatch): + x = Symbol(name='x') + y = Symbol(name='y') + + foo = Callable('foo', DummyExpr(x, y), 'void', parameters=(x, y)) + graph = Graph(foo) + + @iet_pass + def inject_expr(iet): + body = iet.body._rebuild(body=iet.body.body + (DummyExpr(x, x),)) + return iet._rebuild(body=body), {} + + monkeypatch.setattr(iet_engine, '_update_args', + lambda *args, **kwargs: pytest.fail("_update_args called")) + + inject_expr(graph) + + assert graph.root.parameters is foo.parameters + + def test_special_array_definition(): class MyArray(Array): diff --git a/tests/test_ir.py b/tests/test_ir.py index 16440ec54a..7611d39dfc 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -7,6 +7,7 @@ Constant, Dimension, Eq, Function, Grid, Inc, Operator, SubDimension, TimeFunction, switchconfig ) +from devito.ir.clusters import Cluster, ClusterGroup from devito.ir.cgen import ccode from devito.ir.equations import LoweredEq from devito.ir.equations.algorithms import dimension_sort @@ -17,7 +18,8 @@ ) from devito.ir.support.guards import GuardOverflow from devito.ir.support.space import ( - Backward, Forward, Interval, IntervalGroup, IterationSpace, NullInterval + Backward, Forward, Interval, IntervalGroup, IterationInterval, IterationSpace, + NullInterval, null_ispace ) from devito.symbolics import DefFunction, FieldFromPointer from devito.tools import prod @@ -140,6 +142,12 @@ def test_vector_cmp(self, v_num, v_literal): assert v2 <= vs3 assert vs3 > v2 + def test_timedaccess_cached(self, fc, x, y): + ta0 = TimedAccess(fc[x, y], 'R', 0) + ta1 = TimedAccess(fc[x, y], 'R', 0, null_ispace) + + assert ta0 is ta1 + def test_iteration_instance_arithmetic(self, x, y, ii_num, ii_literal): """ Test arithmetic operations involving objects of type IterationInstance. @@ -359,6 +367,60 @@ def x(self, grid): def y(self, grid): return grid.dimensions[1] + def test_null_interval_cache_identity(self, x): + i0 = NullInterval(x) + i1 = NullInterval(x) + + assert i0 is i1 + + def test_interval_cache_identity(self, x): + i0 = Interval(x, -2, 2) + i1 = Interval(x, -2, 2) + + assert i0 is i1 + + def test_iteration_interval_cache_identity(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + i0 = IterationInterval(Interval(x), (xi,), Forward) + i1 = IterationInterval(Interval(x), (xi,), Forward) + + assert i0 is i1 + + def test_iteration_interval_cache_distinguishes_sub_iterators(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + i0 = IterationInterval(Interval(x), (xi,), Forward) + i1 = IterationInterval(Interval(x), (), Forward) + + assert i0 is not i1 + + def test_interval_group_cache_identity(self, x, y): + ig0 = IntervalGroup([Interval(x, -2, 2), Interval(y, -1, 1)], + relations=((x, y),), mode='partial') + ig1 = IntervalGroup((Interval(x, -2, 2), Interval(y, -1, 1)), + relations=((x, y),), mode='partial') + + assert ig0 is ig1 + + def test_iteration_space_cache_identity(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + ispace0 = IterationSpace([Interval(x)], {x: (xi,)}, {x: Forward}) + ispace1 = IterationSpace([Interval(x)], {x: (xi,)}, {x: Forward}) + + assert ispace0 is ispace1 + assert isinstance(ispace0[x], IterationInterval) + assert ispace0[x] is ispace1[x] + + def test_iteration_space_cache_distinguishes_sub_iterators(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + ispace0 = IterationSpace([Interval(x)], {x: (xi,)}, {x: Forward}) + ispace1 = IterationSpace([Interval(x)], directions={x: Forward}) + + assert ispace0 is not ispace1 + def test_intervals_intersection(self, x, y): nullx = NullInterval(x) @@ -788,6 +850,18 @@ def test_indirect_access(self): v = scope.d_flow.pop() assert v.function is s1 + def test_ireq_function_views_indirect_indices(self): + grid = Grid(shape=(4,)) + x, = grid.dimensions + + u = Function(name='u', grid=grid) + f = Function(name='f', grid=grid) + a = Function(name='a', grid=grid) + + expr = LoweredEq(Eq(u, f[a[x]])) + + assert set(expr.read_functions) == {f, a} + def test_array_shared(self): grid = Grid(shape=(4, 4)) x, y = grid.dimensions @@ -1088,6 +1162,25 @@ def test_dimension_sort(self, expr, expected): assert list(dimension_sort(expr)) == eval(expected) +class TestClusterGroup: + + def test_eq_hash_include_ispace(self): + grid = Grid(shape=(4,)) + x, = grid.dimensions + + f = Function(name='f', grid=grid) + cluster = Cluster(Eq(f[x], 1)) + + ispace0 = IterationSpace([Interval(x, 0, 0)], directions={x: Forward}) + ispace1 = IterationSpace([Interval(x, 0, 0)], directions={x: Backward}) + + cgroup0 = ClusterGroup((cluster,), ispace0) + cgroup1 = ClusterGroup((cluster,), ispace1) + + assert cgroup0 != cgroup1 + assert len({cgroup0, cgroup1}) == 2 + + class TestGuards: def test_guard_overflow(self): diff --git a/tests/test_mpi.py b/tests/test_mpi.py index 7746e06155..465a937f6e 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -3294,7 +3294,8 @@ def run_adjoint_F(self, nd): @pytest.mark.parametrize('nd', [1, 2, 3]) @pytest.mark.parallel(mode=[(4, 'basic'), (4, 'diag'), (4, 'overlap'), - (4, 'overlap2'), (4, 'full')]) + (4, 'overlap2'), (4, 'full')], + timeout=600) def test_adjoint_F(self, nd, mode): self.run_adjoint_F(nd) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 49f134148f..1f500e939d 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -18,7 +18,7 @@ INT, BaseCast, CallFromPointer, Cast, DefFunction, FieldFromComposite, FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, RoundUp, Rvalue, SizeOf, VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions, - retrieve_indexed, uxreplace + retrieve_indexed, subs_if_composite, uxreplace, xreplace_indices ) from devito.tools import CustomDtype, as_tuple from devito.types import ( @@ -841,6 +841,18 @@ def test_is_on_grid(): assert all(uu._grid_map == {} for uu in retrieve_functions(u.subs({x: x0}).evaluate)) +def test_retrieve_functions_mixed_carriers(): + grid = Grid((10,)) + x = grid.dimensions[0] + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + + expr = f + FIndexed(g.base, x) + + assert retrieve_functions(expr, mode='unique') == {f, g} + + @pytest.mark.parametrize('expr,expected', [ ('f[x+2]*g[x+4] + f[x+3]*g[x+5] + f[x+4] + f[x+1]', ['f[x+2]', 'g[x+4]', 'f[x+3]', 'g[x+5]', 'f[x+1]', 'f[x+4]']), @@ -898,6 +910,55 @@ def test_expressions(self, expr, subs, expected): assert uxreplace(eval(expr), eval(subs)) == eval(expected) + def test_uxreplace_reuses_empty_substitution(self): + grid = Grid(shape=(4, 4)) + f = Function(name='f', grid=grid) + expr = f.indexify() + 1 + + assert uxreplace(expr, {}) is expr + + def test_subs_if_composite_reuses_untouched_sequence(self): + grid = Grid(shape=(4, 4)) + x, y = grid.dimensions + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + + exprs = (Eq(f[x, y], f[x, y] + 1),) + + assert subs_if_composite(exprs, {}) is exprs + assert subs_if_composite(exprs, {g[x, y]: f[x, y]}) is exprs + assert subs_if_composite(exprs, {g[x, y] + 1: f[x, y]}) is exprs + + processed = subs_if_composite(exprs, {f[x, y]: g[x, y]}) + + assert processed is not exprs + assert processed[0] is not exprs[0] + + def test_pow_to_mul_reuses_untouched_sequence(self): + grid = Grid(shape=(4, 4)) + x, y = grid.dimensions + f = Function(name='f', grid=grid) + + exprs = (Eq(f[x, y], f[x, y] + 1),) + + assert pow_to_mul(exprs) is exprs + assert pow_to_mul([exprs[0]])[0] is exprs[0] + + processed = pow_to_mul((Eq(f[x, y], f[x, y]**2),)) + + assert processed is not exprs + + def test_xreplace_indices_reuses_untouched_sequence(self): + grid = Grid(shape=(4, 4)) + x, y = grid.dimensions + z = Dimension(name='z') + f = Function(name='f', grid=grid) + + exprs = (Eq(f[x, y], f[x, y] + 1),) + + assert xreplace_indices(exprs, {z: z + 1}) is exprs + assert xreplace_indices(exprs, {x: x + 1}) is not exprs + def test_custom_reconstructable(self): class MyDefFunction(DefFunction): diff --git a/tests/test_tools.py b/tests/test_tools.py index 0b06883e78..e1943fe802 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -7,8 +7,9 @@ from devito import Eq, Operator, switchenv from devito.tools import ( + DefaultFrozenDict, CacheInstances, UnboundedMultiTuple, UnboundTuple, ctypes_to_cstr, filter_ordered, - toposort, transitive_closure + memoized_meth, memoized_weak_meth, toposort, transitive_closure ) from devito.types.basic import Symbol @@ -61,6 +62,93 @@ def test_transitive_closure(): assert mapper == {a: d, b: d, c: d, f: e} +def test_memoized_meth(): + + class Obj: + + def __init__(self): + self.calls = 0 + + @memoized_meth + def f(self, x=None): + self.calls += 1 + return x + + obj = Obj() + + assert obj.f(1) == 1 + assert obj.f(1) == 1 + assert obj.calls == 1 + + assert obj.f(x=2) == 2 + assert obj.f(x=2) == 2 + assert obj.calls == 2 + + assert obj.f([3]) == [3] + assert obj.f([3]) == [3] + assert obj.calls == 4 + + +def test_memoized_weak_meth(): + + class Root: + pass + + class Obj: + + def __init__(self, mode): + self.mode = mode + self.calls = 0 + + @memoized_weak_meth(key=lambda i: i.mode, freeze=tuple, thaw=list) + def f(self, root): + self.calls += 1 + return [self.mode] + + root = Root() + obj0 = Obj('a') + obj1 = Obj('a') + obj2 = Obj('b') + + ret = obj0.f(root) + ret.append('mutated') + + assert obj1.f(root) == ['a'] + assert obj0.calls == 1 + assert obj1.calls == 0 + + assert obj2.f(root) == ['b'] + assert obj2.calls == 1 + + assert obj0.f([]) == ['a'] + assert obj0.f([]) == ['a'] + assert obj0.calls == 3 + + +def test_default_frozen_dict(): + mapper = DefaultFrozenDict({'a': 'b'}, default='c') + + assert mapper['a'] == 'b' + assert mapper['d'] == 'c' + assert mapper.get('d') is None + assert mapper.get('d', 'e') == 'e' + + copied = mapper.copy(c='d') + assert copied['c'] == 'd' + assert copied['e'] == 'c' + + +def test_default_frozen_dict_factory(): + mapper = DefaultFrozenDict(default=lambda: []) + + v0 = mapper[a] + v1 = mapper[b] + + assert v0 == [] + assert v1 == [] + assert v0 is not v1 + + def test_loops_in_transitive_closure(): a = Symbol('a') b = Symbol('b') @@ -212,6 +300,30 @@ def __init__(self, value: int): cache_size = Object._instance_cache.cache_info()[-1] assert cache_size == 0 + def test_uncached_subclass_bypasses_parent_preprocess(self): + """ + Tests that an uncached subclass does not inherit its parent's + preprocessing contract. + """ + class Parent(CacheInstances): + @classmethod + def _preprocess_args(cls, value): + return (value + 1,), {} + + def __init__(self, value: int): + self.value = value + + class Child(Parent): + _instance_cache_size = 0 + + def __init__(self, left: int, right: int): + self.value = (left, right) + + obj0 = Child(1, 2) + obj1 = Child(1, 2) + + assert obj0.value == (1, 2) + assert obj0 is not obj1 def test_switchenv(): # Save previous environment diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 06eb933351..b5d12f81d5 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -6,8 +6,8 @@ from devito.ir.equations import DummyEq from devito.ir.iet import ( Block, Call, Callable, Conditional, Expression, FindApplications, FindNodes, - FindSections, FindSymbols, IsPerfectIteration, Iteration, MapNodes, Transformer, - printAST + FindSections, FindSymbols, FindWithin, IsPerfectIteration, Iteration, MapNodes, + Transformer, Uxreplace, printAST ) from devito.types import Array, SpaceDimension, Symbol @@ -210,6 +210,15 @@ def test_find_sections(exprs, block1, block2, block3): assert len(found[2]) == 1 +def test_find_within_not_cached_like_findnodes(block3): + expr0 = FindWithin(Expression, block3.nodes[0], block3.nodes[1]).visit(block3) + expr1 = FindWithin(Expression, block3.nodes[1], block3.nodes[2]).visit(block3) + + assert len(expr0) == 3 + assert len(expr1) == 3 + assert expr0 != expr1 + + def test_is_perfect_iteration(block1, block2, block3, block4): checker = IsPerfectIteration() @@ -249,6 +258,14 @@ def test_transformer_wrap(exprs, block1, block2, block3): assert "a[i] = a[i] + b[i] + 5.0F;" in newcode +def test_transformer_reuses_untouched_node(block1): + assert Transformer({}).visit(block1) is block1 + + +def test_uxreplace_reuses_untouched_node(block1): + assert Uxreplace({}).visit(block1) is block1 + + def test_transformer_replace(exprs, block1, block2, block3): """Basic transformer test that replaces an expression""" line1 = '// Replaced expression'