Skip to content

Migrate MoE dispatch / combine primitives into MSCCL++#796

Open
seagater wants to merge 36 commits intomainfrom
qinghuazhou/expert_parallel
Open

Migrate MoE dispatch / combine primitives into MSCCL++#796
seagater wants to merge 36 commits intomainfrom
qinghuazhou/expert_parallel

Conversation

@seagater
Copy link
Copy Markdown
Contributor

@seagater seagater commented May 4, 2026

A port of DeepEP's MoE dispatch / combine primitives into MSCCL++, targeting:

High-Throughput (HT) mode from DeepEP, branch chhwang/dev-atomic-add-cleanup — which already swaps NVSHMEM for mscclpp::PortChannel / mscclpp::MemoryChannel.
Low-Latency (LL) mode from nccl/contrib/nccl_ep

Status

Feature Status
Buffer construction + IPC + sync ✅ ported (NVLink + RDMA)
get_dispatch_layout ✅ ported
intranode_dispatch (NVLink) ✅ validated (8 ranks, 1 node)
intranode_combine (NVLink) ✅ validated (8 ranks, 1 node)
internode_dispatch (NVLink+RDMA) ✅ validated (16 ranks, 2xH100x8)
internode_combine (NVLink+RDMA) ✅ validated (16 ranks, 2xH100x8)
low_latency_dispatch (RDMA+IPC) ✅validated (8 ranks intra-node; 16 ranks 2xH100x8)
low_latency_combine (RDMA+IPC) ✅ validated (8 ranks intra-node; 16 ranks 2xH100x8)
Multi-ProxyService sharding ✅ env-tunable, arch-aware default
Connection::atomicAdd API ✅ cherry-picked into mscclpp
Python frontend mscclpp.ext.ep ✅ wraps HT + LL paths
pybind11 module mscclpp_ep_cpp ✅ builds conditionally

See more details in src/ext/ep/README.md

chhwang and others added 30 commits April 20, 2026 18:32
Port DeepEP's high-throughput MoE dispatch/combine kernels onto MSCCL++
as an optional build target `mscclpp_ep_cpp`, gated by -DMSCCLPP_BUILD_EXT_EP
(OFF by default). Sources are lifted from DeepEP branch
`chhwang/dev-atomic-add-cleanup` and rebased onto upstream MSCCL++ APIs;
the NVSHMEM / IBGDA dependencies are replaced with `PortChannel` +
`MemoryChannel` + the new `Connection::atomicAdd` primitive.

Scope
-----
Intranode (NVLink-only):
  * `Buffer` ctor/dtor: cudaMalloc nvl workspace, export IPC handle,
    allocate FIFO + peer-pointer tables, start `ProxyService`.
  * `sync()`: import peer IPC handles, upload peer pointer table,
    build `MemoryDevice2DeviceSemaphore` + `MemoryChannel` per peer.
  * `get_dispatch_layout`, `intranode_dispatch`, `intranode_combine`
    ported verbatim (torch::Tensor ABI preserved).

Internode HT (NVLink + RDMA):
  * `sync()` RDMA branch: cudaMalloc RDMA buffer + `bootstrap->barrier()`
    (replacing NVSHMEM symmetric-heap allocation); register with
    `all_transport`, exchange via `sendMemory`/`recvMemory`, build 12 IB
    QPs/peer + 16 semaphores/peer + 16 port channels/peer.
  * Full `internode.cu` port (notify_dispatch / dispatch / cached_notify
    / combine / get_dispatch_layout). The 4 raw `ChannelTrigger` atomic
    sites are rewritten to call the new
    `PortChannelDeviceHandle::atomicAdd(offset, value)` API; the single
    `nvshmem_fence()` is replaced with `__threadfence_system()` (remote
    visibility guaranteed by the subsequent port-channel barrier).
  * `internode_dispatch` / `internode_combine` host code ported, with
    the torch tensor marshalling and CPU spin-wait on mapped counters.

Low-latency (pure RDMA):
  * Not ported. `low_latency_dispatch`, `low_latency_combine`,
    `clean_low_latency_buffer`, `get_next_low_latency_combine_buffer`
    throw `std::runtime_error`; the Python frontend refuses to
    construct a Buffer with `low_latency_mode=True`.

Python layer
------------
* New pybind11 + libtorch Python extension `mscclpp_ep_cpp` (separate
  from the nanobind `_mscclpp` because the EP ABI carries
  `torch::Tensor` / `at::cuda::CUDAStream`).
* `mscclpp.ext.ep.Buffer` mirrors `deep_ep.Buffer`; exchanges device
  IDs, IPC handles and the bootstrap UniqueId over the user's
  `torch.distributed` process group before calling `sync()`.
* `mscclpp.ext` auto-imports `ep` if the extension is built.

Build
-----
* `src/ext/ep/CMakeLists.txt`: finds Python + Torch; warns and skips if
  `CMAKE_PREFIX_PATH` doesn't point at `torch.utils.cmake_prefix_path`.
  Falls back to Torch's bundled pybind11 if a standalone pybind11 is not
  installed. Links `libtorch_python` explicitly (without it, `import
  mscclpp_ep_cpp` fails with `undefined symbol: THPDtypeType`).
* Top-level `CMakeLists.txt` exposes the `MSCCLPP_BUILD_EXT_EP` option
  (default OFF).

Tests
-----
* `test/python/ext/ep/test_ep_smoke.py`: skipped if the extension isn't
  built. Covers Config round-trip, low-latency size hint, and the LL
  construction guard. Multi-rank functional tests still to do on H100.

Notes
-----
* Builds against the preceding "atomic add" commit which adds
  `Connection::atomicAdd` and `PortChannelDeviceHandle::atomicAdd` to
  upstream MSCCL++.
* Intranode path verified end-to-end (build + import + smoke tests).
* Internode HT is code-complete but requires real IB hardware to
  validate; see `src/ext/ep/README.md` for the detailed port plan and
  remaining LL migration.
Port DeepEP's pure-RDMA low-latency (LL) MoE kernels from
csrc/kernels/internode_ll.cu (branch chhwang/dev-atomic-add-cleanup)
into the MSCCL++ EP extension. NVSHMEM / IBGDA device primitives are
replaced with MSCCL++ PortChannelDeviceHandle operations:

  nvshmemx_barrier_all_block()            -> port-channel signal+wait ring
  nvshmemi_ibgda_put_nbi_warp(...)        -> lane-0 PortChannel.put(...)
  nvshmemi_ibgda_amo_nonfetch_add(...)    -> lane-0 PortChannel.atomicAdd(...)

The atomicAdd path relies on the MSCCL++ Connection::atomicAdd /
PortChannelDeviceHandle::atomicAdd API cherry-picked from branch
chhwang/new-atomic-add; the LL dispatch path uses a signed delta
(-num_tokens_sent - 1) which the new int64_t signature supports.

Changes:
* New file src/ext/ep/kernels/internode_ll.cu (~530 lines) with the
  three kernels clean_low_latency_buffer, dispatch<kUseFP8,...>,
  combine<...> plus their launchers. rdma_buffer_ptr is threaded
  through the launchers so the kernel can translate virtual addresses
  into registered-memory offsets expected by MSCCL++.
* kernels/api.cuh: replace the single stub signature with full LL
  launcher prototypes.
* buffer.cc: replace the four LL throw-stubs
  (clean_low_latency_buffer, low_latency_dispatch,
  low_latency_combine, get_next_low_latency_combine_buffer) with
  torch-Tensor implementations ported from DeepEP/csrc/deep_ep.cpp.
* Drop src/ext/ep/internode_stub.cc and its CMake entry.
* python/mscclpp/ext/ep/buffer.py: remove the low_latency_mode=True
  NotImplementedError guard; update docstring.
* test/python/ext/ep/test_ep_smoke.py: rename
  test_low_latency_rejected -> test_low_latency_buffer_construct
  to reflect that LL construction is now accepted.
* src/ext/ep/README.md: update status matrix, document the
  NVSHMEM -> MSCCL++ translation table, and list the known
  limitations.

This is a structural port: the kernels compile, link, and pass the
single-rank smoke tests, but end-to-end behaviour on multi-node H100
is not yet validated. Two known caveats:

  1. Performance will NOT match IBGDA because MSCCL++ port channels
     use a CPU proxy; this port is for functional parity, not latency.
  2. Buffer::sync() in LL mode only connects peers that share the
     same local GPU id (DeepEP convention), so the LL kernels assume
     a one-GPU-per-node topology (num_ranks == num_rdma_ranks).
     Multi-GPU-per-node LL layouts will need a follow-up in sync().

Tested:
  cmake --build build -j --target mscclpp_ep_cpp   # builds clean
  pytest test/python/ext/ep/test_ep_smoke.py        # 3 passed
Three issues blocked end-to-end intranode validation across multiple
ranks. This commit fixes them and adds a 2/4/8-rank functional test.

1. Combine receiver: OOB __shared__ read

   In the combine receiver warp, the wait loop evaluated
   `channel_tail_idx[recv_lane_id] <= expected_head` before the
   `expected_head >= 0` guard. `channel_tail_idx` is a shared array
   of size `kNumRanks`, but the loop runs on all 32 lanes of a warp,
   so lanes with `recv_lane_id >= kNumRanks` indexed out of bounds.
   compute-sanitizer reported "Invalid __shared__ read of size 4
   bytes" at combine<bf16,2,768>+0xdd0, surfaced asynchronously as
   cudaErrorIllegalAddress at the kernel launch site. Swap the
   operands so the rank-bounds check short-circuits the shared read.

2. Python bindings: UniqueId ABI

   `mscclpp::UniqueId` is a `std::array<uint8_t, N>` which pybind11
   auto-converts to a Python `list`, silently overriding any
   `py::class_<UniqueId>` wrapper. Expose `create_unique_id` /
   `connect` as lambdas that produce/consume `py::bytes` and memcpy
   into a local `UniqueId`. Also coerce `bytes`->`bytearray` at the
   Python call site for `sync()` whose signature expects
   `pybind11::bytearray`.

3. Python frontend: communicator required for NVL-only sync

   `Buffer::sync()` uses `communicator->connect(ipc_config, ...)` on
   the pure-NVLink path, so the communicator must be initialized
   even when `num_rdma_ranks == 1` and `low_latency_mode == False`.
   Always broadcast the unique id and call `runtime.connect()`
   before `sync()`.

Validation on a single H100x8 node via torchrun:
- 2 ranks: dispatch 195 tokens, combine diff=0
- 4 ranks: dispatch 371 tokens, combine diff=0
- 8 ranks: dispatch 456 tokens, combine diff=0

Test harness added at test/python/ext/ep/test_intranode_multirank.py.
The `internode` kernels index device-side port channel handles as
`port_channel_handles[channel_id * num_ranks + peer_rank]`, where
`peer_rank` is a global rank in [0, num_ranks). `Buffer::sync` was
building that table by iterating `std::unordered_map<int, MemoryId>`
(and similarly for connections/semaphores), which yields hash order
rather than ascending rank order. Once the cross-node fan-out grew
beyond a single peer, a local rank's trigger for peer `r` landed on
the semaphore/memory pair of a different peer, so RDMA puts and
atomic tail updates went to the wrong destination and the forwarder
spun on a tail counter that never advanced.

Changes:
  - Build `sema_ids` and `port_channel_handles` by iterating
    `for (int r = 0; r < num_ranks; ++r)` and looking up the
    connection / memory id for rank `r`, skipping ranks excluded by
    low-latency mode (inserting a placeholder handle so the stride
    stays `num_ranks`).
  - Tag the RDMA-phase `sendMemory`/`recvMemory`/`connect` calls with
    `kRdmaTag = 1` so they do not collide with NVL-phase tag-0
    traffic between the same pair of ranks.
  - Drop an unused `r` local in the NVL setup loop.

With this fix and a matched `libmscclpp.so` on both nodes, the
2-node x 8-GPU internode HT dispatch path completes successfully
(`[dispatch] OK`). Combine is still under investigation.

Also adds `test/python/ext/ep/test_internode_multirank.py`, a
torchrun-based 2-node functional test that exercises
`get_dispatch_layout` -> `internode_dispatch` -> `internode_combine`
and validates per-source-rank token values end-to-end.
Two issues prevented internode HT combine from completing on 2x8 H100:

1. Wrong prefix matrices passed to internode_combine. Combine runs in the
   reverse direction of dispatch, so it must consume the receiver-side
   matrices returned by dispatch (recv_rdma_channel_prefix_matrix,
   recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix), not the
   sender-side rdma_channel_prefix_matrix / gbl_channel_prefix_matrix.
   This matches DeepEP's deep_ep/buffer.py::internode_combine handle
   unpacking. Without the fix the NVL forwarder's 'NVL check' timed out
   because token_start_idx/token_end_idx were computed against the wrong
   per-channel layout.

2. Cross-rank race between dispatch and combine. Even with the correct
   matrices, launching combine immediately after dispatch deadlocked the
   forwarder NVL check (tail stuck one short of expected_head) because
   peers still had in-flight dispatch proxy traffic while fast ranks had
   already started combine. A torch.cuda.synchronize() + dist.barrier()
   between the two calls makes the test pass deterministically on 16
   ranks (combine diff == 0, max|expected| up to 60.0).

The barrier in the test is a workaround; the real fix belongs in
Buffer::internode_dispatch / Buffer::internode_combine so the
dispatch->combine handoff fully fences outstanding proxy work across
ranks. Marked with an XXX comment in the test.
Refresh status docs and comments now that internode HT dispatch and
combine have been validated end-to-end on 2 nodes x 8 H100 GPUs via
test/python/ext/ep/test_internode_multirank.py (all 16 ranks recover
their per-rank token payloads with zero diff).

- src/ext/ep/README.md: consolidate the previously duplicated README
  into a single document; mark intranode and internode HT dispatch and
  combine as validated in the status table; add a 'Running the tests'
  section with torchrun examples for both the intranode and the 2x8
  internode setups; record the dispatch->combine
  torch.cuda.synchronize() + dist.barrier() requirement under Known
  limitations; mark Phase 2 DONE and keep Phase 3 (LL) as structural
  port, untested.

- python/mscclpp/ext/ep/buffer.py: update the module docstring and the
  Buffer constructor docstring to say internode HT is validated and
  clarify that LL mode is untested on multi-node hardware.

- src/ext/ep/buffer.cc: drop the stale 'NVSHMEM support not yet ported'
  and 'low-latency paths still stubbed' comments. mscclpp_ep does not
  use NVSHMEM at all (PortChannel/MemoryChannel replace it), and the LL
  paths are a structural port that is present but untested, not stubbed.
  Note validation on 2x H100x8 in the internode section header.
- Buffer::sync no longer drops non-same-GPU-id peers in low_latency_mode.
  DeepEP's original filter was safe because its LL path used NVSHMEM; this
  port drives LL via PortChannel so the kernel indexes
  port_channel_handles[local_expert*num_ranks + dst_rank] for every
  dst_rank. All peers now get a real memory/connection/semaphore/port
  channel entry.
- Add test/python/ext/ep/test_low_latency_multirank.py (LL dispatch+combine
  functional round-trip, BF16 only). Works cross-node in DeepEP's
  1-GPU-per-node topology.
- Known limitation documented in src/ext/ep/README.md and the test docstring:
  intra-node 8-GPU LL currently hangs because every peer transfer routes
  through the CPU proxy over IB loopback between distinct HCAs on the same
  host, and (separately) CudaIpcConnection::atomicAdd is a 64-bit op which
  mis-aligns the 32-bit rdma_recv_count slots when used for same-node
  peers. Proper fix needs a mixed-transport LL variant (MemoryChannel for
  same-node, PortChannel for cross-node) or 64-bit counters.
Gated behind MSCCLPP_EP_BENCH=1 to keep correctness runs fast. Reports
per-iter latency (max across ranks, CUDA-event timed) and aggregate
effective bandwidth (sum across ranks, dispatch+combine payload bytes).
Tunable via MSCCLPP_EP_BENCH_WARMUP / _ITERS / _TOKENS / _HIDDEN.

Bench reuses the Buffer allocated for the correctness phase and
self-skips if the requested hidden exceeds the per-peer NVL/RDMA budget.
Previously the optional benchmark measured full round-trip latency. Split
it to time dispatch alone (N iters) and combine alone (N iters reusing
one dispatch output), reporting per-phase latency (max across ranks) and
aggregate effective bandwidth (sum across ranks).

Applies to intranode HT, internode HT, and the (currently unreachable on
intra-node 8-GPU) LL test. Internode HT keeps the sync+barrier guard
between dispatch and combine but excludes it from either phase's timing.
…o int64

The low-latency dispatch/combine kernels signal recv counts via MSCCL++
PortChannel.atomicAdd, which lowers to IB IBV_WR_ATOMIC_FETCH_AND_ADD.
That opcode requires the remote address to be 8-byte aligned, but
LowLatencyLayout packed the per-expert signaling slots as int32. Odd
slots landed at offset %8 == 4; the NIC silently dropped those atomics
and the target rank spun forever in recv_hook (observed: even->odd
direction works, odd->even does not, across all tested topologies
including 2-rank intra-node, 8-rank intra-node, and 2-node 1-GPU-each).

Widen dispatch_rdma_recv_count_buffer / combine_rdma_recv_flag_buffer to
int64_t, update clean kernel + kernel signatures + next_clean pointers
accordingly, and add int64_t overloads for st_na_release /
ld_acquire_sys_global in utils.cuh.

Also drop the bogus self CUDA-IPC connection in Buffer::sync() that was
previously skewing the cross-rank buildAndAddSemaphore handshake order;
the kernel's same-rank branch uses a direct warp copy and never touches
the self port-channel slot (filled with a zero-initialized placeholder
so the [local_expert*num_ranks + dst_rank] indexing still holds).
Dropping the self ipc_cfg connection caused cudaErrorInvalidResourceHandle
on multi-node launches. Keep the self connection (needed by other code
paths that assume every rank is in the connections map) but continue to
skip the self slot in the semaphore + port-channel construction loops so
the kernel's [local_expert*num_ranks + dst_rank] indexing hits only peer
handles; the self slot is a zero-initialized placeholder since the
kernel's same-rank branch uses a direct warp copy.
The prior commit skipped r==rank in the semaphore and port-channel
build loops on the theory that the self-slot handshake skew was the
cause of LL direction asymmetry. That was wrong (the real bug was
int32 atomic alignment), and skipping self breaks other code paths
that assume every rank slot is represented -- cross-node HT and LL
failed with cudaErrorInvalidResourceHandle at the first barrier after
Buffer init. Restore the self-inclusive loop.
When all ranks live on the same host (num_rdma_ranks == 1), the LL
kernels now bypass PortChannel/IB-loopback entirely. In Buffer::sync()
we additionally:
  - allGather IPC handles for each rank's rdma_buffer_ptr and
    cudaIpcOpenMemHandle them into peer_rdma_bases[]
  - build per-peer MemoryChannels over CUDA IPC connections (tag=2)
    used only for the LL barrier ring

The three LL kernels (clean / dispatch / combine) gain a kIpcPath
template parameter and two extra args (peer_rdma_bases,
memory_channel_handles). At each peer op:
  - put -> peer-mapped warp copy over NVLink
  - atomicAdd-like flag store -> single-writer st_na_release on peer ptr
  - signal/wait barrier -> MemoryChannel signal/wait

Cross-node LL (num_rdma_ranks > 1) is untouched; the IPC setup block is
a no-op. The host launch wrappers select the variant via use_ipc_path.
Each local expert sends one copy per dispatched token back to its owner,
so the bytes actually on the wire during combine match dispatch. The
previous num_tokens×hidden under-counted by ~num_topk×, making combine
BW look artificially low next to dispatch.
- Report both per-rank and aggregate BW to align with NCCL-EP's ep_bench
  (which reports per-rank GB/s).
- Accept MSCCLPP_EP_LL_TOKENS/HIDDEN/TOPK/EXPERTS_PER_RANK env overrides
  so we can match external benchmark problem sizes (NCCL-EP LL defaults
  are num_tokens=128, hidden=7168, top_k=8).
Same alignment with NCCL-EP ep_bench as the LL test: report both
per-rank (agg/num_ranks) and aggregate throughput.
LL dispatch/combine are latency-bound at typical problem sizes: for
num_experts=32 the previous grid was cell_div(32,3)=11 blocks, i.e. 8%
of a 132-SM H100. The recv-side bodies already stride tokens by sm_id,
so extra blocks parallelize token work linearly. Extra blocks past
num_experts are gated out of the send/count phases by the existing
'responsible_expert_idx < num_experts' check.

Cap at the device's SM count (cooperative launch + launch_bounds(960,1)
allow one block per SM).
On the PortChannel (cross-node) path the extra blocks don't help: the
dispatch recv loop strides tokens per-warp-group (not per-SM), and the
additional blocks instead add cooperative-grid sync overhead and
increase concurrent host-proxy FIFO traffic. Measured cross-node
dispatch regressed from 1013us to 3063us when the unconditional grid
bump was active.

Keep the scaled grid for the IPC path (intra-node), where combine-recv
and dispatch token striding scale with sm_id and the 1.2-1.3x speedup
reproduces.
The LL combine benchmark was cloning the ~58 MB dispatch recv buffer
('recv_x.clone()') on every timed iteration, adding ~20 us of D2D
memcpy per sample and masking kernel-level changes. It also called
torch.empty() for the output inside the loop. Both now live outside
the timed region; the kernel is invoked against a persistent bench_out
and the recv_x produced by the most recent dispatch.
NCCL-EP's LL dispatch/combine kernel uses (numWarpGroups=1,
numWarpsPerGroup=32) when num_experts <= device_num_sms, giving each
SM ownership of a single expert and 32 warps to cooperate on its
recv-side per-(expert, src_rank) work. We were using (3, 10) — 3
experts per SM, 10 warps per (expert, rank) pair — which left a
significant amount of recv-side parallelism on the table because each
warp had to walk ~3x more tokens sequentially.

Switching to (1, 32) for both dispatch and combine matches NCCL-EP's
structure for typical EP sizes (num_experts in {32, 64, 256}) where
num_experts <= 132 SMs.

The static_assert kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup
still holds (9 <= 32) and the wider block also lets the staging loop
process the hidden-dim with one int4 per thread (hidden_bf16_int4=896
fits easily in 992 working threads).
Cross-node LL regressed when (1, 32) was applied uniformly: dispatch
1031us -> 1570us, combine 2553us -> 3484us. Larger grid means more
concurrent putWithSignal calls onto the host-proxy FIFO and a costlier
cg::this_grid().sync() between phases, both of which dominate the IB
path even though more SMs help the recv-side compute.

Make (kNumWarpGroups, kNumWarpsPerGroup) path-dependent: (1, 32) when
use_ipc_path, (3, 10) otherwise. Restores cross-node performance and
keeps the intra-node win.
- Add MSCCLPP_EP_BENCH_EXPERTS / _TOPK env knobs so the bench phase can
  match NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8). The
  functional check above continues to use the smaller (num_ranks*4
  experts, topk=4) configuration.

- Switch BW accounting from recv_tokens*hidden to bench_tokens*hidden,
  matching NCCL-EP's `RDMA_send` per-rank byte count. The previous
  formula counted DeepEP's expanded recv layout (one row per
  (token,src_rank) pair), inflating reported GB/s ~5x and making
  cross-stack comparisons misleading.
Same change as the intra-node bench (commit 4ed6f22), applied to the
cross-node test:

- Add MSCCLPP_EP_BENCH_EXPERTS / _TOPK env knobs so the bench phase can
  match NCCL-EP's `ep_bench -a ht` defaults (256 experts, top-8).
- Switch BW accounting from recv_tokens*hidden to bench_tokens*hidden,
  matching NCCL-EP's `RDMA_send` per-rank byte count.
Each mscclpp::ProxyService spawns one host-side proxy thread that
drains its FIFO and posts IB work requests. With LL combine pushing
~1k put + 60 atomicAdd FIFO entries per iter, that single thread is
the wall-clock bottleneck on cross-node runs.

Split the channel set across kNumProxyServices=4 separate services
so the host-side dispatch parallelism scales linearly. SemaphoreIds
and MemoryIds are scoped to a ProxyService, so:

- addMemory() is broadcast to every service in the same global order
  so a single MemoryId still identifies the memory everywhere.
- Each (peer_rank, channel_idx) is assigned to one proxy_idx via
  round-robin; the resulting PortChannel is built on that proxy and
  inherits its FIFO. The kernel is unchanged: the flat handle array
  routes the right way automatically.

No kernel-level changes, no tuning of QP count, no new env knobs.
… 1 on Blackwell)

Override at runtime with MSCCLPP_EP_NUM_PROXIES.
N=8 is the knee on H100+IB; N>=12 collapses from CPU oversubscription.
Intra-node LL is unchanged.
Add dist.barrier() + dist.destroy_process_group() in a finally block so
non-zero ranks don't poll the TCPStore after rank 0 (the store server)
exits, which produced noisy 'recvValue failed / Connection was likely
closed' stack traces from ProcessGroupNCCL's HeartbeatMonitor.

Also pass device_id to init_process_group in the internode test to
silence 'Guessing device ID based on global rank' warnings.
Aligns with NCCL-EP's ep_bench convention (BW computed from average time
across ranks). Previously we reported only the max time and computed BW
per-rank, which made our numbers more pessimistic than NCCL-EP's.
…rdma)

For HT intra/internode benches, compute per-rank avg total_send/rdma_send
and total_recv/rdma_recv token counts (matching NCCL-EP ep_bench
accounting) and print send-side and recv-side BW split into total / nvl
/ rdma columns. Combine reverses send<->recv. Byte-count line mirrors
NCCL-EP's '(per rank avg)' summary so numbers are directly comparable.
Set TORCH_NCCL_ENABLE_MONITORING=0 before importing torch.distributed.
The barrier+destroy_process_group finally block (afbdcd6) suffices
under torchrun, but under mpirun rank 0 (the TCPStore server) can exit
before non-zero ranks finish teardown, and the background heartbeat
thread polls the store and logs 'recvValue failed / Connection was
likely closed'. Disabling the monitor outright is safe for short-lived
bench runs.
…h NCCL-EP

Previously total_send_tokens was Sigma over dst_rank of num_tokens_per_rank
which over-counts intra-node fan-out. NCCL-EP's ep_bench collapses
multiple destinations on the same node into one count; on a single-node
run that means total_send_tokens = number of tokens with at least one
valid expert. Switching to is_token_in_rank.any(dim=1).sum() makes the
send-side BW comparable to NCCL-EP's send: total_bw / nvl_bw line.
Add optional out_packed_recv_x / out_src_info / out_layout_range /
out_count parameters to Buffer::low_latency_dispatch so callers can
hoist the four recv-side allocations out of a hot loop, mirroring the
existing out= path on low_latency_combine.

The bench in test_low_latency_multirank.py preallocates these tensors
once and passes them on every iter so the timed loop reflects kernel
cost, not torch.empty + caching-allocator overhead.
@seagater seagater requested review from a team and Copilot May 4, 2026 16:50
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR ports DeepEP/NCCL-EP Mixture-of-Experts (MoE) expert-parallel dispatch/combine functionality into MSCCL++ by adding a new optional EP extension (CUDA kernels + pybind11 module + Python frontend) and extending MSCCL++ PortChannel/Connection plumbing with a remote atomicAdd trigger needed by the EP low-latency path.

Changes:

  • Add src/ext/ep EP extension (kernels, Buffer/Config/EventHandle, pybind11 module, build system integration) and a Python wrapper under python/mscclpp/ext/ep.
  • Add a new Connection::atomicAdd API and PortChannel FIFO trigger type (type==0) with unit tests covering IPC/IB/Ethernet modes.
  • Add Python tests for intranode/internode HT and LL multi-rank round-trips plus a single-rank smoke test.

Reviewed changes

Copilot reviewed 37 out of 38 changed files in this pull request and generated 22 comments.

Show a summary per file
File Description
test/python/ext/ep/test_low_latency_multirank.py New multi-rank LL functional test and optional micro-benchmark.
test/python/ext/ep/test_intranode_multirank.py New multi-rank intranode (HT/NVLink) functional test and optional micro-benchmark.
test/python/ext/ep/test_internode_multirank.py New multi-rank internode (HT/NVLink+RDMA) functional test and optional micro-benchmark.
test/python/ext/ep/test_ep_smoke.py Single-rank CI-friendly smoke tests for the EP extension.
test/mp_unit/port_channel_tests.cu Adds a concurrent PortChannel atomicAdd device test kernel and test cases.
test/mp_unit/mp_unit_tests.hpp Declares the new PortChannelOneToOneTest::testAtomicAdd helper.
src/ext/ep/README.md EP extension overview, build/run instructions, and validation notes.
src/ext/ep/kernels/utils.cuh Adds DeepEP-derived low-level device utilities used by EP kernels.
src/ext/ep/kernels/runtime.cu Adds intranode barrier launcher used by EP intranode kernels.
src/ext/ep/kernels/launch.cuh Adds cooperative launch helpers and SWITCH_* macros for EP kernels.
src/ext/ep/kernels/intranode_kernel.cu Adds intranode dispatch/combine kernels (HT/NVLink).
src/ext/ep/kernels/internode_ll.cu Adds internode low-latency dispatch/combine kernels (PortChannel/MemoryChannel-based).
src/ext/ep/kernels/exception.cuh Adds EP-local CUDA error/assert helpers.
src/ext/ep/kernels/configs.cuh Adds EP kernel constants and CUDA type includes.
src/ext/ep/kernels/buffer.cuh Adds device-side buffer view helpers for EP kernels.
src/ext/ep/kernels/api.cuh Declares host-callable EP kernel entrypoints (intranode/internode/LL).
src/ext/ep/event.hpp Adds an EP EventHandle wrapper for stream/event hand-off.
src/ext/ep/config.hpp Adds EP Config and low-latency RDMA layout + size hint helper.
src/ext/ep/CMakeLists.txt Adds conditional EP build (Torch/pybind11) and installs outputs.
src/ext/ep/buffer.hpp Declares the EP Buffer C++ API exposed to Python.
src/ext/ep/bindings.cpp Defines the mscclpp_ep_cpp pybind11 module.
src/ext/CMakeLists.txt Wires MSCCLPP_BUILD_EXT_EP to include the EP subdirectory.
src/core/port_channel.cc Teaches ProxyService to handle type==0 triggers as atomicAdd.
src/core/include/context.hpp Adds proxy-side atomic stream/context members and CudaIpcStream::atomicAdd.
src/core/include/connection.hpp Adds virtual atomicAdd to connection implementations.
src/core/context.cc Adds CudaIpcStream cleanup and notes about not syncing proxy atomic stream in sync().
src/core/connection.cc Implements Connection::atomicAdd for CUDA IPC, IB, and Ethernet transports.
src/core/atomicadd_kernel.cu Adds a GPU kernel used for proxy-driven 64-bit atomic add on CUDA IPC.
python/mscclpp/ext/ep/buffer.py Adds the torch.distributed-aware Python EP Buffer wrapper.
python/mscclpp/ext/ep/init.py Exposes EP Buffer/Config/EventHandle from the Python package.
python/mscclpp/ext/init.py Optionally imports mscclpp.ext.ep when the extension is built.
include/mscclpp/port_channel_device.hpp Adds PortChannelDeviceHandle::atomicAdd trigger helper.
include/mscclpp/gpu.hpp Extends HIP compatibility typedefs/macros for driver-like context calls.
include/mscclpp/fifo_device.hpp Documents reserving trigger type==0 for atomic add.
include/mscclpp/core.hpp Adds public Connection::atomicAdd API declaration.
CMakeLists.txt Adds top-level MSCCLPP_BUILD_EXT_EP option.

Comment on lines +19 to +25
#define SWITCH_RANKS(case_macro) \
switch (num_ranks) { \
case 2: case_macro(2); \
case 4: case_macro(4); \
case 8: case_macro(8); \
default: EP_HOST_ASSERT(false and "Unsupported ranks"); \
} while (false)
Comment on lines +1 to +9
#pragma once

#include "exception.cuh"

#define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \
{ \
constexpr int kLoopStride = 32 * (UNROLL_FACTOR); \
typename std::remove_reference<decltype(LD_FUNC((SRC) + 0))>::type unrolled_values[(UNROLL_FACTOR)]; \
auto __src = (SRC); \
Comment on lines +679 to +684
int min_head = std::numeric_limits<int>::max();
#pragma unroll
for (int i = 1; i < num_recv_warps; ++ i) if (not warp_retired[i])
min_head = min(min_head, warp_channel_head_idx[i][recv_lane_id]);
if (min_head != std::numeric_limits<int>::max() and min_head > last_head)
st_relaxed_sys_global(channel_head_idx_ptr, last_head = min_head);
Comment thread src/ext/ep/CMakeLists.txt
Comment on lines +107 to +146
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cc *.cpp *.cu)

if(MSCCLPP_USE_ROCM)
file(GLOB_RECURSE CU_SOURCES *.cu)
set_source_files_properties(${CU_SOURCES} PROPERTIES LANGUAGE CXX)
endif()

add_library(mscclpp_ep SHARED ${SOURCES})

target_include_directories(mscclpp_ep PRIVATE
include
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/src/core/include
${PROJECT_SOURCE_DIR}/src/ext/include
${GPU_INCLUDE_DIRS}
)

target_link_libraries(mscclpp_ep PUBLIC mscclpp)
target_link_libraries(mscclpp_ep PRIVATE ${GPU_LIBRARIES} Threads::Threads)

set_target_properties(mscclpp_ep PROPERTIES
LINKER_LANGUAGE CXX
POSITION_INDEPENDENT_CODE 1
VERSION ${MSCCLPP_VERSION}
SOVERSION ${MSCCLPP_SOVERSION})

if(MSCCLPP_USE_CUDA)
target_compile_definitions(mscclpp_ep PRIVATE MSCCLPP_USE_CUDA)
elseif(MSCCLPP_USE_ROCM)
target_compile_definitions(mscclpp_ep PRIVATE MSCCLPP_USE_ROCM)
foreach(arch ${MSCCLPP_GPU_ARCHS})
target_compile_options(mscclpp_ep PRIVATE --offload-arch=${arch})
endforeach()
endif()

install(TARGETS mscclpp_ep
LIBRARY DESTINATION ${INSTALL_PREFIX}/lib)
Comment on lines +1 to +3
"""Multi-rank low-latency functional test for mscclpp_ep.

Launch with (intra-node, 8 GPUs):
Comment on lines +1 to +4
#pragma once

#include "configs.cuh"
#include "exception.cuh"
Comment on lines +1 to +4
#pragma once

#include <string>
#include <exception>
Comment on lines +1 to +5
#include "configs.cuh"
#include "buffer.cuh"
#include "exception.cuh"
#include "launch.cuh"
#include "utils.cuh"
Comment thread src/ext/ep/buffer.hpp Outdated
Comment on lines +1 to +3
#pragma once

// Forcibly disable NDEBUG
Comment on lines +5 to +9
These tests only exercise single-rank / pure-Python code paths so they can
run in CI without multi-GPU resources. Multi-rank dispatch/combine tests
belong in ``test/python/ext/ep/test_intranode.py`` and are left as TODO
until the Python frontend is validated on H100.

@Binyang2014
Copy link
Copy Markdown
Contributor

@copilot pls review this PR carefully, focus on any potential memory allocation issue, resource lifecycle, point out possible performance bottleneck and give the suggestion

seagater and others added 2 commits May 5, 2026 19:24
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants