Skip to content

[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443

Open
denera wants to merge 37 commits intoNVIDIA:mainfrom
denera:common/tp-overlap-cublasmp
Open

[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443
denera wants to merge 37 commits intoNVIDIA:mainfrom
denera:common/tp-overlap-cublasmp

Conversation

@denera
Copy link
Copy Markdown
Collaborator

@denera denera commented Dec 2, 2025

Description

This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera self-assigned this Dec 2, 2025
@denera denera force-pushed the common/tp-overlap-cublasmp branch 2 times, most recently from 908bbc2 to 69cf235 Compare December 2, 2025 20:12
Comment thread transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
@@ -17,6 +18,12 @@

#define NVTE_COMM_OVERLAP_MAX_STREAMS 3

/* \brief Check if TE is built with cuBlasMp.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: cuBLASMp

@@ -526,6 +514,11 @@ class CommOverlapHelper : torch::CustomClassHolder {
ExtComm comm);

void ub_barrier(ExtComm comm);

int64_t get_nccl_comm_ptr(std::string comm_name) {
NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL.");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 4596411 to b4ad546 Compare December 16, 2025 19:04
@denera denera marked this pull request as ready for review December 16, 2025 22:58
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Dec 16, 2025

Greptile Summary

This PR integrates the cuBLASMp backend into the Comm+GEMM overlap API, adding alternative constructors to CommOverlapCore, CommOverlapBase, and CommOverlapP2PBase, then dispatching to the cuBLASMp kernels (nvte_all_gather_gemm, nvte_gemm_reduce_scatter, nvte_gemm_all_reduce) at each overlap call site. Framework-level wiring (PyTorch initialize_ub, JAX collective_gemm_bootstrap) and test infrastructure are updated accordingly.

  • Several critical runtime issues flagged in prior review rounds remain unaddressed in this diff: ncclCommInitAll is called in a distributed multi-process context where it is single-process-only (stack overflow when numranks > 1), the \"intra\" NCCL communicator key is never inserted for the single-TP-domain case (runtime crash on get_nccl_comm(\"intra\")), and ncclCommGetUniqueId is a non-standard NCCL API not in the public headers.
  • ProcessGroupNCCL.hpp is #included in common.h without an #ifdef NVTE_WITH_CUBLASMP guard, pulling an NCCL-specific PyTorch header into every PyTorch extension build regardless of feature selection.

Confidence Score: 2/5

Not safe to merge — multiple critical runtime issues in the cuBLASMp NCCL initialization path remain unaddressed from prior review rounds.

Three P0/P1 findings from previous rounds are still present in the diff: (1) ncclCommInitAll called with numranks > 1 in a multi-process setting corrupts the stack, (2) the "intra" NCCL communicator key is never inserted for single-TP-domain deployments causing a guaranteed runtime crash, and (3) ncclCommGetUniqueId is not part of the public NCCL API. These collectively make the cuBLASMp code path non-functional in any real distributed scenario.

transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp (ncclCommInitAll, missing "intra" key, ncclCommGetUniqueId), transformer_engine/pytorch/csrc/common.h (unguarded ProcessGroupNCCL.hpp include), transformer_engine/jax/csrc/extensions/cgemm_helper.cpp (tp_rank vs domain ID)

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp Adds cuBLASMp NCCL communicator initialization; multiple critical issues remain (ncclCommInitAll misuse, missing "intra" key in single-domain case, non-standard ncclCommGetUniqueId API) flagged in prior review rounds.
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp Adds cuBLASMp constructor and dispatch methods (cublasmp_ag_gemm, cublasmp_gemm_rs, cublasmp_gemm_ar); destructor guards correctly gate Userbuffers teardown; overall logic is sound.
transformer_engine/pytorch/csrc/common.h ProcessGroupNCCL.hpp added without #ifdef NVTE_WITH_CUBLASMP guard, pulling a NCCL-specific PyTorch header into all builds regardless of feature flags.
transformer_engine/pytorch/module/base.py with_cublasmp parameter added to initialize_ub; bulk-overlap silently falls back to Userbuffers without warning (previously flagged); local_rank correctly represents TP-group rank.
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp use_cublasmp flag wired through plan registry; use_cublasmp correctly added to plan_id hash; tp_rank passed as get_tp_domain_id() (wrong: should be local rank within TP group) — previously flagged.
examples/jax/collective_gemm/run_test_cgemm.sh Backend loop structure significantly improved; prior issues (premature rm, bash array comma, concurrent launches, discarded --use-cublasmp flag) largely addressed; a redundant wait at the end remains.
tests/pytorch/distributed/run_gemm_with_overlap.py Reference computation refactored to per-rank local comparisons avoiding global gather; cuBLASMp path wired through all GEMM variants; atol loosened to 0.002 for non-quantized runs.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[initialize_ub / collective_gemm_bootstrap] --> B{with_cublasmp?}
    B -- No --> C[CommOverlapHelper\nUserbuffers bootstrap\nnccl_comms empty]
    B -- Yes --> D[CommOverlapHelper\nUserbuffers bootstrap +\nncclCommInitAll world +\nncclCommInitRank intra]
    D --> E{intra_domain_group\npresent?}
    E -- No --> F[nccl_comms: world only\n'intra' key MISSING]
    E -- Yes --> G[nccl_comms: world + intra]
    C --> H[CommOverlapBase / CommOverlapP2PBase\nUserbuffers constructor]
    G --> I[get_nccl_comm intra]
    F --> I
    I -- key found --> J[CommOverlapBase / CommOverlapP2PBase\ncuBLASMp constructor\nnvte_comm_gemm_ctx_create]
    I -- key missing --> K[NVTE_ERROR runtime crash]
    J --> L{Overlap method}
    L -- AG --> M[cublasmp_ag_gemm\nnvte_all_gather_gemm]
    L -- RS --> N[cublasmp_gemm_rs\nnvte_gemm_reduce_scatter]
    L -- AR --> O[cublasmp_gemm_ar\nnvte_gemm_all_reduce]
Loading

Reviews (19): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (8)

  1. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 335 (link)

    logic: Variable shadowing bug: k is assigned k * _tp_size where k appears on both sides. Should be k = k_local * _tp_size.

  2. transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 135 (link)

    logic: Invalid reinterpret_cast: cannot cast an int* (pointer) to int (value). Should be reinterpret_cast<void**>(&handler._device_barrier).

  3. transformer_engine/pytorch/csrc/extensions.h, line 517 (link)

    syntax: Stray character a that will cause compilation failure.

  4. transformer_engine/pytorch/csrc/extensions.h, line 537-540 (link)

    logic: Constructor parameter mismatch: CommOverlapBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_rank, ...). Order of tp_rank and tp_size is swapped.

  5. transformer_engine/pytorch/csrc/extensions.h, line 563-566 (link)

    logic: Constructor parameter mismatch: CommOverlapP2PBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_rank, ...). Order of tp_rank and tp_size is swapped.

  6. transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 196-199 (link)

    logic: Constructor parameter mismatch: CommOverlapP2PBase constructor expects (nccl_comm_ptr, tp_rank, tp_size, ...) but called with (nccl_comm_ptr, tp_size, tp_domain_id, ...). Should use tp_rank instead of tp_domain_id.

    Should this use tp_rank (local_device_id_within_tp_domain) or node_id (tp_domain_id)?

  7. tests/pytorch/distributed/run_gemm_with_overlap.py, line 416-418 (link)

    style: Unconditional initialization: local_kernel2_t_shape is only used when ub_obj2 is not None, but it's always initialized here. This creates an unused variable in most cases.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  8. tests/pytorch/distributed/run_gemm_with_overlap.py, line 490 (link)

    style: Incorrect initialization: ref2_g should only be assigned when ub_obj2 is not None, but it's unconditionally assigned to (0,) which doesn't match the expected tensor type.

18 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

@denera denera force-pushed the common/tp-overlap-cublasmp branch from 147036f to c5471f8 Compare December 17, 2025 02:15
denera and others added 6 commits December 17, 2025 02:16
…rk extensions

Signed-off-by: Alp Dener <adener@nvidia.com>
…entirely

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from c5471f8 to d79bf21 Compare December 17, 2025 02:16
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 87 (link)

    logic: return value from nvte_comm_gemm_ctx_create() is not assigned to _cublasmp_ctx, causing null pointer when destructor calls nvte_comm_gemm_ctx_destroy(_cublasmp_ctx) on line 173

  2. transformer_engine/jax/cpp_extensions/gemm.py, line 819 (link)

    style: parameter order changed - collective_op moved after transpose_batch_sequence, sequence_dim, is_outer. Verify this matches the calling convention and doesn't break compatibility. Was this parameter reordering intentional, and have all call sites been verified?

  3. transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 89-91 (link)

    logic: _tp_size field is not initialized in the cuBlasMp constructor, but it's used in cublasmp_ag_gemm() and cublasmp_gemm_rs() methods (lines 321, 332). Add initialization

18 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 364b416 to ee517d3 Compare December 17, 2025 02:50
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. transformer_engine/pytorch/module/base.py, line 415-417 (link)

    logic: Parameter order is incorrect - the C++ constructor signature is (helper, tp_rank, tp_size, ...) but Python is passing (helper, tp_size, local_rank, ...). This swaps tp_rank and tp_size, causing incorrect initialization.

  2. transformer_engine/pytorch/module/base.py, line 387-389 (link)

    logic: Parameter order is incorrect - the C++ constructor signature is (helper, tp_rank, tp_size, ...) but Python is passing (helper, tp_size, local_rank, ...). This swaps tp_rank and tp_size, causing incorrect initialization.

  3. tests/pytorch/distributed/run_gemm_with_overlap.py, line 340-344 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  4. tests/pytorch/distributed/run_gemm_with_overlap.py, line 355-359 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  5. tests/pytorch/distributed/run_gemm_with_overlap.py, line 383 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

  6. tests/pytorch/distributed/run_gemm_with_overlap.py, line 394 (link)

    logic: Parameter order is incorrect - C++ signature is (helper, tp_rank, tp_size, ...) but passing (helper, tp_size, tp_rank, ...). Swap the second and third parameters.

19 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from 5cb8204 to 51b64fb Compare December 17, 2025 03:36
Comment thread transformer_engine/common/CMakeLists.txt Outdated
Comment thread transformer_engine/jax/cpp_extensions/gemm.py Outdated
cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority,
cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/,
cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag);
if (cgemm_helper.use_cublasmp) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cgemm_helper is undefined — compilation error

cgemm_helper is not declared anywhere in get_executor(). The local variables in scope are cgemm_config (of type CgemmConfig) and comm_handler (of type CommunicatorHandler). The intended field is cgemm_config.use_cublasmp, which is the field that was added to CgemmConfig in this PR.

Suggested change
if (cgemm_helper.use_cublasmp) {
if (cgemm_config.use_cublasmp) {

Comment on lines +75 to +79
NVTE_CHECK_NCCL(ncclCommInitAll(&nccl_world, numranks, nullptr));
nccl_comms.insert({"world", nccl_world});

if (intra_domain_group.has_value()) {
// Use the global rank of the local rank 0 process as the unique ID for the intra-node communicator
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ncclCommInitAll is for single-process multi-GPU — buffer overflow when numranks > 1

ncclCommInitAll(ncclComm_t *comms, int ndev, const int *devlist) is designed to initialise ndev communicators in a single process that owns all ndev GPUs. The first argument must point to an array of at least ndev ncclComm_t objects.

Here nccl_world is a single ncclComm_t on the stack, but numranks (the world size) is passed as ndev. When numranks > 1, NCCL will write numranks communicator handles starting at &nccl_world, overflowing into adjacent stack memory.

In a multi-process distributed setting each process must call ncclGetUniqueId (on rank 0), broadcast the ncclUniqueId, and then call ncclCommInitRank on every rank:

ncclUniqueId world_id;
if (myrank == 0) ncclGetUniqueId(&world_id);
// broadcast world_id to all ranks via the intra-domain torch PG or MPI
ncclComm_t nccl_world;
NVTE_CHECK_NCCL(ncclCommInitRank(&nccl_world, numranks, world_id, myrank));
nccl_comms.insert({"world", nccl_world});

…llation)

Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from a25e667 to 6c6cc4d Compare March 16, 2026 17:14
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
BACKENDS=("userbuffers", "cublasmp")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bash array comma syntax creates malformed "userbuffers," element

In bash, array elements are separated by whitespace — commas are treated as literal characters, not separators. This declaration:

BACKENDS=("userbuffers", "cublasmp")

creates a two-element array where the first element is "userbuffers," (with a trailing comma) and the second is "cublasmp". Downstream, LOG_FILE and grep patterns expand to ${TEST_NAME}_gpu_0_userbuffers,.log (comma included), which may confuse any tooling that parses the filenames. The fix is to use whitespace separation:

Suggested change
BACKENDS=("userbuffers", "cublasmp")
BACKENDS=("userbuffers" "cublasmp")

Comment on lines +108 to +117
)
if [ "$BACKEND" == "cublasmp" ]; then
pytest_args+=("--use-cublasmp")
fi

pytest_args=(
"-s"
"-c $TE_PATH/tests/jax/pytest.ini"
"-vs"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--use-cublasmp flag silently discarded — cuBLASMp backend never actually tested

pytest_args+=("--use-cublasmp") on line 110 appends to pytest_args before it is initialised. Then line 113 unconditionally resets pytest_args to a fresh array, discarding the previously appended flag. The result is that pytest is always invoked without --use-cublasmp, so the "cublasmp" loop iteration runs identical Userbuffers tests to the first iteration — the cuBLASMp code path is never exercised.

The fix is to initialise pytest_args before the conditional, then append the backend-specific flag:

      pytest_args=(
        "-s"
        "-c" "$TE_PATH/tests/jax/pytest.ini"
        "-vs"
      )
      if [ "$BACKEND" == "cublasmp" ]; then
        pytest_args+=("--use-cublasmp")
      fi

denera added 2 commits March 16, 2026 17:57
Signed-off-by: Alp Dener <adener@nvidia.com>
Comment on lines +439 to +448
ub_obj = tex.CommOverlap(
helper,
local_rank,
tp_size,
num_comm_sm=num_sm,
atomic_gemm=atomic_gemm,
)
else:
ub_obj = tex.CommOverlap(
shape, # Communication buffer shape
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent cuBLASMp fallback for bulk overlap goes unannounced

When with_cublasmp=True is passed but the configured method is "bulk", the condition with_cublasmp and method != "bulk" is False and the code silently constructs a regular Userbuffers CommOverlap object instead. The user has explicitly opted into cuBLASMp but will unknowingly receive the Userbuffers backend for every bulk-overlap buffer — with no warning anywhere in the output.

The C++ side already documents the incompatibility (NVTE_CHECK(!_with_cublasmp, "Bulk overlap is not supported with cuBlasMp")), so the Python side should at least surface this to the caller:

else:
    if with_cublasmp and method != "bulk":
        ub_obj = tex.CommOverlap(
            helper,
            local_rank,
            tp_size,
            num_comm_sm=num_sm,
            atomic_gemm=atomic_gemm,
        )
    else:
        if with_cublasmp and method == "bulk":
            warnings.warn(
                "cuBLASMp does not support bulk overlap; falling back to "
                "Userbuffers backend for bulk-overlap buffers.",
                stacklevel=2,
            )
        ub_obj = tex.CommOverlap(
            shape,
            ...
        )

Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
@denera denera force-pushed the common/tp-overlap-cublasmp branch from ebe7679 to 5a8c7ae Compare March 17, 2026 20:02
@denera denera added 2.15.0 and removed 2.14.0 labels Mar 23, 2026
@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
denera and others added 2 commits April 24, 2026 16:45
…fer requirement

Fix several issues preventing delayed-scaling FP8 collective GEMM tests
from passing with the cuBLASMp backend:

- Clean up stale NCCL unique ID files between test runs using a
  sync_global_devices barrier so crashed runs don't poison subsequent ones
- Use NumPy instead of JAX ops in process-0-only result checks to avoid
  multi-process XLA compilation deadlocks
- Expose nvte_built_with_cublasmp() to Python and add runtime skip logic
  in conftest.py and run_test_cgemm.sh
- Add cuBLASMp RS output path in gemm.cpp (cuBLASMp writes reduce-scattered
  result directly into D, unlike Userbuffers which uses an intermediate ubuf)

Also document on gemm() and collective_gemm_bootstrap() that XLA command
buffers must be disabled when using collective GEMM with communication
overlap, since both Userbuffers and cuBLASMp use internal CUDA streams
for NCCL collectives that break CUDA graph capture.

Signed-off-by: adener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants