From 0df2f90bb38b819b435d7c2562caae5fa58b2f1d Mon Sep 17 00:00:00 2001 From: Ben Dudson Date: Mon, 22 Jun 2026 15:31:30 -0700 Subject: [PATCH] eval_into asynchronous kernels Gather kernels using an `eval_into(result, expression)` builder pattern. The kernels can be streamed asynchronously or merged into one large kernel. --- examples/hasegawa-wakatani/hw.cxx | 13 +- include/bout/fieldops.hxx | 271 ++++++++++++++++++++++++++-- tests/unit/field/test_field2d.cxx | 24 +++ tests/unit/field/test_field3d.cxx | 27 +++ tests/unit/field/test_fieldperp.cxx | 25 +++ 5 files changed, 341 insertions(+), 19 deletions(-) diff --git a/examples/hasegawa-wakatani/hw.cxx b/examples/hasegawa-wakatani/hw.cxx index c3f8717597..a675120dcf 100644 --- a/examples/hasegawa-wakatani/hw.cxx +++ b/examples/hasegawa-wakatani/hw.cxx @@ -1,5 +1,7 @@ #include +#include +#include #include #include #include @@ -109,10 +111,13 @@ class HW : public PhysicsModel { nonzonal_phi -= averageY(DC(phi)); } - ddt(n) = - -bracket(phi, n, bm) + alpha * (nonzonal_phi - nonzonal_n) - kappa * DDZ(phi); - - ddt(vort) = -bracket(phi, vort, bm) + alpha * (nonzonal_phi - nonzonal_n); + // Two kernels can be evaluated asynchronously + eval_into(ddt(n), // Density equation + -bracket(phi, n, bm) + alpha * (nonzonal_phi - nonzonal_n) + - kappa * DDZ(phi)) + .eval_into(ddt(vort), // Vorticity equation + -bracket(phi, vort, bm) + alpha * (nonzonal_phi - nonzonal_n)) + .stream(); return 0; } diff --git a/include/bout/fieldops.hxx b/include/bout/fieldops.hxx index 53e28042c9..bf11d7b2bc 100644 --- a/include/bout/fieldops.hxx +++ b/include/bout/fieldops.hxx @@ -12,7 +12,10 @@ #include #include #include +#include #include +#include +#include #if BOUT_HAS_CUDA #include @@ -298,7 +301,160 @@ struct StreamsRAII { StreamsRAII& operator=(StreamsRAII&&) = delete; }; inline struct StreamsRAII streams; + +struct BorrowedStreams { + std::vector borrowed; + + cudaStream_t acquire() { + auto stream = streams.get(); + borrowed.push_back(stream); + return stream; + } + + void synchronize() { + for (auto& stream : borrowed) { + cudaStreamSynchronize(stream); + } + } + + ~BorrowedStreams() { + for (auto& stream : borrowed) { + streams.put(stream); + } + } + + BorrowedStreams() = default; + BorrowedStreams(const BorrowedStreams&) = delete; + BorrowedStreams(BorrowedStreams&&) = delete; + BorrowedStreams& operator=(const BorrowedStreams&) = delete; + BorrowedStreams& operator=(BorrowedStreams&&) = delete; +}; +#endif + +template +void launchExprView(BoutReal* out, const ExprView& expr_view +#if BOUT_HAS_CUDA && defined(__CUDACC__) + , + cudaStream_t stream +#endif +) { + if (expr_view.size() == 0) { + return; + } + +#if BOUT_HAS_CUDA && defined(__CUDACC__) + int blocks = (expr_view.size() + THREADS - 1) / THREADS; + evaluatorExpr<<>>(out, expr_view); +#else + int e = expr_view.size(); + for (int i = 0; i < e; ++i) { + const int idx = expr_view.regionIdx(i); + out[idx] = expr_view(idx); + } +#endif +} + +template +void launchExprAsync(BoutReal* out, const Expr& expr +#if BOUT_HAS_CUDA && defined(__CUDACC__) + , + cudaStream_t stream +#endif +) { + launchExprView(out, static_cast(expr) +#if BOUT_HAS_CUDA && defined(__CUDACC__) + , + stream +#endif + ); +} + +template +void launchExprSync(BoutReal* out, const Expr& expr) { +#if BOUT_HAS_CUDA && defined(__CUDACC__) + auto stream = streams.get(); + launchExprAsync(out, expr, stream); + cudaStreamSynchronize(stream); + streams.put(stream); +#else + launchExprAsync(out, expr); #endif +} + +namespace bout::detail { + +template +inline constexpr bool is_eval_result_v = + std::is_same_v, Field2D> || std::is_same_v, Field3D> + || std::is_same_v, FieldPerp>; + +template +inline constexpr bool is_eval_compatible_v = + (std::is_same_v, Field3D> && is_expr_field3d_v) + || (std::is_same_v, Field2D> && is_expr_field2d_v) + || (std::is_same_v, FieldPerp> && is_expr_fieldperp_v); + +template +inline constexpr bool is_materialized_eval_expr_v = + std::is_same_v, Field3D> + || std::is_same_v, Field2D> + || std::is_same_v, FieldPerp>; + +template +void resetEvalResult(Result& result, const Expr& expr) { + using ResultType = std::decay_t; + + if constexpr (std::is_same_v) { + result = Field3D{expr.getMesh(), expr.getLocation(), expr.getDirections(), + expr.getRegionID()}; + } else if constexpr (std::is_same_v) { + result = Field2D{expr.getMesh(), expr.getLocation(), expr.getDirections(), + expr.getRegionID()}; + } else if constexpr (std::is_same_v) { + result = FieldPerp{expr.getMesh(), expr.getLocation(), expr.getIndex(), + expr.getDirections(), expr.getRegionID()}; + } else { + static_assert(is_eval_result_v, "Unsupported eval_into result type"); + } +} + +template +void prepareEvalResult(Result& result, const Expr& expr) { + if (!result.isAllocated() || result.getMesh() != expr.getMesh()) { + resetEvalResult(result, expr); + } + + if constexpr (std::is_same_v, Field3D>) { + result.clearParallelSlices(); + result.setRegion(expr.getRegionID()); + } + + result.setLocation(expr.getLocation()); + result.setDirections(expr.getDirections()); + + if constexpr (std::is_same_v, FieldPerp>) { + result.setIndex(expr.getIndex()); + } + + result.allocate(); +} + +template +BoutReal* evalResultData(Result& result) { + return static_cast::View>(result).data; +} + +template +void executeEvalTask(Result& result, const Expr& expr) { + if constexpr (is_materialized_eval_expr_v) { + result = expr; + } else { + prepareEvalResult(result, expr); + launchExprSync(evalResultData(result), expr); + } +} + +} // namespace bout::detail template auto reduceExpr(const ExprView& expr_view) -> typename Reducer::State { @@ -356,6 +512,8 @@ struct BinaryExpr { : lhs(lhs), rhs(rhs), indices(indices), f(f), mesh(mesh), location(location), directions(directions), regionID(regionID), yindex(yindex) {} + BinaryExpr(const BinaryExpr&) = default; + BinaryExpr(BinaryExpr&&) = default; BinaryExpr& operator=(const BinaryExpr&) = delete; BinaryExpr& operator=(BinaryExpr&&) = delete; @@ -402,21 +560,7 @@ struct BinaryExpr { operator View() { return View{lhs, rhs, &indices[0], indices.size(), f}; } operator View() const { return View{lhs, rhs, &indices[0], indices.size(), f}; } - void evaluate(BoutReal* data) const { -#if BOUT_HAS_CUDA && defined(__CUDACC__) - cudaStream_t stream = streams.get(); - int blocks = (size() + THREADS - 1) / THREADS; - evaluatorExpr<<>>(&data[0], static_cast(*this)); - cudaStreamSynchronize(stream); - streams.put(stream); -#else - int e = size(); - for (int i = 0; i < e; ++i) { - int idx = regionIdx(i); - data[idx] = operator()(idx); // single‐pass fusion - } -#endif - } + void evaluate(BoutReal* data) const { launchExprSync(&data[0], *this); } Mesh* getMesh() const { return mesh; } CELL_LOC getLocation() const { return location; } @@ -425,4 +569,101 @@ struct BinaryExpr { int getIndex() const { return yindex.value_or(-1); } }; +template +struct EvalTask { + Result* result; + std::decay_t expr; +}; + +template +struct EvalBuilder { + std::tuple tasks; + + template + auto eval_into(Result& result, Expr&& expr) && { + using ExprType = std::decay_t; + static_assert(bout::detail::is_eval_result_v, + "eval_into only supports Field2D, Field3D, and FieldPerp results"); + static_assert(bout::detail::is_eval_compatible_v, + "eval_into result type does not match the expression family"); + + using Task = EvalTask, ExprType>; + return EvalBuilder{std::tuple_cat( + std::move(tasks), std::make_tuple(Task{&result, std::forward(expr)}))}; + } + + template + auto eval_into(Result& result, Expr&& expr) const& { + using ExprType = std::decay_t; + static_assert(bout::detail::is_eval_result_v, + "eval_into only supports Field2D, Field3D, and FieldPerp results"); + static_assert(bout::detail::is_eval_compatible_v, + "eval_into result type does not match the expression family"); + + using Task = EvalTask, ExprType>; + return EvalBuilder{ + std::tuple_cat(tasks, std::make_tuple(Task{&result, std::forward(expr)}))}; + } + + // Prototype entry point: this currently shares the stream execution path + // until a fused multi-output kernel is added. + void merge() && { stream_impl(); } + void merge() const& { stream_impl(); } + + void stream() && { stream_impl(); } + void stream() const& { stream_impl(); } + +private: + void stream_impl() const { +#if BOUT_HAS_CUDA && defined(__CUDACC__) + std::apply( + [](auto&... task) { + (([&] { + if constexpr (!bout::detail::is_materialized_eval_expr_v< + decltype(task.expr)>) { + bout::detail::prepareEvalResult(*task.result, task.expr); + } + }()), + ...); + }, + tasks); + + BorrowedStreams borrowed_streams; + std::apply( + [&](auto&... task) { + (([&] { + if constexpr (bout::detail::is_materialized_eval_expr_v< + decltype(task.expr)>) { + *task.result = task.expr; + } else { + launchExprAsync(bout::detail::evalResultData(*task.result), task.expr, + borrowed_streams.acquire()); + } + }()), + ...); + }, + tasks); + borrowed_streams.synchronize(); +#else + std::apply( + [](auto&... task) { + ((bout::detail::executeEvalTask(*task.result, task.expr)), ...); + }, + tasks); +#endif + } +}; + +template +auto eval_into(Result& result, Expr&& expr) { + using ExprType = std::decay_t; + static_assert(bout::detail::is_eval_result_v, + "eval_into only supports Field2D, Field3D, and FieldPerp results"); + static_assert(bout::detail::is_eval_compatible_v, + "eval_into result type does not match the expression family"); + + using Task = EvalTask, ExprType>; + return EvalBuilder{std::make_tuple(Task{&result, std::forward(expr)})}; +} + #endif // BOUT_FIELDSOPS_HXX diff --git a/tests/unit/field/test_field2d.cxx b/tests/unit/field/test_field2d.cxx index a91acd4a40..02021039a4 100644 --- a/tests/unit/field/test_field2d.cxx +++ b/tests/unit/field/test_field2d.cxx @@ -1560,4 +1560,28 @@ TEST_F(Field2DTest, DC) { EXPECT_EQ(DC(field), field); } +TEST_F(Field2DTest, EvalIntoMergeChainsBinaryExprs) { + Field2D a{ + mesh_staggered, CELL_XLOW, {YDirectionType::Aligned, ZDirectionType::Average}}; + Field2D b{a}; + Field2D c{a}; + Field2D d{a}; + + a = 1.0; + b = 2.0; + c = 3.0; + d = 4.0; + + Field2D first; + Field2D second; + + eval_into(first, a + b * c).eval_into(second, b + d).merge(); + + EXPECT_TRUE(IsFieldEqual(first, 7.0)); + EXPECT_TRUE(IsFieldEqual(second, 6.0)); + EXPECT_EQ(first.getLocation(), a.getLocation()); + EXPECT_EQ(first.getDirectionY(), a.getDirectionY()); + EXPECT_EQ(first.getDirectionZ(), a.getDirectionZ()); +} + #pragma GCC diagnostic pop diff --git a/tests/unit/field/test_field3d.cxx b/tests/unit/field/test_field3d.cxx index 905b182018..997b85bb90 100644 --- a/tests/unit/field/test_field3d.cxx +++ b/tests/unit/field/test_field3d.cxx @@ -2590,5 +2590,32 @@ TEST_F(Field3DTest, Field3DParallel) { EXPECT_TRUE(IsFieldEqual(field3, 6.0)); } +TEST_F(Field3DTest, EvalIntoStreamChainsBinaryExprs) { + Field3D a{ + mesh_staggered, CELL_XLOW, {YDirectionType::Aligned, ZDirectionType::Average}}; + Field3D b{a}; + Field3D c{a}; + Field3D d{a}; + + a = 1.0; + b = 2.0; + c = 3.0; + d = 4.0; + + Field3D first; + Field3D second; + first = 0.0; + first.splitParallelSlices(); + + eval_into(first, a + b * c).eval_into(second, b + d).stream(); + + EXPECT_TRUE(IsFieldEqual(first, 7.0)); + EXPECT_TRUE(IsFieldEqual(second, 6.0)); + EXPECT_EQ(first.getLocation(), a.getLocation()); + EXPECT_EQ(first.getDirectionY(), a.getDirectionY()); + EXPECT_EQ(first.getDirectionZ(), a.getDirectionZ()); + EXPECT_FALSE(first.hasParallelSlices()); +} + // Restore compiler warnings #pragma GCC diagnostic pop diff --git a/tests/unit/field/test_fieldperp.cxx b/tests/unit/field/test_fieldperp.cxx index 46f07d589f..e62b0022af 100644 --- a/tests/unit/field/test_fieldperp.cxx +++ b/tests/unit/field/test_fieldperp.cxx @@ -1899,4 +1899,29 @@ TEST_F(FieldPerpTest, Inequality) { EXPECT_FALSE(field1 == field5); } +TEST_F(FieldPerpTest, EvalIntoStreamChainsBinaryExprs) { + FieldPerp a{ + mesh_staggered, CELL_XLOW, 3, {YDirectionType::Aligned, ZDirectionType::Average}}; + FieldPerp b{a}; + FieldPerp c{a}; + FieldPerp d{a}; + + a = 1.0; + b = 2.0; + c = 3.0; + d = 4.0; + + FieldPerp first; + FieldPerp second; + + eval_into(first, a + b * c).eval_into(second, b + d).stream(); + + EXPECT_TRUE(IsFieldEqual(first, 7.0)); + EXPECT_TRUE(IsFieldEqual(second, 6.0)); + EXPECT_EQ(first.getLocation(), a.getLocation()); + EXPECT_EQ(first.getIndex(), a.getIndex()); + EXPECT_EQ(first.getDirectionY(), a.getDirectionY()); + EXPECT_EQ(first.getDirectionZ(), a.getDirectionZ()); +} + #pragma GCC diagnostic pop