From cf77bdbd23a43303e0a4d3234c2d062d14a9a2aa Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 21 Jan 2026 02:08:34 -0800 Subject: [PATCH 01/42] first working dispatch and combine primitive for k=1 --- CMakeLists.txt | 2 + csrc/dispatch.h | 2 + csrc/host_ir/evaluator.cpp | 63 +++++ csrc/host_ir/evaluator.h | 2 + csrc/multidevice/communication.cpp | 161 +++++++++++ csrc/multidevice/communication.h | 162 +++++++++++ csrc/multidevice/dispatch_combine.cpp | 267 ++++++++++++++++++ csrc/multidevice/dispatch_combine.h | 51 ++++ .../cpp/test_multidevice_dispatch_combine.cpp | 121 ++++++++ 9 files changed, 831 insertions(+) create mode 100644 csrc/multidevice/dispatch_combine.cpp create mode 100644 csrc/multidevice/dispatch_combine.h create mode 100644 tests/cpp/test_multidevice_dispatch_combine.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 13dd918282b..b325b325d9c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -235,6 +235,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/multidevice/communication.cpp ${NVFUSER_SRCS_DIR}/multidevice/communicator.cpp ${NVFUSER_SRCS_DIR}/multidevice/cuda_p2p.cpp + ${NVFUSER_SRCS_DIR}/multidevice/dispatch_combine.cpp ${NVFUSER_SRCS_DIR}/multidevice/ipc_handle.cpp ${NVFUSER_SRCS_DIR}/multidevice/ipc_utils.cpp ${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp @@ -1143,6 +1144,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir_overlap.cpp diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 3bf3b8350ff..01aa278af71 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -118,6 +118,8 @@ class Val; f(Merge); \ f(Partition); \ f(Combine); \ + f(MoEDispatch); \ + f(MoECombine); \ f(Swizzle); \ f(Swizzle2D); \ f(Resize); \ diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index 2ceedfddc40..a847a9d5f99 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -25,6 +25,7 @@ #include "multidevice/allocation_utils.h" #include "multidevice/communication.h" #include "multidevice/cuda_p2p.h" +#include "multidevice/dispatch_combine.h" #include "multidevice/execution_utils.h" #include "multidevice/symmetric_tensor.h" #include "multidevice/utils.h" @@ -386,6 +387,68 @@ void HostIrEvaluator::handle(P2PCommunication* communication) { } } +void HostIrEvaluator::handle(MoEDispatch* dispatch) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + auto x = getKnownConcreteValue(dispatch->inX()).as(); + auto topk_idx = + getKnownConcreteValue(dispatch->inTopkIdx()).as(); + auto topk_weights = + getKnownConcreteValue(dispatch->inTopkWeights()).as(); + auto is_token_in_rank = + getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); + + auto result = dispatchWithCudaBackend( + x, + topk_idx, + topk_weights, + is_token_in_rank, + dispatch->numExperts(), + communicator_, + dispatch->backend()); + + expr_evaluator_.bind(dispatch->outX(), result.recv_x); + expr_evaluator_.bind(dispatch->outTopkIdx(), result.recv_topk_idx); + expr_evaluator_.bind(dispatch->outTopkWeights(), result.recv_topk_weights); + expr_evaluator_.bind(dispatch->outSrcIdx(), result.recv_src_idx); + expr_evaluator_.bind(dispatch->outSrcRank(), result.recv_src_rank); + expr_evaluator_.bind(dispatch->outTokensToRank(), result.n_tokens_to_rank); + expr_evaluator_.bind( + dispatch->outTokensFromRank(), result.n_tokens_from_rank); +} + +void HostIrEvaluator::handle(MoECombine* combine) { + NVF_ERROR( + communicator_ != nullptr && communicator_->is_available(), + "A valid communicator must be provided"); + + auto x = getKnownConcreteValue(combine->inX()).as(); + auto topk_weights = + getKnownConcreteValue(combine->inTopkWeights()).as(); + auto src_idx = getKnownConcreteValue(combine->inSrcIdx()).as(); + auto src_rank = getKnownConcreteValue(combine->inSrcRank()).as(); + auto n_tokens_to_rank = + getKnownConcreteValue(combine->inTokensToRank()).as(); + auto n_tokens_from_rank = + getKnownConcreteValue(combine->inTokensFromRank()).as(); + + auto result = combineWithCudaBackend( + x, + topk_weights, + src_idx, + src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + communicator_, + combine->backend()); + + expr_evaluator_.bind(combine->outX(), result.combined_x); + expr_evaluator_.bind( + combine->outTopkWeights(), result.combined_topk_weights); +} + void HostIrEvaluator::handle(Wait* wait) { Expr* expr = wait->communication(); auto* p2p_comm = dynamic_cast(expr); diff --git a/csrc/host_ir/evaluator.h b/csrc/host_ir/evaluator.h index 22833156cab..c1b0a70ef78 100644 --- a/csrc/host_ir/evaluator.h +++ b/csrc/host_ir/evaluator.h @@ -98,6 +98,8 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch { void handle(LaunchKernel*) override; void handle(Communication*) override; void handle(P2PCommunication*) override; + void handle(MoEDispatch*) override; + void handle(MoECombine*) override; void handle(Wait*) override; void handle(kir::ForLoop*) override; void handle(hir::ForLoop*) override; diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 06b4ffa426c..febbd519d10 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -321,6 +321,167 @@ std::string P2PCommunication::toString(int indent_size) const { return toInlineString(indent_size) + "\n"; } +MoEDispatch::MoEDispatch( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_idx, + TensorView* out_topk_weights, + TensorView* out_src_idx, + TensorView* out_src_rank, + TensorView* out_n_tokens_to_rank, + TensorView* out_n_tokens_from_rank, + TensorView* in_x, + TensorView* in_topk_idx, + TensorView* in_topk_weights, + TensorView* in_is_token_in_rank, + int64_t num_experts, + CommunicatorBackend backend) + : Expr(passkey) { + addInput(in_x); + addInput(in_topk_idx); + addInput(in_topk_weights); + addInput(in_is_token_in_rank); + addOutput(out_x); + addOutput(out_topk_idx); + addOutput(out_topk_weights); + addOutput(out_src_idx); + addOutput(out_src_rank); + addOutput(out_n_tokens_to_rank); + addOutput(out_n_tokens_from_rank); + addDataAttribute(num_experts); + addDataAttribute(backend); + validate(); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MoEDispatch) + +std::string MoEDispatch::toInlineString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "Dispatch " << name() << " (" + << "num_experts=" << numExperts() << ", " + << "backend=" << backend() << ", " + << "in=" << inX() << ", " + << "topk_idx=" << inTopkIdx() << ", " + << "topk_weights=" << inTopkWeights() << ", " + << "is_token_in_rank=" << inIsTokenInRank() << ", " + << "out=" << outX() << ")"; + return ss.str(); +} + +std::string MoEDispatch::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + +void MoEDispatch::validate() { + NVF_CHECK(numExperts() > 0, "num_experts must be positive."); + NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); + NVF_CHECK(inTopkIdx()->isA(), "topk_idx must be a TensorView."); + NVF_CHECK( + inTopkIdx()->getDataType().has_value() && + isIntegralType(*inTopkIdx()->getDataType()), + "topk_idx must be integral."); + NVF_CHECK( + inTopkWeights()->getDataType().has_value() && + isFloatingPointType(*inTopkWeights()->getDataType()), + "topk_weights must be floating point."); + NVF_CHECK( + inIsTokenInRank()->getDataType() == DataType::Bool, + "is_token_in_rank must be Bool."); + NVF_CHECK( + outTopkIdx()->getDataType().has_value() && + isIntegralType(*outTopkIdx()->getDataType()), + "out_topk_idx must be integral."); + NVF_CHECK( + outTopkWeights()->getDataType().has_value() && + isFloatingPointType(*outTopkWeights()->getDataType()), + "out_topk_weights must be floating point."); + NVF_CHECK( + outSrcIdx()->getDataType().has_value() && + isIntegralType(*outSrcIdx()->getDataType()), + "out_src_idx must be integral."); + NVF_CHECK( + outSrcRank()->getDataType().has_value() && + isIntegralType(*outSrcRank()->getDataType()), + "out_src_rank must be integral."); + NVF_CHECK( + outTokensToRank()->getDataType().has_value() && + isIntegralType(*outTokensToRank()->getDataType()), + "out_n_tokens_to_rank must be integral."); + NVF_CHECK( + outTokensFromRank()->getDataType().has_value() && + isIntegralType(*outTokensFromRank()->getDataType()), + "out_n_tokens_from_rank must be integral."); +} + +MoECombine::MoECombine( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_weights, + TensorView* in_x, + TensorView* in_topk_weights, + TensorView* in_src_idx, + TensorView* in_src_rank, + TensorView* in_n_tokens_to_rank, + TensorView* in_n_tokens_from_rank, + CommunicatorBackend backend) + : Expr(passkey) { + addInput(in_x); + addInput(in_topk_weights); + addInput(in_src_idx); + addInput(in_src_rank); + addInput(in_n_tokens_to_rank); + addInput(in_n_tokens_from_rank); + addOutput(out_x); + addOutput(out_topk_weights); + addDataAttribute(backend); + validate(); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(MoECombine) + +std::string MoECombine::toInlineString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "Combine " << name() << " (" + << "backend=" << backend() << ", " + << "in=" << inX() << ", " + << "src_idx=" << inSrcIdx() << ", " + << "src_rank=" << inSrcRank() << ", " + << "out=" << outX() << ")"; + return ss.str(); +} + +std::string MoECombine::toString(int indent_size) const { + return toInlineString(indent_size) + "\n"; +} + +void MoECombine::validate() { + NVF_CHECK(inX()->isA(), "in_x must be a TensorView."); + NVF_CHECK( + inTopkWeights()->getDataType().has_value() && + isFloatingPointType(*inTopkWeights()->getDataType()), + "in_topk_weights must be floating point."); + NVF_CHECK( + inSrcIdx()->getDataType().has_value() && + isIntegralType(*inSrcIdx()->getDataType()), + "in_src_idx must be integral."); + NVF_CHECK( + inSrcRank()->getDataType().has_value() && + isIntegralType(*inSrcRank()->getDataType()), + "in_src_rank must be integral."); + NVF_CHECK( + inTokensToRank()->getDataType().has_value() && + isIntegralType(*inTokensToRank()->getDataType()), + "in_n_tokens_to_rank must be integral."); + NVF_CHECK( + inTokensFromRank()->getDataType().has_value() && + isIntegralType(*inTokensFromRank()->getDataType()), + "in_n_tokens_from_rank must be integral."); + NVF_CHECK( + outTopkWeights()->getDataType().has_value() && + isFloatingPointType(*outTopkWeights()->getDataType()), + "out_topk_weights must be floating point."); +} + namespace { c10::intrusive_ptr postBroadcast( Communication* communication, diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 1a7f1a1cc4c..9c880110b5e 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -174,6 +174,168 @@ class P2PCommunication : public Expr { } }; +// Dispatch represents intra-node MoE token dispatch. It shuffles tokens from +// the local rank to destination ranks based on `is_token_in_rank`. +class MoEDispatch : public Expr { + public: + using Expr::Expr; + + MoEDispatch( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_idx, + TensorView* out_topk_weights, + TensorView* out_src_idx, + TensorView* out_src_rank, + TensorView* out_n_tokens_to_rank, + TensorView* out_n_tokens_from_rank, + TensorView* in_x, + TensorView* in_topk_idx, + TensorView* in_topk_weights, + TensorView* in_is_token_in_rank, + int64_t num_experts, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + MoEDispatch(const MoEDispatch& other) = delete; + MoEDispatch& operator=(const MoEDispatch& other) = delete; + MoEDispatch(MoEDispatch&& other) = delete; + MoEDispatch& operator=(MoEDispatch&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "MoEDispatch"; + } + + TensorView* outX() const { + return output(0)->as(); + } + + TensorView* outTopkIdx() const { + return output(1)->as(); + } + + TensorView* outTopkWeights() const { + return output(2)->as(); + } + + TensorView* outSrcIdx() const { + return output(3)->as(); + } + + TensorView* outSrcRank() const { + return output(4)->as(); + } + + TensorView* outTokensToRank() const { + return output(5)->as(); + } + + TensorView* outTokensFromRank() const { + return output(6)->as(); + } + + TensorView* inX() const { + return input(0)->as(); + } + + TensorView* inTopkIdx() const { + return input(1)->as(); + } + + TensorView* inTopkWeights() const { + return input(2)->as(); + } + + TensorView* inIsTokenInRank() const { + return input(3)->as(); + } + + int64_t numExperts() const { + return attribute(0); + } + + CommunicatorBackend backend() const { + return attribute(1); + } + + private: + void validate(); +}; + +// Combine represents intra-node MoE token combine. It shuffles tokens back to +// their source ranks using `src_rank` and `src_idx`. +class MoECombine : public Expr { + public: + using Expr::Expr; + + MoECombine( + IrBuilderPasskey passkey, + TensorView* out_x, + TensorView* out_topk_weights, + TensorView* in_x, + TensorView* in_topk_weights, + TensorView* in_src_idx, + TensorView* in_src_rank, + TensorView* in_n_tokens_to_rank, + TensorView* in_n_tokens_from_rank, + CommunicatorBackend backend = CommunicatorBackend::kNccl); + + MoECombine(const MoECombine& other) = delete; + MoECombine& operator=(const MoECombine& other) = delete; + MoECombine(MoECombine&& other) = delete; + MoECombine& operator=(MoECombine&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "MoECombine"; + } + + TensorView* outX() const { + return output(0)->as(); + } + + TensorView* outTopkWeights() const { + return output(1)->as(); + } + + TensorView* inX() const { + return input(0)->as(); + } + + TensorView* inTopkWeights() const { + return input(1)->as(); + } + + TensorView* inSrcIdx() const { + return input(2)->as(); + } + + TensorView* inSrcRank() const { + return input(3)->as(); + } + + TensorView* inTokensToRank() const { + return input(4)->as(); + } + + TensorView* inTokensFromRank() const { + return input(5)->as(); + } + + CommunicatorBackend backend() const { + return attribute(0); + } + + private: + void validate(); +}; + // The method "post" triggers the execution of the communication. This call is // non-blocking. The communication can be posted multiple times. // It is assumed that the current device_index (given by diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp new file mode 100644 index 00000000000..7ac888c539a --- /dev/null +++ b/csrc/multidevice/dispatch_combine.cpp @@ -0,0 +1,267 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include "multidevice/dispatch_combine.h" + +#include +#include + +#include + +#include "multidevice/communicator.h" +#include "utils.h" + +namespace nvfuser { +namespace { + +CommunicatorBackend getBackendForDispatch(CommunicatorBackend backend) { + if (backend == CommunicatorBackend::kCuda) { + return CommunicatorBackend::kNccl; + } + return backend; +} + +std::vector toSplitSizes(const at::Tensor& sizes_tensor) { + auto cpu_sizes = sizes_tensor.to(at::kCPU); + auto* ptr = cpu_sizes.data_ptr(); + return std::vector(ptr, ptr + cpu_sizes.numel()); +} + +int64_t sumSplitSizes(const std::vector& splits) { + int64_t total = 0; + for (auto value : splits) { + total += value; + } + return total; +} + +at::Tensor flattenTopk(const at::Tensor& topk, int64_t num_tokens) { + if (topk.numel() == num_tokens) { + return topk.reshape({num_tokens}); + } + if (topk.dim() == 2 && topk.size(0) == num_tokens && + topk.size(1) == 1) { + return topk.reshape({num_tokens}); + } + NVF_CHECK( + false, + "Only topk=1 supported. topk_idx/weights must be shape [T] or [T, 1], got: ", + topk.sizes()); +} + +void ensureTopk1Assignment(const at::Tensor& is_token_in_rank) { + auto token_counts = is_token_in_rank.to(at::kLong).sum(1); + auto min_val = token_counts.min().item(); + auto max_val = token_counts.max().item(); + NVF_CHECK( + min_val == 1 && max_val == 1, + "Only topk=1 is supported. Each token must be assigned to exactly one rank."); +} + +} // namespace + +DispatchResult dispatchWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_idx, + const at::Tensor& topk_weights, + const at::Tensor& is_token_in_rank, + int64_t num_experts, + Communicator* communicator, + CommunicatorBackend backend) { + NVF_CHECK(communicator != nullptr, "Dispatch requires a valid communicator."); + NVF_CHECK(x.is_cuda(), "Dispatch input x must be on CUDA."); + NVF_CHECK(topk_idx.is_cuda(), "Dispatch topk_idx must be on CUDA."); + NVF_CHECK(topk_weights.is_cuda(), "Dispatch topk_weights must be on CUDA."); + NVF_CHECK( + is_token_in_rank.is_cuda(), + "Dispatch is_token_in_rank must be on CUDA."); + NVF_CHECK( + is_token_in_rank.dim() == 2, + "is_token_in_rank must be 2D [tokens, ranks], got: ", + is_token_in_rank.sizes()); + NVF_CHECK( + x.dim() == 2, "Dispatch expects x to be 2D [tokens, hidden]."); + + const int64_t num_tokens = x.size(0); + const int64_t hidden = x.size(1); + const int64_t world_size = communicator->size(); + const int64_t my_rank = communicator->deviceId(); + NVF_CHECK( + is_token_in_rank.size(1) == world_size, + "is_token_in_rank second dim must match world size."); + NVF_CHECK(num_experts % world_size == 0, "num_experts must be divisible."); + + c10::cuda::CUDAGuard device_guard(x.device()); + ensureTopk1Assignment(is_token_in_rank); + + auto topk_idx_flat = flattenTopk(topk_idx, num_tokens); + auto topk_weights_flat = flattenTopk(topk_weights, num_tokens); + + auto rank_for_token = + is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); + auto sorted = rank_for_token.sort(); + auto sorted_indices = std::get<1>(sorted); + + auto send_x = x.index_select(0, sorted_indices); + auto send_topk_idx = topk_idx_flat.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); + auto send_src_idx = sorted_indices.to(at::kLong); + auto send_src_rank = at::full( + {num_tokens}, + my_rank, + at::TensorOptions().dtype(at::kLong).device(x.device())); + send_src_rank = send_src_rank.index_select(0, sorted_indices); + + auto rank_for_token_cpu = rank_for_token.to(at::kCPU); + auto n_tokens_to_rank_cpu = + at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); + auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); + auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + + CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + communicator->isBackendAvailable(actual_backend), + "Backend not available for dispatch: ", + actual_backend); + auto* pg = communicator->getWorld(actual_backend); + NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + + std::vector one_split(world_size, 1); + if (auto work = pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)) { + work->wait(); + } + + auto input_splits = toSplitSizes(n_tokens_to_rank); + auto output_splits = toSplitSizes(n_tokens_from_rank); + auto total_recv = sumSplitSizes(output_splits); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); + auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); + auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); + + if (auto work = + pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)) { + work->wait(); + } + + const int64_t experts_per_rank = num_experts / world_size; + auto local_expert = recv_topk_idx - my_rank * experts_per_rank; + auto expert_sorted = local_expert.sort(); + auto expert_order = std::get<1>(expert_sorted); + recv_x = recv_x.index_select(0, expert_order); + recv_topk_idx = recv_topk_idx.index_select(0, expert_order); + recv_topk_weights = recv_topk_weights.index_select(0, expert_order); + recv_src_idx = recv_src_idx.index_select(0, expert_order); + recv_src_rank = recv_src_rank.index_select(0, expert_order); + + return DispatchResult{ + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank}; +} + +CombineResult combineWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_weights, + const at::Tensor& src_idx, + const at::Tensor& src_rank, + const at::Tensor& n_tokens_to_rank, + const at::Tensor& n_tokens_from_rank, + Communicator* communicator, + CommunicatorBackend backend) { + NVF_CHECK(communicator != nullptr, "Combine requires a valid communicator."); + NVF_CHECK(x.is_cuda(), "Combine input x must be on CUDA."); + NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); + NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); + NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); + NVF_CHECK(n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); + NVF_CHECK( + n_tokens_from_rank.is_cuda(), + "Combine n_tokens_from_rank must be CUDA."); + NVF_CHECK(x.dim() == 2, "Combine expects x to be 2D [tokens, hidden]."); + NVF_CHECK( + src_idx.dim() == 1 && src_rank.dim() == 1, + "src_idx and src_rank must be 1D."); + NVF_CHECK( + n_tokens_to_rank.numel() == communicator->size(), + "n_tokens_to_rank must match world size."); + NVF_CHECK( + n_tokens_from_rank.numel() == communicator->size(), + "n_tokens_from_rank must match world size."); + + c10::cuda::CUDAGuard device_guard(x.device()); + + auto sorted = src_rank.sort(); + auto sorted_indices = std::get<1>(sorted); + auto send_x = x.index_select(0, sorted_indices); + auto send_topk_weights = topk_weights.index_select(0, sorted_indices); + auto send_src_idx = src_idx.index_select(0, sorted_indices); + + auto input_splits = toSplitSizes(n_tokens_from_rank); + auto output_splits = toSplitSizes(n_tokens_to_rank); + auto total_recv = sumSplitSizes(output_splits); + auto hidden = x.size(1); + + CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + communicator->isBackendAvailable(actual_backend), + "Backend not available for combine: ", + actual_backend); + auto* pg = communicator->getWorld(actual_backend); + NVF_CHECK(pg != nullptr, "Combine backend is null."); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); + auto recv_src_idx = at::empty({total_recv}, src_idx.options()); + + if (auto work = + pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)) { + work->wait(); + } + if (auto work = pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)) { + work->wait(); + } + + auto combined_x = at::empty({total_recv, hidden}, x.options()); + combined_x.index_copy_(0, recv_src_idx, recv_x); + auto combined_topk_weights = + at::empty({total_recv}, topk_weights.options()); + combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); + + return CombineResult{combined_x, combined_topk_weights}; +} + +} // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h new file mode 100644 index 00000000000..0d8f75c9f6d --- /dev/null +++ b/csrc/multidevice/dispatch_combine.h @@ -0,0 +1,51 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include "multidevice/communicator.h" +#include "visibility.h" + +namespace nvfuser { + +struct DispatchResult { + at::Tensor recv_x; + at::Tensor recv_topk_idx; + at::Tensor recv_topk_weights; + at::Tensor recv_src_idx; + at::Tensor recv_src_rank; + at::Tensor n_tokens_to_rank; + at::Tensor n_tokens_from_rank; +}; + +struct CombineResult { + at::Tensor combined_x; + at::Tensor combined_topk_weights; +}; + +NVF_API DispatchResult dispatchWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_idx, + const at::Tensor& topk_weights, + const at::Tensor& is_token_in_rank, + int64_t num_experts, + Communicator* communicator, + CommunicatorBackend backend); + +NVF_API CombineResult combineWithCudaBackend( + const at::Tensor& x, + const at::Tensor& topk_weights, + const at::Tensor& src_idx, + const at::Tensor& src_rank, + const at::Tensor& n_tokens_to_rank, + const at::Tensor& n_tokens_from_rank, + Communicator* communicator, + CommunicatorBackend backend); + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp new file mode 100644 index 00000000000..be13743c8b8 --- /dev/null +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -0,0 +1,121 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include + +#include "fusion.h" +#include "host_ir/container.h" +#include "host_ir/evaluator.h" +#include "ir/all_nodes.h" +#include "multidevice/communication.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { +namespace hir { + +class DispatchCombineTest : public MultiDeviceTest {}; + +TEST_F(DispatchCombineTest, DispatchCombineTop1) { + if (!communicator_->is_available() || communicator_->size() < 2) { + GTEST_SKIP() << "This test needs at least 2 ranks."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + constexpr int64_t kNumExpertsPerRank = 2; + const int64_t num_experts = world_size * kNumExpertsPerRank; + constexpr int64_t kNumTokens = 8; + constexpr int64_t kHidden = 4; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* in_x = makeSymbolicTensor(2); + auto* in_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* in_topk_weights = makeSymbolicTensor(1); + auto* in_is_token_in_rank = makeSymbolicTensor(2, DataType::Bool); + + auto* recv_x = makeSymbolicTensor(2); + auto* recv_topk_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_topk_weights = makeSymbolicTensor(1); + auto* recv_src_idx = makeSymbolicTensor(1, DataType::Int); + auto* recv_src_rank = makeSymbolicTensor(1, DataType::Int); + auto* n_tokens_to_rank = makeSymbolicTensor(1, DataType::Int); + auto* n_tokens_from_rank = makeSymbolicTensor(1, DataType::Int); + + auto* dispatch = IrBuilder::create( + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + in_x, + in_topk_idx, + in_topk_weights, + in_is_token_in_rank, + num_experts, + CommunicatorBackend::kCuda); + + auto* combined_x = makeSymbolicTensor(2); + auto* combined_topk_weights = makeSymbolicTensor(1); + auto* combine = IrBuilder::create( + combined_x, + combined_topk_weights, + recv_x, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank, + CommunicatorBackend::kCuda); + + hic->pushBackTopLevelExprs(dispatch); + hic->pushBackTopLevelExprs(combine); + + hic->addInput(in_x); + hic->addInput(in_topk_idx); + hic->addInput(in_topk_weights); + hic->addInput(in_is_token_in_rank); + hic->addOutput(combined_x); + + HostIrEvaluator hie(std::move(hic), communicator_); + + auto float_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kFloat); + auto int_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kLong); + + auto x = at::arange(kNumTokens * kHidden, float_options) + .reshape({kNumTokens, kHidden}) + + static_cast(my_rank) * 1000.0; + auto topk_idx = + (at::arange(kNumTokens, int_options) + my_rank) % num_experts; + auto topk_weights = at::ones({kNumTokens}, float_options); + + auto token_rank = topk_idx.div(kNumExpertsPerRank, "trunc"); + auto rank_ids = at::arange(world_size, int_options); + auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); + + auto outputs = hie.runWithInput( + {{in_x, x}, + {in_topk_idx, topk_idx}, + {in_topk_weights, topk_weights}, + {in_is_token_in_rank, is_token_in_rank}}); + auto combined = outputs.back().as(); + + EXPECT_TRUE(at::allclose(combined, x)) + << "Dispatch/Combine mismatch on rank " << my_rank; +} + +} // namespace hir +} // namespace nvfuser From 66e7811afa48f0ce819a66fd3191a699842d4254 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 21 Jan 2026 05:27:05 -0800 Subject: [PATCH 02/42] add comments and cleanup --- csrc/host_ir/evaluator.cpp | 10 +- csrc/multidevice/communication.h | 16 +- csrc/multidevice/dispatch_combine.cpp | 152 +++++++++--------- csrc/multidevice/dispatch_combine.h | 97 +++++++++-- .../cpp/test_multidevice_dispatch_combine.cpp | 21 ++- 5 files changed, 186 insertions(+), 110 deletions(-) diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index a847a9d5f99..5f6bb83227d 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -393,14 +393,13 @@ void HostIrEvaluator::handle(MoEDispatch* dispatch) { "A valid communicator must be provided"); auto x = getKnownConcreteValue(dispatch->inX()).as(); - auto topk_idx = - getKnownConcreteValue(dispatch->inTopkIdx()).as(); + auto topk_idx = getKnownConcreteValue(dispatch->inTopkIdx()).as(); auto topk_weights = getKnownConcreteValue(dispatch->inTopkWeights()).as(); auto is_token_in_rank = getKnownConcreteValue(dispatch->inIsTokenInRank()).as(); - auto result = dispatchWithCudaBackend( + auto result = doMoEDispatch( x, topk_idx, topk_weights, @@ -434,7 +433,7 @@ void HostIrEvaluator::handle(MoECombine* combine) { auto n_tokens_from_rank = getKnownConcreteValue(combine->inTokensFromRank()).as(); - auto result = combineWithCudaBackend( + auto result = doMoECombine( x, topk_weights, src_idx, @@ -445,8 +444,7 @@ void HostIrEvaluator::handle(MoECombine* combine) { combine->backend()); expr_evaluator_.bind(combine->outX(), result.combined_x); - expr_evaluator_.bind( - combine->outTopkWeights(), result.combined_topk_weights); + expr_evaluator_.bind(combine->outTopkWeights(), result.combined_topk_weights); } void HostIrEvaluator::handle(Wait* wait) { diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 9c880110b5e..a3f806b6c64 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -175,7 +175,13 @@ class P2PCommunication : public Expr { }; // Dispatch represents intra-node MoE token dispatch. It shuffles tokens from -// the local rank to destination ranks based on `is_token_in_rank`. +// the local rank to destination ranks based on `in_is_token_in_rank`. +// +// Example shapes (topk=1): +// in_x: [T, H], in_topk_idx: [T] or [T, 1], in_topk_weights: [T] or [T, 1], +// in_is_token_in_rank: [T, R] (one-hot), num_experts = R * experts_per_rank. +// Outputs are recv-aligned tensors: out_x/out_topk_*/out_src_* with [T_recv, +// ...] and out_n_tokens_to_rank/out_n_tokens_from_rank with shape [R]. class MoEDispatch : public Expr { public: using Expr::Expr; @@ -266,7 +272,13 @@ class MoEDispatch : public Expr { }; // Combine represents intra-node MoE token combine. It shuffles tokens back to -// their source ranks using `src_rank` and `src_idx`. +// their source ranks using `in_src_rank` and `in_src_idx`. +// +// Example shapes (topk=1): +// in_x: [T_recv, H], in_topk_weights: [T_recv], in_src_idx: [T_recv], +// in_src_rank: [T_recv], in_n_tokens_to_rank: [R], in_n_tokens_from_rank: +// [R]. Outputs are source-aligned: out_x/out_topk_weights with shape [T_src, +// ...]. class MoECombine : public Expr { public: using Expr::Expr; diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 7ac888c539a..738e27765d9 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -1,6 +1,6 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ @@ -19,13 +19,6 @@ namespace nvfuser { namespace { -CommunicatorBackend getBackendForDispatch(CommunicatorBackend backend) { - if (backend == CommunicatorBackend::kCuda) { - return CommunicatorBackend::kNccl; - } - return backend; -} - std::vector toSplitSizes(const at::Tensor& sizes_tensor) { auto cpu_sizes = sizes_tensor.to(at::kCPU); auto* ptr = cpu_sizes.data_ptr(); @@ -40,32 +33,27 @@ int64_t sumSplitSizes(const std::vector& splits) { return total; } -at::Tensor flattenTopk(const at::Tensor& topk, int64_t num_tokens) { - if (topk.numel() == num_tokens) { - return topk.reshape({num_tokens}); - } - if (topk.dim() == 2 && topk.size(0) == num_tokens && - topk.size(1) == 1) { - return topk.reshape({num_tokens}); +void waitWork(const c10::intrusive_ptr& work) { + if (work) { + work->wait(); } - NVF_CHECK( - false, - "Only topk=1 supported. topk_idx/weights must be shape [T] or [T, 1], got: ", - topk.sizes()); } -void ensureTopk1Assignment(const at::Tensor& is_token_in_rank) { - auto token_counts = is_token_in_rank.to(at::kLong).sum(1); - auto min_val = token_counts.min().item(); - auto max_val = token_counts.max().item(); +at::Tensor flattenTopk(const at::Tensor& topk, int64_t num_tokens) { + const bool is_1d = topk.dim() == 1 && topk.size(0) == num_tokens; + const bool is_2d = + topk.dim() == 2 && topk.size(0) == num_tokens && topk.size(1) == 1; NVF_CHECK( - min_val == 1 && max_val == 1, - "Only topk=1 is supported. Each token must be assigned to exactly one rank."); + is_1d || is_2d, + "Only topk=1 supported. topk_idx/weights must be shape [T] or [T, 1], " + "got: ", + topk.sizes()); + return topk.reshape({num_tokens}); } } // namespace -DispatchResult dispatchWithCudaBackend( +DispatchResult doMoEDispatch( const at::Tensor& x, const at::Tensor& topk_idx, const at::Tensor& topk_weights, @@ -78,14 +66,12 @@ DispatchResult dispatchWithCudaBackend( NVF_CHECK(topk_idx.is_cuda(), "Dispatch topk_idx must be on CUDA."); NVF_CHECK(topk_weights.is_cuda(), "Dispatch topk_weights must be on CUDA."); NVF_CHECK( - is_token_in_rank.is_cuda(), - "Dispatch is_token_in_rank must be on CUDA."); + is_token_in_rank.is_cuda(), "Dispatch is_token_in_rank must be on CUDA."); NVF_CHECK( is_token_in_rank.dim() == 2, "is_token_in_rank must be 2D [tokens, ranks], got: ", is_token_in_rank.sizes()); - NVF_CHECK( - x.dim() == 2, "Dispatch expects x to be 2D [tokens, hidden]."); + NVF_CHECK(x.dim() == 2, "Dispatch expects x to be 2D [tokens, hidden]."); const int64_t num_tokens = x.size(0); const int64_t hidden = x.size(1); @@ -97,33 +83,49 @@ DispatchResult dispatchWithCudaBackend( NVF_CHECK(num_experts % world_size == 0, "num_experts must be divisible."); c10::cuda::CUDAGuard device_guard(x.device()); - ensureTopk1Assignment(is_token_in_rank); + NVF_CHECK( + [&]() { + auto token_counts = is_token_in_rank.to(at::kLong).sum(1); + auto min_val = token_counts.min().item(); + auto max_val = token_counts.max().item(); + return min_val == 1 && max_val == 1; + }(), + "Only topk=1 is supported. Each token must be assigned to exactly one " + "rank."); auto topk_idx_flat = flattenTopk(topk_idx, num_tokens); auto topk_weights_flat = flattenTopk(topk_weights, num_tokens); - auto rank_for_token = - is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); + // Determine destination rank per token (topk=1). + auto rank_for_token = is_token_in_rank.to(at::kLong).argmax(1).to(at::kLong); + // Sort tokens by destination rank for contiguous alltoall slices. auto sorted = rank_for_token.sort(); auto sorted_indices = std::get<1>(sorted); + // Reorder payloads so alltoall can send contiguous chunks per rank. auto send_x = x.index_select(0, sorted_indices); auto send_topk_idx = topk_idx_flat.index_select(0, sorted_indices); auto send_topk_weights = topk_weights_flat.index_select(0, sorted_indices); + // Track original token indices and source rank for the combine step. auto send_src_idx = sorted_indices.to(at::kLong); + // All entries are identical, so no relayout is needed. auto send_src_rank = at::full( {num_tokens}, my_rank, at::TensorOptions().dtype(at::kLong).device(x.device())); - send_src_rank = send_src_rank.index_select(0, sorted_indices); + // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we + // sync/copy here. GPU-initiated comms can avoid this extra sync. auto rank_for_token_cpu = rank_for_token.to(at::kCPU); auto n_tokens_to_rank_cpu = at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); - CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + backend == CommunicatorBackend::kNccl, + "Only NCCL backend is supported for MoEDispatch."); + CommunicatorBackend actual_backend = backend; NVF_CHECK( communicator->isBackendAvailable(actual_backend), "Backend not available for dispatch: ", @@ -131,43 +133,36 @@ DispatchResult dispatchWithCudaBackend( auto* pg = communicator->getWorld(actual_backend); NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + // Exchange per-rank token counts to build split sizes for alltoall. std::vector one_split(world_size, 1); - if (auto work = pg->alltoall_base( - n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)) { - work->wait(); - } + waitWork(pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); + // Convert count tensors to CPU split vectors and size the receive buffers. auto input_splits = toSplitSizes(n_tokens_to_rank); auto output_splits = toSplitSizes(n_tokens_from_rank); auto total_recv = sumSplitSizes(output_splits); + // Allocate receive buffers for payloads and metadata. + // TODO: support preallocated buffers. auto recv_x = at::empty({total_recv, hidden}, x.options()); auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); - if (auto work = - pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_topk_idx, send_topk_idx, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_src_rank, send_src_rank, output_splits, input_splits)) { - work->wait(); - } - + // Alltoall exchange payloads with per-rank splits. + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)); + + // Locally reorder by expert id so each rank processes contiguous experts. const int64_t experts_per_rank = num_experts / world_size; auto local_expert = recv_topk_idx - my_rank * experts_per_rank; auto expert_sorted = local_expert.sort(); @@ -188,7 +183,7 @@ DispatchResult dispatchWithCudaBackend( n_tokens_from_rank}; } -CombineResult combineWithCudaBackend( +CombineResult doMoECombine( const at::Tensor& x, const at::Tensor& topk_weights, const at::Tensor& src_idx, @@ -202,10 +197,10 @@ CombineResult combineWithCudaBackend( NVF_CHECK(topk_weights.is_cuda(), "Combine topk_weights must be on CUDA."); NVF_CHECK(src_idx.is_cuda(), "Combine src_idx must be on CUDA."); NVF_CHECK(src_rank.is_cuda(), "Combine src_rank must be on CUDA."); - NVF_CHECK(n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); NVF_CHECK( - n_tokens_from_rank.is_cuda(), - "Combine n_tokens_from_rank must be CUDA."); + n_tokens_to_rank.is_cuda(), "Combine n_tokens_to_rank must be CUDA."); + NVF_CHECK( + n_tokens_from_rank.is_cuda(), "Combine n_tokens_from_rank must be CUDA."); NVF_CHECK(x.dim() == 2, "Combine expects x to be 2D [tokens, hidden]."); NVF_CHECK( src_idx.dim() == 1 && src_rank.dim() == 1, @@ -219,18 +214,23 @@ CombineResult combineWithCudaBackend( c10::cuda::CUDAGuard device_guard(x.device()); + // Sort by source rank so alltoall can send contiguous chunks per rank. auto sorted = src_rank.sort(); auto sorted_indices = std::get<1>(sorted); auto send_x = x.index_select(0, sorted_indices); auto send_topk_weights = topk_weights.index_select(0, sorted_indices); auto send_src_idx = src_idx.index_select(0, sorted_indices); + // Split sizes come from dispatch counts. auto input_splits = toSplitSizes(n_tokens_from_rank); auto output_splits = toSplitSizes(n_tokens_to_rank); auto total_recv = sumSplitSizes(output_splits); auto hidden = x.size(1); - CommunicatorBackend actual_backend = getBackendForDispatch(backend); + NVF_CHECK( + backend == CommunicatorBackend::kNccl, + "Only NCCL backend is supported for MoECombine."); + CommunicatorBackend actual_backend = backend; NVF_CHECK( communicator->isBackendAvailable(actual_backend), "Backend not available for combine: ", @@ -238,27 +238,21 @@ CombineResult combineWithCudaBackend( auto* pg = communicator->getWorld(actual_backend); NVF_CHECK(pg != nullptr, "Combine backend is null."); + // Allocate receive buffers and exchange payloads back to source ranks. auto recv_x = at::empty({total_recv, hidden}, x.options()); auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); auto recv_src_idx = at::empty({total_recv}, src_idx.options()); - if (auto work = - pg->alltoall_base(recv_x, send_x, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)) { - work->wait(); - } - if (auto work = pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)) { - work->wait(); - } + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); combined_x.index_copy_(0, recv_src_idx, recv_x); - auto combined_topk_weights = - at::empty({total_recv}, topk_weights.options()); + auto combined_topk_weights = at::empty({total_recv}, topk_weights.options()); combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); return CombineResult{combined_x, combined_topk_weights}; diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 0d8f75c9f6d..5714a45a818 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -1,6 +1,6 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ @@ -15,30 +15,95 @@ namespace nvfuser { struct DispatchResult { - at::Tensor recv_x; - at::Tensor recv_topk_idx; - at::Tensor recv_topk_weights; - at::Tensor recv_src_idx; - at::Tensor recv_src_rank; - at::Tensor n_tokens_to_rank; - at::Tensor n_tokens_from_rank; + at::Tensor recv_x; // Dispatched tokens received on this rank. + at::Tensor recv_topk_idx; // Expert ids aligned with recv_x. + at::Tensor recv_topk_weights; // Gating weights aligned with recv_x. + at::Tensor recv_src_idx; // Source token indices for combine. + at::Tensor recv_src_rank; // Source ranks for combine. + at::Tensor n_tokens_to_rank; // Tokens sent to each rank (this rank's view). + at::Tensor n_tokens_from_rank; // Tokens received from each rank. }; struct CombineResult { - at::Tensor combined_x; - at::Tensor combined_topk_weights; + at::Tensor combined_x; // Combined tokens back in original order. + at::Tensor combined_topk_weights; // Combined gating weights per token. }; -NVF_API DispatchResult dispatchWithCudaBackend( - const at::Tensor& x, - const at::Tensor& topk_idx, - const at::Tensor& topk_weights, - const at::Tensor& is_token_in_rank, +// Dispatch MoE tokens to the owning ranks. Only k=1 is supported for now. +// +// Args: +// x: Token embeddings on this rank, shape [T, H]. +// topk_idx: Global expert ids per token (topk=1), shape [T] or [T, 1]. +// topk_weights: Gating weights per token (topk=1), shape [T] or [T, 1]. +// is_token_in_rank: One-hot token-to-rank assignment, shape [T, R]. +// num_experts: Total experts across all ranks (must be divisible by R). +// communicator: Communicator for alltoall exchange. +// backend: Communication backend (only NCCL is supported for now). +// +// Returns: +// DispatchResult with recv_* tensors on this rank. +// +// Example: +// // world_size=2, num_experts=4, T=4, H=2, topk=1 +// // Experts are partitioned by rank: +// // rank0 owns experts {0, 1}, rank1 owns experts {2, 3} +// // Rank0 holds tokens 0,1 and rank1 holds tokens 2,3 in x: +// // rank0 x = [x0, x1], rank1 x = [x2, x3] +// // token->rank: [0, 1, 1, 1] (rank0 keeps x0, sends x1; rank1 keeps x2,x3) +// // is_token_in_rank = +// // [[1, 0], +// // [0, 1], +// // [0, 1], +// // [0, 1]] +// // topk_idx = [0, 2, 3, 2] (global expert ids) +// // After dispatch on rank0: +// // recv_x has token {0} +// // recv_topk_idx aligned with recv_x (e.g., [0]) +// // recv_src_idx tells original token positions (e.g., [0]) +// // After dispatch on rank1: +// // recv_x has tokens {1, 2, 3} +// // recv_topk_idx aligned with recv_x (e.g., [2, 3, 2]) +// // recv_src_idx tells original token positions (e.g., [1, 2, 3]) +// auto out = doMoEDispatch( +// x, topk_idx, topk_weights, is_token_in_rank, 4, comm, +// CommunicatorBackend::kNccl); +NVF_API DispatchResult doMoEDispatch( + const at::Tensor& x, // [T, H] + const at::Tensor& topk_idx, // [T] or [T, 1] + const at::Tensor& topk_weights, // [T] or [T, 1] + const at::Tensor& is_token_in_rank, // [T, R] int64_t num_experts, Communicator* communicator, CommunicatorBackend backend); -NVF_API CombineResult combineWithCudaBackend( +// Combine dispatched MoE results back to original token order. +// +// Args: +// x: Token embeddings after expert compute, shape [T_recv, H]. +// topk_weights: Gating weights aligned with x, shape [T_recv]. +// src_idx: Original token indices for each row of x, shape [T_recv]. +// src_rank: Original source rank per token, shape [T_recv]. +// n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. +// n_tokens_from_rank: Tokens received from each rank (from dispatch), shape +// [R]. communicator: Communicator for alltoall exchange. backend: +// Communication backend (only NCCL is supported for now). +// +// Returns: +// CombineResult with tokens restored to original order on this rank. +// +// Example: +// // Continuing the dispatch example (experts partitioned by rank): +// // rank0 owns experts {0, 1}, rank1 owns experts {2, 3} +// // After expert compute: +// // rank0 recv_x has token {0} with src_idx = [0], src_rank = [0] +// // rank1 recv_x has tokens {1, 2, 3} with src_idx = [1, 2, 3], +// // src_rank = [0, 1, 1] +// // n_tokens_to_rank and n_tokens_from_rank are [R] counts per rank. +// // Combine scatters results back to original token order per rank. +// auto combined = doMoECombine( +// x, topk_weights, src_idx, src_rank, n_tokens_to_rank, +// n_tokens_from_rank, comm, CommunicatorBackend::kNccl); +NVF_API CombineResult doMoECombine( const at::Tensor& x, const at::Tensor& topk_weights, const at::Tensor& src_idx, diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index be13743c8b8..0d84dbc03e0 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -1,6 +1,6 @@ // clang-format off /* - * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ @@ -32,7 +32,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { const int64_t my_rank = communicator_->deviceId(); constexpr int64_t kNumExpertsPerRank = 2; const int64_t num_experts = world_size * kNumExpertsPerRank; - constexpr int64_t kNumTokens = 8; + constexpr int64_t kNumTokens = 4; constexpr int64_t kHidden = 4; auto hic = std::make_unique(); @@ -64,7 +64,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { in_topk_weights, in_is_token_in_rank, num_experts, - CommunicatorBackend::kCuda); + CommunicatorBackend::kNccl); auto* combined_x = makeSymbolicTensor(2); auto* combined_topk_weights = makeSymbolicTensor(1); @@ -77,7 +77,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, - CommunicatorBackend::kCuda); + CommunicatorBackend::kNccl); hic->pushBackTopLevelExprs(dispatch); hic->pushBackTopLevelExprs(combine); @@ -98,14 +98,21 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { auto x = at::arange(kNumTokens * kHidden, float_options) .reshape({kNumTokens, kHidden}) + static_cast(my_rank) * 1000.0; - auto topk_idx = - (at::arange(kNumTokens, int_options) + my_rank) % num_experts; + auto topk_idx = at::zeros({kNumTokens}, int_options); auto topk_weights = at::ones({kNumTokens}, float_options); - auto token_rank = topk_idx.div(kNumExpertsPerRank, "trunc"); + // Asymmetric example: + // token->rank: [0, 1, 1, 1] so rank0 gets 1 token, rank1 gets 3 tokens. auto rank_ids = at::arange(world_size, int_options); + auto token_rank = at::tensor({0, 1, 1, 1}, int_options); auto is_token_in_rank = token_rank.unsqueeze(1).eq(rank_ids); + // Experts are partitioned by rank. Use rank0 expert0, rank1 experts0/1. + topk_idx.index_put_({0}, 0); + topk_idx.index_put_({1}, kNumExpertsPerRank); + topk_idx.index_put_({2}, kNumExpertsPerRank + 1); + topk_idx.index_put_({3}, kNumExpertsPerRank); + auto outputs = hie.runWithInput( {{in_x, x}, {in_topk_idx, topk_idx}, From dda9aa7c2be35ef1e604fb12b63d8a5278834657 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 22 Jan 2026 09:33:18 -0800 Subject: [PATCH 03/42] add kernel based a2av and cuda backend for d/c --- CMakeLists.txt | 2 + csrc/multidevice/alltoallv.cu | 37 ++ csrc/multidevice/cuda_p2p.cpp | 315 ++++++++++++++++++ csrc/multidevice/cuda_p2p.h | 29 ++ csrc/multidevice/dispatch_combine.cpp | 309 +++++++++++++---- csrc/multidevice/dispatch_combine.h | 4 +- tests/cpp/test_multidevice_alltoallv.cpp | 82 +++++ .../cpp/test_multidevice_dispatch_combine.cpp | 20 +- 8 files changed, 726 insertions(+), 72 deletions(-) create mode 100644 csrc/multidevice/alltoallv.cu create mode 100644 tests/cpp/test_multidevice_alltoallv.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b325b325d9c..ff76e741b4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1144,6 +1144,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_alltoallv.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_dispatch_combine.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_host_ir.cpp @@ -1393,6 +1394,7 @@ list(APPEND NVFUSER_RUNTIME_FILES ${NVFUSER_ROOT}/runtime/mbarrier.cu ${NVFUSER_ROOT}/runtime/memory.cu ${NVFUSER_ROOT}/runtime/multicast.cu + ${NVFUSER_SRCS_DIR}/multidevice/alltoallv.cu ${NVFUSER_ROOT}/runtime/random_numbers.cu ${NVFUSER_ROOT}/runtime/tensor_memory.cu ${NVFUSER_ROOT}/runtime/tensor.cu diff --git a/csrc/multidevice/alltoallv.cu b/csrc/multidevice/alltoallv.cu new file mode 100644 index 00000000000..9725794f838 --- /dev/null +++ b/csrc/multidevice/alltoallv.cu @@ -0,0 +1,37 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +extern "C" __global__ void alltoallv_kernel( + const unsigned char* send, + const unsigned long long* recv_ptrs, + const long long* send_offsets, + const long long* send_sizes, + const long long* recv_offsets, + long long world_size, + long long elem_size, + long long max_send_bytes) { + const long long peer = static_cast(blockIdx.y); + if (peer >= world_size) { + return; + } + const long long bytes = send_sizes[peer] * elem_size; + if (bytes == 0) { + return; + } + const long long idx = + static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (idx >= bytes) { + return; + } + const long long send_byte_offset = send_offsets[peer] * elem_size + idx; + const long long recv_byte_offset = recv_offsets[peer] * elem_size + idx; + auto* dst = reinterpret_cast( + static_cast(recv_ptrs[peer])); + dst[recv_byte_offset] = send[send_byte_offset]; +} + diff --git a/csrc/multidevice/cuda_p2p.cpp b/csrc/multidevice/cuda_p2p.cpp index 6ad709fa062..8804c1a7a79 100644 --- a/csrc/multidevice/cuda_p2p.cpp +++ b/csrc/multidevice/cuda_p2p.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include "multidevice/cuda_p2p.h" +#include "nvfuser_resources/alltoallv.h" #include "nvfuser_resources/multicast.h" #include "cuda_utils.h" @@ -34,6 +35,143 @@ P2pProtocol getP2pProtocol() { } namespace { +void launchAlltoallvKernel( + const void* send, + const uint64_t* recv_ptrs, + const int64_t* send_offsets, + const int64_t* send_sizes, + const int64_t* recv_offsets, + int64_t world_size, + int64_t elem_size, + int64_t max_send_bytes, + CUstream stream) { + static CUmodule module = nullptr; + static CUfunction kernel = nullptr; + + if (module == nullptr) { + nvrtcProgram prog; + NVFUSER_NVRTC_SAFE_CALL(nvrtcCreateProgram( + &prog, + nvfuser_resources::alltoallv_cu, + "alltoallv.cu", + 0, + nullptr, + nullptr)); + + int major = 0; + int minor = 0; + int device = 0; + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDevice(&device)); + cudaDeviceProp prop; + NVFUSER_CUDA_RT_SAFE_CALL(cudaGetDeviceProperties(&prop, device)); + major = prop.major; + minor = prop.minor; + + std::string arch_arg = "--gpu-architecture=compute_" + + std::to_string(major) + std::to_string(minor); + std::vector opts = {arch_arg.c_str(), "--std=c++17"}; + // NVRTC needs CUDA headers to compile alltoallv.cu. + opts.push_back("-I/usr/local/cuda/include"); + opts.push_back("-I/usr/local/cuda/include/cccl"); + + nvrtcResult res = nvrtcCompileProgram(prog, (int)opts.size(), opts.data()); + if (res != NVRTC_SUCCESS) { + size_t logSize; + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &logSize)); + std::vector log(logSize); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.data())); + NVF_ERROR(false, "Alltoallv kernel compilation failed:\n", log.data()); + } + + size_t ptxSize; + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTXSize(prog, &ptxSize)); + std::vector ptx(ptxSize); + NVFUSER_NVRTC_SAFE_CALL(nvrtcGetPTX(prog, ptx.data())); + NVFUSER_NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); + + CUresult load_result = cuModuleLoadData(&module, ptx.data()); + if (load_result != CUDA_SUCCESS) { + constexpr size_t kLogSize = 8192; + char error_log[kLogSize]; + char info_log[kLogSize]; + CUjit_option options[] = { + CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + CU_JIT_INFO_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + CU_JIT_LOG_VERBOSE}; + void* option_values[] = { + (void*)error_log, + (void*)kLogSize, + (void*)info_log, + (void*)kLogSize, + (void*)1}; + cuModuleLoadDataEx(&module, ptx.data(), 5, options, option_values); + NVF_ERROR( + false, + "Alltoallv kernel module load failed with error: ", + load_result, + "\nInfo Log:\n", + info_log, + "\nError Log:\n", + error_log); + } + + NVFUSER_CUDA_SAFE_CALL( + cuModuleGetFunction(&kernel, module, "alltoallv_kernel")); + } + + if (max_send_bytes == 0) { + return; + } + + constexpr int kThreads = 256; + const int64_t blocks_x = (max_send_bytes + kThreads - 1) / kThreads; + void* args_kernel[] = { + const_cast(static_cast(&send)), + const_cast(static_cast(&recv_ptrs)), + const_cast(static_cast(&send_offsets)), + const_cast(static_cast(&send_sizes)), + const_cast(static_cast(&recv_offsets)), + &world_size, + &elem_size, + &max_send_bytes}; + NVFUSER_CUDA_SAFE_CALL(cuLaunchKernel( + kernel, + blocks_x, + static_cast(world_size), + 1, + kThreads, + 1, + 1, + 0, + stream, + args_kernel, + nullptr)); +} + +std::vector serializeInt64Vector(const std::vector& values) { + std::vector bytes(values.size() * sizeof(int64_t)); + std::memcpy(bytes.data(), values.data(), bytes.size()); + return bytes; +} + +std::vector deserializeInt64Vector(const std::vector& bytes) { + NVF_CHECK( + bytes.size() % sizeof(int64_t) == 0, "Invalid int64 byte buffer size."); + const size_t count = bytes.size() / sizeof(int64_t); + std::vector values(count); + std::memcpy(values.data(), bytes.data(), bytes.size()); + return values; +} + +std::string alltoallvCountsKey(const std::string& tag, int64_t rank) { + return "nvfuser_alltoallv_counts_" + tag + "_" + std::to_string(rank); +} + +std::string alltoallvBarrierKey(const std::string& tag, int64_t rank) { + return "nvfuser_alltoallv_barrier_" + tag + "_" + std::to_string(rank); +} void launchMulticastKernel( void* dst, @@ -710,4 +848,181 @@ void waitWithCudaBackend( } } +AlltoallvMetadata prepareAlltoallvMetadata( + const at::Tensor& send_counts, + const std::string& tag) { + Communicator& comm = Communicator::getInstance(); + const int64_t world_size = comm.size(); + const int64_t my_rank = comm.deviceId(); + NVF_CHECK( + send_counts.is_cuda(), "alltoallv send_counts must be CUDA tensor."); + NVF_CHECK( + send_counts.dim() == 1 && send_counts.numel() == world_size, + "alltoallv send_counts must be 1D [R]."); + + auto store = comm.getTcpStore(); + auto send_counts_cpu = send_counts.to(at::kCPU); + auto* send_ptr = send_counts_cpu.data_ptr(); + std::vector send_counts_vec(send_ptr, send_ptr + world_size); + + store->set( + alltoallvCountsKey(tag, my_rank), serializeInt64Vector(send_counts_vec)); + + std::vector> counts_matrix(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + auto bytes = store->get(alltoallvCountsKey(tag, rank)); + counts_matrix[rank] = deserializeInt64Vector(bytes); + NVF_CHECK( + (int64_t)counts_matrix[rank].size() == world_size, + "Invalid alltoallv counts size."); + } + comm.barrier(); + for (int64_t rank = 0; rank < world_size; ++rank) { + store->deleteKey(alltoallvCountsKey(tag, rank)); + } + + std::vector recv_counts_vec(world_size, 0); + for (int64_t sender = 0; sender < world_size; ++sender) { + recv_counts_vec[sender] = counts_matrix[sender][my_rank]; + } + + std::vector send_offsets_vec(world_size, 0); + int64_t prefix = 0; + for (int64_t rank = 0; rank < world_size; ++rank) { + send_offsets_vec[rank] = prefix; + prefix += send_counts_vec[rank]; + } + + std::vector recv_offsets_vec(world_size, 0); + for (int64_t peer = 0; peer < world_size; ++peer) { + int64_t offset = 0; + for (int64_t sender = 0; sender < my_rank; ++sender) { + offset += counts_matrix[sender][peer]; + } + recv_offsets_vec[peer] = offset; + } + + int64_t total_recv = 0; + for (auto value : recv_counts_vec) { + total_recv += value; + } + + int64_t max_recv = 0; + int64_t max_send_total = 0; + for (int64_t rank = 0; rank < world_size; ++rank) { + int64_t total = 0; + for (int64_t sender = 0; sender < world_size; ++sender) { + total += counts_matrix[sender][rank]; + } + if (total > max_recv) { + max_recv = total; + } + } + + for (int64_t rank = 0; rank < world_size; ++rank) { + int64_t total = 0; + for (int64_t dest = 0; dest < world_size; ++dest) { + total += counts_matrix[rank][dest]; + } + if (total > max_send_total) { + max_send_total = total; + } + } + + int64_t max_send = 0; + for (auto value : send_counts_vec) { + if (value > max_send) { + max_send = value; + } + } + + auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); + auto send_offsets_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + send_offsets_cpu.data_ptr(), + send_offsets_vec.data(), + world_size * sizeof(int64_t)); + auto recv_offsets_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + recv_offsets_cpu.data_ptr(), + recv_offsets_vec.data(), + world_size * sizeof(int64_t)); + auto recv_counts_cpu = at::empty({world_size}, cpu_options); + std::memcpy( + recv_counts_cpu.data_ptr(), + recv_counts_vec.data(), + world_size * sizeof(int64_t)); + + AlltoallvMetadata metadata; + metadata.send_counts = send_counts; + metadata.recv_counts = recv_counts_cpu.to(send_counts.device()); + metadata.send_offsets = send_offsets_cpu.to(send_counts.device()); + metadata.recv_offsets = recv_offsets_cpu.to(send_counts.device()); + metadata.total_recv = total_recv; + metadata.max_recv = max_recv; + metadata.max_send_total = max_send_total; + metadata.max_send_bytes = max_send; + metadata.world_size = world_size; + return metadata; +} + +void alltoallvWithCudaBackend( + const at::Tensor& send, + const at::Tensor& recv, + const AlltoallvMetadata& metadata, + const std::vector& recv_ptrs, + CUstream stream) { + NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA."); + NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA."); + NVF_CHECK( + (int64_t)recv_ptrs.size() == metadata.world_size, + "recv_ptrs size must match world size."); + + auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU); + auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options); + auto* ptrs = recv_ptrs_cpu.data_ptr(); + for (int64_t rank = 0; rank < metadata.world_size; ++rank) { + ptrs[rank] = + static_cast(reinterpret_cast(recv_ptrs[rank])); + } + auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device()); + + const int64_t elem_stride = + metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1; + NVF_CHECK( + metadata.max_send_total == 0 || + send.numel() % metadata.max_send_total == 0, + "alltoallv send numel must be divisible by max_send_total."); + NVF_CHECK( + metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, + "alltoallv recv numel must be divisible by max_recv."); + + auto send_offsets = metadata.send_offsets; + auto send_counts = metadata.send_counts; + auto recv_offsets = metadata.recv_offsets; + int64_t max_send_bytes = metadata.max_send_bytes; + if (elem_stride > 1) { + send_offsets = metadata.send_offsets * elem_stride; + send_counts = metadata.send_counts * elem_stride; + recv_offsets = metadata.recv_offsets * elem_stride; + max_send_bytes = metadata.max_send_bytes * elem_stride; + } + + launchAlltoallvKernel( + send.data_ptr(), + reinterpret_cast(recv_ptrs_cuda.data_ptr()), + send_offsets.data_ptr(), + send_counts.data_ptr(), + recv_offsets.data_ptr(), + metadata.world_size, + send.element_size(), + max_send_bytes * send.element_size(), + stream); +} + +void alltoallvBarrier(const std::string& tag) { + Communicator& comm = Communicator::getInstance(); + comm.barrier(); +} + } // namespace nvfuser diff --git a/csrc/multidevice/cuda_p2p.h b/csrc/multidevice/cuda_p2p.h index 4947e4e6ee1..e9fd5828597 100644 --- a/csrc/multidevice/cuda_p2p.h +++ b/csrc/multidevice/cuda_p2p.h @@ -9,6 +9,10 @@ #include +#include +#include +#include + #include "multidevice/ipc_handle.h" namespace nvfuser { @@ -43,4 +47,29 @@ void waitWithCudaBackend( CUstream stream, int64_t root); +struct AlltoallvMetadata { + at::Tensor send_counts; // CUDA [R] + at::Tensor recv_counts; // CUDA [R] + at::Tensor send_offsets; // CUDA [R] + at::Tensor recv_offsets; // CUDA [R] + int64_t total_recv = 0; + int64_t max_recv = 0; + int64_t max_send_total = 0; + int64_t max_send_bytes = 0; + int64_t world_size = 0; +}; + +AlltoallvMetadata prepareAlltoallvMetadata( + const at::Tensor& send_counts, + const std::string& tag); + +void alltoallvWithCudaBackend( + const at::Tensor& send, + const at::Tensor& recv, + const AlltoallvMetadata& metadata, + const std::vector& recv_ptrs, + CUstream stream); + +void alltoallvBarrier(const std::string& tag); + } // namespace nvfuser diff --git a/csrc/multidevice/dispatch_combine.cpp b/csrc/multidevice/dispatch_combine.cpp index 738e27765d9..cbad812aa06 100644 --- a/csrc/multidevice/dispatch_combine.cpp +++ b/csrc/multidevice/dispatch_combine.cpp @@ -11,9 +11,12 @@ #include #include +#include #include #include "multidevice/communicator.h" +#include "multidevice/cuda_p2p.h" +#include "multidevice/symmetric_tensor.h" #include "utils.h" namespace nvfuser { @@ -114,53 +117,160 @@ DispatchResult doMoEDispatch( my_rank, at::TensorOptions().dtype(at::kLong).device(x.device())); - // For CPU-initiated comms (e.g. NCCL), split metadata must live on CPU, so we - // sync/copy here. GPU-initiated comms can avoid this extra sync. + // Split metadata is exchanged via CPU (TCPStore), so we sync/copy here. auto rank_for_token_cpu = rank_for_token.to(at::kCPU); auto n_tokens_to_rank_cpu = at::bincount(rank_for_token_cpu, {}, world_size).to(at::kLong); auto n_tokens_to_rank = n_tokens_to_rank_cpu.to(x.device()); - auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + if (backend == CommunicatorBackend::kNccl) { + NVF_CHECK( + communicator->isBackendAvailable(backend), + "Backend not available for dispatch: ", + backend); + auto* pg = communicator->getWorld(backend); + NVF_CHECK(pg != nullptr, "Dispatch backend is null."); + + auto n_tokens_from_rank = at::empty_like(n_tokens_to_rank); + std::vector one_split(world_size, 1); + waitWork(pg->alltoall_base( + n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); + + auto input_splits = toSplitSizes(n_tokens_to_rank); + auto output_splits = toSplitSizes(n_tokens_from_rank); + auto total_recv = sumSplitSizes(output_splits); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); + auto recv_topk_weights = + at::empty({total_recv}, topk_weights_flat.options()); + auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); + auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); + + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_idx, send_topk_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_rank, send_src_rank, output_splits, input_splits)); + + const int64_t experts_per_rank = num_experts / world_size; + auto local_expert = recv_topk_idx - my_rank * experts_per_rank; + auto expert_sorted = local_expert.sort(); + auto expert_order = std::get<1>(expert_sorted); + recv_x = recv_x.index_select(0, expert_order); + recv_topk_idx = recv_topk_idx.index_select(0, expert_order); + recv_topk_weights = recv_topk_weights.index_select(0, expert_order); + recv_src_idx = recv_src_idx.index_select(0, expert_order); + recv_src_rank = recv_src_rank.index_select(0, expert_order); + + return DispatchResult{ + recv_x, + recv_topk_idx, + recv_topk_weights, + recv_src_idx, + recv_src_rank, + n_tokens_to_rank, + n_tokens_from_rank}; + } NVF_CHECK( - backend == CommunicatorBackend::kNccl, - "Only NCCL backend is supported for MoEDispatch."); - CommunicatorBackend actual_backend = backend; - NVF_CHECK( - communicator->isBackendAvailable(actual_backend), - "Backend not available for dispatch: ", - actual_backend); - auto* pg = communicator->getWorld(actual_backend); - NVF_CHECK(pg != nullptr, "Dispatch backend is null."); - - // Exchange per-rank token counts to build split sizes for alltoall. - std::vector one_split(world_size, 1); - waitWork(pg->alltoall_base( - n_tokens_from_rank, n_tokens_to_rank, one_split, one_split)); - - // Convert count tensors to CPU split vectors and size the receive buffers. - auto input_splits = toSplitSizes(n_tokens_to_rank); - auto output_splits = toSplitSizes(n_tokens_from_rank); - auto total_recv = sumSplitSizes(output_splits); - - // Allocate receive buffers for payloads and metadata. - // TODO: support preallocated buffers. - auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_idx = at::empty({total_recv}, topk_idx_flat.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights_flat.options()); - auto recv_src_idx = at::empty({total_recv}, send_src_idx.options()); - auto recv_src_rank = at::empty({total_recv}, send_src_rank.options()); - - // Alltoall exchange payloads with per-rank splits. - waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_idx, send_topk_idx, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_rank, send_src_rank, output_splits, input_splits)); + backend == CommunicatorBackend::kCuda, + "Only CUDA and NCCL backends are supported for MoEDispatch."); + + auto metadata = + prepareAlltoallvMetadata(n_tokens_to_rank, "moe_dispatch_counts"); + auto n_tokens_from_rank = metadata.recv_counts; + const int64_t total_recv = metadata.total_recv; + const int64_t max_recv = metadata.max_recv; + + // Allocate symmetric buffers for send/recv payloads. + auto send_x_sym = SymmetricTensor::allocate( + {metadata.max_send_total, hidden}, x.scalar_type(), x.device()); + send_x_sym.narrow(0, 0, num_tokens).copy_(send_x); + auto send_topk_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_idx_flat.scalar_type(), x.device()); + send_topk_idx_sym.narrow(0, 0, num_tokens).copy_(send_topk_idx); + auto send_topk_weights_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_weights_flat.scalar_type(), x.device()); + send_topk_weights_sym.narrow(0, 0, num_tokens).copy_(send_topk_weights); + auto send_src_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, send_src_idx.scalar_type(), x.device()); + send_src_idx_sym.narrow(0, 0, num_tokens).copy_(send_src_idx); + auto send_src_rank_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, send_src_rank.scalar_type(), x.device()); + send_src_rank_sym.narrow(0, 0, num_tokens).copy_(send_src_rank); + + auto recv_x_sym = SymmetricTensor::allocate( + {max_recv, hidden}, x.scalar_type(), x.device()); + auto recv_topk_idx_sym = SymmetricTensor::allocate( + {max_recv}, topk_idx_flat.scalar_type(), x.device()); + auto recv_topk_weights_sym = SymmetricTensor::allocate( + {max_recv}, topk_weights_flat.scalar_type(), x.device()); + auto recv_src_idx_sym = SymmetricTensor::allocate( + {max_recv}, send_src_idx.scalar_type(), x.device()); + auto recv_src_rank_sym = SymmetricTensor::allocate( + {max_recv}, send_src_rank.scalar_type(), x.device()); + + SymmetricTensor recv_x_handle(recv_x_sym); + SymmetricTensor recv_topk_idx_handle(recv_topk_idx_sym); + SymmetricTensor recv_topk_weights_handle(recv_topk_weights_sym); + SymmetricTensor recv_src_idx_handle(recv_src_idx_sym); + SymmetricTensor recv_src_rank_handle(recv_src_rank_sym); + recv_x_handle.setupRemoteHandles("moe_dispatch_recv_x"); + recv_topk_idx_handle.setupRemoteHandles("moe_dispatch_recv_topk_idx"); + recv_topk_weights_handle.setupRemoteHandles("moe_dispatch_recv_topk_weights"); + recv_src_idx_handle.setupRemoteHandles("moe_dispatch_recv_src_idx"); + recv_src_rank_handle.setupRemoteHandles("moe_dispatch_recv_src_rank"); + + std::vector recv_x_ptrs(world_size); + std::vector recv_topk_idx_ptrs(world_size); + std::vector recv_topk_weights_ptrs(world_size); + std::vector recv_src_idx_ptrs(world_size); + std::vector recv_src_rank_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_x_ptrs[rank] = recv_x_handle.remoteTensor(rank).data_ptr(); + recv_topk_idx_ptrs[rank] = + recv_topk_idx_handle.remoteTensor(rank).data_ptr(); + recv_topk_weights_ptrs[rank] = + recv_topk_weights_handle.remoteTensor(rank).data_ptr(); + recv_src_idx_ptrs[rank] = recv_src_idx_handle.remoteTensor(rank).data_ptr(); + recv_src_rank_ptrs[rank] = + recv_src_rank_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = + static_cast(at::cuda::getDefaultCUDAStream().stream()); + alltoallvWithCudaBackend( + send_x_sym, recv_x_sym, metadata, recv_x_ptrs, stream); + alltoallvWithCudaBackend( + send_topk_idx_sym, + recv_topk_idx_sym, + metadata, + recv_topk_idx_ptrs, + stream); + alltoallvWithCudaBackend( + send_topk_weights_sym, + recv_topk_weights_sym, + metadata, + recv_topk_weights_ptrs, + stream); + alltoallvWithCudaBackend( + send_src_idx_sym, recv_src_idx_sym, metadata, recv_src_idx_ptrs, stream); + alltoallvWithCudaBackend( + send_src_rank_sym, + recv_src_rank_sym, + metadata, + recv_src_rank_ptrs, + stream); + alltoallvBarrier("moe_dispatch_counts"); + auto recv_x = recv_x_sym.narrow(0, 0, total_recv); + auto recv_topk_idx = recv_topk_idx_sym.narrow(0, 0, total_recv); + auto recv_topk_weights = recv_topk_weights_sym.narrow(0, 0, total_recv); + auto recv_src_idx = recv_src_idx_sym.narrow(0, 0, total_recv); + auto recv_src_rank = recv_src_rank_sym.narrow(0, 0, total_recv); // Locally reorder by expert id so each rank processes contiguous experts. const int64_t experts_per_rank = num_experts / world_size; @@ -212,6 +322,7 @@ CombineResult doMoECombine( n_tokens_from_rank.numel() == communicator->size(), "n_tokens_from_rank must match world size."); + const int64_t world_size = communicator->size(); c10::cuda::CUDAGuard device_guard(x.device()); // Sort by source rank so alltoall can send contiguous chunks per rank. @@ -222,32 +333,100 @@ CombineResult doMoECombine( auto send_src_idx = src_idx.index_select(0, sorted_indices); // Split sizes come from dispatch counts. - auto input_splits = toSplitSizes(n_tokens_from_rank); - auto output_splits = toSplitSizes(n_tokens_to_rank); - auto total_recv = sumSplitSizes(output_splits); - auto hidden = x.size(1); + if (backend == CommunicatorBackend::kNccl) { + NVF_CHECK( + communicator->isBackendAvailable(backend), + "Backend not available for combine: ", + backend); + auto* pg = communicator->getWorld(backend); + NVF_CHECK(pg != nullptr, "Combine backend is null."); + + auto input_splits = toSplitSizes(n_tokens_from_rank); + auto output_splits = toSplitSizes(n_tokens_to_rank); + auto total_recv = sumSplitSizes(output_splits); + auto hidden = x.size(1); + + auto recv_x = at::empty({total_recv, hidden}, x.options()); + auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); + auto recv_src_idx = at::empty({total_recv}, src_idx.options()); + + waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_topk_weights, send_topk_weights, output_splits, input_splits)); + waitWork(pg->alltoall_base( + recv_src_idx, send_src_idx, output_splits, input_splits)); + + auto combined_x = at::empty({total_recv, hidden}, x.options()); + combined_x.index_copy_(0, recv_src_idx, recv_x); + auto combined_topk_weights = + at::empty({total_recv}, topk_weights.options()); + combined_topk_weights.index_copy_(0, recv_src_idx, recv_topk_weights); + + return CombineResult{combined_x, combined_topk_weights}; + } NVF_CHECK( - backend == CommunicatorBackend::kNccl, - "Only NCCL backend is supported for MoECombine."); - CommunicatorBackend actual_backend = backend; - NVF_CHECK( - communicator->isBackendAvailable(actual_backend), - "Backend not available for combine: ", - actual_backend); - auto* pg = communicator->getWorld(actual_backend); - NVF_CHECK(pg != nullptr, "Combine backend is null."); - - // Allocate receive buffers and exchange payloads back to source ranks. - auto recv_x = at::empty({total_recv, hidden}, x.options()); - auto recv_topk_weights = at::empty({total_recv}, topk_weights.options()); - auto recv_src_idx = at::empty({total_recv}, src_idx.options()); - - waitWork(pg->alltoall_base(recv_x, send_x, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_topk_weights, send_topk_weights, output_splits, input_splits)); - waitWork(pg->alltoall_base( - recv_src_idx, send_src_idx, output_splits, input_splits)); + backend == CommunicatorBackend::kCuda, + "Only CUDA and NCCL backends are supported for MoECombine."); + + auto metadata = + prepareAlltoallvMetadata(n_tokens_from_rank, "moe_combine_counts"); + const int64_t total_recv = metadata.total_recv; + const int64_t max_recv = metadata.max_recv; + auto hidden = x.size(1); + + // Allocate symmetric buffers for send/recv payloads. + auto send_x_sym = SymmetricTensor::allocate( + {metadata.max_send_total, hidden}, x.scalar_type(), x.device()); + send_x_sym.narrow(0, 0, x.size(0)).copy_(send_x); + auto send_topk_weights_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, topk_weights.scalar_type(), x.device()); + send_topk_weights_sym.narrow(0, 0, x.size(0)).copy_(send_topk_weights); + auto send_src_idx_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, src_idx.scalar_type(), x.device()); + send_src_idx_sym.narrow(0, 0, x.size(0)).copy_(send_src_idx); + + auto recv_x_sym = SymmetricTensor::allocate( + {max_recv, hidden}, x.scalar_type(), x.device()); + auto recv_topk_weights_sym = SymmetricTensor::allocate( + {max_recv}, topk_weights.scalar_type(), x.device()); + auto recv_src_idx_sym = + SymmetricTensor::allocate({max_recv}, src_idx.scalar_type(), x.device()); + + SymmetricTensor recv_x_handle(recv_x_sym); + SymmetricTensor recv_topk_weights_handle(recv_topk_weights_sym); + SymmetricTensor recv_src_idx_handle(recv_src_idx_sym); + recv_x_handle.setupRemoteHandles("moe_combine_recv_x"); + recv_topk_weights_handle.setupRemoteHandles("moe_combine_recv_topk_weights"); + recv_src_idx_handle.setupRemoteHandles("moe_combine_recv_src_idx"); + + std::vector recv_x_ptrs(world_size); + std::vector recv_topk_weights_ptrs(world_size); + std::vector recv_src_idx_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_x_ptrs[rank] = recv_x_handle.remoteTensor(rank).data_ptr(); + recv_topk_weights_ptrs[rank] = + recv_topk_weights_handle.remoteTensor(rank).data_ptr(); + recv_src_idx_ptrs[rank] = recv_src_idx_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = + static_cast(at::cuda::getDefaultCUDAStream().stream()); + alltoallvWithCudaBackend( + send_x_sym, recv_x_sym, metadata, recv_x_ptrs, stream); + alltoallvWithCudaBackend( + send_topk_weights_sym, + recv_topk_weights_sym, + metadata, + recv_topk_weights_ptrs, + stream); + alltoallvWithCudaBackend( + send_src_idx_sym, recv_src_idx_sym, metadata, recv_src_idx_ptrs, stream); + alltoallvBarrier("moe_combine_counts"); + + auto recv_x = recv_x_sym.narrow(0, 0, total_recv); + auto recv_topk_weights = recv_topk_weights_sym.narrow(0, 0, total_recv); + auto recv_src_idx = recv_src_idx_sym.narrow(0, 0, total_recv); // Scatter by original token index to restore local order. auto combined_x = at::empty({total_recv, hidden}, x.options()); diff --git a/csrc/multidevice/dispatch_combine.h b/csrc/multidevice/dispatch_combine.h index 5714a45a818..ceb0a2652b4 100644 --- a/csrc/multidevice/dispatch_combine.h +++ b/csrc/multidevice/dispatch_combine.h @@ -38,7 +38,7 @@ struct CombineResult { // is_token_in_rank: One-hot token-to-rank assignment, shape [T, R]. // num_experts: Total experts across all ranks (must be divisible by R). // communicator: Communicator for alltoall exchange. -// backend: Communication backend (only NCCL is supported for now). +// backend: Communication backend (CUDA or NCCL). // // Returns: // DispatchResult with recv_* tensors on this rank. @@ -86,7 +86,7 @@ NVF_API DispatchResult doMoEDispatch( // n_tokens_to_rank: Tokens sent to each rank (from dispatch), shape [R]. // n_tokens_from_rank: Tokens received from each rank (from dispatch), shape // [R]. communicator: Communicator for alltoall exchange. backend: -// Communication backend (only NCCL is supported for now). +// Communication backend (CUDA or NCCL). // // Returns: // CombineResult with tokens restored to original order on this rank. diff --git a/tests/cpp/test_multidevice_alltoallv.cpp b/tests/cpp/test_multidevice_alltoallv.cpp new file mode 100644 index 00000000000..02cb21b7892 --- /dev/null +++ b/tests/cpp/test_multidevice_alltoallv.cpp @@ -0,0 +1,82 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2026-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include + +#include "multidevice/cuda_p2p.h" +#include "multidevice/symmetric_tensor.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { +namespace hir { + +class AlltoallvCudaTest : public MultiDeviceTest {}; + +TEST_F(AlltoallvCudaTest, AlltoallvAsymmetric) { + if (!communicator_->is_available() || communicator_->size() < 2) { + GTEST_SKIP() << "This test needs at least 2 ranks."; + } + + const int64_t world_size = communicator_->size(); + const int64_t my_rank = communicator_->deviceId(); + + auto int_options = + at::TensorOptions().device(communicator_->device()).dtype(at::kLong); + + auto count_for = [](int64_t sender, int64_t dest) { + return (sender + dest) % 3 + 1; + }; + auto send_counts = at::empty({world_size}, int_options); + for (int64_t dest = 0; dest < world_size; ++dest) { + send_counts.index_put_({dest}, count_for(my_rank, dest)); + } + + auto metadata = prepareAlltoallvMetadata(send_counts, "test_alltoallv_counts"); + const int64_t max_recv = metadata.max_recv; + const int64_t total_send = send_counts.sum().item(); + auto send_sym = SymmetricTensor::allocate( + {metadata.max_send_total}, at::kLong, communicator_->device()); + send_sym.narrow(0, 0, total_send) + .copy_(at::arange(total_send, int_options) + my_rank * 1000); + + auto recv_sym = SymmetricTensor::allocate( + {max_recv}, at::kLong, communicator_->device()); + SymmetricTensor recv_handle(recv_sym); + recv_handle.setupRemoteHandles("test_alltoallv_recv"); + + std::vector recv_ptrs(world_size); + for (int64_t rank = 0; rank < world_size; ++rank) { + recv_ptrs[rank] = recv_handle.remoteTensor(rank).data_ptr(); + } + + auto stream = at::cuda::getDefaultCUDAStream().stream(); + alltoallvWithCudaBackend(send_sym, recv_sym, metadata, recv_ptrs, stream); + alltoallvBarrier("test_alltoallv_counts"); + + auto recv_view = recv_sym.narrow(0, 0, metadata.total_recv); + std::vector expected_vec; + expected_vec.reserve(static_cast(metadata.total_recv)); + for (int64_t sender = 0; sender < world_size; ++sender) { + int64_t offset = 0; + for (int64_t dest = 0; dest < my_rank; ++dest) { + offset += count_for(sender, dest); + } + const int64_t count = count_for(sender, my_rank); + for (int64_t i = 0; i < count; ++i) { + expected_vec.push_back(offset + i + sender * 1000); + } + } + auto expected = at::tensor(expected_vec, int_options); + EXPECT_TRUE(at::equal(recv_view, expected)) + << "Alltoallv mismatch on rank " << my_rank; +} + +} // namespace hir +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_dispatch_combine.cpp b/tests/cpp/test_multidevice_dispatch_combine.cpp index 0d84dbc03e0..1a28c6e18d5 100644 --- a/tests/cpp/test_multidevice_dispatch_combine.cpp +++ b/tests/cpp/test_multidevice_dispatch_combine.cpp @@ -21,15 +21,21 @@ namespace nvfuser { namespace hir { -class DispatchCombineTest : public MultiDeviceTest {}; +class DispatchCombineTest + : public MultiDeviceTest, + public ::testing::WithParamInterface {}; -TEST_F(DispatchCombineTest, DispatchCombineTop1) { +TEST_P(DispatchCombineTest, DispatchCombineTop1) { if (!communicator_->is_available() || communicator_->size() < 2) { GTEST_SKIP() << "This test needs at least 2 ranks."; } const int64_t world_size = communicator_->size(); const int64_t my_rank = communicator_->deviceId(); + const auto backend = GetParam(); + if (!communicator_->isBackendAvailable(CommunicatorBackend::kNccl)) { + GTEST_SKIP() << "Backend " << backend << " not available."; + } constexpr int64_t kNumExpertsPerRank = 2; const int64_t num_experts = world_size * kNumExpertsPerRank; constexpr int64_t kNumTokens = 4; @@ -64,7 +70,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { in_topk_weights, in_is_token_in_rank, num_experts, - CommunicatorBackend::kNccl); + backend); auto* combined_x = makeSymbolicTensor(2); auto* combined_topk_weights = makeSymbolicTensor(1); @@ -77,7 +83,7 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { recv_src_rank, n_tokens_to_rank, n_tokens_from_rank, - CommunicatorBackend::kNccl); + backend); hic->pushBackTopLevelExprs(dispatch); hic->pushBackTopLevelExprs(combine); @@ -119,10 +125,14 @@ TEST_F(DispatchCombineTest, DispatchCombineTop1) { {in_topk_weights, topk_weights}, {in_is_token_in_rank, is_token_in_rank}}); auto combined = outputs.back().as(); - EXPECT_TRUE(at::allclose(combined, x)) << "Dispatch/Combine mismatch on rank " << my_rank; } +INSTANTIATE_TEST_SUITE_P( + DispatchCombineBackends, + DispatchCombineTest, + ::testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kCuda)); + } // namespace hir } // namespace nvfuser From 7aa2de86034d62071317abc40b94d90fafe1f253 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 25 Feb 2026 13:35:54 +0200 Subject: [PATCH 04/42] unstable - add nixl backend --- csrc/multidevice/nixl.cpp | 447 ++++++++++++++++++++++++++++ csrc/multidevice/nixl.h | 139 +++++++++ tests/cpp/test_multidevice_nixl.cpp | 289 ++++++++++++++++++ 3 files changed, 875 insertions(+) create mode 100644 csrc/multidevice/nixl.cpp create mode 100644 csrc/multidevice/nixl.h create mode 100644 tests/cpp/test_multidevice_nixl.cpp diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp new file mode 100644 index 00000000000..3f71c267cfd --- /dev/null +++ b/csrc/multidevice/nixl.cpp @@ -0,0 +1,447 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include "multidevice/nixl.h" + +#include + +#ifdef USE_NIXL +#include +#endif + +namespace nvfuser { + +// =================================================================== +// NixlTransferHandle +// =================================================================== + +class NixlTransferHandleImpl { + public: +#ifdef USE_NIXL + nixl_xfer_req_t xfer_handle{}; + bool prepared = false; + bool posted = false; +#endif +}; + +NixlTransferHandle::NixlTransferHandle() = default; +NixlTransferHandle::~NixlTransferHandle() = default; +NixlTransferHandle::NixlTransferHandle(NixlTransferHandle&&) noexcept = + default; +NixlTransferHandle& NixlTransferHandle::operator=( + NixlTransferHandle&&) noexcept = default; + +bool NixlTransferHandle::isValid() const { + if (!impl_) { + return false; + } +#ifdef USE_NIXL + return impl_->prepared; +#else + return false; +#endif +} + +// =================================================================== +// Tensor validation and descriptor helpers +// =================================================================== + +namespace { + +void validateCudaTensors(const std::vector& tensors) { + NVF_ERROR(!tensors.empty(), "Tensor list must not be empty"); + for (const auto& t : tensors) { + NVF_ERROR(t.is_cuda(), "All tensors must be CUDA tensors"); + NVF_ERROR(t.is_contiguous(), "All tensors must be contiguous"); + } +} + +#ifdef USE_NIXL +nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { + nixl_reg_dlist_t dlist(VRAM, tensors.size()); + for (const auto& t : tensors) { + dlist.addDesc( + {reinterpret_cast(t.data_ptr()), + static_cast(t.numel()) * t.element_size(), + static_cast(t.device().index())}); + } + return dlist; +} + +nixl_xfer_dlist_t buildXferDlist(const std::vector& tensors) { + nixl_xfer_dlist_t dlist(VRAM, tensors.size()); + for (const auto& t : tensors) { + dlist.addDesc( + {reinterpret_cast(t.data_ptr()), + static_cast(t.numel()) * t.element_size(), + static_cast(t.device().index())}); + } + return dlist; +} + +nixl_xfer_op_t toNixlXferOp(NixlXferOp op) { + switch (op) { + case NixlXferOp::kRead: + return NIXL_XFER_READ; + case NixlXferOp::kWrite: + return NIXL_XFER_WRITE; + } + std::unreachable(); +} +#endif + +} // namespace + +// =================================================================== +// NixlBackend::Impl +// =================================================================== + +class NixlBackend::Impl { + public: + explicit Impl(Communicator& communicator); + ~Impl(); + + bool isAvailable() const { + return available_; + } + + void registerTensors(const std::vector& tensors); + void deregisterTensors(const std::vector& tensors); + void exchangeMetadata(); + + NixlTransferHandle prepareTransfer( + const std::vector& local_tensors, + const std::vector& remote_tensors, + int64_t remote_rank, + NixlXferOp op); + + void postTransfer(NixlTransferHandle& handle); + NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; + void waitTransfer(NixlTransferHandle& handle); + + private: +#ifdef USE_NIXL + std::unique_ptr agent_; +#endif + Communicator& communicator_; + bool available_ = false; + bool metadata_exchanged_ = false; +}; + +// ------------------------------------------------------------------- +// Construction / destruction +// ------------------------------------------------------------------- + +NixlBackend::Impl::Impl(Communicator& communicator) + : communicator_(communicator) { +#ifdef USE_NIXL + std::string agent_name = constructAgentName(communicator_.deviceId()); + agent_ = std::make_unique(agent_name); + if (!agent_) { + NVF_THROW("Failed to create NIXL agent"); + } + + nixl_b_params_t params; + nixl_status_t status = agent_->loadBackend("UCX", ¶ms); + if (status != NIXL_SUCCESS) { + agent_.reset(); + NVF_THROW("Failed to load UCX backend for NIXL agent"); + return; + } + + available_ = true; +#endif +} + +NixlBackend::Impl::~Impl() { +#ifdef USE_NIXL + agent_.reset(); +#endif +} + +std::string NixlBackend::Impl::constructAgentName(int deviceId){ + return "rank_" + std::to_string(deviceId); +} + +// ------------------------------------------------------------------- +// Memory registration +// ------------------------------------------------------------------- + +void NixlBackend::Impl::registerTensors( + const std::vector& tensors) { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + validateCudaTensors(tensors); + + nixl_reg_dlist_t dlist = buildRegDlist(tensors); + nixl_status_t status = agent_->registerMem(dlist); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL registerMem failed with status ", + static_cast(status)); + + metadata_exchanged_ = false; +#else + (void)tensors; + NVF_THROW("NIXL support not compiled"); +#endif +} + +void NixlBackend::Impl::deregisterTensors( + const std::vector& tensors) { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + validateCudaTensors(tensors); + + nixl_reg_dlist_t dlist = buildRegDlist(tensors); + nixl_status_t status = agent_->deregisterMem(dlist); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL deregisterMem failed with status ", + static_cast(status)); + + metadata_exchanged_ = false; +#else + (void)tensors; + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif +} + +// ------------------------------------------------------------------- +// Metadata exchange +// ------------------------------------------------------------------- + +void NixlBackend::Impl::exchangeMetadata() { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + + std::string local_md = agent_->getLocalMD(); + auto* store = communicator_.getTcpStore(); + const int64_t my_rank = communicator_.deviceId(); + const int64_t world_size = communicator_.size(); + + std::string key_prefix = "nixl_agent_md_rank_"; + store->set( + key_prefix + std::to_string(my_rank), + std::vector(local_md.begin(), local_md.end())); + + for (int64_t rank = 0; rank < world_size; ++rank) { + if (rank == my_rank) { + continue; + } + auto bytes = store->get(key_prefix + std::to_string(rank)); + std::string remote_md(bytes.begin(), bytes.end()); + nixl_status_t status = agent_->loadRemoteMD(remote_md); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL loadRemoteMD failed for rank ", + rank, + " with status ", + static_cast(status)); + } + + // Barrier before deleting keys so no rank reads a deleted key. + communicator_.barrier(); + + store->deleteKey(key_prefix + std::to_string(my_rank)); + metadata_exchanged_ = true; +#else + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif +} + +// ------------------------------------------------------------------- +// Transfer preparation +// ------------------------------------------------------------------- + +// Prepare a transfer between local and remote tensor pairs. +// +// The local and remote descriptor lists are built from the tensors' +// data pointers, byte sizes, and CUDA device indices. NIXL pairs +// local_tensors[i] with remote_tensors[i]. The direction depends on `op`: +// kRead -- data flows from remote_tensors[i] into local_tensors[i] +// kWrite -- data flows from local_tensors[i] into remote_tensors[i] +// +// Preconditions: +// - exchangeMetadata() has been called since the last registration change +// - local_tensors and remote_tensors have the same length +// - all tensors are contiguous CUDA tensors +// - remote tensors must have been registered on remote_rank's agent +NixlTransferHandle NixlBackend::Impl::prepareTransfer( + const std::vector& local_tensors, + const std::vector& remote_tensors, + int64_t remote_rank, + NixlXferOp op) { + NixlTransferHandle handle; +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + NVF_ERROR(metadata_exchanged_, "exchangeMetadata() must be called first"); + NVF_ERROR( + local_tensors.size() == remote_tensors.size(), + "Local and remote tensor lists must have the same size. Got ", + local_tensors.size(), + " vs ", + remote_tensors.size()); + validateCudaTensors(local_tensors); + validateCudaTensors(remote_tensors); + + std::string remote_agent_name = constructAgentName(remote_rank); + + nixl_xfer_dlist_t local_dlist = buildXferDlist(local_tensors); + nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_tensors); + + auto impl = std::make_unique(); + nixl_status_t status = agent_->prepXferDlist( + toNixlXferOp(op), + local_dlist, + remote_dlist, + remote_agent_name, + impl->xfer_handle); + NVF_ERROR( + status == NIXL_SUCCESS, + "NIXL prepXferDlist failed with status ", + static_cast(status)); + + impl->prepared = true; + handle.impl_ = std::move(impl); +#else + (void)local_tensors; + (void)remote_tensors; + (void)remote_rank; + (void)op; + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif + return handle; +} + +// ------------------------------------------------------------------- +// Transfer posting +// ------------------------------------------------------------------- + +void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + NVF_ERROR(handle.isValid(), "Cannot post an invalid transfer handle"); + NVF_ERROR( + !handle.impl_->posted, + "Transfer already posted. Wait for completion before re-posting."); + + nixl_status_t status = agent_->postXferReq(handle.impl_->xfer_handle); + NVF_ERROR( + status == NIXL_SUCCESS || status == NIXL_IN_PROG, + "NIXL postXferReq failed with status ", + static_cast(status)); + + handle.impl_->posted = true; +#else + (void)handle; + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif +} + +// ------------------------------------------------------------------- +// Transfer status / wait +// ------------------------------------------------------------------- + +NixlXferStatus NixlBackend::Impl::getTransferStatus( + const NixlTransferHandle& handle) const { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + NVF_ERROR(handle.isValid(), "Cannot query status of an invalid handle"); + NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); + + nixl_status_t status = agent_->getXferStatus(handle.impl_->xfer_handle); + switch (status) { + case NIXL_SUCCESS: + return NixlXferStatus::kDone; + case NIXL_IN_PROG: + return NixlXferStatus::kInProgress; + default: + return NixlXferStatus::kError; + } +#else + (void)handle; + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif +} + +void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { +#ifdef USE_NIXL + NVF_ERROR(available_, "NIXL backend is not available"); + NVF_ERROR(handle.isValid(), "Cannot wait on an invalid handle"); + NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); + + NixlXferStatus xfer_status; + do { + xfer_status = getTransferStatus(handle); + NVF_ERROR( + xfer_status != NixlXferStatus::kError, + "NIXL transfer completed with an error"); + } while (xfer_status == NixlXferStatus::kInProgress); + + handle.impl_->posted = false; +#else + (void)handle; + NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); +#endif +} + +// =================================================================== +// NixlBackend singleton + public API +// =================================================================== + +NixlBackend::NixlBackend() + : impl_(std::make_unique(Communicator::getInstance())) {} + +NixlBackend& NixlBackend::getInstance() { + static auto* instance = new NixlBackend(); + return *instance; +} + +void NixlBackend::cleanup() { + impl_.reset(); +} + +bool NixlBackend::isAvailable() const { + return impl_ && impl_->isAvailable(); +} + +void NixlBackend::registerTensors(const std::vector& tensors) { + impl_->registerTensors(tensors); +} + +void NixlBackend::deregisterTensors(const std::vector& tensors) { + impl_->deregisterTensors(tensors); +} + +void NixlBackend::exchangeMetadata() { + impl_->exchangeMetadata(); +} + +NixlTransferHandle NixlBackend::prepareTransfer( + const std::vector& local_tensors, + const std::vector& remote_tensors, + int64_t remote_rank, + NixlXferOp op) { + return impl_->prepareTransfer( + local_tensors, remote_tensors, remote_rank, op); +} + +void NixlBackend::postTransfer(NixlTransferHandle& handle) { + impl_->postTransfer(handle); +} + +NixlXferStatus NixlBackend::getTransferStatus( + const NixlTransferHandle& handle) const { + return impl_->getTransferStatus(handle); +} + +void NixlBackend::waitTransfer(NixlTransferHandle& handle) { + impl_->waitTransfer(handle); +} + +} // namespace nvfuser diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h new file mode 100644 index 00000000000..3b1dedf7cb9 --- /dev/null +++ b/csrc/multidevice/nixl.h @@ -0,0 +1,139 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include +#include + +#include "multidevice/communicator.h" +#include "visibility.h" + +namespace nvfuser { + +// Transfer direction. NIXL uses a one-sided model: +// Read = pull remote data into local buffers +// Write = push local data into remote buffers +enum class NixlXferOp { + kRead, + kWrite, +}; + +enum class NixlXferStatus { + kDone, + kInProgress, + kError, +}; + +// ------------------------------------------------------------------- +// NixlTransferHandle: opaque handle for a prepared transfer +// ------------------------------------------------------------------- +// Returned by NixlBackend::prepareTransfer(). Callers hold this handle +// and pass it to postTransfer() / waitTransfer(). The actual NIXL +// transfer handle lives inside the impl; this is just an owning wrapper. +class NixlTransferHandleImpl; + +class NVF_API NixlTransferHandle { + public: + NixlTransferHandle(); + ~NixlTransferHandle(); + NixlTransferHandle(NixlTransferHandle&&) noexcept; + NixlTransferHandle& operator=(NixlTransferHandle&&) noexcept; + + NixlTransferHandle(const NixlTransferHandle&) = delete; + NixlTransferHandle& operator=(const NixlTransferHandle&) = delete; + + bool isValid() const; + + private: + friend class NixlBackend; + std::unique_ptr impl_; +}; + +// ------------------------------------------------------------------- +// NixlBackend: singleton NIXL backend over UCX for GPU tensors +// ------------------------------------------------------------------- +// Singleton - Wraps a nixlAgent with the UCX backend and provides a tensor-level +// API for registering GPU memory and performing RDMA transfers. +// +// Lifecycle: +// 1. getInstance() - creates agent, loads UCX backend +// 2. registerTensors() - register GPU tensors for RDMA access +// 3. exchangeMetadata() - all ranks share their registration info +// 4. prepareTransfer() - expensive one-time setup per transfer pattern +// 5. postTransfer() - cheap, non-blocking data movement +// 6. waitTransfer() - block until complete +// +// Thread safety: methods are NOT thread-safe. The caller must +// synchronize if the same NixlBackend is used from multiple threads. +class NixlBackend { + public: + static NixlBackend& getInstance(); + + NixlBackend(const NixlBackend&) = delete; + NixlBackend& operator=(const NixlBackend&) = delete; + ~NixlBackend() = delete; + + // Explicitly tear down the singleton. Must be called before program + // exit (same pattern as Communicator::cleanup). + void cleanup(); + + bool isAvailable() const; + + // ------------------------------------------------------------------ + // Memory registration + // ------------------------------------------------------------------ + + // Register CUDA tensors with the NIXL agent so they can participate + // in RDMA transfers. Tensors must be contiguous and remain alive + // until deregisterTensors() is called. + void registerTensors(const std::vector& tensors); + + void deregisterTensors(const std::vector& tensors); + + // ------------------------------------------------------------------ + // Metadata exchange + // ------------------------------------------------------------------ + // Exchange local agent metadata with all peers through the TCPStore. + // Must be called after registerTensors() and before prepareTransfer() + // whenever the set of registered tensors changes. + void exchangeMetadata(); + + // ------------------------------------------------------------------ + // Transfer lifecycle + // ------------------------------------------------------------------ + + // Prepare a transfer between pairs of tensors. + // local_tensors[i] and remote_tensors[i] must have the same byte size. + // All tensors must be contiguous CUDA tensors and previously registered. + // The returned handle can be posted multiple times (preparation is + // amortized). + NixlTransferHandle prepareTransfer( + const std::vector& local_tensors, + const std::vector& remote_tensors, + int64_t remote_rank, + NixlXferOp op); + + // Post a previously prepared transfer for execution (non-blocking). + void postTransfer(NixlTransferHandle& handle); + + // Poll the status of a posted transfer without blocking. + NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; + + // Block until the transfer completes (or errors out). + void waitTransfer(NixlTransferHandle& handle); + + private: + NixlBackend(); + + class Impl; + std::unique_ptr impl_; +}; + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_nixl.cpp b/tests/cpp/test_multidevice_nixl.cpp new file mode 100644 index 00000000000..eb8de2ba3b8 --- /dev/null +++ b/tests/cpp/test_multidevice_nixl.cpp @@ -0,0 +1,289 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include "multidevice/nixl.h" +#include "tests/cpp/multidevice.h" + +namespace nvfuser { + +using NixlTest = MultiDeviceTest; + +// ------------------------------------------------------------------- +// NixlTransferHandle tests +// ------------------------------------------------------------------- + +TEST_F(NixlTest, TransferHandleDefaultConstruction) { + NixlTransferHandle handle; + EXPECT_FALSE(handle.isValid()); +} + +TEST_F(NixlTest, TransferHandleMoveConstruction) { + NixlTransferHandle h1; + EXPECT_FALSE(h1.isValid()); + + NixlTransferHandle h2(std::move(h1)); + EXPECT_FALSE(h2.isValid()); +} + +TEST_F(NixlTest, TransferHandleMoveAssignment) { + NixlTransferHandle h1; + NixlTransferHandle h2; + h2 = std::move(h1); + EXPECT_FALSE(h2.isValid()); +} + +// ------------------------------------------------------------------- +// NixlBackend singleton tests +// ------------------------------------------------------------------- + +TEST_F(NixlTest, SingletonIsAccessible) { + NixlBackend& backend = NixlBackend::getInstance(); + // isAvailable() returns true only when USE_NIXL is defined and the + // UCX backend loaded successfully. Either outcome is valid here. + (void)backend.isAvailable(); +} + +// ------------------------------------------------------------------- +// Input validation tests (these exercise the guards in the impl) +// ------------------------------------------------------------------- + +TEST_F(NixlTest, RegisterEmptyTensorListThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + std::vector empty; + EXPECT_THROW(backend.registerTensors(empty), nvfError); +} + +TEST_F(NixlTest, RegisterCpuTensorThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto cpu_tensor = at::randn({64}); + EXPECT_THROW(backend.registerTensors({cpu_tensor}), nvfError); +} + +TEST_F(NixlTest, RegisterNonContiguousTensorThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto t = at::randn({8, 8}, tensor_options_); + auto non_contig = t.transpose(0, 1); + ASSERT_FALSE(non_contig.is_contiguous()); + EXPECT_THROW(backend.registerTensors({non_contig}), nvfError); +} + +TEST_F(NixlTest, DeregisterEmptyTensorListThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + std::vector empty; + EXPECT_THROW(backend.deregisterTensors(empty), nvfError); +} + +// ------------------------------------------------------------------- +// Transfer preparation validation +// ------------------------------------------------------------------- + +TEST_F(NixlTest, PrepareTransferWithoutMetadataExchangeThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto local = at::randn({64}, tensor_options_); + auto remote = at::randn({64}, tensor_options_); + backend.registerTensors({local}); + backend.registerTensors({remote}); + + EXPECT_THROW( + (void)backend.prepareTransfer({toTensorDesc(local)}, {toTensorDesc(remote)}, 0, NixlXferOp::kRead), + nvfError); + + backend.deregisterTensors({local}); + backend.deregisterTensors({remote}); +} + +TEST_F(NixlTest, PrepareTransferMismatchedSizesThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto t1 = at::randn({64}, tensor_options_); + auto t2 = at::randn({64}, tensor_options_); + auto t3 = at::randn({64}, tensor_options_); + backend.registerTensors({t1, t2, t3}); + backend.exchangeMetadata(); + + EXPECT_THROW( + (void)backend.prepareTransfer({toTensorDesc(t1), toTensorDesc(t2)}, {toTensorDesc(t3)}, 0, NixlXferOp::kRead), nvfError); + + backend.deregisterTensors({t1, t2, t3}); +} + +// ------------------------------------------------------------------- +// Post / wait on invalid handles +// ------------------------------------------------------------------- + +TEST_F(NixlTest, PostInvalidHandleThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + NixlTransferHandle invalid_handle; + EXPECT_THROW(backend.postTransfer(invalid_handle), nvfError); +} + +TEST_F(NixlTest, WaitInvalidHandleThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + NixlTransferHandle invalid_handle; + EXPECT_THROW(backend.waitTransfer(invalid_handle), nvfError); +} + +TEST_F(NixlTest, GetStatusInvalidHandleThrows) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + NixlTransferHandle invalid_handle; + EXPECT_THROW((void)backend.getTransferStatus(invalid_handle), nvfError); +} + +// ------------------------------------------------------------------- +// End-to-end transfer test (requires >= 2 devices with NIXL) +// ------------------------------------------------------------------- + +TEST_F(NixlTest, ReadTransferEndToEnd) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + if (communicator_->size() < 2) { + GTEST_SKIP() << "Need at least 2 devices for transfer test"; + } + + const int64_t rank = communicator_->deviceId(); + const int64_t world_size = communicator_->size(); + const int64_t peer_rank = (rank + 1) % world_size; + constexpr int64_t kSize = 1024; + + // Ring style transfer: each rank reads from its peer's remote tensor to its local . + auto src = at::full({kSize}, static_cast(rank + 1), tensor_options_); + auto dst = at::zeros({kSize}, tensor_options_); + cudaDeviceSynchronize(); + + backend.registerTensors({src, dst}); + backend.exchangeMetadata(); + + // Fetch the remote tensor descriptor from the peer + std::string src_key_prefix = "nixl_test_read_transfer_src_rank_"; + storeTensorDescs(*communicator_, src_key_prefix + std::to_string(rank), {src}); + auto remote_src_descs = fetchTensorDescs(*communicator_, src_key_prefix + std::to_string(peer_rank)); + communicator_->barrier(); + communicator_->getTcpStore()->deleteKey(src_key_prefix + std::to_string(rank)); + auto remote_src_desc = remote_src_descs[0]; // Only one remote tensor is expected + + // Each rank reads from its peer. After the read, local should contain + // the values that the peer stored in *its* remote tensor. + auto handle = backend.prepareTransfer( + {toTensorDesc(dst)}, {remote_src_desc}, peer_rank, NixlXferOp::kRead); + ASSERT_TRUE(handle.isValid()); + + backend.postTransfer(handle); + backend.waitTransfer(handle); + + auto local_cpu = dst.cpu(); + float expected_val = static_cast(peer_rank + 1); + EXPECT_TRUE(at::allclose(local_cpu, at::full({kSize}, expected_val))); + + backend.deregisterTensors({dst, src}); +} + +TEST_F(NixlTest, WriteTransferEndToEnd) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + if (communicator_->size() < 2) { + GTEST_SKIP() << "Need at least 2 devices for transfer test"; + } + + const int64_t rank = communicator_->deviceId(); + const int64_t world_size = communicator_->size(); + const int64_t peer_rank = (rank + 1) % world_size; + constexpr int64_t kSize = 512; + + // Each rank writes its local to the remote of its peer in a ring style + auto src = at::full({kSize}, static_cast(rank + 1), tensor_options_); + auto dst = at::zeros({kSize}, tensor_options_); + cudaDeviceSynchronize(); + + backend.registerTensors({src, dst}); + backend.exchangeMetadata(); + + // Fetch the remote tensor descriptor from the peer + std::string dst_key_prefix = "nixl_test_write_transfer_dst_rank_"; + storeTensorDescs(*communicator_, dst_key_prefix + std::to_string(rank), {dst}); + auto remote_dst_descs = fetchTensorDescs(*communicator_, dst_key_prefix + std::to_string(peer_rank)); + communicator_->barrier(); + communicator_->getTcpStore()->deleteKey(dst_key_prefix + std::to_string(rank)); + auto remote_dst_desc = remote_dst_descs[0]; // Only one remote tensor is expected + + // Each rank writes its local tensor into its peer's remote tensor. + auto handle = backend.prepareTransfer( + {toTensorDesc(src)}, {remote_dst_desc}, peer_rank, NixlXferOp::kWrite); + ASSERT_TRUE(handle.isValid()); + + backend.postTransfer(handle); + backend.waitTransfer(handle); + + // After a barrier, the peer should have written into our remote tensor . + communicator_->barrier(); + + auto remote_cpu = dst.cpu(); + int64_t writer_rank = (rank - 1 + world_size) % world_size; + float expected_val = static_cast(writer_rank + 1); + EXPECT_TRUE(at::allclose(remote_cpu, at::full({kSize}, expected_val))); + + backend.deregisterTensors({src, dst}); +} + +TEST_F(NixlTest, RegisterDeregisterRoundTrip) { + NixlBackend& backend = NixlBackend::getInstance(); + if (!backend.isAvailable()) { + GTEST_SKIP() << "NIXL backend not available"; + } + + auto t1 = at::randn({256}, tensor_options_); + auto t2 = at::randn({128}, tensor_options_); + + backend.registerTensors({t1, t2}); + backend.deregisterTensors({t1, t2}); + + // Re-registering the same tensors should succeed. + backend.registerTensors({t1, t2}); + backend.deregisterTensors({t1, t2}); +} + +} // namespace nvfuser From 9a8a377109a5038c55e4a53d11b7e26ee53cf4b9 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 10:36:56 +0200 Subject: [PATCH 05/42] unstable --- CMakeLists.txt | 39 ++++++ csrc/multidevice/communicator.cpp | 3 + csrc/multidevice/communicator.h | 11 ++ csrc/multidevice/multidevice.h | 2 +- csrc/multidevice/nixl.cpp | 189 +++++++++++++++++++----------- csrc/multidevice/nixl.h | 105 ++++++++++++++++- 6 files changed, 276 insertions(+), 73 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ff76e741b4c..94b87209c21 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,7 @@ set(NVFUSER_CUTLASS "${NVFUSER_ROOT}/cutlass") set(NVFUSER_THIRD_PARTY_DIR "${NVFUSER_ROOT}/third_party") option(NVFUSER_STANDALONE_BUILD_WITH_UCC "" OFF) +option(NVFUSER_STANDALONE_BUILD_WITH_NIXL "" OFF) option(NVFUSER_EXPLICIT_ERROR_CHECK "" OFF) option(NVFUSER_ENABLE_DEPENDENCY_REPORT "Enable Python-based dependency reporting and log capture" ON) @@ -240,6 +241,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/multidevice/ipc_utils.cpp ${NVFUSER_SRCS_DIR}/multidevice/device_mesh.cpp ${NVFUSER_SRCS_DIR}/multidevice/executor.cpp + ${NVFUSER_SRCS_DIR}/multidevice/nixl.cpp ${NVFUSER_SRCS_DIR}/multidevice/execution_utils.cpp ${NVFUSER_SRCS_DIR}/multidevice/propagation.cpp ${NVFUSER_SRCS_DIR}/multidevice/resharding.cpp @@ -586,6 +588,37 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) target_compile_definitions(codegen_internal PRIVATE NVFUSER_BUILD_WITH_UCC) endif() +if(NVFUSER_STANDALONE_BUILD_WITH_NIXL) + # User may need to set NIXL_PREFIX to the NIXL install directory. + find_path(NIXL_INCLUDE_DIR nixl.h + HINTS $ENV{NIXL_PREFIX}/include ENV CPATH + ) + find_library(NIXL_LIBRARY nixl + HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + ) + find_library(NIXL_BUILD_LIBRARY nixl_build + HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + ) + + if(NOT NIXL_INCLUDE_DIR OR NOT NIXL_LIBRARY) + message(FATAL_ERROR "NIXL not found. Set NIXL_PREFIX to the NIXL install directory.") + endif() + + message(STATUS "Found NIXL: ${NIXL_LIBRARY} (include: ${NIXL_INCLUDE_DIR})") + if(NIXL_BUILD_LIBRARY) + message(STATUS "Found NIXL build lib: ${NIXL_BUILD_LIBRARY}") + endif() + + add_library(__nvfuser_nixl INTERFACE) + target_include_directories(__nvfuser_nixl INTERFACE ${NIXL_INCLUDE_DIR}) + target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_LIBRARY}) + if(NIXL_BUILD_LIBRARY) + target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_BUILD_LIBRARY}) + endif() + target_link_libraries(codegen_internal PRIVATE __nvfuser_nixl) + target_compile_definitions(codegen_internal PRIVATE USE_NIXL) +endif() + add_dependencies(codegen_internal flatc build_flatbuffer_config) # installing nvfuser headers @@ -1153,6 +1186,7 @@ if(BUILD_TEST) ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_lower_communication_cuda.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_matmul.cpp + ${NVFUSER_ROOT}/tests/cpp/test_multidevice_nixl.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_pipeline.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_sharding.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_stream_parallel_type.cpp @@ -1457,6 +1491,11 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) message(STATUS " UCX_DIR : $ENV{UCX_DIR}") endif() message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_UCC : ${NVFUSER_STANDALONE_BUILD_WITH_UCC}") +message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_NIXL : ${NVFUSER_STANDALONE_BUILD_WITH_NIXL}") +if(NVFUSER_STANDALONE_BUILD_WITH_NIXL) + message(STATUS " NIXL_INCLUDE_DIR: ${NIXL_INCLUDE_DIR}") + message(STATUS " NIXL_LIBRARY : ${NIXL_LIBRARY}") +endif() message(STATUS " NVFUSER_BUILD_WITH_ASAN : ${NVFUSER_BUILD_WITH_ASAN}") message(STATUS " NVFUSER_DISTRIBUTED : ${NVFUSER_DISTRIBUTED}") message(STATUS " NVFUSER_CPP_STANDARD : ${NVFUSER_CPP_STANDARD}") diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index 208277f98a1..98c36ab8d3b 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -41,6 +41,9 @@ std::ostream& operator<<(std::ostream& out, const CommunicatorBackend& cb) { case CommunicatorBackend::kCuda: out << "CUDA"; break; + case CommunicatorBackend::kNixl: + out << "NIXL"; + break; } return out; } diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index b56e6fee3aa..c4a1eb3d09b 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -11,6 +11,9 @@ #include #include +#include +#include + #ifdef NVFUSER_DISTRIBUTED #include #include @@ -116,6 +119,12 @@ class NVF_API Communicator { return ucc_available_; } else if (backend == CommunicatorBackend::kNccl) { return nccl_available_; + } else if (backend == CommunicatorBackend::kNixl) { +#ifdef USE_NIXL + return true; +#else + return false; +#endif } return false; } @@ -124,6 +133,7 @@ class NVF_API Communicator { return store_.get(); } + private: Communicator( CommunicatorBackend backend = comm_backend_default, @@ -155,4 +165,5 @@ class NVF_API Communicator { std::unordered_map> backends_; }; + } // namespace nvfuser diff --git a/csrc/multidevice/multidevice.h b/csrc/multidevice/multidevice.h index 288a89fe952..7915f5e3d92 100644 --- a/csrc/multidevice/multidevice.h +++ b/csrc/multidevice/multidevice.h @@ -19,5 +19,5 @@ using DeviceType = c10::Device; using Team = std::vector; // Supported backends. -enum class CommunicatorBackend { kNccl, kUcc, kCuda }; +enum class CommunicatorBackend { kNccl, kUcc, kCuda, kNixl }; } // namespace nvfuser diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 3f71c267cfd..fa0ac7c7a94 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -6,13 +6,16 @@ */ // clang-format on #include "multidevice/nixl.h" +#include "exceptions.h" +#include +#include +#include #include #ifdef USE_NIXL #include #endif - namespace nvfuser { // =================================================================== @@ -22,10 +25,11 @@ namespace nvfuser { class NixlTransferHandleImpl { public: #ifdef USE_NIXL - nixl_xfer_req_t xfer_handle{}; + // TODO - is it leaking when handleimpl is destroyed ? + nixlXferReqH* xfer_handle = nullptr; +#endif bool prepared = false; bool posted = false; -#endif }; NixlTransferHandle::NixlTransferHandle() = default; @@ -61,8 +65,9 @@ void validateCudaTensors(const std::vector& tensors) { } #ifdef USE_NIXL + nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { - nixl_reg_dlist_t dlist(VRAM, tensors.size()); + nixl_reg_dlist_t dlist(VRAM_SEG, tensors.size()); for (const auto& t : tensors) { dlist.addDesc( {reinterpret_cast(t.data_ptr()), @@ -72,13 +77,10 @@ nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { return dlist; } -nixl_xfer_dlist_t buildXferDlist(const std::vector& tensors) { - nixl_xfer_dlist_t dlist(VRAM, tensors.size()); - for (const auto& t : tensors) { - dlist.addDesc( - {reinterpret_cast(t.data_ptr()), - static_cast(t.numel()) * t.element_size(), - static_cast(t.device().index())}); +nixl_xfer_dlist_t buildXferDlist(const std::vector& descs) { + nixl_xfer_dlist_t dlist(VRAM_SEG, descs.size()); + for (const auto& desc : descs) { + dlist.addDesc({desc.addr, desc.size, desc.dev}); } return dlist; } @@ -86,12 +88,13 @@ nixl_xfer_dlist_t buildXferDlist(const std::vector& tensors) { nixl_xfer_op_t toNixlXferOp(NixlXferOp op) { switch (op) { case NixlXferOp::kRead: - return NIXL_XFER_READ; + return NIXL_READ; case NixlXferOp::kWrite: - return NIXL_XFER_WRITE; + return NIXL_WRITE; } - std::unreachable(); + NVF_THROW("Invalid NIXL transfer operation: ", static_cast(op)); } + #endif } // namespace @@ -114,8 +117,8 @@ class NixlBackend::Impl { void exchangeMetadata(); NixlTransferHandle prepareTransfer( - const std::vector& local_tensors, - const std::vector& remote_tensors, + const std::vector& local_descs, + const std::vector& remote_descs, int64_t remote_rank, NixlXferOp op); @@ -124,8 +127,11 @@ class NixlBackend::Impl { void waitTransfer(NixlTransferHandle& handle); private: + std::string constructAgentName(int64_t rank); + #ifdef USE_NIXL std::unique_ptr agent_; + nixlBackendH* backend_ = nullptr; #endif Communicator& communicator_; bool available_ = false; @@ -140,37 +146,73 @@ NixlBackend::Impl::Impl(Communicator& communicator) : communicator_(communicator) { #ifdef USE_NIXL std::string agent_name = constructAgentName(communicator_.deviceId()); - agent_ = std::make_unique(agent_name); - if (!agent_) { - NVF_THROW("Failed to create NIXL agent"); - } + nixlAgentConfig cfg(false); + agent_ = std::make_unique(agent_name, cfg); nixl_b_params_t params; - nixl_status_t status = agent_->loadBackend("UCX", ¶ms); + nixl_status_t status = agent_->createBackend("UCX", params, backend_); if (status != NIXL_SUCCESS) { agent_.reset(); - NVF_THROW("Failed to load UCX backend for NIXL agent"); - return; + NVF_THROW("Failed to create UCX backend for NIXL agent"); + } + + // Probe: verify that VRAM (CUDA GPU memory) is actually usable with + // the UCX backend. Some UCX installations lack CUDA support, causing + // registerMem to silently misclassify VRAM as host memory. We detect + // this by registering a small buffer and asking NIXL to prepare a + // local descriptor list for VRAM -- if no backend claims VRAM, the + // probe fails and we mark the backend as unavailable. + { + auto probe = at::empty( + {1}, + at::TensorOptions().dtype(at::kByte).device( + at::kCUDA, communicator_.deviceId())); + nixl_reg_dlist_t reg_dlist(VRAM_SEG, 1); + reg_dlist.addDesc( + {reinterpret_cast(probe.data_ptr()), + probe.nbytes(), + static_cast(probe.device().index())}); + + nixl_status_t reg_status = agent_->registerMem(reg_dlist); + if (reg_status != NIXL_SUCCESS) { + return; + } + + nixl_xfer_dlist_t xfer_dlist(VRAM_SEG, 1); + xfer_dlist.addDesc( + {reinterpret_cast(probe.data_ptr()), + probe.nbytes(), + static_cast(probe.device().index())}); + + nixlDlistH* dlist_handle = nullptr; + nixl_status_t prep_status = + agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); + + if (dlist_handle) { + agent_->releasedDlistH(dlist_handle); + } + agent_->deregisterMem(reg_dlist); + + if (prep_status != NIXL_SUCCESS) { + return; + } } available_ = true; #endif } -NixlBackend::Impl::~Impl() { -#ifdef USE_NIXL - agent_.reset(); -#endif -} +NixlBackend::Impl::~Impl() = default; -std::string NixlBackend::Impl::constructAgentName(int deviceId){ - return "rank_" + std::to_string(deviceId); +std::string NixlBackend::Impl::constructAgentName(int64_t rank){ + return "rank_" + std::to_string(rank); } // ------------------------------------------------------------------- // Memory registration // ------------------------------------------------------------------- +// TODO - consider adding RAII wrapper void NixlBackend::Impl::registerTensors( const std::vector& tensors) { #ifdef USE_NIXL @@ -219,23 +261,31 @@ void NixlBackend::Impl::exchangeMetadata() { #ifdef USE_NIXL NVF_ERROR(available_, "NIXL backend is not available"); - std::string local_md = agent_->getLocalMD(); + nixl_blob_t local_md; + nixl_status_t md_status = agent_->getLocalMD(local_md); + NVF_ERROR( + md_status == NIXL_SUCCESS, + "NIXL getLocalMD failed with status ", + static_cast(md_status)); + auto* store = communicator_.getTcpStore(); - const int64_t my_rank = communicator_.deviceId(); - const int64_t world_size = communicator_.size(); + const auto my_rank = communicator_.deviceId(); + const auto world_size = communicator_.size(); - std::string key_prefix = "nixl_agent_md_rank_"; + std::string md_key_prefix = "nixl_agent_md_rank_"; store->set( - key_prefix + std::to_string(my_rank), + md_key_prefix + std::to_string(my_rank), std::vector(local_md.begin(), local_md.end())); for (int64_t rank = 0; rank < world_size; ++rank) { if (rank == my_rank) { continue; } - auto bytes = store->get(key_prefix + std::to_string(rank)); - std::string remote_md(bytes.begin(), bytes.end()); - nixl_status_t status = agent_->loadRemoteMD(remote_md); + // Fetch & load MD + auto bytes = store->get(md_key_prefix + std::to_string(rank)); + nixl_blob_t remote_md(bytes.begin(), bytes.end()); + std::string remote_agent_name; + nixl_status_t status = agent_->loadRemoteMD(remote_md, remote_agent_name); NVF_ERROR( status == NIXL_SUCCESS, "NIXL loadRemoteMD failed for rank ", @@ -247,7 +297,7 @@ void NixlBackend::Impl::exchangeMetadata() { // Barrier before deleting keys so no rank reads a deleted key. communicator_.barrier(); - store->deleteKey(key_prefix + std::to_string(my_rank)); + store->deleteKey(md_key_prefix + std::to_string(my_rank)); metadata_exchanged_ = true; #else NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); @@ -260,20 +310,18 @@ void NixlBackend::Impl::exchangeMetadata() { // Prepare a transfer between local and remote tensor pairs. // -// The local and remote descriptor lists are built from the tensors' -// data pointers, byte sizes, and CUDA device indices. NIXL pairs -// local_tensors[i] with remote_tensors[i]. The direction depends on `op`: -// kRead -- data flows from remote_tensors[i] into local_tensors[i] -// kWrite -- data flows from local_tensors[i] into remote_tensors[i] +// NIXL pairs local_tensors[i] with remote_tensors[i]. The direction +// depends on `op`: +// kRead -- data flows from remote into local +// kWrite -- data flows from local into remote // -// Preconditions: -// - exchangeMetadata() has been called since the last registration change -// - local_tensors and remote_tensors have the same length -// - all tensors are contiguous CUDA tensors -// - remote tensors must have been registered on remote_rank's agent +// remote_tensors are LOCAL tensors whose data_ptr identifies the +// corresponding registration slot. The actual remote addresses are +// looked up from the descriptors exchanged during exchangeMetadata(). +// This requires all ranks to register tensors in the same order. NixlTransferHandle NixlBackend::Impl::prepareTransfer( - const std::vector& local_tensors, - const std::vector& remote_tensors, + const std::vector& local_descs, // Local addresses + const std::vector& remote_descs, // Remote tensors (not valid on this rank) int64_t remote_rank, NixlXferOp op) { NixlTransferHandle handle; @@ -281,21 +329,19 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( NVF_ERROR(available_, "NIXL backend is not available"); NVF_ERROR(metadata_exchanged_, "exchangeMetadata() must be called first"); NVF_ERROR( - local_tensors.size() == remote_tensors.size(), + local_descs.size() == remote_descs.size(), "Local and remote tensor lists must have the same size. Got ", - local_tensors.size(), + local_descs.size(), " vs ", - remote_tensors.size()); - validateCudaTensors(local_tensors); - validateCudaTensors(remote_tensors); + remote_descs.size()); std::string remote_agent_name = constructAgentName(remote_rank); - nixl_xfer_dlist_t local_dlist = buildXferDlist(local_tensors); - nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_tensors); + nixl_xfer_dlist_t local_dlist = buildXferDlist(local_descs); + nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_descs); auto impl = std::make_unique(); - nixl_status_t status = agent_->prepXferDlist( + nixl_status_t status = agent_->createXferReq( toNixlXferOp(op), local_dlist, remote_dlist, @@ -303,14 +349,14 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( impl->xfer_handle); NVF_ERROR( status == NIXL_SUCCESS, - "NIXL prepXferDlist failed with status ", + "NIXL createXferReq failed with status ", static_cast(status)); impl->prepared = true; handle.impl_ = std::move(impl); #else - (void)local_tensors; - (void)remote_tensors; + (void)local_descs; + (void)remote_descs; (void)remote_rank; (void)op; NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); @@ -375,6 +421,7 @@ void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { NVF_ERROR(handle.isValid(), "Cannot wait on an invalid handle"); NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); + // TODO - check this spin loop NixlXferStatus xfer_status; do { xfer_status = getTransferStatus(handle); @@ -399,10 +446,12 @@ NixlBackend::NixlBackend() NixlBackend& NixlBackend::getInstance() { static auto* instance = new NixlBackend(); + NVF_CHECK(!instance->cleaned_up_, "NIXL backend has been cleaned up"); return *instance; } void NixlBackend::cleanup() { + cleaned_up_ = true; impl_.reset(); } @@ -411,37 +460,45 @@ bool NixlBackend::isAvailable() const { } void NixlBackend::registerTensors(const std::vector& tensors) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); impl_->registerTensors(tensors); } void NixlBackend::deregisterTensors(const std::vector& tensors) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); impl_->deregisterTensors(tensors); } void NixlBackend::exchangeMetadata() { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); impl_->exchangeMetadata(); } NixlTransferHandle NixlBackend::prepareTransfer( - const std::vector& local_tensors, - const std::vector& remote_tensors, + const std::vector& local_descs, + const std::vector& remote_descs, int64_t remote_rank, NixlXferOp op) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); return impl_->prepareTransfer( - local_tensors, remote_tensors, remote_rank, op); + local_descs, remote_descs, remote_rank, op); } void NixlBackend::postTransfer(NixlTransferHandle& handle) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); impl_->postTransfer(handle); } NixlXferStatus NixlBackend::getTransferStatus( const NixlTransferHandle& handle) const { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); return impl_->getTransferStatus(handle); } void NixlBackend::waitTransfer(NixlTransferHandle& handle) { + NVF_CHECK(isAvailable(), "NIXL backend is not available"); impl_->waitTransfer(handle); } -} // namespace nvfuser + +} // namespace nvfuser \ No newline at end of file diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index 3b1dedf7cb9..1cc5b84a7b8 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -11,7 +11,9 @@ #include #include #include +#include +#include "exceptions.h" #include "multidevice/communicator.h" #include "visibility.h" @@ -31,6 +33,93 @@ enum class NixlXferStatus { kError, }; +// ------------------------------------------------------------------ +// Todo - those functions should be moved to a more global file +// Helper functions for serializing and deserializing tensors descriptors for TCP store +struct TensorDesc { + uintptr_t addr; + size_t size; + uint32_t dev; +}; +static_assert(std::is_trivially_copyable_v, + "TensorDesc must be trivially copyable for serialization"); + +inline TensorDesc toTensorDesc(const at::Tensor& tensor) { + return { + .addr = reinterpret_cast(tensor.data_ptr()), + .size = static_cast(tensor.numel()) * tensor.element_size(), + .dev = static_cast(tensor.device().index()) + }; +} + +inline at::Tensor fromTensorDesc(const TensorDesc& desc) { + /* + Tensors must be valid on this device + */ + return at::from_blob( + reinterpret_cast(desc.addr), + {static_cast(desc.size)}, + at::TensorOptions().device(at::Device(at::kCUDA, desc.dev)).dtype(at::kByte) + ); +} + +inline std::vector serializeTensorsDescs( + const std::vector& descs) { + size_t count = descs.size(); + std::vector buf(sizeof(count) + count * sizeof(TensorDesc)); + std::memcpy(buf.data(), &count, sizeof(count)); + if (count == 0) + return buf; + + std::memcpy( + buf.data() + sizeof(count), + descs.data(), + descs.size() * sizeof(TensorDesc)); + return buf; +} + +inline std::vector deserializeTensorsDescs( + const std::vector& buf) { + NVF_ERROR(buf.size() >= sizeof(size_t), "Invalid serialized descriptor data"); + size_t count; + std::memcpy(&count, buf.data(), sizeof(count)); + NVF_ERROR( + buf.size() == sizeof(count) + count * sizeof(TensorDesc), + "Corrupted serialized descriptor data"); + + std::vector descs(count); + if (count > 0) { + std::memcpy( + descs.data(), + buf.data() + sizeof(count), + count * sizeof(TensorDesc)); + } + return descs; +} + +inline void storeTensorDescs(Communicator& communicator, const std::string& key, const std::vector& descs) { + NVF_CHECK(communicator.is_available(), "Communicator is not available"); + communicator.getTcpStore()->set(key, serializeTensorsDescs(descs)); +} + +inline void storeTensorDescs(Communicator& communicator, const std::string& key, const std::vector& tensors) { + std::vector descs; + descs.reserve(tensors.size()); + for (const auto& tensor : tensors) { + descs.push_back(toTensorDesc(tensor)); + } + storeTensorDescs(communicator, key, descs); +} + +inline std::vector fetchTensorDescs(Communicator& communicator, const std::string& key) { + NVF_CHECK(communicator.is_available(), "Communicator is not available"); + auto bytes = communicator.getTcpStore()->get(key); + return deserializeTensorsDescs(bytes); +} + +// End of Todo - those functions should be moved to a more global file +// ------------------------------------------------------------------ + // ------------------------------------------------------------------- // NixlTransferHandle: opaque handle for a prepared transfer // ------------------------------------------------------------------- @@ -49,7 +138,7 @@ class NVF_API NixlTransferHandle { NixlTransferHandle(const NixlTransferHandle&) = delete; NixlTransferHandle& operator=(const NixlTransferHandle&) = delete; - bool isValid() const; + [[nodiscard]] bool isValid() const; private: friend class NixlBackend; @@ -84,7 +173,7 @@ class NixlBackend { // exit (same pattern as Communicator::cleanup). void cleanup(); - bool isAvailable() const; + [[nodiscard]] bool isAvailable() const; // ------------------------------------------------------------------ // Memory registration @@ -114,9 +203,9 @@ class NixlBackend { // All tensors must be contiguous CUDA tensors and previously registered. // The returned handle can be posted multiple times (preparation is // amortized). - NixlTransferHandle prepareTransfer( - const std::vector& local_tensors, - const std::vector& remote_tensors, + [[nodiscard]] NixlTransferHandle prepareTransfer( + const std::vector& local_descs, + const std::vector& remote_descs, int64_t remote_rank, NixlXferOp op); @@ -124,16 +213,20 @@ class NixlBackend { void postTransfer(NixlTransferHandle& handle); // Poll the status of a posted transfer without blocking. - NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; + [[nodiscard]] NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; // Block until the transfer completes (or errors out). void waitTransfer(NixlTransferHandle& handle); private: NixlBackend(); + bool cleaned_up_ = false; class Impl; std::unique_ptr impl_; }; + + + } // namespace nvfuser From 0f2152890df45889c6664f1a03cb4e7f0a12929c Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 10:44:15 +0200 Subject: [PATCH 06/42] add python build changes for nixl --- python/setup.py | 3 +++ python/utils.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/python/setup.py b/python/setup.py index 9d340016d5d..aba47bee0f1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -32,6 +32,9 @@ # NVFUSER_BUILD_WITH_UCC # Build nvfuser with UCC support. You may need to specify environment variables of UCC_HOME, UCC_DIR, UCX_HOME, UCX_DIR. # +# NVFUSER_BUILD_WITH_NIXL +# Build nvfuser with NIXL support. You may need to set NIXL_PREFIX to the NIXL install directory. +# # NVFUSER_BUILD_WITHOUT_DISTRIBUTED # Build nvfuser without multidevice support # diff --git a/python/utils.py b/python/utils.py index 272d347c23e..303220ca867 100644 --- a/python/utils.py +++ b/python/utils.py @@ -22,6 +22,7 @@ class BuildConfig: no_benchmark: bool = False no_ninja: bool = False build_with_ucc: bool = False + build_with_nixl: bool = False build_with_asan: bool = False build_without_distributed: bool = False explicit_error_check: bool = False @@ -98,6 +99,12 @@ def parse_args(): action="store_true", help="Build nvfuser with UCC support", ) + parser.add_argument( + "--build-with-nixl", + dest="build_with_nixl", + action="store_true", + help="Build nvfuser with NIXL support", + ) parser.add_argument( "--explicit-error-check", dest="explicit_error_check", @@ -200,6 +207,7 @@ def create_build_config(): no_benchmark=args.no_benchmark, no_ninja=args.no_ninja, build_with_ucc=args.build_with_ucc, + build_with_nixl=args.build_with_nixl, build_with_asan=args.build_with_asan, build_without_distributed=args.build_without_distributed, explicit_error_check=args.explicit_error_check, @@ -245,6 +253,8 @@ def override_build_config_from_env(config): config.no_ninja = get_env_flag_bool("NVFUSER_BUILD_NO_NINJA") if "NVFUSER_BUILD_WITH_UCC" in os.environ: config.build_with_ucc = get_env_flag_bool("NVFUSER_BUILD_WITH_UCC") + if "NVFUSER_BUILD_WITH_NIXL" in os.environ: + config.build_with_nixl = get_env_flag_bool("NVFUSER_BUILD_WITH_NIXL") if "NVFUSER_BUILD_WITH_ASAN" in os.environ: config.build_with_asan = get_env_flag_bool("NVFUSER_BUILD_WITH_ASAN") if "NVFUSER_BUILD_WITHOUT_DISTRIBUTED" in os.environ: @@ -442,7 +452,11 @@ def cmake(config, relative_path): logger_level = logger.getEffectiveLevel() logger.setLevel(logging.CRITICAL) - pytorch_cmake_config = "-DCMAKE_PREFIX_PATH=" + get_pytorch_cmake_prefix() + cmake_prefix_path = get_pytorch_cmake_prefix() + llvm_dir = os.environ.get("LLVM_DIR") + if llvm_dir: + cmake_prefix_path += ";" + llvm_dir + pytorch_cmake_config = "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path logger.setLevel(logger_level) @@ -469,6 +483,7 @@ def on_or_off(flag: bool) -> str: f"-DUSE_DISTRIBUTED={pytorch_use_distributed}", f"-DNVFUSER_BUILD_WITH_ASAN={on_or_off(config.build_with_asan)}", f"-DNVFUSER_STANDALONE_BUILD_WITH_UCC={on_or_off(config.build_with_ucc)}", + f"-DNVFUSER_STANDALONE_BUILD_WITH_NIXL={on_or_off(config.build_with_nixl)}", f"-DNVFUSER_EXPLICIT_ERROR_CHECK={on_or_off(config.explicit_error_check)}", f"-DBUILD_TEST={on_or_off(not config.no_test)}", f"-DBUILD_PYTHON={on_or_off(not config.no_python)}", @@ -480,6 +495,25 @@ def on_or_off(flag: bool) -> str: "-B", cmake_build_dir, ] + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home: + cmd_str.append(f"-DCUDA_TOOLKIT_ROOT_DIR={cuda_home}") + nvcc_path = os.path.join(cuda_home, "bin", "nvcc") + if os.path.isfile(nvcc_path): + cmd_str.append(f"-DCMAKE_CUDA_COMPILER={nvcc_path}") + cudahostcxx = os.environ.get("CUDAHOSTCXX") + if cudahostcxx: + resolved = shutil.which(cudahostcxx) or cudahostcxx + cmd_str.append(f"-DCMAKE_CUDA_HOST_COMPILER={resolved}") + os.environ["CUDAHOSTCXX"] = resolved + cc = os.environ.get("CC") + if cc: + resolved = shutil.which(cc) or cc + cmd_str.append(f"-DCMAKE_C_COMPILER={resolved}") + cxx = os.environ.get("CXX") + if cxx: + resolved = shutil.which(cxx) or cxx + cmd_str.append(f"-DCMAKE_CXX_COMPILER={resolved}") if config.cutlass_max_jobs: cmd_str.append(f"-DCUTLASS_MAX_JOBS={config.cutlass_max_jobs}") if config.nvmmh_include_dir: From 6144827540126d54452e8fd71e15341a3beb6db1 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 10:44:56 +0200 Subject: [PATCH 07/42] fix typo --- python/utils.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/python/utils.py b/python/utils.py index 303220ca867..908433ec4cb 100644 --- a/python/utils.py +++ b/python/utils.py @@ -495,25 +495,6 @@ def on_or_off(flag: bool) -> str: "-B", cmake_build_dir, ] - cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") - if cuda_home: - cmd_str.append(f"-DCUDA_TOOLKIT_ROOT_DIR={cuda_home}") - nvcc_path = os.path.join(cuda_home, "bin", "nvcc") - if os.path.isfile(nvcc_path): - cmd_str.append(f"-DCMAKE_CUDA_COMPILER={nvcc_path}") - cudahostcxx = os.environ.get("CUDAHOSTCXX") - if cudahostcxx: - resolved = shutil.which(cudahostcxx) or cudahostcxx - cmd_str.append(f"-DCMAKE_CUDA_HOST_COMPILER={resolved}") - os.environ["CUDAHOSTCXX"] = resolved - cc = os.environ.get("CC") - if cc: - resolved = shutil.which(cc) or cc - cmd_str.append(f"-DCMAKE_C_COMPILER={resolved}") - cxx = os.environ.get("CXX") - if cxx: - resolved = shutil.which(cxx) or cxx - cmd_str.append(f"-DCMAKE_CXX_COMPILER={resolved}") if config.cutlass_max_jobs: cmd_str.append(f"-DCUTLASS_MAX_JOBS={config.cutlass_max_jobs}") if config.nvmmh_include_dir: From b32587a19354c9b3f36b0048e9eaf8c0d1ba9f31 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 10:52:45 +0200 Subject: [PATCH 08/42] restore main: --- csrc/multidevice/cuda_p2p.h | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/multidevice/cuda_p2p.h b/csrc/multidevice/cuda_p2p.h index 38ae6c549fc..514195c0746 100644 --- a/csrc/multidevice/cuda_p2p.h +++ b/csrc/multidevice/cuda_p2p.h @@ -10,10 +10,6 @@ #include #include -#include -#include -#include - #include "multidevice/ipc_handle.h" namespace nvfuser { From f8a94fcae21d909d4935306e4e22ff1e5af1d4a0 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 18:04:23 +0200 Subject: [PATCH 09/42] fix bug where zero-length buffer was passed to nixl --- csrc/multidevice/communicator.h | 2 +- csrc/multidevice/nixl.cpp | 44 +++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index c4a1eb3d09b..127276f6cb4 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -131,7 +131,7 @@ class NVF_API Communicator { c10d::TCPStore* getTcpStore() { return store_.get(); - } +} private: diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index fa0ac7c7a94..f71c70b7f4a 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -67,7 +68,7 @@ void validateCudaTensors(const std::vector& tensors) { #ifdef USE_NIXL nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { - nixl_reg_dlist_t dlist(VRAM_SEG, tensors.size()); + nixl_reg_dlist_t dlist(VRAM_SEG); for (const auto& t : tensors) { dlist.addDesc( {reinterpret_cast(t.data_ptr()), @@ -78,7 +79,7 @@ nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { } nixl_xfer_dlist_t buildXferDlist(const std::vector& descs) { - nixl_xfer_dlist_t dlist(VRAM_SEG, descs.size()); + nixl_xfer_dlist_t dlist(VRAM_SEG); for (const auto& desc : descs) { dlist.addDesc({desc.addr, desc.size, desc.dev}); } @@ -163,30 +164,47 @@ NixlBackend::Impl::Impl(Communicator& communicator) // local descriptor list for VRAM -- if no backend claims VRAM, the // probe fails and we mark the backend as unavailable. { + constexpr int64_t kProbeBytes = 64; auto probe = at::empty( - {1}, + {kProbeBytes}, at::TensorOptions().dtype(at::kByte).device( at::kCUDA, communicator_.deviceId())); - nixl_reg_dlist_t reg_dlist(VRAM_SEG, 1); - reg_dlist.addDesc( - {reinterpret_cast(probe.data_ptr()), - probe.nbytes(), - static_cast(probe.device().index())}); + size_t nbytes = static_cast(probe.nbytes()); + uintptr_t addr = reinterpret_cast(probe.data_ptr()); + uint32_t dev_idx = static_cast(probe.device().index()); + + std::cerr << "[NixlBackend probe] device=" << dev_idx + << " addr=0x" << std::hex << addr << std::dec + << " nbytes=" << nbytes + << " numel=" << probe.numel() + << " element_size=" << probe.element_size() << std::endl; + + NVF_ERROR(nbytes > 0, "NIXL probe: unexpected zero-byte tensor"); + NVF_ERROR(addr != 0, "NIXL probe: null data pointer"); + + nixl_reg_dlist_t reg_dlist(VRAM_SEG); + reg_dlist.addDesc({addr, nbytes, static_cast(dev_idx)}); + + std::cerr << "[NixlBackend probe] reg_dlist desc: addr=0x" << std::hex + << reg_dlist[0].addr << std::dec + << " len=" << reg_dlist[0].len + << " devId=" << reg_dlist[0].devId << std::endl; nixl_status_t reg_status = agent_->registerMem(reg_dlist); + std::cerr << "[NixlBackend probe] registerMem returned " + << reg_status << std::endl; if (reg_status != NIXL_SUCCESS) { return; } - nixl_xfer_dlist_t xfer_dlist(VRAM_SEG, 1); - xfer_dlist.addDesc( - {reinterpret_cast(probe.data_ptr()), - probe.nbytes(), - static_cast(probe.device().index())}); + nixl_xfer_dlist_t xfer_dlist(VRAM_SEG); + xfer_dlist.addDesc({addr, nbytes, static_cast(dev_idx)}); nixlDlistH* dlist_handle = nullptr; nixl_status_t prep_status = agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); + std::cerr << "[NixlBackend probe] prepXferDlist returned " + << prep_status << std::endl; if (dlist_handle) { agent_->releasedDlistH(dlist_handle); From a6b6f870737b5e9348d75cf312dd2bab6edd69c1 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Thu, 26 Feb 2026 18:13:55 +0200 Subject: [PATCH 10/42] Reduce probe size to 1 --- csrc/multidevice/nixl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index f71c70b7f4a..37881137e40 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -164,7 +164,7 @@ NixlBackend::Impl::Impl(Communicator& communicator) // local descriptor list for VRAM -- if no backend claims VRAM, the // probe fails and we mark the backend as unavailable. { - constexpr int64_t kProbeBytes = 64; + constexpr int64_t kProbeBytes = 1; auto probe = at::empty( {kProbeBytes}, at::TensorOptions().dtype(at::kByte).device( From 95460af2cd1857ffa3e5a8dea7d3892c37dcdfbc Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Sun, 1 Mar 2026 16:09:47 +0200 Subject: [PATCH 11/42] Address PR comments. --- csrc/multidevice/communicator.cpp | 7 ++++++- csrc/multidevice/communicator.h | 7 ++----- csrc/multidevice/nixl.cpp | 26 +++++++++----------------- csrc/multidevice/nixl.h | 11 ----------- 4 files changed, 17 insertions(+), 34 deletions(-) diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index e4c1c1cc584..c021a129670 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -186,7 +186,8 @@ Communicator::Communicator( master_port_( c10d::TCPStoreOptions::kDefaultPort + 42), // to avoid collision ucc_available_(false), - nccl_available_(false) { + nccl_available_(false), + nixl_available_(false) { if (isOptionDisabled(DisableOption::Multidevice)) { TORCH_WARN( "Multi-device support is disabled. All communication operations will " @@ -239,6 +240,10 @@ Communicator::Communicator( #ifdef USE_C10D_NCCL nccl_available_ = true; #endif + +#ifdef USE_NIXL + nixl_available_ = true; +#endif } namespace { diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index 127276f6cb4..f54d0535434 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -120,11 +120,7 @@ class NVF_API Communicator { } else if (backend == CommunicatorBackend::kNccl) { return nccl_available_; } else if (backend == CommunicatorBackend::kNixl) { -#ifdef USE_NIXL - return true; -#else - return false; -#endif + return nixl_available_; } return false; } @@ -159,6 +155,7 @@ class NVF_API Communicator { int master_port_; bool ucc_available_; bool nccl_available_; + bool nixl_available_; // stores the world's store used for the backend init c10::intrusive_ptr store_; // cache for the created backends. The keys are strings generated from Teams diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 37881137e40..06d4939b5fd 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -26,8 +26,15 @@ namespace nvfuser { class NixlTransferHandleImpl { public: #ifdef USE_NIXL - // TODO - is it leaking when handleimpl is destroyed ? + explicit NixlTransferHandleImpl(nixlAgent* agent) : agent(agent) {} + nixlAgent* agent; nixlXferReqH* xfer_handle = nullptr; + + ~NixlTransferHandleImpl() { + if (xfer_handle) { + agent->releaseXferReq(xfer_handle); + } + } #endif bool prepared = false; bool posted = false; @@ -173,26 +180,13 @@ NixlBackend::Impl::Impl(Communicator& communicator) uintptr_t addr = reinterpret_cast(probe.data_ptr()); uint32_t dev_idx = static_cast(probe.device().index()); - std::cerr << "[NixlBackend probe] device=" << dev_idx - << " addr=0x" << std::hex << addr << std::dec - << " nbytes=" << nbytes - << " numel=" << probe.numel() - << " element_size=" << probe.element_size() << std::endl; - NVF_ERROR(nbytes > 0, "NIXL probe: unexpected zero-byte tensor"); NVF_ERROR(addr != 0, "NIXL probe: null data pointer"); nixl_reg_dlist_t reg_dlist(VRAM_SEG); reg_dlist.addDesc({addr, nbytes, static_cast(dev_idx)}); - std::cerr << "[NixlBackend probe] reg_dlist desc: addr=0x" << std::hex - << reg_dlist[0].addr << std::dec - << " len=" << reg_dlist[0].len - << " devId=" << reg_dlist[0].devId << std::endl; - nixl_status_t reg_status = agent_->registerMem(reg_dlist); - std::cerr << "[NixlBackend probe] registerMem returned " - << reg_status << std::endl; if (reg_status != NIXL_SUCCESS) { return; } @@ -203,8 +197,6 @@ NixlBackend::Impl::Impl(Communicator& communicator) nixlDlistH* dlist_handle = nullptr; nixl_status_t prep_status = agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); - std::cerr << "[NixlBackend probe] prepXferDlist returned " - << prep_status << std::endl; if (dlist_handle) { agent_->releasedDlistH(dlist_handle); @@ -358,7 +350,7 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( nixl_xfer_dlist_t local_dlist = buildXferDlist(local_descs); nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_descs); - auto impl = std::make_unique(); + auto impl = std::make_unique(agent_.get()); nixl_status_t status = agent_->createXferReq( toNixlXferOp(op), local_dlist, diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index 1cc5b84a7b8..e9de9c8384b 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -52,17 +52,6 @@ inline TensorDesc toTensorDesc(const at::Tensor& tensor) { }; } -inline at::Tensor fromTensorDesc(const TensorDesc& desc) { - /* - Tensors must be valid on this device - */ - return at::from_blob( - reinterpret_cast(desc.addr), - {static_cast(desc.size)}, - at::TensorOptions().device(at::Device(at::kCUDA, desc.dev)).dtype(at::kByte) - ); -} - inline std::vector serializeTensorsDescs( const std::vector& descs) { size_t count = descs.size(); From 41ec0ac51decd5d80056dc57a04d840746db888a Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 4 Mar 2026 11:56:41 +0200 Subject: [PATCH 12/42] typos --- csrc/multidevice/communicator.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index f54d0535434..25fed9eebfc 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -11,7 +11,6 @@ #include #include -#include #include #ifdef NVFUSER_DISTRIBUTED @@ -129,7 +128,6 @@ class NVF_API Communicator { return store_.get(); } - private: Communicator( CommunicatorBackend backend = comm_backend_default, From d63ffd7cd02d800c930b0880c3abb7cedbd6655d Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 4 Mar 2026 12:01:32 +0200 Subject: [PATCH 13/42] set getAgentName to inline --- csrc/multidevice/nixl.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 06d4939b5fd..6cbfd737121 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -135,7 +135,7 @@ class NixlBackend::Impl { void waitTransfer(NixlTransferHandle& handle); private: - std::string constructAgentName(int64_t rank); + inline std::string getAgentName(int64_t rank); #ifdef USE_NIXL std::unique_ptr agent_; @@ -153,7 +153,7 @@ class NixlBackend::Impl { NixlBackend::Impl::Impl(Communicator& communicator) : communicator_(communicator) { #ifdef USE_NIXL - std::string agent_name = constructAgentName(communicator_.deviceId()); + std::string agent_name = getAgentName(communicator_.deviceId()); nixlAgentConfig cfg(false); agent_ = std::make_unique(agent_name, cfg); @@ -214,7 +214,7 @@ NixlBackend::Impl::Impl(Communicator& communicator) NixlBackend::Impl::~Impl() = default; -std::string NixlBackend::Impl::constructAgentName(int64_t rank){ +std::string NixlBackend::Impl::getAgentName(int64_t rank){ return "rank_" + std::to_string(rank); } @@ -345,7 +345,7 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( " vs ", remote_descs.size()); - std::string remote_agent_name = constructAgentName(remote_rank); + std::string remote_agent_name = getAgentName(remote_rank); nixl_xfer_dlist_t local_dlist = buildXferDlist(local_descs); nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_descs); From 86e50288cc42b754656fd90a80627341d1cf4066 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 4 Mar 2026 13:41:15 +0200 Subject: [PATCH 14/42] fix comments in nixl.cpp --- csrc/multidevice/nixl.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 6cbfd737121..7a3b1b173d5 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -325,13 +325,9 @@ void NixlBackend::Impl::exchangeMetadata() { // kRead -- data flows from remote into local // kWrite -- data flows from local into remote // -// remote_tensors are LOCAL tensors whose data_ptr identifies the -// corresponding registration slot. The actual remote addresses are -// looked up from the descriptors exchanged during exchangeMetadata(). -// This requires all ranks to register tensors in the same order. NixlTransferHandle NixlBackend::Impl::prepareTransfer( const std::vector& local_descs, // Local addresses - const std::vector& remote_descs, // Remote tensors (not valid on this rank) + const std::vector& remote_descs, // Remote tensors (cannot be dereferenced on this rank) int64_t remote_rank, NixlXferOp op) { NixlTransferHandle handle; From 7283aa89823f4dedc969cf3e1a742fb25e0069d0 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 4 Mar 2026 14:23:42 +0200 Subject: [PATCH 15/42] clean ifdef USE_NIXL statements --- csrc/multidevice/nixl.cpp | 118 +++++++++++++------------------------- 1 file changed, 41 insertions(+), 77 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 7a3b1b173d5..253298bc9bc 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -111,21 +111,19 @@ nixl_xfer_op_t toNixlXferOp(NixlXferOp op) { // NixlBackend::Impl // =================================================================== +#ifdef USE_NIXL + class NixlBackend::Impl { public: - explicit Impl(Communicator& communicator); + static std::unique_ptr create(Communicator& communicator); ~Impl(); - bool isAvailable() const { - return available_; - } - void registerTensors(const std::vector& tensors); void deregisterTensors(const std::vector& tensors); void exchangeMetadata(); NixlTransferHandle prepareTransfer( - const std::vector& local_descs, + const std::vector& local_descs, const std::vector& remote_descs, int64_t remote_rank, NixlXferOp op); @@ -135,14 +133,12 @@ class NixlBackend::Impl { void waitTransfer(NixlTransferHandle& handle); private: + explicit Impl(Communicator& communicator); inline std::string getAgentName(int64_t rank); -#ifdef USE_NIXL std::unique_ptr agent_; nixlBackendH* backend_ = nullptr; -#endif Communicator& communicator_; - bool available_ = false; bool metadata_exchanged_ = false; }; @@ -151,16 +147,21 @@ class NixlBackend::Impl { // ------------------------------------------------------------------- NixlBackend::Impl::Impl(Communicator& communicator) - : communicator_(communicator) { -#ifdef USE_NIXL - std::string agent_name = getAgentName(communicator_.deviceId()); + : communicator_(communicator) {} + +std::unique_ptr NixlBackend::Impl::create( + Communicator& communicator) { + std::unique_ptr impl(new Impl(communicator)); + + std::string agent_name = impl->getAgentName(communicator.deviceId()); nixlAgentConfig cfg(false); - agent_ = std::make_unique(agent_name, cfg); + impl->agent_ = std::make_unique(agent_name, cfg); nixl_b_params_t params; - nixl_status_t status = agent_->createBackend("UCX", params, backend_); + nixl_status_t status = + impl->agent_->createBackend("UCX", params, impl->backend_); if (status != NIXL_SUCCESS) { - agent_.reset(); + impl->agent_.reset(); NVF_THROW("Failed to create UCX backend for NIXL agent"); } @@ -175,7 +176,7 @@ NixlBackend::Impl::Impl(Communicator& communicator) auto probe = at::empty( {kProbeBytes}, at::TensorOptions().dtype(at::kByte).device( - at::kCUDA, communicator_.deviceId())); + at::kCUDA, communicator.deviceId())); size_t nbytes = static_cast(probe.nbytes()); uintptr_t addr = reinterpret_cast(probe.data_ptr()); uint32_t dev_idx = static_cast(probe.device().index()); @@ -186,9 +187,9 @@ NixlBackend::Impl::Impl(Communicator& communicator) nixl_reg_dlist_t reg_dlist(VRAM_SEG); reg_dlist.addDesc({addr, nbytes, static_cast(dev_idx)}); - nixl_status_t reg_status = agent_->registerMem(reg_dlist); + nixl_status_t reg_status = impl->agent_->registerMem(reg_dlist); if (reg_status != NIXL_SUCCESS) { - return; + return nullptr; } nixl_xfer_dlist_t xfer_dlist(VRAM_SEG); @@ -196,25 +197,24 @@ NixlBackend::Impl::Impl(Communicator& communicator) nixlDlistH* dlist_handle = nullptr; nixl_status_t prep_status = - agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); + impl->agent_->prepXferDlist(NIXL_INIT_AGENT, xfer_dlist, dlist_handle); if (dlist_handle) { - agent_->releasedDlistH(dlist_handle); + impl->agent_->releasedDlistH(dlist_handle); } - agent_->deregisterMem(reg_dlist); + impl->agent_->deregisterMem(reg_dlist); if (prep_status != NIXL_SUCCESS) { - return; + return nullptr; } } - available_ = true; -#endif + return impl; } NixlBackend::Impl::~Impl() = default; -std::string NixlBackend::Impl::getAgentName(int64_t rank){ +std::string NixlBackend::Impl::getAgentName(int64_t rank) { return "rank_" + std::to_string(rank); } @@ -225,8 +225,6 @@ std::string NixlBackend::Impl::getAgentName(int64_t rank){ // TODO - consider adding RAII wrapper void NixlBackend::Impl::registerTensors( const std::vector& tensors) { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); validateCudaTensors(tensors); nixl_reg_dlist_t dlist = buildRegDlist(tensors); @@ -237,16 +235,10 @@ void NixlBackend::Impl::registerTensors( static_cast(status)); metadata_exchanged_ = false; -#else - (void)tensors; - NVF_THROW("NIXL support not compiled"); -#endif } void NixlBackend::Impl::deregisterTensors( const std::vector& tensors) { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); validateCudaTensors(tensors); nixl_reg_dlist_t dlist = buildRegDlist(tensors); @@ -257,10 +249,6 @@ void NixlBackend::Impl::deregisterTensors( static_cast(status)); metadata_exchanged_ = false; -#else - (void)tensors; - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif } // ------------------------------------------------------------------- @@ -268,9 +256,6 @@ void NixlBackend::Impl::deregisterTensors( // ------------------------------------------------------------------- void NixlBackend::Impl::exchangeMetadata() { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); - nixl_blob_t local_md; nixl_status_t md_status = agent_->getLocalMD(local_md); NVF_ERROR( @@ -291,7 +276,7 @@ void NixlBackend::Impl::exchangeMetadata() { if (rank == my_rank) { continue; } - // Fetch & load MD + // Fetch & load MD auto bytes = store->get(md_key_prefix + std::to_string(rank)); nixl_blob_t remote_md(bytes.begin(), bytes.end()); std::string remote_agent_name; @@ -309,9 +294,6 @@ void NixlBackend::Impl::exchangeMetadata() { store->deleteKey(md_key_prefix + std::to_string(my_rank)); metadata_exchanged_ = true; -#else - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif } // ------------------------------------------------------------------- @@ -326,13 +308,10 @@ void NixlBackend::Impl::exchangeMetadata() { // kWrite -- data flows from local into remote // NixlTransferHandle NixlBackend::Impl::prepareTransfer( - const std::vector& local_descs, // Local addresses - const std::vector& remote_descs, // Remote tensors (cannot be dereferenced on this rank) + const std::vector& local_descs, + const std::vector& remote_descs, int64_t remote_rank, NixlXferOp op) { - NixlTransferHandle handle; -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); NVF_ERROR(metadata_exchanged_, "exchangeMetadata() must be called first"); NVF_ERROR( local_descs.size() == remote_descs.size(), @@ -359,14 +338,8 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( static_cast(status)); impl->prepared = true; + NixlTransferHandle handle; handle.impl_ = std::move(impl); -#else - (void)local_descs; - (void)remote_descs; - (void)remote_rank; - (void)op; - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif return handle; } @@ -375,8 +348,6 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( // ------------------------------------------------------------------- void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); NVF_ERROR(handle.isValid(), "Cannot post an invalid transfer handle"); NVF_ERROR( !handle.impl_->posted, @@ -389,10 +360,6 @@ void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { static_cast(status)); handle.impl_->posted = true; -#else - (void)handle; - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif } // ------------------------------------------------------------------- @@ -401,8 +368,6 @@ void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { NixlXferStatus NixlBackend::Impl::getTransferStatus( const NixlTransferHandle& handle) const { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); NVF_ERROR(handle.isValid(), "Cannot query status of an invalid handle"); NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); @@ -415,15 +380,9 @@ NixlXferStatus NixlBackend::Impl::getTransferStatus( default: return NixlXferStatus::kError; } -#else - (void)handle; - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif } void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { -#ifdef USE_NIXL - NVF_ERROR(available_, "NIXL backend is not available"); NVF_ERROR(handle.isValid(), "Cannot wait on an invalid handle"); NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); @@ -437,18 +396,23 @@ void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { } while (xfer_status == NixlXferStatus::kInProgress); handle.impl_->posted = false; -#else - (void)handle; - NVF_THROW("NIXL support not compiled (USE_NIXL not defined)"); -#endif } +#else // !USE_NIXL + +class NixlBackend::Impl {}; + +#endif // USE_NIXL + // =================================================================== // NixlBackend singleton + public API // =================================================================== -NixlBackend::NixlBackend() - : impl_(std::make_unique(Communicator::getInstance())) {} +NixlBackend::NixlBackend() { +#ifdef USE_NIXL + impl_ = Impl::create(Communicator::getInstance()); +#endif +} NixlBackend& NixlBackend::getInstance() { static auto* instance = new NixlBackend(); @@ -462,7 +426,7 @@ void NixlBackend::cleanup() { } bool NixlBackend::isAvailable() const { - return impl_ && impl_->isAvailable(); + return impl_ != nullptr; } void NixlBackend::registerTensors(const std::vector& tensors) { From a085c549d97252419489e38dc69920e5077e5fc5 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Sun, 8 Mar 2026 11:58:08 +0200 Subject: [PATCH 16/42] inline exchangeMetadata inside registerTensors --- csrc/multidevice/nixl.cpp | 9 +++------ csrc/multidevice/nixl.h | 19 ++++++------------- tests/cpp/test_multidevice_nixl.cpp | 24 +----------------------- 3 files changed, 10 insertions(+), 42 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 253298bc9bc..d3688fb105f 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -120,7 +120,6 @@ class NixlBackend::Impl { void registerTensors(const std::vector& tensors); void deregisterTensors(const std::vector& tensors); - void exchangeMetadata(); NixlTransferHandle prepareTransfer( const std::vector& local_descs, @@ -133,6 +132,7 @@ class NixlBackend::Impl { void waitTransfer(NixlTransferHandle& handle); private: + void exchangeMetadata(); explicit Impl(Communicator& communicator); inline std::string getAgentName(int64_t rank); @@ -235,6 +235,7 @@ void NixlBackend::Impl::registerTensors( static_cast(status)); metadata_exchanged_ = false; + exchangeMetadata(); } void NixlBackend::Impl::deregisterTensors( @@ -249,6 +250,7 @@ void NixlBackend::Impl::deregisterTensors( static_cast(status)); metadata_exchanged_ = false; + exchangeMetadata(); } // ------------------------------------------------------------------- @@ -439,11 +441,6 @@ void NixlBackend::deregisterTensors(const std::vector& tensors) { impl_->deregisterTensors(tensors); } -void NixlBackend::exchangeMetadata() { - NVF_CHECK(isAvailable(), "NIXL backend is not available"); - impl_->exchangeMetadata(); -} - NixlTransferHandle NixlBackend::prepareTransfer( const std::vector& local_descs, const std::vector& remote_descs, diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index e9de9c8384b..b6df00f4c25 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -142,11 +142,10 @@ class NVF_API NixlTransferHandle { // // Lifecycle: // 1. getInstance() - creates agent, loads UCX backend -// 2. registerTensors() - register GPU tensors for RDMA access -// 3. exchangeMetadata() - all ranks share their registration info -// 4. prepareTransfer() - expensive one-time setup per transfer pattern -// 5. postTransfer() - cheap, non-blocking data movement -// 6. waitTransfer() - block until complete +// 2. registerTensors() - register GPU tensors and exchange metadata (collective) +// 3. prepareTransfer() - expensive one-time setup per transfer pattern +// 4. postTransfer() - cheap, non-blocking data movement +// 5. waitTransfer() - block until complete // // Thread safety: methods are NOT thread-safe. The caller must // synchronize if the same NixlBackend is used from multiple threads. @@ -171,18 +170,12 @@ class NixlBackend { // Register CUDA tensors with the NIXL agent so they can participate // in RDMA transfers. Tensors must be contiguous and remain alive // until deregisterTensors() is called. + // Both methods are collective: they exchange agent metadata with all + // peers through the TCPStore, so all ranks must call them together. void registerTensors(const std::vector& tensors); void deregisterTensors(const std::vector& tensors); - // ------------------------------------------------------------------ - // Metadata exchange - // ------------------------------------------------------------------ - // Exchange local agent metadata with all peers through the TCPStore. - // Must be called after registerTensors() and before prepareTransfer() - // whenever the set of registered tensors changes. - void exchangeMetadata(); - // ------------------------------------------------------------------ // Transfer lifecycle // ------------------------------------------------------------------ diff --git a/tests/cpp/test_multidevice_nixl.cpp b/tests/cpp/test_multidevice_nixl.cpp index eb8de2ba3b8..f9257a00146 100644 --- a/tests/cpp/test_multidevice_nixl.cpp +++ b/tests/cpp/test_multidevice_nixl.cpp @@ -99,25 +99,6 @@ TEST_F(NixlTest, DeregisterEmptyTensorListThrows) { // Transfer preparation validation // ------------------------------------------------------------------- -TEST_F(NixlTest, PrepareTransferWithoutMetadataExchangeThrows) { - NixlBackend& backend = NixlBackend::getInstance(); - if (!backend.isAvailable()) { - GTEST_SKIP() << "NIXL backend not available"; - } - - auto local = at::randn({64}, tensor_options_); - auto remote = at::randn({64}, tensor_options_); - backend.registerTensors({local}); - backend.registerTensors({remote}); - - EXPECT_THROW( - (void)backend.prepareTransfer({toTensorDesc(local)}, {toTensorDesc(remote)}, 0, NixlXferOp::kRead), - nvfError); - - backend.deregisterTensors({local}); - backend.deregisterTensors({remote}); -} - TEST_F(NixlTest, PrepareTransferMismatchedSizesThrows) { NixlBackend& backend = NixlBackend::getInstance(); if (!backend.isAvailable()) { @@ -128,7 +109,6 @@ TEST_F(NixlTest, PrepareTransferMismatchedSizesThrows) { auto t2 = at::randn({64}, tensor_options_); auto t3 = at::randn({64}, tensor_options_); backend.registerTensors({t1, t2, t3}); - backend.exchangeMetadata(); EXPECT_THROW( (void)backend.prepareTransfer({toTensorDesc(t1), toTensorDesc(t2)}, {toTensorDesc(t3)}, 0, NixlXferOp::kRead), nvfError); @@ -194,8 +174,7 @@ TEST_F(NixlTest, ReadTransferEndToEnd) { cudaDeviceSynchronize(); backend.registerTensors({src, dst}); - backend.exchangeMetadata(); - + // Fetch the remote tensor descriptor from the peer std::string src_key_prefix = "nixl_test_read_transfer_src_rank_"; storeTensorDescs(*communicator_, src_key_prefix + std::to_string(rank), {src}); @@ -240,7 +219,6 @@ TEST_F(NixlTest, WriteTransferEndToEnd) { cudaDeviceSynchronize(); backend.registerTensors({src, dst}); - backend.exchangeMetadata(); // Fetch the remote tensor descriptor from the peer std::string dst_key_prefix = "nixl_test_write_transfer_dst_rank_"; From 13ae58f93227a8d3d3cd29be98ee022d2167ab5d Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Sun, 8 Mar 2026 12:31:59 +0200 Subject: [PATCH 17/42] include deviceId (rank) inside TensorDesc --- csrc/multidevice/nixl.cpp | 2 +- csrc/multidevice/nixl.h | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index d3688fb105f..708da13ab54 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -134,7 +134,7 @@ class NixlBackend::Impl { private: void exchangeMetadata(); explicit Impl(Communicator& communicator); - inline std::string getAgentName(int64_t rank); + inline std::string getAgentName(int64_t device_id); std::unique_ptr agent_; nixlBackendH* backend_ = nullptr; diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index b6df00f4c25..89f15397d62 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -39,16 +39,16 @@ enum class NixlXferStatus { struct TensorDesc { uintptr_t addr; size_t size; - uint32_t dev; + uint32_t dev; // deviceId (rank) owning this tensor }; static_assert(std::is_trivially_copyable_v, "TensorDesc must be trivially copyable for serialization"); -inline TensorDesc toTensorDesc(const at::Tensor& tensor) { +inline TensorDesc toTensorDesc(const at::Tensor& tensor, int64_t device_id) { return { .addr = reinterpret_cast(tensor.data_ptr()), .size = static_cast(tensor.numel()) * tensor.element_size(), - .dev = static_cast(tensor.device().index()) + .dev = static_cast(device_id) }; } @@ -95,7 +95,7 @@ inline void storeTensorDescs(Communicator& communicator, const std::string& key, std::vector descs; descs.reserve(tensors.size()); for (const auto& tensor : tensors) { - descs.push_back(toTensorDesc(tensor)); + descs.push_back(toTensorDesc(tensor, communicator.deviceId())); } storeTensorDescs(communicator, key, descs); } @@ -188,7 +188,6 @@ class NixlBackend { [[nodiscard]] NixlTransferHandle prepareTransfer( const std::vector& local_descs, const std::vector& remote_descs, - int64_t remote_rank, NixlXferOp op); // Post a previously prepared transfer for execution (non-blocking). From 149c15a7fa2a35c6886a11102dcae5e083ca5b21 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Sun, 8 Mar 2026 12:47:16 +0200 Subject: [PATCH 18/42] remove useless handleImpl.isPrepared --- csrc/multidevice/nixl.cpp | 19 +++---------------- csrc/multidevice/nixl.h | 2 -- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 708da13ab54..dccf3f2062c 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -36,7 +36,6 @@ class NixlTransferHandleImpl { } } #endif - bool prepared = false; bool posted = false; }; @@ -47,17 +46,6 @@ NixlTransferHandle::NixlTransferHandle(NixlTransferHandle&&) noexcept = NixlTransferHandle& NixlTransferHandle::operator=( NixlTransferHandle&&) noexcept = default; -bool NixlTransferHandle::isValid() const { - if (!impl_) { - return false; - } -#ifdef USE_NIXL - return impl_->prepared; -#else - return false; -#endif -} - // =================================================================== // Tensor validation and descriptor helpers // =================================================================== @@ -339,7 +327,6 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( "NIXL createXferReq failed with status ", static_cast(status)); - impl->prepared = true; NixlTransferHandle handle; handle.impl_ = std::move(impl); return handle; @@ -350,7 +337,7 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( // ------------------------------------------------------------------- void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { - NVF_ERROR(handle.isValid(), "Cannot post an invalid transfer handle"); + NVF_ERROR(handle.impl_, "Transfer handle is empty - was it moved from?"); NVF_ERROR( !handle.impl_->posted, "Transfer already posted. Wait for completion before re-posting."); @@ -370,7 +357,7 @@ void NixlBackend::Impl::postTransfer(NixlTransferHandle& handle) { NixlXferStatus NixlBackend::Impl::getTransferStatus( const NixlTransferHandle& handle) const { - NVF_ERROR(handle.isValid(), "Cannot query status of an invalid handle"); + NVF_ERROR(handle.impl_, "Transfer handle is empty - was it moved from?"); NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); nixl_status_t status = agent_->getXferStatus(handle.impl_->xfer_handle); @@ -385,7 +372,7 @@ NixlXferStatus NixlBackend::Impl::getTransferStatus( } void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { - NVF_ERROR(handle.isValid(), "Cannot wait on an invalid handle"); + NVF_ERROR(handle.impl_, "Transfer handle is empty - was it moved from?"); NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); // TODO - check this spin loop diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index 89f15397d62..fdb86211e52 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -127,8 +127,6 @@ class NVF_API NixlTransferHandle { NixlTransferHandle(const NixlTransferHandle&) = delete; NixlTransferHandle& operator=(const NixlTransferHandle&) = delete; - [[nodiscard]] bool isValid() const; - private: friend class NixlBackend; std::unique_ptr impl_; From 1b4178894a15d5da43d4a5fa62c442cd243af4dd Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Sun, 8 Mar 2026 15:13:02 +0200 Subject: [PATCH 19/42] add thread yield in wait transfer loop --- csrc/multidevice/nixl.cpp | 4 +++- csrc/multidevice/nixl.h | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index dccf3f2062c..657cf8deb73 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -375,13 +375,15 @@ void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { NVF_ERROR(handle.impl_, "Transfer handle is empty - was it moved from?"); NVF_ERROR(handle.impl_->posted, "Transfer has not been posted yet"); - // TODO - check this spin loop NixlXferStatus xfer_status; do { xfer_status = getTransferStatus(handle); NVF_ERROR( xfer_status != NixlXferStatus::kError, "NIXL transfer completed with an error"); + if (xfer_status == NixlXferStatus::kInProgress) { + std::this_thread::yield(); + } } while (xfer_status == NixlXferStatus::kInProgress); handle.impl_->posted = false; diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index fdb86211e52..077a08f53c2 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -169,7 +169,7 @@ class NixlBackend { // in RDMA transfers. Tensors must be contiguous and remain alive // until deregisterTensors() is called. // Both methods are collective: they exchange agent metadata with all - // peers through the TCPStore, so all ranks must call them together. + // peers through the TCPStore, so all ranks must call them together and in the same order. void registerTensors(const std::vector& tensors); void deregisterTensors(const std::vector& tensors); From 2eccaa59ece4f4b336e4c12342d797c0dfb7d1a0 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Sun, 8 Mar 2026 15:22:54 +0200 Subject: [PATCH 20/42] remove remote_rank from prepare transfer --- csrc/multidevice/communicator.h | 2 +- csrc/multidevice/nixl.cpp | 38 +++++++++++++++++++++++------ tests/cpp/test_multidevice_nixl.cpp | 35 +++----------------------- 3 files changed, 35 insertions(+), 40 deletions(-) diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index 25fed9eebfc..61f76783d12 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -126,7 +126,7 @@ class NVF_API Communicator { c10d::TCPStore* getTcpStore() { return store_.get(); -} + } private: Communicator( diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 657cf8deb73..1fe32769df1 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -112,7 +112,6 @@ class NixlBackend::Impl { NixlTransferHandle prepareTransfer( const std::vector& local_descs, const std::vector& remote_descs, - int64_t remote_rank, NixlXferOp op); void postTransfer(NixlTransferHandle& handle); @@ -266,7 +265,6 @@ void NixlBackend::Impl::exchangeMetadata() { if (rank == my_rank) { continue; } - // Fetch & load MD auto bytes = store->get(md_key_prefix + std::to_string(rank)); nixl_blob_t remote_md(bytes.begin(), bytes.end()); std::string remote_agent_name; @@ -300,9 +298,11 @@ void NixlBackend::Impl::exchangeMetadata() { NixlTransferHandle NixlBackend::Impl::prepareTransfer( const std::vector& local_descs, const std::vector& remote_descs, - int64_t remote_rank, NixlXferOp op) { NVF_ERROR(metadata_exchanged_, "exchangeMetadata() must be called first"); + NVF_ERROR( + !remote_descs.empty(), + "remote_descs must not be empty"); NVF_ERROR( local_descs.size() == remote_descs.size(), "Local and remote tensor lists must have the same size. Got ", @@ -310,7 +310,7 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( " vs ", remote_descs.size()); - std::string remote_agent_name = getAgentName(remote_rank); + std::string remote_agent_name = getAgentName(remote_descs.at(0).dev); nixl_xfer_dlist_t local_dlist = buildXferDlist(local_descs); nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_descs); @@ -391,7 +391,31 @@ void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { #else // !USE_NIXL -class NixlBackend::Impl {}; +class NixlBackend::Impl { + public: + static std::unique_ptr create(Communicator&) { return nullptr; } + void registerTensors(const std::vector&) { + NVF_THROW("NIXL not available"); + } + void deregisterTensors(const std::vector&) { + NVF_THROW("NIXL not available"); + } + NixlTransferHandle prepareTransfer( + const std::vector&, + const std::vector&, + NixlXferOp) { + NVF_THROW("NIXL not available"); + } + void postTransfer(NixlTransferHandle&) { + NVF_THROW("NIXL not available"); + } + NixlXferStatus getTransferStatus(const NixlTransferHandle&) const { + NVF_THROW("NIXL not available"); + } + void waitTransfer(NixlTransferHandle&) { + NVF_THROW("NIXL not available"); + } +}; #endif // USE_NIXL @@ -433,11 +457,9 @@ void NixlBackend::deregisterTensors(const std::vector& tensors) { NixlTransferHandle NixlBackend::prepareTransfer( const std::vector& local_descs, const std::vector& remote_descs, - int64_t remote_rank, NixlXferOp op) { NVF_CHECK(isAvailable(), "NIXL backend is not available"); - return impl_->prepareTransfer( - local_descs, remote_descs, remote_rank, op); + return impl_->prepareTransfer(local_descs, remote_descs, op); } void NixlBackend::postTransfer(NixlTransferHandle& handle) { diff --git a/tests/cpp/test_multidevice_nixl.cpp b/tests/cpp/test_multidevice_nixl.cpp index f9257a00146..71dcdd724d9 100644 --- a/tests/cpp/test_multidevice_nixl.cpp +++ b/tests/cpp/test_multidevice_nixl.cpp @@ -14,30 +14,6 @@ namespace nvfuser { using NixlTest = MultiDeviceTest; -// ------------------------------------------------------------------- -// NixlTransferHandle tests -// ------------------------------------------------------------------- - -TEST_F(NixlTest, TransferHandleDefaultConstruction) { - NixlTransferHandle handle; - EXPECT_FALSE(handle.isValid()); -} - -TEST_F(NixlTest, TransferHandleMoveConstruction) { - NixlTransferHandle h1; - EXPECT_FALSE(h1.isValid()); - - NixlTransferHandle h2(std::move(h1)); - EXPECT_FALSE(h2.isValid()); -} - -TEST_F(NixlTest, TransferHandleMoveAssignment) { - NixlTransferHandle h1; - NixlTransferHandle h2; - h2 = std::move(h1); - EXPECT_FALSE(h2.isValid()); -} - // ------------------------------------------------------------------- // NixlBackend singleton tests // ------------------------------------------------------------------- @@ -110,8 +86,9 @@ TEST_F(NixlTest, PrepareTransferMismatchedSizesThrows) { auto t3 = at::randn({64}, tensor_options_); backend.registerTensors({t1, t2, t3}); + const int64_t rank = communicator_->deviceId(); EXPECT_THROW( - (void)backend.prepareTransfer({toTensorDesc(t1), toTensorDesc(t2)}, {toTensorDesc(t3)}, 0, NixlXferOp::kRead), nvfError); + (void)backend.prepareTransfer({toTensorDesc(t1, rank), toTensorDesc(t2, rank)}, {toTensorDesc(t3, rank)}, NixlXferOp::kRead), nvfError); backend.deregisterTensors({t1, t2, t3}); } @@ -186,9 +163,7 @@ TEST_F(NixlTest, ReadTransferEndToEnd) { // Each rank reads from its peer. After the read, local should contain // the values that the peer stored in *its* remote tensor. auto handle = backend.prepareTransfer( - {toTensorDesc(dst)}, {remote_src_desc}, peer_rank, NixlXferOp::kRead); - ASSERT_TRUE(handle.isValid()); - + {toTensorDesc(dst, rank)}, {remote_src_desc}, NixlXferOp::kRead); backend.postTransfer(handle); backend.waitTransfer(handle); @@ -230,9 +205,7 @@ TEST_F(NixlTest, WriteTransferEndToEnd) { // Each rank writes its local tensor into its peer's remote tensor. auto handle = backend.prepareTransfer( - {toTensorDesc(src)}, {remote_dst_desc}, peer_rank, NixlXferOp::kWrite); - ASSERT_TRUE(handle.isValid()); - + {toTensorDesc(src, rank)}, {remote_dst_desc}, NixlXferOp::kWrite); backend.postTransfer(handle); backend.waitTransfer(handle); From 10d010af80abc036e5eb3a2517b7ec331d188dd2 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Sun, 8 Mar 2026 15:23:42 +0200 Subject: [PATCH 21/42] add nixlbackend::impl when use_nixl is false --- csrc/multidevice/nixl.cpp | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 1fe32769df1..cf86401e57c 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -393,28 +393,21 @@ void NixlBackend::Impl::waitTransfer(NixlTransferHandle& handle) { class NixlBackend::Impl { public: - static std::unique_ptr create(Communicator&) { return nullptr; } - void registerTensors(const std::vector&) { - NVF_THROW("NIXL not available"); - } - void deregisterTensors(const std::vector&) { - NVF_THROW("NIXL not available"); - } + void registerTensors(const std::vector&) {} + void deregisterTensors(const std::vector&) {} + void exchangeMetadata() {} NixlTransferHandle prepareTransfer( const std::vector&, const std::vector&, + int64_t, NixlXferOp) { - NVF_THROW("NIXL not available"); - } - void postTransfer(NixlTransferHandle&) { - NVF_THROW("NIXL not available"); + return {}; } + void postTransfer(NixlTransferHandle&) {} NixlXferStatus getTransferStatus(const NixlTransferHandle&) const { - NVF_THROW("NIXL not available"); - } - void waitTransfer(NixlTransferHandle&) { - NVF_THROW("NIXL not available"); + return NixlXferStatus::kError; } + void waitTransfer(NixlTransferHandle&) {} }; #endif // USE_NIXL From e9062a4a5810e3aaf8e4454e37a44cbfaadbe056 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Mon, 9 Mar 2026 14:33:30 +0200 Subject: [PATCH 22/42] Move exchangeMetadata to private when USE_NIXL is false --- csrc/multidevice/nixl.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index cf86401e57c..c9d18619021 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -395,11 +395,9 @@ class NixlBackend::Impl { public: void registerTensors(const std::vector&) {} void deregisterTensors(const std::vector&) {} - void exchangeMetadata() {} NixlTransferHandle prepareTransfer( const std::vector&, const std::vector&, - int64_t, NixlXferOp) { return {}; } @@ -408,6 +406,9 @@ class NixlBackend::Impl { return NixlXferStatus::kError; } void waitTransfer(NixlTransferHandle&) {} + + private: + void exchangeMetadata() {} }; #endif // USE_NIXL From 9047991d3133d7de67c3093f69b43c2d69d75444 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Mon, 9 Mar 2026 15:16:47 +0200 Subject: [PATCH 23/42] fix CI: clang-format, clang-tidy, and trailing newline Made-with: Cursor --- csrc/multidevice/communicator.h | 1 - csrc/multidevice/nixl.cpp | 10 ++---- csrc/multidevice/nixl.h | 56 ++++++++++++++++------------- tests/cpp/test_multidevice_nixl.cpp | 36 +++++++++++++------ 4 files changed, 60 insertions(+), 43 deletions(-) diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index 61f76783d12..cd3b0876e6c 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -160,5 +160,4 @@ class NVF_API Communicator { std::unordered_map> backends_; }; - } // namespace nvfuser diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index c9d18619021..771b5adfc5b 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -41,8 +41,7 @@ class NixlTransferHandleImpl { NixlTransferHandle::NixlTransferHandle() = default; NixlTransferHandle::~NixlTransferHandle() = default; -NixlTransferHandle::NixlTransferHandle(NixlTransferHandle&&) noexcept = - default; +NixlTransferHandle::NixlTransferHandle(NixlTransferHandle&&) noexcept = default; NixlTransferHandle& NixlTransferHandle::operator=( NixlTransferHandle&&) noexcept = default; @@ -300,9 +299,7 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( const std::vector& remote_descs, NixlXferOp op) { NVF_ERROR(metadata_exchanged_, "exchangeMetadata() must be called first"); - NVF_ERROR( - !remote_descs.empty(), - "remote_descs must not be empty"); + NVF_ERROR(!remote_descs.empty(), "remote_descs must not be empty"); NVF_ERROR( local_descs.size() == remote_descs.size(), "Local and remote tensor lists must have the same size. Got ", @@ -472,5 +469,4 @@ void NixlBackend::waitTransfer(NixlTransferHandle& handle) { impl_->waitTransfer(handle); } - -} // namespace nvfuser \ No newline at end of file +} // namespace nvfuser diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index 077a08f53c2..f7be3a28314 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -9,9 +9,9 @@ #include #include +#include #include #include -#include #include "exceptions.h" #include "multidevice/communicator.h" @@ -35,21 +35,22 @@ enum class NixlXferStatus { // ------------------------------------------------------------------ // Todo - those functions should be moved to a more global file -// Helper functions for serializing and deserializing tensors descriptors for TCP store +// Helper functions for serializing and deserializing tensors descriptors for +// TCP store struct TensorDesc { uintptr_t addr; size_t size; uint32_t dev; // deviceId (rank) owning this tensor }; -static_assert(std::is_trivially_copyable_v, - "TensorDesc must be trivially copyable for serialization"); +static_assert( + std::is_trivially_copyable_v, + "TensorDesc must be trivially copyable for serialization"); inline TensorDesc toTensorDesc(const at::Tensor& tensor, int64_t device_id) { return { - .addr = reinterpret_cast(tensor.data_ptr()), - .size = static_cast(tensor.numel()) * tensor.element_size(), - .dev = static_cast(device_id) - }; + .addr = reinterpret_cast(tensor.data_ptr()), + .size = static_cast(tensor.numel()) * tensor.element_size(), + .dev = static_cast(device_id)}; } inline std::vector serializeTensorsDescs( @@ -57,9 +58,10 @@ inline std::vector serializeTensorsDescs( size_t count = descs.size(); std::vector buf(sizeof(count) + count * sizeof(TensorDesc)); std::memcpy(buf.data(), &count, sizeof(count)); - if (count == 0) + if (count == 0) { return buf; - + } + std::memcpy( buf.data() + sizeof(count), descs.data(), @@ -79,19 +81,23 @@ inline std::vector deserializeTensorsDescs( std::vector descs(count); if (count > 0) { std::memcpy( - descs.data(), - buf.data() + sizeof(count), - count * sizeof(TensorDesc)); + descs.data(), buf.data() + sizeof(count), count * sizeof(TensorDesc)); } return descs; } -inline void storeTensorDescs(Communicator& communicator, const std::string& key, const std::vector& descs) { +inline void storeTensorDescs( + Communicator& communicator, + const std::string& key, + const std::vector& descs) { NVF_CHECK(communicator.is_available(), "Communicator is not available"); communicator.getTcpStore()->set(key, serializeTensorsDescs(descs)); } -inline void storeTensorDescs(Communicator& communicator, const std::string& key, const std::vector& tensors) { +inline void storeTensorDescs( + Communicator& communicator, + const std::string& key, + const std::vector& tensors) { std::vector descs; descs.reserve(tensors.size()); for (const auto& tensor : tensors) { @@ -100,7 +106,9 @@ inline void storeTensorDescs(Communicator& communicator, const std::string& key, storeTensorDescs(communicator, key, descs); } -inline std::vector fetchTensorDescs(Communicator& communicator, const std::string& key) { +inline std::vector fetchTensorDescs( + Communicator& communicator, + const std::string& key) { NVF_CHECK(communicator.is_available(), "Communicator is not available"); auto bytes = communicator.getTcpStore()->get(key); return deserializeTensorsDescs(bytes); @@ -135,12 +143,13 @@ class NVF_API NixlTransferHandle { // ------------------------------------------------------------------- // NixlBackend: singleton NIXL backend over UCX for GPU tensors // ------------------------------------------------------------------- -// Singleton - Wraps a nixlAgent with the UCX backend and provides a tensor-level -// API for registering GPU memory and performing RDMA transfers. +// Singleton - Wraps a nixlAgent with the UCX backend and provides a +// tensor-level API for registering GPU memory and performing RDMA transfers. // // Lifecycle: // 1. getInstance() - creates agent, loads UCX backend -// 2. registerTensors() - register GPU tensors and exchange metadata (collective) +// 2. registerTensors() - register GPU tensors and exchange metadata +// (collective) // 3. prepareTransfer() - expensive one-time setup per transfer pattern // 4. postTransfer() - cheap, non-blocking data movement // 5. waitTransfer() - block until complete @@ -169,7 +178,8 @@ class NixlBackend { // in RDMA transfers. Tensors must be contiguous and remain alive // until deregisterTensors() is called. // Both methods are collective: they exchange agent metadata with all - // peers through the TCPStore, so all ranks must call them together and in the same order. + // peers through the TCPStore, so all ranks must call them together and in the + // same order. void registerTensors(const std::vector& tensors); void deregisterTensors(const std::vector& tensors); @@ -192,7 +202,8 @@ class NixlBackend { void postTransfer(NixlTransferHandle& handle); // Poll the status of a posted transfer without blocking. - [[nodiscard]] NixlXferStatus getTransferStatus(const NixlTransferHandle& handle) const; + [[nodiscard]] NixlXferStatus getTransferStatus( + const NixlTransferHandle& handle) const; // Block until the transfer completes (or errors out). void waitTransfer(NixlTransferHandle& handle); @@ -205,7 +216,4 @@ class NixlBackend { std::unique_ptr impl_; }; - - - } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_nixl.cpp b/tests/cpp/test_multidevice_nixl.cpp index 71dcdd724d9..a6868ef7542 100644 --- a/tests/cpp/test_multidevice_nixl.cpp +++ b/tests/cpp/test_multidevice_nixl.cpp @@ -88,7 +88,11 @@ TEST_F(NixlTest, PrepareTransferMismatchedSizesThrows) { const int64_t rank = communicator_->deviceId(); EXPECT_THROW( - (void)backend.prepareTransfer({toTensorDesc(t1, rank), toTensorDesc(t2, rank)}, {toTensorDesc(t3, rank)}, NixlXferOp::kRead), nvfError); + (void)backend.prepareTransfer( + {toTensorDesc(t1, rank), toTensorDesc(t2, rank)}, + {toTensorDesc(t3, rank)}, + NixlXferOp::kRead), + nvfError); backend.deregisterTensors({t1, t2, t3}); } @@ -145,7 +149,8 @@ TEST_F(NixlTest, ReadTransferEndToEnd) { const int64_t peer_rank = (rank + 1) % world_size; constexpr int64_t kSize = 1024; - // Ring style transfer: each rank reads from its peer's remote tensor to its local . + // Ring style transfer: each rank reads from its peer's remote tensor to + // its local . auto src = at::full({kSize}, static_cast(rank + 1), tensor_options_); auto dst = at::zeros({kSize}, tensor_options_); cudaDeviceSynchronize(); @@ -154,11 +159,15 @@ TEST_F(NixlTest, ReadTransferEndToEnd) { // Fetch the remote tensor descriptor from the peer std::string src_key_prefix = "nixl_test_read_transfer_src_rank_"; - storeTensorDescs(*communicator_, src_key_prefix + std::to_string(rank), {src}); - auto remote_src_descs = fetchTensorDescs(*communicator_, src_key_prefix + std::to_string(peer_rank)); + storeTensorDescs( + *communicator_, src_key_prefix + std::to_string(rank), {src}); + auto remote_src_descs = fetchTensorDescs( + *communicator_, src_key_prefix + std::to_string(peer_rank)); communicator_->barrier(); - communicator_->getTcpStore()->deleteKey(src_key_prefix + std::to_string(rank)); - auto remote_src_desc = remote_src_descs[0]; // Only one remote tensor is expected + communicator_->getTcpStore()->deleteKey( + src_key_prefix + std::to_string(rank)); + auto remote_src_desc = + remote_src_descs[0]; // Only one remote tensor is expected // Each rank reads from its peer. After the read, local should contain // the values that the peer stored in *its* remote tensor. @@ -188,7 +197,8 @@ TEST_F(NixlTest, WriteTransferEndToEnd) { const int64_t peer_rank = (rank + 1) % world_size; constexpr int64_t kSize = 512; - // Each rank writes its local to the remote of its peer in a ring style + // Each rank writes its local to the remote of its peer in a ring + // style auto src = at::full({kSize}, static_cast(rank + 1), tensor_options_); auto dst = at::zeros({kSize}, tensor_options_); cudaDeviceSynchronize(); @@ -197,11 +207,15 @@ TEST_F(NixlTest, WriteTransferEndToEnd) { // Fetch the remote tensor descriptor from the peer std::string dst_key_prefix = "nixl_test_write_transfer_dst_rank_"; - storeTensorDescs(*communicator_, dst_key_prefix + std::to_string(rank), {dst}); - auto remote_dst_descs = fetchTensorDescs(*communicator_, dst_key_prefix + std::to_string(peer_rank)); + storeTensorDescs( + *communicator_, dst_key_prefix + std::to_string(rank), {dst}); + auto remote_dst_descs = fetchTensorDescs( + *communicator_, dst_key_prefix + std::to_string(peer_rank)); communicator_->barrier(); - communicator_->getTcpStore()->deleteKey(dst_key_prefix + std::to_string(rank)); - auto remote_dst_desc = remote_dst_descs[0]; // Only one remote tensor is expected + communicator_->getTcpStore()->deleteKey( + dst_key_prefix + std::to_string(rank)); + auto remote_dst_desc = + remote_dst_descs[0]; // Only one remote tensor is expected // Each rank writes its local tensor into its peer's remote tensor. auto handle = backend.prepareTransfer( From b3b5fbd027b58a80ec2271419282fbb2b3773aad Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Mon, 9 Mar 2026 15:52:50 +0200 Subject: [PATCH 24/42] fix ci --- csrc/multidevice/nixl.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index f7be3a28314..419d7c4485a 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -22,12 +22,12 @@ namespace nvfuser { // Transfer direction. NIXL uses a one-sided model: // Read = pull remote data into local buffers // Write = push local data into remote buffers -enum class NixlXferOp { +enum class NixlXferOp : std::uint8_t { kRead, kWrite, }; -enum class NixlXferStatus { +enum class NixlXferStatus : std::uint8_t { kDone, kInProgress, kError, @@ -72,7 +72,7 @@ inline std::vector serializeTensorsDescs( inline std::vector deserializeTensorsDescs( const std::vector& buf) { NVF_ERROR(buf.size() >= sizeof(size_t), "Invalid serialized descriptor data"); - size_t count; + size_t count = 0; std::memcpy(&count, buf.data(), sizeof(count)); NVF_ERROR( buf.size() == sizeof(count) + count * sizeof(TensorDesc), From b283c2aa67c55af862ead1a62b69f6232e89bb67 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Mon, 9 Mar 2026 16:11:13 +0200 Subject: [PATCH 25/42] fix CI --- csrc/multidevice/communicator.cpp | 12 +++++++----- csrc/multidevice/multidevice.h | 3 ++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/csrc/multidevice/communicator.cpp b/csrc/multidevice/communicator.cpp index c021a129670..168da4af93a 100644 --- a/csrc/multidevice/communicator.cpp +++ b/csrc/multidevice/communicator.cpp @@ -9,6 +9,7 @@ #include +#include #include #include #include @@ -124,7 +125,8 @@ bool parseEnv( } // retrieves master port - if ((env = std::getenv("NVFUSER_MASTER_PORT")) != nullptr) { + env = std::getenv("NVFUSER_MASTER_PORT"); + if (env != nullptr) { master_port = std::atoi(env); } else { LOG(INFO) << "The environment variable NVFUSER_MASTER_PORT has not been " @@ -256,10 +258,10 @@ void waitForDebuggerAtRanks( std::cerr << "Process " << pid << " is waiting for the debugger. To continue debugging, " << "start gdb, `attach " << pid - << "`, `set var waiting=false`, and `fini`." << std::endl; + << "`, `set var waiting=false`, and `fini`." << '\n'; while (waiting) { // Please change `waiting` in the debugger. } - std::cerr << "Process " << getpid() << " finished waiting." << std::endl; + std::cerr << "Process " << getpid() << " finished waiting." << '\n'; } if (communicator->is_available()) { @@ -362,7 +364,7 @@ void Communicator::cleanup() { // in different orders between ranks have been causing a hang. std::vector>> keyed_backends(backends_.begin(), backends_.end()); - std::sort(keyed_backends.begin(), keyed_backends.end()); + std::ranges::sort(keyed_backends); for (auto& [key, backend] : keyed_backends) { // Call shutdown before destructing a ProcessGroupNCCL as instructed by // https://github.com/pytorch/pytorch/blob/e62073d7997c9e63896cb5289ffd0874a8cc1838/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1164-L1170. @@ -396,7 +398,7 @@ c10d::Backend* Communicator::getBackendForTeam( #ifdef NVFUSER_DISTRIBUTED backends_[team_key] = [&]() -> c10::intrusive_ptr { // check that the caller's rank belongs to the requested team - auto rank_it = std::find(team.begin(), team.end(), deviceId()); + auto rank_it = std::ranges::find(team, deviceId()); if (rank_it == team.end()) { return nullptr; } diff --git a/csrc/multidevice/multidevice.h b/csrc/multidevice/multidevice.h index 7915f5e3d92..738dc46b1f2 100644 --- a/csrc/multidevice/multidevice.h +++ b/csrc/multidevice/multidevice.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include @@ -19,5 +20,5 @@ using DeviceType = c10::Device; using Team = std::vector; // Supported backends. -enum class CommunicatorBackend { kNccl, kUcc, kCuda, kNixl }; +enum class CommunicatorBackend : std::uint8_t { kNccl, kUcc, kCuda, kNixl }; } // namespace nvfuser From c4726d08cbd8fedb002b7f802d7910e31cfb3085 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Mon, 9 Mar 2026 18:37:07 +0200 Subject: [PATCH 26/42] Separate device and rank in tensordesc for more clarity --- csrc/multidevice/communicator.h | 2 -- csrc/multidevice/nixl.cpp | 9 +++++++-- csrc/multidevice/nixl.h | 8 +++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/csrc/multidevice/communicator.h b/csrc/multidevice/communicator.h index cd3b0876e6c..7d1e0b89305 100644 --- a/csrc/multidevice/communicator.h +++ b/csrc/multidevice/communicator.h @@ -11,8 +11,6 @@ #include #include -#include - #ifdef NVFUSER_DISTRIBUTED #include #include diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 771b5adfc5b..1f8275698e1 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -120,7 +121,7 @@ class NixlBackend::Impl { private: void exchangeMetadata(); explicit Impl(Communicator& communicator); - inline std::string getAgentName(int64_t device_id); + inline std::string getAgentName(int64_t rank); std::unique_ptr agent_; nixlBackendH* backend_ = nullptr; @@ -306,8 +307,12 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( local_descs.size(), " vs ", remote_descs.size()); + NVF_ERROR( + std::all_of(remote_descs.begin(), remote_descs.end(), + [&](const TensorDesc& d){ return d.rank == remote_descs[0].rank; }), + "All remote descriptors must belong to the same remote peer"); - std::string remote_agent_name = getAgentName(remote_descs.at(0).dev); + std::string remote_agent_name = getAgentName(remote_descs.at(0).rank); nixl_xfer_dlist_t local_dlist = buildXferDlist(local_descs); nixl_xfer_dlist_t remote_dlist = buildXferDlist(remote_descs); diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index 419d7c4485a..8226ef1b047 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -40,17 +40,19 @@ enum class NixlXferStatus : std::uint8_t { struct TensorDesc { uintptr_t addr; size_t size; - uint32_t dev; // deviceId (rank) owning this tensor + uint32_t dev; // CUDA device index (tensor.device().index()) + int64_t rank; // communicator rank owning this tensor }; static_assert( std::is_trivially_copyable_v, "TensorDesc must be trivially copyable for serialization"); -inline TensorDesc toTensorDesc(const at::Tensor& tensor, int64_t device_id) { +inline TensorDesc toTensorDesc(const at::Tensor& tensor, int64_t rank) { return { .addr = reinterpret_cast(tensor.data_ptr()), .size = static_cast(tensor.numel()) * tensor.element_size(), - .dev = static_cast(device_id)}; + .dev = static_cast(tensor.device().index()), + .rank = rank}; } inline std::vector serializeTensorsDescs( From 7ff4aae401602647e1f93823d2f38f879c77a694 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Tue, 10 Mar 2026 11:36:09 +0200 Subject: [PATCH 27/42] fix linter --- csrc/multidevice/nixl.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 1f8275698e1..7fd69218af5 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -308,8 +308,10 @@ NixlTransferHandle NixlBackend::Impl::prepareTransfer( " vs ", remote_descs.size()); NVF_ERROR( - std::all_of(remote_descs.begin(), remote_descs.end(), - [&](const TensorDesc& d){ return d.rank == remote_descs[0].rank; }), + std::all_of( + remote_descs.begin(), + remote_descs.end(), + [&](const TensorDesc& d) { return d.rank == remote_descs[0].rank; }), "All remote descriptors must belong to the same remote peer"); std::string remote_agent_name = getAgentName(remote_descs.at(0).rank); From 073924045e911421737fbd4f7a8c4d4fee241c71 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Tue, 10 Mar 2026 12:01:07 +0200 Subject: [PATCH 28/42] Update cmake config --- CMakeLists.txt | 27 +---- cmake/DependencyRequirements.cmake | 3 + cmake/deps/handle_nixl.cmake | 71 ++++++++++++ python/tools/check_dependencies.py | 2 + python/tools/prereqs/__init__.py | 2 + python/tools/prereqs/requirements/__init__.py | 2 + python/tools/prereqs/requirements/nixl.py | 102 ++++++++++++++++++ 7 files changed, 187 insertions(+), 22 deletions(-) create mode 100644 cmake/deps/handle_nixl.cmake create mode 100644 python/tools/prereqs/requirements/nixl.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 296be226ae6..a7e2ee25121 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,6 +77,7 @@ include(cmake/deps/handle_torch.cmake) include(cmake/deps/handle_pybind11.cmake) include(cmake/deps/handle_llvm.cmake) include(cmake/deps/handle_nvmmh.cmake) +include(cmake/deps/handle_nixl.cmake) include(cmake/deps/handle_git_submodules.cmake) # Initialize success flag @@ -96,6 +97,7 @@ handle_torch() # Must come AFTER python and cudatoolkit. handle_pybind11() handle_llvm() handle_nvmmh() # Must come AFTER python to query correct site-packages +handle_nixl() # Must come AFTER python and cudatoolkit for CUDA version check if(NVFUSER_ENABLE_DEPENDENCY_REPORT) stop_capture(DEP_LOGS) @@ -585,27 +587,7 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) target_compile_definitions(codegen_internal PRIVATE NVFUSER_BUILD_WITH_UCC) endif() -if(NVFUSER_STANDALONE_BUILD_WITH_NIXL) - # User may need to set NIXL_PREFIX to the NIXL install directory. - find_path(NIXL_INCLUDE_DIR nixl.h - HINTS $ENV{NIXL_PREFIX}/include ENV CPATH - ) - find_library(NIXL_LIBRARY nixl - HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu - ) - find_library(NIXL_BUILD_LIBRARY nixl_build - HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu - ) - - if(NOT NIXL_INCLUDE_DIR OR NOT NIXL_LIBRARY) - message(FATAL_ERROR "NIXL not found. Set NIXL_PREFIX to the NIXL install directory.") - endif() - - message(STATUS "Found NIXL: ${NIXL_LIBRARY} (include: ${NIXL_INCLUDE_DIR})") - if(NIXL_BUILD_LIBRARY) - message(STATUS "Found NIXL build lib: ${NIXL_BUILD_LIBRARY}") - endif() - +if(NIXL_FOUND) add_library(__nvfuser_nixl INTERFACE) target_include_directories(__nvfuser_nixl INTERFACE ${NIXL_INCLUDE_DIR}) target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_LIBRARY}) @@ -1367,7 +1349,8 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) endif() message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_UCC : ${NVFUSER_STANDALONE_BUILD_WITH_UCC}") message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_NIXL : ${NVFUSER_STANDALONE_BUILD_WITH_NIXL}") -if(NVFUSER_STANDALONE_BUILD_WITH_NIXL) +message(STATUS " NIXL_FOUND : ${NIXL_FOUND}") +if(NIXL_FOUND) message(STATUS " NIXL_INCLUDE_DIR: ${NIXL_INCLUDE_DIR}") message(STATUS " NIXL_LIBRARY : ${NIXL_LIBRARY}") endif() diff --git a/cmake/DependencyRequirements.cmake b/cmake/DependencyRequirements.cmake index 6b941dc4c62..b595d448079 100644 --- a/cmake/DependencyRequirements.cmake +++ b/cmake/DependencyRequirements.cmake @@ -41,5 +41,8 @@ set(NVFUSER_REQUIREMENT_LLVM_VERSION_MIN "18.1") # NVMMH set(NVFUSER_REQUIREMENT_NVMMH_OPTIONAL "TRUE") +# NIXL +set(NVFUSER_REQUIREMENT_NIXL_OPTIONAL "TRUE") + # Git Submodules (required for build) # No version requirement - just checks if submodules are initialized diff --git a/cmake/deps/handle_nixl.cmake b/cmake/deps/handle_nixl.cmake new file mode 100644 index 00000000000..e84dcfefaa7 --- /dev/null +++ b/cmake/deps/handle_nixl.cmake @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# ------------------------------------------------------------------------------ +# NIXL Handler +# ------------------------------------------------------------------------------ + +macro(handle_nixl) + message("") + message("Finding NIXL...") + + if(NOT NVFUSER_STANDALONE_BUILD_WITH_NIXL) + set(NIXL_FOUND FALSE) + message(STATUS "NIXL disabled (NVFUSER_STANDALONE_BUILD_WITH_NIXL=OFF)") + else() + # User may need to set NIXL_PREFIX to the NIXL install directory. + find_path(NIXL_INCLUDE_DIR nixl.h + HINTS $ENV{NIXL_PREFIX}/include ENV CPATH + ) + find_library(NIXL_LIBRARY nixl + HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + ) + find_library(NIXL_BUILD_LIBRARY nixl_build + HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + ) + + if(NIXL_INCLUDE_DIR AND NIXL_LIBRARY) + set(NIXL_FOUND TRUE) + message(STATUS "Found NIXL: ${NIXL_LIBRARY} (include: ${NIXL_INCLUDE_DIR})") + if(NIXL_BUILD_LIBRARY) + message(STATUS "Found NIXL build lib: ${NIXL_BUILD_LIBRARY}") + endif() + else() + set(NIXL_FOUND FALSE) + message(WARNING "NIXL not found – building without NIXL support. Set NIXL_PREFIX to the NIXL install directory.") + endif() + + # CUDA major version constraint check + if(NIXL_FOUND AND Python_FOUND AND CUDAToolkit_FOUND) + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import nixl; print(nixl._pkg.__name__.split('_cu')[-1])" + OUTPUT_VARIABLE nixl_cuda_major + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + RESULT_VARIABLE nixl_cuda_result + ) + + if(nixl_cuda_result EQUAL 0 AND NOT nixl_cuda_major STREQUAL "") + set(NIXL_CUDA_VERSION "${nixl_cuda_major}") + set(cuda_toolkit_major "${CUDAToolkit_VERSION_MAJOR}") + + if(NOT nixl_cuda_major STREQUAL cuda_toolkit_major) + set(NIXL_CUDA_constraint_status "mismatch") + set(NIXL_CUDA_constraint_found "${nixl_cuda_major}") + set(NIXL_CUDA_constraint_required "${cuda_toolkit_major}") + message(WARNING "NIXL CUDA major version mismatch: NIXL built for CUDA ${nixl_cuda_major}, but CUDAToolkit major is ${cuda_toolkit_major}") + else() + set(NIXL_CUDA_constraint_status "match") + set(NIXL_CUDA_constraint_version "${nixl_cuda_major}") + endif() + else() + set(NIXL_CUDA_constraint_status "not_available") + endif() + else() + set(NIXL_CUDA_constraint_status "not_available") + endif() + endif() + + set_dependency_report_status(NIXL) +endmacro() diff --git a/python/tools/check_dependencies.py b/python/tools/check_dependencies.py index 43628e52314..08032fbd04d 100644 --- a/python/tools/check_dependencies.py +++ b/python/tools/check_dependencies.py @@ -29,6 +29,7 @@ CompilerRequirement, NinjaRequirement, NVMMHRequirement, + NIXLRequirement, GitSubmodulesRequirement, ) @@ -59,6 +60,7 @@ def __init__(self, deps_path: Path): self.requirements.append(Pybind11Requirement(cmake_vars)) self.requirements.append(LLVMRequirement(cmake_vars)) self.requirements.append(NVMMHRequirement(cmake_vars)) + self.requirements.append(NIXLRequirement(cmake_vars)) def _load_cmake_vars(self, deps_path: Path) -> Dict: """Load CMake variables from JSON file""" diff --git a/python/tools/prereqs/__init__.py b/python/tools/prereqs/__init__.py index 6627bbd5f46..fc4246afbf2 100644 --- a/python/tools/prereqs/__init__.py +++ b/python/tools/prereqs/__init__.py @@ -71,6 +71,7 @@ CompilerRequirement, GitSubmodulesRequirement, NinjaRequirement, + NIXLRequirement, ) __all__ = [ @@ -102,4 +103,5 @@ "CompilerRequirement", "GitSubmodulesRequirement", "NinjaRequirement", + "NIXLRequirement", ] diff --git a/python/tools/prereqs/requirements/__init__.py b/python/tools/prereqs/requirements/__init__.py index e5c6f0059e6..0a66c4012eb 100644 --- a/python/tools/prereqs/requirements/__init__.py +++ b/python/tools/prereqs/requirements/__init__.py @@ -13,6 +13,7 @@ from .git_submodules import GitSubmodulesRequirement from .ninja import NinjaRequirement from .nvmmh import NVMMHRequirement +from .nixl import NIXLRequirement __all__ = [ # Base classes @@ -30,4 +31,5 @@ "GitSubmodulesRequirement", "NinjaRequirement", "NVMMHRequirement", + "NIXLRequirement", ] diff --git a/python/tools/prereqs/requirements/nixl.py b/python/tools/prereqs/requirements/nixl.py new file mode 100644 index 00000000000..7ec0e484af0 --- /dev/null +++ b/python/tools/prereqs/requirements/nixl.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +"""NIXL dependency requirement with CUDA constraint validation.""" + +from typing import Dict +from .base import BooleanRequirement +from ..colors import colorize + + +class NIXLRequirement(BooleanRequirement): + """ + NIXL check with CUDA major version constraint. + + CMake variables used: + - NIXL_FOUND: Whether NIXL is available + - NVFUSER_REQUIREMENT_NIXL_STATUS: Validation status + - NVFUSER_REQUIREMENT_NIXL_OPTIONAL: Whether NIXL is optional + - NIXL_CUDA_constraint_status: CUDA constraint validation result + - "match": NIXL CUDA major == CUDAToolkit major + - "mismatch": Versions don't match (WARNING) + - "not_available": Unable to determine NIXL CUDA version + - NIXL_CUDA_constraint_version: CUDA major version if match + - NIXL_CUDA_constraint_found: NIXL's CUDA major version (if mismatch) + - NIXL_CUDA_constraint_required: System's CUDA major version (if mismatch) + """ + + def __init__(self, cmake_vars: Dict): + name = "NIXL" + found_var = "NIXL_FOUND" + status_var = "NVFUSER_REQUIREMENT_NIXL_STATUS" + optional_var = "NVFUSER_REQUIREMENT_NIXL_OPTIONAL" + location_var = "NIXL_LIBRARY" + + super().__init__( + name, cmake_vars, found_var, status_var, optional_var, location_var + ) + + self.constraint_status = cmake_vars.get("NIXL_CUDA_constraint_status") + self.constraint_version = cmake_vars.get("NIXL_CUDA_constraint_version") + self.constraint_found = cmake_vars.get("NIXL_CUDA_constraint_found") + self.constraint_required = cmake_vars.get("NIXL_CUDA_constraint_required") + + def format_status_line(self, colors) -> str: + main_line = super().format_status_line(colors) + + constraint_line = self._format_cuda_constraint(colors) + if constraint_line: + return main_line + "\n" + constraint_line + return main_line + + def _format_cuda_constraint(self, colors) -> str: + if not self.constraint_status or self.constraint_status == "not_available": + return "" + + name_padded = f"{'NIXL_CUDA':<15}" + + if self.constraint_status == "match": + cuda_version = self.constraint_version or "unknown" + status_part = colorize(colors.GREEN, "[nvFuser] ✓") + " " + name_padded + version_part = colorize( + colors.CYAN, f"CUDA {cuda_version} (NIXL.CUDA == CUDAToolkit major)" + ) + return f"{status_part} {version_part}" + elif self.constraint_status == "mismatch": + nixl_cuda = self.constraint_found or "unknown" + toolkit_cuda = self.constraint_required or "unknown" + status_part = colorize(colors.BOLD_RED, "[nvFuser] ✗") + " " + name_padded + error_part = colorize( + colors.BOLD_RED, + f"mismatch (NIXL: CUDA {nixl_cuda}, CUDAToolkit: CUDA {toolkit_cuda})", + ) + return f"{status_part} {error_part}" + return "" + + def generate_help(self, platform_info): + print("NIXL") + print() + print("Why: NIXL provides high-performance data transfer for multi-device nvFuser.") + print() + print("Install NIXL:") + print() + print(" Recommended: pip installation:") + print() + print(" pip install nixl") + print() + print(" Or build from source and set NIXL_PREFIX to the install directory.") + print() + print(" Note: This is an optional dependency. nvFuser will build without it,") + print(" but multi-device NIXL-based transfers will not be available.") + print() + + if self.constraint_status == "mismatch": + print() + print("IMPORTANT: NIXL CUDA Version Mismatch Detected") + print() + print(" NIXL was built for a different CUDA major version than your") + print(" system's CUDA Toolkit. This will cause linking or runtime errors.") + print() + print(" Resolution: Install the NIXL package matching your CUDA version.") + print(" Check system CUDA major version: nvcc --version") + print() From 5560230ef7e9452c9b30a80ccfe4534841a9e625 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Tue, 10 Mar 2026 16:46:44 +0200 Subject: [PATCH 29/42] Fix CI --- python/tools/prereqs/requirements/nixl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tools/prereqs/requirements/nixl.py b/python/tools/prereqs/requirements/nixl.py index 7ec0e484af0..85f6a207293 100644 --- a/python/tools/prereqs/requirements/nixl.py +++ b/python/tools/prereqs/requirements/nixl.py @@ -76,7 +76,9 @@ def _format_cuda_constraint(self, colors) -> str: def generate_help(self, platform_info): print("NIXL") print() - print("Why: NIXL provides high-performance data transfer for multi-device nvFuser.") + print( + "Why: NIXL provides high-performance data transfer for multi-device nvFuser." + ) print() print("Install NIXL:") print() From ec01db9a05bebfdbc8e80fdcb73ead18b4dacb64 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Tue, 10 Mar 2026 17:41:14 +0200 Subject: [PATCH 30/42] Fix no-headers in NIXL install instructions (Cmake config) --- python/tools/prereqs/requirements/nixl.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/tools/prereqs/requirements/nixl.py b/python/tools/prereqs/requirements/nixl.py index 85f6a207293..e7e8995ad40 100644 --- a/python/tools/prereqs/requirements/nixl.py +++ b/python/tools/prereqs/requirements/nixl.py @@ -80,13 +80,20 @@ def generate_help(self, platform_info): "Why: NIXL provides high-performance data transfer for multi-device nvFuser." ) print() + print(" nvFuser links against the NIXL C++ API (nixl.h / libnixl.so).") + print(" 'pip install nixl' provides the shared library but does NOT install") + print(" the C++ headers. You need both headers and libraries for the build.") + print() print("Install NIXL:") print() - print(" Recommended: pip installation:") + print(" Option 1 (recommended for CI): Run the helper script that pip-installs") + print(" nixl for the .so and clones the repo for headers:") print() - print(" pip install nixl") + print(" bash tools/install-nixl.sh") + print(" export NIXL_PREFIX=/tmp/nixl-prefix # or your chosen path") print() - print(" Or build from source and set NIXL_PREFIX to the install directory.") + print(" Option 2: Build NIXL from source and set NIXL_PREFIX to the install") + print(" directory (must contain include/nixl.h and lib/libnixl.so).") print() print(" Note: This is an optional dependency. nvFuser will build without it,") print(" but multi-device NIXL-based transfers will not be available.") From 67a92f056bbf2c39e5b2dcbc759f470409b9004b Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Tue, 10 Mar 2026 17:44:38 +0200 Subject: [PATCH 31/42] Replace NVFUSER_STANDALONE_BUILD_WITH_NIXL by NVFUSER_BUILD_WITH_NIXL --- CMakeLists.txt | 4 ++-- cmake/deps/handle_nixl.cmake | 4 ++-- python/utils.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a7e2ee25121..bdcdc35156e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,7 +32,7 @@ set(NVFUSER_CUTLASS "${NVFUSER_ROOT}/cutlass") set(NVFUSER_THIRD_PARTY_DIR "${NVFUSER_ROOT}/third_party") option(NVFUSER_STANDALONE_BUILD_WITH_UCC "" OFF) -option(NVFUSER_STANDALONE_BUILD_WITH_NIXL "" OFF) +option(NVFUSER_BUILD_WITH_NIXL "" OFF) option(NVFUSER_EXPLICIT_ERROR_CHECK "" OFF) option(NVFUSER_ENABLE_DEPENDENCY_REPORT "Enable Python-based dependency reporting and log capture" ON) @@ -1348,7 +1348,7 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) message(STATUS " UCX_DIR : $ENV{UCX_DIR}") endif() message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_UCC : ${NVFUSER_STANDALONE_BUILD_WITH_UCC}") -message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_NIXL : ${NVFUSER_STANDALONE_BUILD_WITH_NIXL}") +message(STATUS " NVFUSER_BUILD_WITH_NIXL : ${NVFUSER_BUILD_WITH_NIXL}") message(STATUS " NIXL_FOUND : ${NIXL_FOUND}") if(NIXL_FOUND) message(STATUS " NIXL_INCLUDE_DIR: ${NIXL_INCLUDE_DIR}") diff --git a/cmake/deps/handle_nixl.cmake b/cmake/deps/handle_nixl.cmake index e84dcfefaa7..f61960e916e 100644 --- a/cmake/deps/handle_nixl.cmake +++ b/cmake/deps/handle_nixl.cmake @@ -10,9 +10,9 @@ macro(handle_nixl) message("") message("Finding NIXL...") - if(NOT NVFUSER_STANDALONE_BUILD_WITH_NIXL) + if(NOT NVFUSER_BUILD_WITH_NIXL) set(NIXL_FOUND FALSE) - message(STATUS "NIXL disabled (NVFUSER_STANDALONE_BUILD_WITH_NIXL=OFF)") + message(STATUS "NIXL disabled (NVFUSER_BUILD_WITH_NIXL=OFF)") else() # User may need to set NIXL_PREFIX to the NIXL install directory. find_path(NIXL_INCLUDE_DIR nixl.h diff --git a/python/utils.py b/python/utils.py index 5d589ebc868..4b8dbc5fb07 100644 --- a/python/utils.py +++ b/python/utils.py @@ -280,7 +280,7 @@ def on_or_off(flag: bool) -> str: f"-DUSE_DISTRIBUTED={pytorch_use_distributed}", f"-DNVFUSER_BUILD_WITH_ASAN={on_or_off(config.build_with_asan)}", f"-DNVFUSER_STANDALONE_BUILD_WITH_UCC={on_or_off(config.build_with_ucc)}", - f"-DNVFUSER_STANDALONE_BUILD_WITH_NIXL={on_or_off(config.build_with_nixl)}", + f"-DNVFUSER_BUILD_WITH_NIXL={on_or_off(config.build_with_nixl)}", f"-DNVFUSER_EXPLICIT_ERROR_CHECK={on_or_off(config.explicit_error_check)}", f"-DBUILD_TEST={on_or_off(not config.no_test)}", f"-DBUILD_PYTHON={on_or_off(not config.no_python)}", From dae35aa0fbde107b61eb41868554feb6e2f56059 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Tue, 10 Mar 2026 18:06:31 +0200 Subject: [PATCH 32/42] move nixl linkage from CmakeList to handle_nixl.cmake --- CMakeLists.txt | 11 +---------- cmake/deps/handle_nixl.cmake | 11 +++++++++++ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bdcdc35156e..ca5ecfc19aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -587,16 +587,7 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) target_compile_definitions(codegen_internal PRIVATE NVFUSER_BUILD_WITH_UCC) endif() -if(NIXL_FOUND) - add_library(__nvfuser_nixl INTERFACE) - target_include_directories(__nvfuser_nixl INTERFACE ${NIXL_INCLUDE_DIR}) - target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_LIBRARY}) - if(NIXL_BUILD_LIBRARY) - target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_BUILD_LIBRARY}) - endif() - target_link_libraries(codegen_internal PRIVATE __nvfuser_nixl) - target_compile_definitions(codegen_internal PRIVATE USE_NIXL) -endif() +link_nixl(codegen_internal) add_dependencies(codegen_internal flatc build_flatbuffer_config) diff --git a/cmake/deps/handle_nixl.cmake b/cmake/deps/handle_nixl.cmake index f61960e916e..a666613b03b 100644 --- a/cmake/deps/handle_nixl.cmake +++ b/cmake/deps/handle_nixl.cmake @@ -69,3 +69,14 @@ macro(handle_nixl) set_dependency_report_status(NIXL) endmacro() + +macro(link_nixl target) + if(NIXL_FOUND) + target_include_directories(${target} PRIVATE ${NIXL_INCLUDE_DIR}) + target_link_libraries(${target} PRIVATE ${NIXL_LIBRARY}) + if(NIXL_BUILD_LIBRARY) + target_link_libraries(${target} PRIVATE ${NIXL_BUILD_LIBRARY}) + endif() + target_compile_definitions(${target} PRIVATE USE_NIXL) + endif() +endmacro() From a9fb56dcbb276bac9595d3280d4d39511a211b49 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Tue, 10 Mar 2026 18:09:44 +0200 Subject: [PATCH 33/42] move TPL locs to handle_nixl.cmake --- CMakeLists.txt | 7 +------ cmake/deps/handle_nixl.cmake | 6 ++++++ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ca5ecfc19aa..50d1e5ad3be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1339,12 +1339,7 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) message(STATUS " UCX_DIR : $ENV{UCX_DIR}") endif() message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_UCC : ${NVFUSER_STANDALONE_BUILD_WITH_UCC}") -message(STATUS " NVFUSER_BUILD_WITH_NIXL : ${NVFUSER_BUILD_WITH_NIXL}") -message(STATUS " NIXL_FOUND : ${NIXL_FOUND}") -if(NIXL_FOUND) - message(STATUS " NIXL_INCLUDE_DIR: ${NIXL_INCLUDE_DIR}") - message(STATUS " NIXL_LIBRARY : ${NIXL_LIBRARY}") -endif() +message(STATUS " NVFUSER_BUILD_WITH_NIXL : ${NVFUSER_BUILD_WITH_NIXL}") message(STATUS " NVFUSER_BUILD_WITH_ASAN : ${NVFUSER_BUILD_WITH_ASAN}") message(STATUS " NVFUSER_DISTRIBUTED : ${NVFUSER_DISTRIBUTED}") message(STATUS " NVFUSER_CPP_STANDARD : ${NVFUSER_CPP_STANDARD}") diff --git a/cmake/deps/handle_nixl.cmake b/cmake/deps/handle_nixl.cmake index a666613b03b..60b98112d1d 100644 --- a/cmake/deps/handle_nixl.cmake +++ b/cmake/deps/handle_nixl.cmake @@ -67,6 +67,12 @@ macro(handle_nixl) endif() endif() + message(STATUS " NIXL_FOUND : ${NIXL_FOUND}") + if(NIXL_FOUND) + message(STATUS " NIXL_INCLUDE_DIR: ${NIXL_INCLUDE_DIR}") + message(STATUS " NIXL_LIBRARY : ${NIXL_LIBRARY}") + endif() + set_dependency_report_status(NIXL) endmacro() From e364949a849a3557af6a4b69591aca3cb028659f Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Tue, 10 Mar 2026 18:26:42 +0200 Subject: [PATCH 34/42] Fix - move nixl linkage to handle_nixl.cmake --- CMakeLists.txt | 12 ++++++++++-- cmake/deps/handle_nixl.cmake | 24 +++++++----------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 50d1e5ad3be..79d3aca75f7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -587,7 +587,10 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) target_compile_definitions(codegen_internal PRIVATE NVFUSER_BUILD_WITH_UCC) endif() -link_nixl(codegen_internal) +if(NIXL_FOUND) + target_link_libraries(codegen_internal PRIVATE __nvfuser_nixl) + target_compile_definitions(codegen_internal PRIVATE USE_NIXL) +endif() add_dependencies(codegen_internal flatc build_flatbuffer_config) @@ -1339,7 +1342,12 @@ if(NVFUSER_STANDALONE_BUILD_WITH_UCC) message(STATUS " UCX_DIR : $ENV{UCX_DIR}") endif() message(STATUS " NVFUSER_STANDALONE_BUILD_WITH_UCC : ${NVFUSER_STANDALONE_BUILD_WITH_UCC}") -message(STATUS " NVFUSER_BUILD_WITH_NIXL : ${NVFUSER_BUILD_WITH_NIXL}") +message(STATUS " NVFUSER_BUILD_WITH_NIXL : ${NVFUSER_BUILD_WITH_NIXL}") +message(STATUS " NIXL_FOUND : ${NIXL_FOUND}") +if(NIXL_FOUND) + message(STATUS " NIXL_INCLUDE_DIR: ${NIXL_INCLUDE_DIR}") + message(STATUS " NIXL_LIBRARY : ${NIXL_LIBRARY}") +endif() message(STATUS " NVFUSER_BUILD_WITH_ASAN : ${NVFUSER_BUILD_WITH_ASAN}") message(STATUS " NVFUSER_DISTRIBUTED : ${NVFUSER_DISTRIBUTED}") message(STATUS " NVFUSER_CPP_STANDARD : ${NVFUSER_CPP_STANDARD}") diff --git a/cmake/deps/handle_nixl.cmake b/cmake/deps/handle_nixl.cmake index 60b98112d1d..bc8bcc13e2a 100644 --- a/cmake/deps/handle_nixl.cmake +++ b/cmake/deps/handle_nixl.cmake @@ -31,6 +31,13 @@ macro(handle_nixl) if(NIXL_BUILD_LIBRARY) message(STATUS "Found NIXL build lib: ${NIXL_BUILD_LIBRARY}") endif() + + add_library(__nvfuser_nixl INTERFACE) + target_include_directories(__nvfuser_nixl INTERFACE ${NIXL_INCLUDE_DIR}) + target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_LIBRARY}) + if(NIXL_BUILD_LIBRARY) + target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_BUILD_LIBRARY}) + endif() else() set(NIXL_FOUND FALSE) message(WARNING "NIXL not found – building without NIXL support. Set NIXL_PREFIX to the NIXL install directory.") @@ -67,22 +74,5 @@ macro(handle_nixl) endif() endif() - message(STATUS " NIXL_FOUND : ${NIXL_FOUND}") - if(NIXL_FOUND) - message(STATUS " NIXL_INCLUDE_DIR: ${NIXL_INCLUDE_DIR}") - message(STATUS " NIXL_LIBRARY : ${NIXL_LIBRARY}") - endif() - set_dependency_report_status(NIXL) endmacro() - -macro(link_nixl target) - if(NIXL_FOUND) - target_include_directories(${target} PRIVATE ${NIXL_INCLUDE_DIR}) - target_link_libraries(${target} PRIVATE ${NIXL_LIBRARY}) - if(NIXL_BUILD_LIBRARY) - target_link_libraries(${target} PRIVATE ${NIXL_BUILD_LIBRARY}) - endif() - target_compile_definitions(${target} PRIVATE USE_NIXL) - endif() -endmacro() From f74078f05004a83faa22f06c85ac7474c5538519 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Tue, 10 Mar 2026 18:32:36 +0200 Subject: [PATCH 35/42] fix linter --- python/tools/prereqs/requirements/nixl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tools/prereqs/requirements/nixl.py b/python/tools/prereqs/requirements/nixl.py index e7e8995ad40..2adfada5117 100644 --- a/python/tools/prereqs/requirements/nixl.py +++ b/python/tools/prereqs/requirements/nixl.py @@ -86,7 +86,9 @@ def generate_help(self, platform_info): print() print("Install NIXL:") print() - print(" Option 1 (recommended for CI): Run the helper script that pip-installs") + print( + " Option 1 (recommended for CI): Run the helper script that pip-installs" + ) print(" nixl for the .so and clones the repo for headers:") print() print(" bash tools/install-nixl.sh") From 9847a112ffa0dfc1832346621ddd5547c1e48842 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 11 Mar 2026 13:06:03 +0200 Subject: [PATCH 36/42] Add NIXL to CI image --- .github/workflows/build.yml | 4 ++- tools/install-nixl.sh | 58 +++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100755 tools/install-nixl.sh diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d70120e2ba4..df56589467d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,13 +42,15 @@ jobs: working-directory: ${{ env.working_directory }} run: | tools/apt-install-things.sh & - tools/pip-install-things.sh & + (tools/pip-install-things.sh && tools/install-nixl.sh) & wait source tools/setup-env.sh export NVFUSER_BUILD_NO_CUTLASS=true export NVFUSER_BUILD_CPP_STANDARD=23 export NVFUSER_BUILD_ENABLE_PCH=true + export NVFUSER_BUILD_WITH_NIXL=1 + export NIXL_PREFIX=/tmp/nixl-prefix pip install -v -e ./python --no-build-isolation - name: Show ccache statistics if: always() diff --git a/tools/install-nixl.sh b/tools/install-nixl.sh new file mode 100755 index 00000000000..f1e62ed5df8 --- /dev/null +++ b/tools/install-nixl.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +# Install NIXL headers and libraries for Fuser CI compilation. +# +# The pip wheel provides libnixl.so (in a meson-python internal directory) +# but not the C development headers. We clone the NIXL repo to get headers +# and create a NIXL_PREFIX directory that handle_nixl.cmake can discover. +# +# Used by: .github/workflows/build.yml (GitHub Actions compilation check) +# For Blossom GPU CI: NIXL should be pre-installed in the CI Docker image +# (see dev/Dockerfile for reference), or this script can be run as a +# pre-build step if the runner has network access. + +set -e + +NIXL_PREFIX="${NIXL_PREFIX:-/tmp/nixl-prefix}" +NIXL_REPO="https://github.com/ai-dynamo/nixl.git" +NIXL_CLONE_DIR="/tmp/nixl-repo" + +# Use --no-deps to avoid pulling in nixl-cu12's torch/numpy dependencies, +# which would conflict with the torch nightly already installed by +# pip-install-things.sh (this script must run AFTER pip-install-things.sh). +pip install --no-deps nixl nixl-cu12 + +# Locate the mesonpy libs directory where libnixl.so lives. +SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])") +# pip install nixl -> nixl-cu12 (default) or nixl-cu13 depending on extras. +# The mesonpy.libs dir is named after the actual variant package. +NIXL_PKG_NAME=$(python3 -c "import nixl; print(nixl._pkg.__name__)") +MESONPY_LIBS="${SITE_PACKAGES}/.${NIXL_PKG_NAME}.mesonpy.libs" + +if [ ! -d "$MESONPY_LIBS" ]; then + echo "Error: mesonpy libs directory not found at $MESONPY_LIBS" + exit 1 +fi + +if [ ! -f "$MESONPY_LIBS/libnixl.so" ]; then + echo "Error: libnixl.so not found in $MESONPY_LIBS" + exit 1 +fi + +# Clone NIXL repo (shallow) for C headers. +git clone --depth 1 "$NIXL_REPO" "$NIXL_CLONE_DIR" + +mkdir -p "$NIXL_PREFIX/include" "$NIXL_PREFIX/lib" + +cp "$NIXL_CLONE_DIR"/src/api/cpp/*.h "$NIXL_PREFIX/include/" + +ln -sf "$MESONPY_LIBS/libnixl.so" "$NIXL_PREFIX/lib/libnixl.so" +if [ -f "$MESONPY_LIBS/libnixl_build.so" ]; then + ln -sf "$MESONPY_LIBS/libnixl_build.so" "$NIXL_PREFIX/lib/libnixl_build.so" +fi + +rm -rf "$NIXL_CLONE_DIR" + +echo "NIXL prefix ready at $NIXL_PREFIX" +echo " include: $(ls "$NIXL_PREFIX/include/")" +echo " lib: $(ls -l "$NIXL_PREFIX/lib/")" From 40758d6b2b53dc532013f6c54bc865afeddc0935 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Wed, 11 Mar 2026 18:36:45 +0200 Subject: [PATCH 37/42] remove import nixl from install-nixl.sh --- tools/install-nixl.sh | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tools/install-nixl.sh b/tools/install-nixl.sh index f1e62ed5df8..94588033dc3 100755 --- a/tools/install-nixl.sh +++ b/tools/install-nixl.sh @@ -23,14 +23,13 @@ NIXL_CLONE_DIR="/tmp/nixl-repo" pip install --no-deps nixl nixl-cu12 # Locate the mesonpy libs directory where libnixl.so lives. +# We avoid "import nixl" because the native extension may fail to load on +# headless CI runners without GPU/RDMA drivers. SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])") -# pip install nixl -> nixl-cu12 (default) or nixl-cu13 depending on extras. -# The mesonpy.libs dir is named after the actual variant package. -NIXL_PKG_NAME=$(python3 -c "import nixl; print(nixl._pkg.__name__)") -MESONPY_LIBS="${SITE_PACKAGES}/.${NIXL_PKG_NAME}.mesonpy.libs" +MESONPY_LIBS=$(find "$SITE_PACKAGES" -maxdepth 1 -name ".nixl_cu*.mesonpy.libs" -type d | head -1) -if [ ! -d "$MESONPY_LIBS" ]; then - echo "Error: mesonpy libs directory not found at $MESONPY_LIBS" +if [ -z "$MESONPY_LIBS" ] || [ ! -d "$MESONPY_LIBS" ]; then + echo "Error: nixl mesonpy libs directory not found in $SITE_PACKAGES" exit 1 fi From bf471ea4bfc08996ba0a163910cb681d32a6a3bb Mon Sep 17 00:00:00 2001 From: x41lakazam Date: Thu, 12 Mar 2026 12:10:39 +0200 Subject: [PATCH 38/42] Add transitive shared libs deps for nixl --- tools/install-nixl.sh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tools/install-nixl.sh b/tools/install-nixl.sh index 94588033dc3..0eb01876018 100755 --- a/tools/install-nixl.sh +++ b/tools/install-nixl.sh @@ -45,10 +45,13 @@ mkdir -p "$NIXL_PREFIX/include" "$NIXL_PREFIX/lib" cp "$NIXL_CLONE_DIR"/src/api/cpp/*.h "$NIXL_PREFIX/include/" -ln -sf "$MESONPY_LIBS/libnixl.so" "$NIXL_PREFIX/lib/libnixl.so" -if [ -f "$MESONPY_LIBS/libnixl_build.so" ]; then - ln -sf "$MESONPY_LIBS/libnixl_build.so" "$NIXL_PREFIX/lib/libnixl_build.so" -fi +# Symlink all shared libraries from the mesonpy libs directory so that +# transitive dependencies of libnixl.so (libserdes.so, libstream.so, +# libnixl_common.so, libetcd-cpp-api-core, etc.) are discoverable by the linker. +for so in "$MESONPY_LIBS"/*.so*; do + [ -e "$so" ] || continue + ln -sf "$so" "$NIXL_PREFIX/lib/$(basename "$so")" +done rm -rf "$NIXL_CLONE_DIR" From df6c38419e44b6c124cf8a303f4232d8cb33efb0 Mon Sep 17 00:00:00 2001 From: x41lakazam Date: Thu, 12 Mar 2026 18:48:21 +0200 Subject: [PATCH 39/42] Add nixl*.mesonpy.libs and nixl*.libs as shared lib dirs in CI's install-nixl.sh --- tools/install-nixl.sh | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/tools/install-nixl.sh b/tools/install-nixl.sh index 0eb01876018..ec3fd56ed59 100755 --- a/tools/install-nixl.sh +++ b/tools/install-nixl.sh @@ -22,21 +22,16 @@ NIXL_CLONE_DIR="/tmp/nixl-repo" # pip-install-things.sh (this script must run AFTER pip-install-things.sh). pip install --no-deps nixl nixl-cu12 -# Locate the mesonpy libs directory where libnixl.so lives. +# Locate shared library directories from the nixl pip packages. # We avoid "import nixl" because the native extension may fail to load on # headless CI runners without GPU/RDMA drivers. +# +# meson-python places bundled libs in .nixl_cu*.mesonpy.libs/ +# auditwheel places bundled libs in nixl_cu*.libs/ +# Both patterns are searched for nixl and nixl-cu* packages. SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])") -MESONPY_LIBS=$(find "$SITE_PACKAGES" -maxdepth 1 -name ".nixl_cu*.mesonpy.libs" -type d | head -1) - -if [ -z "$MESONPY_LIBS" ] || [ ! -d "$MESONPY_LIBS" ]; then - echo "Error: nixl mesonpy libs directory not found in $SITE_PACKAGES" - exit 1 -fi -if [ ! -f "$MESONPY_LIBS/libnixl.so" ]; then - echo "Error: libnixl.so not found in $MESONPY_LIBS" - exit 1 -fi +FOUND_LIBNIXL=false # Clone NIXL repo (shallow) for C headers. git clone --depth 1 "$NIXL_REPO" "$NIXL_CLONE_DIR" @@ -45,14 +40,32 @@ mkdir -p "$NIXL_PREFIX/include" "$NIXL_PREFIX/lib" cp "$NIXL_CLONE_DIR"/src/api/cpp/*.h "$NIXL_PREFIX/include/" -# Symlink all shared libraries from the mesonpy libs directory so that +# Symlink all shared libraries from every nixl-related libs directory so that # transitive dependencies of libnixl.so (libserdes.so, libstream.so, # libnixl_common.so, libetcd-cpp-api-core, etc.) are discoverable by the linker. -for so in "$MESONPY_LIBS"/*.so*; do - [ -e "$so" ] || continue - ln -sf "$so" "$NIXL_PREFIX/lib/$(basename "$so")" +for libs_dir in "$SITE_PACKAGES"/.nixl*.mesonpy.libs "$SITE_PACKAGES"/nixl*.libs; do + [ -d "$libs_dir" ] || continue + echo " Symlinking libs from $libs_dir" + for so in "$libs_dir"/*.so*; do + [ -e "$so" ] || continue + ln -sf "$so" "$NIXL_PREFIX/lib/$(basename "$so")" + done done +# Fallback: search for any remaining nixl-related .so files anywhere in site-packages. +# Some transitive deps (e.g. libetcd-cpp-api-core) may be in package subdirectories. +for so in $(find "$SITE_PACKAGES" -maxdepth 3 -path "*/nixl*/*.so*" -type f 2>/dev/null); do + name=$(basename "$so") + [ -e "$NIXL_PREFIX/lib/$name" ] || ln -sf "$so" "$NIXL_PREFIX/lib/$name" +done + +if [ ! -f "$NIXL_PREFIX/lib/libnixl.so" ]; then + echo "Error: libnixl.so not found in any nixl libs directory under $SITE_PACKAGES" + echo "Searched directories:" + ls -d "$SITE_PACKAGES"/.nixl*.mesonpy.libs "$SITE_PACKAGES"/nixl*.libs 2>/dev/null || echo " (none found)" + exit 1 +fi + rm -rf "$NIXL_CLONE_DIR" echo "NIXL prefix ready at $NIXL_PREFIX" From 1c107793518a215858a1ee62a098394df28ed18b Mon Sep 17 00:00:00 2001 From: x41lakazam Date: Mon, 16 Mar 2026 18:14:46 +0200 Subject: [PATCH 40/42] try to make nixl tests work --- .github/workflows/build.yml | 3 +- cmake/deps/handle_nixl.cmake | 13 ++- tools/install-nixl.sh | 216 ++++++++++++++++++++++++++++------- 3 files changed, 185 insertions(+), 47 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index df56589467d..407ddf723a1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,7 +42,7 @@ jobs: working-directory: ${{ env.working_directory }} run: | tools/apt-install-things.sh & - (tools/pip-install-things.sh && tools/install-nixl.sh) & + (tools/pip-install-things.sh && NIXL_BUILD_MODE=pip tools/install-nixl.sh) & wait source tools/setup-env.sh @@ -51,6 +51,7 @@ jobs: export NVFUSER_BUILD_ENABLE_PCH=true export NVFUSER_BUILD_WITH_NIXL=1 export NIXL_PREFIX=/tmp/nixl-prefix + export LD_LIBRARY_PATH=$NIXL_PREFIX/lib:${LD_LIBRARY_PATH:-} pip install -v -e ./python --no-build-isolation - name: Show ccache statistics if: always() diff --git a/cmake/deps/handle_nixl.cmake b/cmake/deps/handle_nixl.cmake index bc8bcc13e2a..28ed7789ab5 100644 --- a/cmake/deps/handle_nixl.cmake +++ b/cmake/deps/handle_nixl.cmake @@ -19,10 +19,14 @@ macro(handle_nixl) HINTS $ENV{NIXL_PREFIX}/include ENV CPATH ) find_library(NIXL_LIBRARY nixl - HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 + $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + $ENV{NIXL_PREFIX}/lib/aarch64-linux-gnu ) find_library(NIXL_BUILD_LIBRARY nixl_build - HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + HINTS $ENV{NIXL_PREFIX}/lib $ENV{NIXL_PREFIX}/lib64 + $ENV{NIXL_PREFIX}/lib/x86_64-linux-gnu + $ENV{NIXL_PREFIX}/lib/aarch64-linux-gnu ) if(NIXL_INCLUDE_DIR AND NIXL_LIBRARY) @@ -34,6 +38,11 @@ macro(handle_nixl) add_library(__nvfuser_nixl INTERFACE) target_include_directories(__nvfuser_nixl INTERFACE ${NIXL_INCLUDE_DIR}) + + get_filename_component(NIXL_LIB_DIR "${NIXL_LIBRARY}" DIRECTORY) + target_link_directories(__nvfuser_nixl INTERFACE ${NIXL_LIB_DIR}) + target_link_options(__nvfuser_nixl INTERFACE "LINKER:-rpath-link,${NIXL_LIB_DIR}") + target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_LIBRARY}) if(NIXL_BUILD_LIBRARY) target_link_libraries(__nvfuser_nixl INTERFACE ${NIXL_BUILD_LIBRARY}) diff --git a/tools/install-nixl.sh b/tools/install-nixl.sh index ec3fd56ed59..c5b0b3cae20 100755 --- a/tools/install-nixl.sh +++ b/tools/install-nixl.sh @@ -1,73 +1,201 @@ #!/bin/bash -# Install NIXL headers and libraries for Fuser CI compilation. +# Install NIXL headers and libraries for Fuser CI. # -# The pip wheel provides libnixl.so (in a meson-python internal directory) -# but not the C development headers. We clone the NIXL repo to get headers -# and create a NIXL_PREFIX directory that handle_nixl.cmake can discover. +# Two modes: +# 1. pip (default for GitHub Actions): install pre-built wheels, clone repo +# for headers, symlink shared libs into NIXL_PREFIX. +# 2. source (default when CUDA toolkit is detected): build UCX with CUDA +# transport and NIXL from source so the UCX backend can register VRAM. # -# Used by: .github/workflows/build.yml (GitHub Actions compilation check) -# For Blossom GPU CI: NIXL should be pre-installed in the CI Docker image -# (see dev/Dockerfile for reference), or this script can be run as a -# pre-build step if the runner has network access. +# Environment variables: +# NIXL_PREFIX – install prefix (default: /tmp/nixl-prefix) +# NIXL_BUILD_MODE – "pip", "source", or "auto" (default: auto) +# CUDA_HOME – CUDA toolkit root (auto-detected from nvcc) +# +# Used by: +# .github/workflows/build.yml (GitHub Actions compilation check) +# Blossom GPU CI build jobs (needs runtime UCX+CUDA support) +# tools/ci-local-build.sh (local Docker reproduction) set -e NIXL_PREFIX="${NIXL_PREFIX:-/tmp/nixl-prefix}" +NIXL_BUILD_MODE="${NIXL_BUILD_MODE:-auto}" NIXL_REPO="https://github.com/ai-dynamo/nixl.git" NIXL_CLONE_DIR="/tmp/nixl-repo" +UCX_CLONE_DIR="/tmp/ucx-src" +UCX_INSTALL_DIR="/tmp/ucx-install" -# Use --no-deps to avoid pulling in nixl-cu12's torch/numpy dependencies, -# which would conflict with the torch nightly already installed by -# pip-install-things.sh (this script must run AFTER pip-install-things.sh). -pip install --no-deps nixl nixl-cu12 +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- -# Locate shared library directories from the nixl pip packages. -# We avoid "import nixl" because the native extension may fail to load on -# headless CI runners without GPU/RDMA drivers. -# -# meson-python places bundled libs in .nixl_cu*.mesonpy.libs/ -# auditwheel places bundled libs in nixl_cu*.libs/ -# Both patterns are searched for nixl and nixl-cu* packages. -SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])") +detect_cuda_home() { + if [ -n "$CUDA_HOME" ]; then + return + fi + if command -v nvcc >/dev/null 2>&1; then + CUDA_HOME="$(dirname "$(dirname "$(command -v nvcc)")")" + elif [ -d /usr/local/cuda ]; then + CUDA_HOME=/usr/local/cuda + fi +} + +resolve_build_mode() { + if [ "$NIXL_BUILD_MODE" = "auto" ]; then + detect_cuda_home + if [ -n "$CUDA_HOME" ] && [ -x "$CUDA_HOME/bin/nvcc" ]; then + echo "Auto-detected CUDA at $CUDA_HOME — using source build for UCX+CUDA support" + NIXL_BUILD_MODE="source" + else + echo "No CUDA toolkit with nvcc found — using pip install (compile-only)" + NIXL_BUILD_MODE="pip" + fi + fi +} -FOUND_LIBNIXL=false +# --------------------------------------------------------------------------- +# Mode: pip (headers + pre-built .so, no runtime CUDA guarantee) +# --------------------------------------------------------------------------- -# Clone NIXL repo (shallow) for C headers. -git clone --depth 1 "$NIXL_REPO" "$NIXL_CLONE_DIR" +install_pip() { + echo "=== Installing NIXL via pip ===" -mkdir -p "$NIXL_PREFIX/include" "$NIXL_PREFIX/lib" + pip install --no-deps nixl nixl-cu12 || pip install --no-deps nixl -cp "$NIXL_CLONE_DIR"/src/api/cpp/*.h "$NIXL_PREFIX/include/" + SITE_PACKAGES=$(python3 -c "import site; print(site.getsitepackages()[0])") -# Symlink all shared libraries from every nixl-related libs directory so that -# transitive dependencies of libnixl.so (libserdes.so, libstream.so, -# libnixl_common.so, libetcd-cpp-api-core, etc.) are discoverable by the linker. -for libs_dir in "$SITE_PACKAGES"/.nixl*.mesonpy.libs "$SITE_PACKAGES"/nixl*.libs; do + git clone --depth 1 "$NIXL_REPO" "$NIXL_CLONE_DIR" + mkdir -p "$NIXL_PREFIX/include" "$NIXL_PREFIX/lib" + cp "$NIXL_CLONE_DIR"/src/api/cpp/*.h "$NIXL_PREFIX/include/" + + for libs_dir in "$SITE_PACKAGES"/.nixl*.mesonpy.libs "$SITE_PACKAGES"/nixl*.libs; do [ -d "$libs_dir" ] || continue echo " Symlinking libs from $libs_dir" for so in "$libs_dir"/*.so*; do - [ -e "$so" ] || continue - ln -sf "$so" "$NIXL_PREFIX/lib/$(basename "$so")" + [ -e "$so" ] || continue + ln -sf "$so" "$NIXL_PREFIX/lib/$(basename "$so")" done -done + done -# Fallback: search for any remaining nixl-related .so files anywhere in site-packages. -# Some transitive deps (e.g. libetcd-cpp-api-core) may be in package subdirectories. -for so in $(find "$SITE_PACKAGES" -maxdepth 3 -path "*/nixl*/*.so*" -type f 2>/dev/null); do + for so in $(find "$SITE_PACKAGES" -maxdepth 3 -path "*/nixl*/*.so*" -type f 2>/dev/null); do name=$(basename "$so") [ -e "$NIXL_PREFIX/lib/$name" ] || ln -sf "$so" "$NIXL_PREFIX/lib/$name" -done + done + + if [ ! -f "$NIXL_PREFIX/lib/libnixl.so" ]; then + echo "Error: libnixl.so not found under $SITE_PACKAGES" + exit 1 + fi + + rm -rf "$NIXL_CLONE_DIR" +} + +# --------------------------------------------------------------------------- +# Mode: source (UCX built with CUDA transport, NIXL built from source) +# --------------------------------------------------------------------------- + +install_source() { + echo "=== Building NIXL from source with UCX+CUDA ===" + detect_cuda_home + + if [ -z "$CUDA_HOME" ]; then + echo "Error: CUDA_HOME not set and nvcc not found" + exit 1 + fi + echo " CUDA_HOME=$CUDA_HOME" + + # --- build dependencies --------------------------------------------------- + apt-get update -qq 2>/dev/null || true + apt-get install -y -qq libtool autoconf automake pkg-config \ + libibverbs-dev librdmacm-dev libnuma-dev 2>/dev/null || true + pip install meson ninja 2>/dev/null || pip3 install meson ninja + + # --- UCX with CUDA -------------------------------------------------------- + echo "--- Building UCX with CUDA support ---" + if [ -d "$UCX_CLONE_DIR" ]; then rm -rf "$UCX_CLONE_DIR"; fi + git clone --depth 1 -b v1.18.x https://github.com/openucx/ucx.git "$UCX_CLONE_DIR" + ( + cd "$UCX_CLONE_DIR" + ./autogen.sh + ./contrib/configure-release \ + --prefix="$UCX_INSTALL_DIR" \ + --with-cuda="$CUDA_HOME" \ + --enable-mt + make -j"$(nproc)" + make install + ) + + export PKG_CONFIG_PATH="$UCX_INSTALL_DIR/lib/pkgconfig:${PKG_CONFIG_PATH:-}" + export LD_LIBRARY_PATH="$UCX_INSTALL_DIR/lib:${LD_LIBRARY_PATH:-}" + + # --- NIXL from source ----------------------------------------------------- + echo "--- Building NIXL from source ---" + if [ -d "$NIXL_CLONE_DIR" ]; then rm -rf "$NIXL_CLONE_DIR"; fi + git clone --depth 1 "$NIXL_REPO" "$NIXL_CLONE_DIR" + ( + cd "$NIXL_CLONE_DIR" + + CUDA_INC="$CUDA_HOME/include" + CUDA_LIB="$CUDA_HOME/lib64" + [ -d "$CUDA_LIB" ] || CUDA_LIB="$CUDA_HOME/lib" + + meson setup builddir \ + --prefix="$NIXL_PREFIX" \ + -Ducx_path="$UCX_INSTALL_DIR" \ + -Dcudapath_inc="$CUDA_INC" \ + -Dcudapath_lib="$CUDA_LIB" \ + -Dbuild_tests=false \ + -Dbuild_examples=false \ + -Dbuildtype=release + meson compile -C builddir + meson install -C builddir + ) + + # Copy API headers if not already installed by meson + mkdir -p "$NIXL_PREFIX/include" + cp -n "$NIXL_CLONE_DIR"/src/api/cpp/*.h "$NIXL_PREFIX/include/" 2>/dev/null || true + + # Ensure UCX libs are alongside NIXL libs so everything is on one rpath. + # Also copy UCX transport plugins (libuct_cuda.so etc.) so they're discoverable. + mkdir -p "$NIXL_PREFIX/lib/ucx" + for so in "$UCX_INSTALL_DIR"/lib/*.so*; do + [ -e "$so" ] || continue + ln -sf "$so" "$NIXL_PREFIX/lib/$(basename "$so")" + done + for so in "$UCX_INSTALL_DIR"/lib/ucx/*.so*; do + [ -e "$so" ] || continue + ln -sf "$so" "$NIXL_PREFIX/lib/ucx/$(basename "$so")" + done + + rm -rf "$NIXL_CLONE_DIR" "$UCX_CLONE_DIR" +} + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +resolve_build_mode +echo "NIXL build mode: $NIXL_BUILD_MODE" -if [ ! -f "$NIXL_PREFIX/lib/libnixl.so" ]; then - echo "Error: libnixl.so not found in any nixl libs directory under $SITE_PACKAGES" - echo "Searched directories:" - ls -d "$SITE_PACKAGES"/.nixl*.mesonpy.libs "$SITE_PACKAGES"/nixl*.libs 2>/dev/null || echo " (none found)" +case "$NIXL_BUILD_MODE" in + pip) install_pip ;; + source) install_source ;; + *) + echo "Error: unknown NIXL_BUILD_MODE=$NIXL_BUILD_MODE (expected pip, source, or auto)" exit 1 -fi + ;; +esac -rm -rf "$NIXL_CLONE_DIR" +# Ensure LD_LIBRARY_PATH includes NIXL_PREFIX/lib for runtime +export LD_LIBRARY_PATH="$NIXL_PREFIX/lib:${LD_LIBRARY_PATH:-}" +echo "" echo "NIXL prefix ready at $NIXL_PREFIX" -echo " include: $(ls "$NIXL_PREFIX/include/")" -echo " lib: $(ls -l "$NIXL_PREFIX/lib/")" +echo " include: $(ls "$NIXL_PREFIX/include/" 2>/dev/null || echo '(empty)')" +echo " lib: $(ls "$NIXL_PREFIX/lib/" 2>/dev/null || echo '(empty)')" +echo "" +echo "Remember to set:" +echo " export LD_LIBRARY_PATH=$NIXL_PREFIX/lib:\$LD_LIBRARY_PATH" +echo " export UCX_MODULE_DIR=$NIXL_PREFIX/lib/ucx (if UCX was built from source)" From 5eca79f2e65b2daa6360de9ae939d8ba9a128ec2 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Tue, 17 Mar 2026 14:53:52 +0200 Subject: [PATCH 41/42] remove nixl from clang build --- .github/workflows/build.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 407ddf723a1..d70120e2ba4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,16 +42,13 @@ jobs: working-directory: ${{ env.working_directory }} run: | tools/apt-install-things.sh & - (tools/pip-install-things.sh && NIXL_BUILD_MODE=pip tools/install-nixl.sh) & + tools/pip-install-things.sh & wait source tools/setup-env.sh export NVFUSER_BUILD_NO_CUTLASS=true export NVFUSER_BUILD_CPP_STANDARD=23 export NVFUSER_BUILD_ENABLE_PCH=true - export NVFUSER_BUILD_WITH_NIXL=1 - export NIXL_PREFIX=/tmp/nixl-prefix - export LD_LIBRARY_PATH=$NIXL_PREFIX/lib:${LD_LIBRARY_PATH:-} pip install -v -e ./python --no-build-isolation - name: Show ccache statistics if: always() From 5a568fc5ea0fcb2b63bd4db432aab8d47a2e4977 Mon Sep 17 00:00:00 2001 From: Eyal Chocron Date: Mon, 4 May 2026 11:34:58 +0300 Subject: [PATCH 42/42] Address PR comments --- csrc/multidevice/nixl.cpp | 3 ++- csrc/multidevice/nixl.h | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/csrc/multidevice/nixl.cpp b/csrc/multidevice/nixl.cpp index 7fd69218af5..201fe418dde 100644 --- a/csrc/multidevice/nixl.cpp +++ b/csrc/multidevice/nixl.cpp @@ -76,7 +76,8 @@ nixl_reg_dlist_t buildRegDlist(const std::vector& tensors) { nixl_xfer_dlist_t buildXferDlist(const std::vector& descs) { nixl_xfer_dlist_t dlist(VRAM_SEG); for (const auto& desc : descs) { - dlist.addDesc({desc.addr, desc.size, desc.dev}); + dlist.addDesc( + {reinterpret_cast(desc.addr), desc.size, desc.local_rank}); } return dlist; } diff --git a/csrc/multidevice/nixl.h b/csrc/multidevice/nixl.h index 8226ef1b047..fdc04aaebc1 100644 --- a/csrc/multidevice/nixl.h +++ b/csrc/multidevice/nixl.h @@ -38,9 +38,9 @@ enum class NixlXferStatus : std::uint8_t { // Helper functions for serializing and deserializing tensors descriptors for // TCP store struct TensorDesc { - uintptr_t addr; - size_t size; - uint32_t dev; // CUDA device index (tensor.device().index()) + void* addr; + int64_t size; + uint32_t local_rank; // CUDA device index (tensor.device().index()) int64_t rank; // communicator rank owning this tensor }; static_assert( @@ -49,9 +49,9 @@ static_assert( inline TensorDesc toTensorDesc(const at::Tensor& tensor, int64_t rank) { return { - .addr = reinterpret_cast(tensor.data_ptr()), - .size = static_cast(tensor.numel()) * tensor.element_size(), - .dev = static_cast(tensor.device().index()), + .addr = tensor.data_ptr(), + .size = tensor.numel() * tensor.element_size(), + .local_rank = static_cast(tensor.device().index()), .rank = rank}; } @@ -196,8 +196,8 @@ class NixlBackend { // The returned handle can be posted multiple times (preparation is // amortized). [[nodiscard]] NixlTransferHandle prepareTransfer( - const std::vector& local_descs, - const std::vector& remote_descs, + const std::vector& local_descs, // buffers on this rank + const std::vector& remote_descs, // buffers on the remote peer NixlXferOp op); // Post a previously prepared transfer for execution (non-blocking).