[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443
[Common] Comm+GEMM overlap API updated to support cuBlasMp backend (incl. framework API)#2443denera wants to merge 37 commits intoNVIDIA:mainfrom
Conversation
908bbc2 to
69cf235
Compare
| @@ -17,6 +18,12 @@ | |||
|
|
|||
| #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 | |||
|
|
|||
| /* \brief Check if TE is built with cuBlasMp. | |||
| @@ -526,6 +514,11 @@ class CommOverlapHelper : torch::CustomClassHolder { | |||
| ExtComm comm); | |||
|
|
|||
| void ub_barrier(ExtComm comm); | |||
|
|
|||
| int64_t get_nccl_comm_ptr(std::string comm_name) { | |||
| NVTE_CHECK(backend_is_nccl, "Cannot get nccComm_t ptr if backend is not NCCL."); | |||
There was a problem hiding this comment.
This error message could be more descriptive - e.g. something like "chosen backend for the communication-computation overlap (cuBLASMp) requires NCCL communicator, but the passed ProcessGroup uses a different backend."
4596411 to
b4ad546
Compare
Greptile SummaryThis PR integrates the cuBLASMp backend into the Comm+GEMM overlap API, adding alternative constructors to
Confidence Score: 2/5Not safe to merge — multiple critical runtime issues in the cuBLASMp NCCL initialization path remain unaddressed from prior review rounds. Three P0/P1 findings from previous rounds are still present in the diff: (1) ncclCommInitAll called with numranks > 1 in a multi-process setting corrupts the stack, (2) the "intra" NCCL communicator key is never inserted for single-TP-domain deployments causing a guaranteed runtime crash, and (3) ncclCommGetUniqueId is not part of the public NCCL API. These collectively make the cuBLASMp code path non-functional in any real distributed scenario. transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp (ncclCommInitAll, missing "intra" key, ncclCommGetUniqueId), transformer_engine/pytorch/csrc/common.h (unguarded ProcessGroupNCCL.hpp include), transformer_engine/jax/csrc/extensions/cgemm_helper.cpp (tp_rank vs domain ID) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[initialize_ub / collective_gemm_bootstrap] --> B{with_cublasmp?}
B -- No --> C[CommOverlapHelper\nUserbuffers bootstrap\nnccl_comms empty]
B -- Yes --> D[CommOverlapHelper\nUserbuffers bootstrap +\nncclCommInitAll world +\nncclCommInitRank intra]
D --> E{intra_domain_group\npresent?}
E -- No --> F[nccl_comms: world only\n'intra' key MISSING]
E -- Yes --> G[nccl_comms: world + intra]
C --> H[CommOverlapBase / CommOverlapP2PBase\nUserbuffers constructor]
G --> I[get_nccl_comm intra]
F --> I
I -- key found --> J[CommOverlapBase / CommOverlapP2PBase\ncuBLASMp constructor\nnvte_comm_gemm_ctx_create]
I -- key missing --> K[NVTE_ERROR runtime crash]
J --> L{Overlap method}
L -- AG --> M[cublasmp_ag_gemm\nnvte_all_gather_gemm]
L -- RS --> N[cublasmp_gemm_rs\nnvte_gemm_reduce_scatter]
L -- AR --> O[cublasmp_gemm_ar\nnvte_gemm_all_reduce]
Reviews (19): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
There was a problem hiding this comment.
Additional Comments (8)
-
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 335 (link)logic: Variable shadowing bug:
kis assignedk * _tp_sizewherekappears on both sides. Should bek = k_local * _tp_size. -
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 135 (link)logic: Invalid reinterpret_cast: cannot cast an
int*(pointer) toint(value). Should bereinterpret_cast<void**>(&handler._device_barrier). -
transformer_engine/pytorch/csrc/extensions.h, line 517 (link)syntax: Stray character
athat will cause compilation failure. -
transformer_engine/pytorch/csrc/extensions.h, line 537-540 (link)logic: Constructor parameter mismatch:
CommOverlapBaseconstructor expects(nccl_comm_ptr, tp_rank, tp_size, ...)but called with(nccl_comm_ptr, tp_size, tp_rank, ...). Order oftp_rankandtp_sizeis swapped. -
transformer_engine/pytorch/csrc/extensions.h, line 563-566 (link)logic: Constructor parameter mismatch:
CommOverlapP2PBaseconstructor expects(nccl_comm_ptr, tp_rank, tp_size, ...)but called with(nccl_comm_ptr, tp_size, tp_rank, ...). Order oftp_rankandtp_sizeis swapped. -
transformer_engine/jax/csrc/extensions/cgemm_helper.cpp, line 196-199 (link)logic: Constructor parameter mismatch:
CommOverlapP2PBaseconstructor expects(nccl_comm_ptr, tp_rank, tp_size, ...)but called with(nccl_comm_ptr, tp_size, tp_domain_id, ...). Should usetp_rankinstead oftp_domain_id.Should this use tp_rank (local_device_id_within_tp_domain) or node_id (tp_domain_id)?
-
tests/pytorch/distributed/run_gemm_with_overlap.py, line 416-418 (link)style: Unconditional initialization:
local_kernel2_t_shapeis only used whenub_obj2 is not None, but it's always initialized here. This creates an unused variable in most cases.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
tests/pytorch/distributed/run_gemm_with_overlap.py, line 490 (link)style: Incorrect initialization:
ref2_gshould only be assigned whenub_obj2 is not None, but it's unconditionally assigned to(0,)which doesn't match the expected tensor type.
18 files reviewed, 8 comments
147036f to
c5471f8
Compare
…rk extensions Signed-off-by: Alp Dener <adener@nvidia.com>
…entirely Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
c5471f8 to
d79bf21
Compare
There was a problem hiding this comment.
Additional Comments (3)
-
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 87 (link)logic: return value from
nvte_comm_gemm_ctx_create()is not assigned to_cublasmp_ctx, causing null pointer when destructor callsnvte_comm_gemm_ctx_destroy(_cublasmp_ctx)on line 173 -
transformer_engine/jax/cpp_extensions/gemm.py, line 819 (link)style: parameter order changed -
collective_opmoved aftertranspose_batch_sequence,sequence_dim,is_outer. Verify this matches the calling convention and doesn't break compatibility. Was this parameter reordering intentional, and have all call sites been verified? -
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp, line 89-91 (link)logic:
_tp_sizefield is not initialized in the cuBlasMp constructor, but it's used incublasmp_ag_gemm()andcublasmp_gemm_rs()methods (lines 321, 332). Add initialization
18 files reviewed, 3 comments
364b416 to
ee517d3
Compare
There was a problem hiding this comment.
Additional Comments (6)
-
transformer_engine/pytorch/module/base.py, line 415-417 (link)logic: Parameter order is incorrect - the C++ constructor signature is
(helper, tp_rank, tp_size, ...)but Python is passing(helper, tp_size, local_rank, ...). This swapstp_rankandtp_size, causing incorrect initialization. -
transformer_engine/pytorch/module/base.py, line 387-389 (link)logic: Parameter order is incorrect - the C++ constructor signature is
(helper, tp_rank, tp_size, ...)but Python is passing(helper, tp_size, local_rank, ...). This swapstp_rankandtp_size, causing incorrect initialization. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 340-344 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 355-359 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 383 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters. -
tests/pytorch/distributed/run_gemm_with_overlap.py, line 394 (link)logic: Parameter order is incorrect - C++ signature is
(helper, tp_rank, tp_size, ...)but passing(helper, tp_size, tp_rank, ...). Swap the second and third parameters.
19 files reviewed, 6 comments
Signed-off-by: Alp Dener <adener@nvidia.com>
5cb8204 to
51b64fb
Compare
for more information, see https://pre-commit.ci
| cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority, | ||
| cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/, | ||
| cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag); | ||
| if (cgemm_helper.use_cublasmp) { |
There was a problem hiding this comment.
cgemm_helper is undefined — compilation error
cgemm_helper is not declared anywhere in get_executor(). The local variables in scope are cgemm_config (of type CgemmConfig) and comm_handler (of type CommunicatorHandler). The intended field is cgemm_config.use_cublasmp, which is the field that was added to CgemmConfig in this PR.
| if (cgemm_helper.use_cublasmp) { | |
| if (cgemm_config.use_cublasmp) { |
| NVTE_CHECK_NCCL(ncclCommInitAll(&nccl_world, numranks, nullptr)); | ||
| nccl_comms.insert({"world", nccl_world}); | ||
|
|
||
| if (intra_domain_group.has_value()) { | ||
| // Use the global rank of the local rank 0 process as the unique ID for the intra-node communicator |
There was a problem hiding this comment.
ncclCommInitAll is for single-process multi-GPU — buffer overflow when numranks > 1
ncclCommInitAll(ncclComm_t *comms, int ndev, const int *devlist) is designed to initialise ndev communicators in a single process that owns all ndev GPUs. The first argument must point to an array of at least ndev ncclComm_t objects.
Here nccl_world is a single ncclComm_t on the stack, but numranks (the world size) is passed as ndev. When numranks > 1, NCCL will write numranks communicator handles starting at &nccl_world, overflowing into adjacent stack memory.
In a multi-process distributed setting each process must call ncclGetUniqueId (on rank 0), broadcast the ncclUniqueId, and then call ncclCommInitRank on every rank:
ncclUniqueId world_id;
if (myrank == 0) ncclGetUniqueId(&world_id);
// broadcast world_id to all ranks via the intra-domain torch PG or MPI
ncclComm_t nccl_world;
NVTE_CHECK_NCCL(ncclCommInitRank(&nccl_world, numranks, world_id, myrank));
nccl_comms.insert({"world", nccl_world});…llation) Signed-off-by: Alp Dener <adener@nvidia.com>
a25e667 to
6c6cc4d
Compare
for more information, see https://pre-commit.ci
| --process-id=$i 2>&1 | tee "$LOG_FILE" & | ||
| PID=$! | ||
| PIDS+=($PID) | ||
| BACKENDS=("userbuffers", "cublasmp") |
There was a problem hiding this comment.
Bash array comma syntax creates malformed "userbuffers," element
In bash, array elements are separated by whitespace — commas are treated as literal characters, not separators. This declaration:
BACKENDS=("userbuffers", "cublasmp")creates a two-element array where the first element is "userbuffers," (with a trailing comma) and the second is "cublasmp". Downstream, LOG_FILE and grep patterns expand to ${TEST_NAME}_gpu_0_userbuffers,.log (comma included), which may confuse any tooling that parses the filenames. The fix is to use whitespace separation:
| BACKENDS=("userbuffers", "cublasmp") | |
| BACKENDS=("userbuffers" "cublasmp") |
| ) | ||
| if [ "$BACKEND" == "cublasmp" ]; then | ||
| pytest_args+=("--use-cublasmp") | ||
| fi | ||
|
|
||
| pytest_args=( | ||
| "-s" | ||
| "-c $TE_PATH/tests/jax/pytest.ini" | ||
| "-vs" | ||
| ) |
There was a problem hiding this comment.
--use-cublasmp flag silently discarded — cuBLASMp backend never actually tested
pytest_args+=("--use-cublasmp") on line 110 appends to pytest_args before it is initialised. Then line 113 unconditionally resets pytest_args to a fresh array, discarding the previously appended flag. The result is that pytest is always invoked without --use-cublasmp, so the "cublasmp" loop iteration runs identical Userbuffers tests to the first iteration — the cuBLASMp code path is never exercised.
The fix is to initialise pytest_args before the conditional, then append the backend-specific flag:
pytest_args=(
"-s"
"-c" "$TE_PATH/tests/jax/pytest.ini"
"-vs"
)
if [ "$BACKEND" == "cublasmp" ]; then
pytest_args+=("--use-cublasmp")
fiSigned-off-by: Alp Dener <adener@nvidia.com>
…ormerEngine into common/tp-overlap-cublasmp
| ub_obj = tex.CommOverlap( | ||
| helper, | ||
| local_rank, | ||
| tp_size, | ||
| num_comm_sm=num_sm, | ||
| atomic_gemm=atomic_gemm, | ||
| ) | ||
| else: | ||
| ub_obj = tex.CommOverlap( | ||
| shape, # Communication buffer shape |
There was a problem hiding this comment.
Silent cuBLASMp fallback for bulk overlap goes unannounced
When with_cublasmp=True is passed but the configured method is "bulk", the condition with_cublasmp and method != "bulk" is False and the code silently constructs a regular Userbuffers CommOverlap object instead. The user has explicitly opted into cuBLASMp but will unknowingly receive the Userbuffers backend for every bulk-overlap buffer — with no warning anywhere in the output.
The C++ side already documents the incompatibility (NVTE_CHECK(!_with_cublasmp, "Bulk overlap is not supported with cuBlasMp")), so the Python side should at least surface this to the caller:
else:
if with_cublasmp and method != "bulk":
ub_obj = tex.CommOverlap(
helper,
local_rank,
tp_size,
num_comm_sm=num_sm,
atomic_gemm=atomic_gemm,
)
else:
if with_cublasmp and method == "bulk":
warnings.warn(
"cuBLASMp does not support bulk overlap; falling back to "
"Userbuffers backend for bulk-overlap buffers.",
stacklevel=2,
)
ub_obj = tex.CommOverlap(
shape,
...
)Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
ebe7679 to
5a8c7ae
Compare
Signed-off-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Alp Dener <adener@nvidia.com>
…l with distributed GEMM as reference compute Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
…passing all tests Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
…fer requirement Fix several issues preventing delayed-scaling FP8 collective GEMM tests from passing with the cuBLASMp backend: - Clean up stale NCCL unique ID files between test runs using a sync_global_devices barrier so crashed runs don't poison subsequent ones - Use NumPy instead of JAX ops in process-0-only result checks to avoid multi-process XLA compilation deadlocks - Expose nvte_built_with_cublasmp() to Python and add runtime skip logic in conftest.py and run_test_cgemm.sh - Add cuBLASMp RS output path in gemm.cpp (cuBLASMp writes reduce-scattered result directly into D, unlike Userbuffers which uses an intermediate ubuf) Also document on gemm() and collective_gemm_bootstrap() that XLA command buffers must be disabled when using collective GEMM with communication overlap, since both Userbuffers and cuBLASMp use internal CUDA streams for NCCL collectives that break CUDA graph capture. Signed-off-by: adener <adener@nvidia.com> Signed-off-by: Alp Dener <adener@nvidia.com>
for more information, see https://pre-commit.ci
Description
This PR adds support for the NVTE cuBlasMp bindings in the Comm+GEMM overlap API.
Type of change
Checklist: