Skip to content

[Pyt][Common] Enabling/Guarding sm120 support (non - attention)#2833

Open
KshitijLakhani wants to merge 29 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/maint/sm120-pyt-cpp-support
Open

[Pyt][Common] Enabling/Guarding sm120 support (non - attention)#2833
KshitijLakhani wants to merge 29 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/maint/sm120-pyt-cpp-support

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented Apr 3, 2026

Description

This PR is a follow up to : #2693.

PR #2693 aimed to enable/guard PyT attention for sm120
This PR aims to enable/guard non-attention for sm120 (and a small attn related regression fix)

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Runtime/backend guards for SM120 correctness

  • Disabled gated TMA backward kernels to avoid kernel launch failures tied to shared-memory constraints.
  • Forced unfused NVFP4 RHT path to avoid fused RHT shared-memory/resource overreach.
  • Disabled NVFP4 stochastic rounding on backend paths in csrc/quantizer.cpp due to unsupported .rs PTX.
  • Added grouped NVFP4 fallback in cast extension (csrc/extensions/cast.cpp) to use safer per-split processing.
  • Added grouped GEMM runtime guard in gemm/cublaslt_grouped_gemm.cu because cuBLASLt grouped GEMM heuristic returns unsupported (for affected BF16/FP8 cases).

General Bug fix (not SM120 specific)
I stumbled upon this bug specifically when I was testing on SM120, but it is an arch agnostic fix.

  • Fixed MXFP8 CAST_DBIAS shared-memory handoff race
    • Ensured async shared->global source consumption is complete and all warps reach safe reuse point

NVFP4 grouped quantization layout consistency for SM120

  • Aligned grouped NVFP4 metadata with actual SM120 fallback output layout in csrc/quantizer.cpp:
    • default: metadata follows optimize_for_gemm,
    • SM120 grouped fallback (first_dims present): force unswizzled metadata.
  • Propagated grouped layout metadata into split tensor views in grouped_tensor_storage.py so split tensors inherit true grouped layout state.
  • Updated grouped NVFP4 tests in test_nvfp4_group_quantize_graph_safe.py to compare against metadata-selected reference layout and use scoped SM120 tolerance behavior.

** Test changes (SM120 specific)**

  • NVFP4 SR tests: in test_nvfp4_sr_quantize.py, changed SM120 expectation from SR < RN to numerical equivalence (assert_close) because SR is disabled on SM120 backend.
  • FP8 CS numerics: in run_layer_with_overlap.py, added SM120-only looser tolerance for fp8_current_scaling (rtol=0.4, atol=0.25) in deterministic fallback backend scenarios (I borrowed these tolerances from the corresponding distributed test file run_numerics.py)
  • BF16 multi-layer overlap numerics: added narrowly-scoped SM120 tolerance relaxation (rtol=0.05, atol=0.01) for TransformerLayer, multi-layer, overlap_rs_dgrad when deterministic mode routes away from fused attention.
  • THD-vs-dense tolerance and grouped GEMM skips in test_numerics.py, C++ grouped GEMM operator tests, and PyTorch grouped GEMM numerics to match explicit SM120 unsupported/runtime-guarded paths.
  • Skipped SM120 NVFP4 paged-stashing grouped-quantize case due to observed IMA in current kernel assumptions for paged layouts.

SM120 coverage/test harness updates

  • Made custom-recipe grouped-linear shapes 16-aligned on SM120 because the SM120 FP8 GEMM path enforces leading-dimension alignment (lda % 16 == 0) in backward.
  • Narrow distributed debug/tolerance helper updates in run_distributed.py and related tests for observed SM120 outlier behavior.
  • Relaxed one NVFP4 bias-grad check (single element outlier in ffn1.bias.grad exceeded prior absolute tolerances) for SM120 only

Fused attention SM120 regression fix
Reinstated lost SM120 conditionals in fused_attn_f16_arbitrary_seqlen.cu (This was lost when 2677 was merged and conflict resolution.):

  • restored SM120-specific behavior for stats stride selection (use_ragged_stats path),
  • restored SM120-aware output_S shape handling for THD + cuDNN >= 9.6.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Apr 3, 2026
@KshitijLakhani KshitijLakhani changed the title [Pyt][Common Enabling/Guarding sm120 support (non - attention) [Pyt][Common] Enabling/Guarding sm120 support (non - attention) Apr 3, 2026
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch 2 times, most recently from 59ab765 to 5cbb074 Compare April 10, 2026 07:40
…p8::cast_gated_bwd kernel as sm120 shmem requirements are insufficient

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…rted

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…s Flash and not Fused

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…e different

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…MM lda constraints

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…debug test activation comparisons

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch from dd7b903 to 3b84fda Compare April 15, 2026 22:44
- Route grouped NVFP4 with first_dims through SM120 fallback split quantize path.
- Ensure grouped tensor swizzle metadata reflects actual runtime layout
- Propagate grouped layout metadata to split tensor views instead of re-deriving from quantizer flags.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
- Select expected scale reference layout from backend-reported _with_gemm_swizzled_scales.
- Assert grouped/split metadata consistency before validating scales.
- Apply SM120-only tolerance relaxation for scale comparisons and skip unsupported SM120 paged-stashing cas

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch from c9fa7e1 to a0afae8 Compare April 21, 2026 23:25
- SM120 backend currently disables NVFP4 stochastic rounding, so SR no longer outperforms RN.
- Update SR assertions to use close-equality on SM120 and keep strict SR<RN checks for sm!=120.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…shape that was lost in an earlier PR's merge conflict

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch from 979178e to cb9e0d3 Compare April 22, 2026 06:52
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch from b01b227 to ccf0da4 Compare April 22, 2026 07:19
@KshitijLakhani KshitijLakhani marked this pull request as ready for review April 22, 2026 22:32
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 22, 2026

Greptile Summary

This PR enables and guards SM120 (Blackwell GB10x) support across non-attention paths in TransformerEngine: disabling gated TMA backward kernels and fused NVFP4 RHT paths that exceed SM120 shared-memory limits, adding a per-split grouped NVFP4 quantization fallback in cast.cpp, blocking stochastic rounding that depends on unsupported PTX, preventing cuBLASLt grouped GEMM on SM120, propagating the resulting unswizzled scale layout through grouped_tensor_storage.py, and reinstating lost SM120 conditionals in fused_attn_f16_arbitrary_seqlen.cu. A genuine cross-architecture bug fix is also included: an async shared→global race in the MXFP8 CAST_DBIAS kernel is resolved with cp_async_bulk_wait_group_read<0> + __syncthreads barriers.

Confidence Score: 5/5

Safe to merge; all findings are P2 style/naming issues with no functional impact on correctness.

No P0 or P1 defects found. The atol/rtol naming swap in the test tolerance helper is harmless today (symmetric values). The typo in the warning string and the broad tolerance for bias-grad on SM120 are both P2 quality concerns. Core correctness guards and the MXFP8 race-condition fix look correct.

tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py (atol/rtol swap), tests/pytorch/test_fusible_ops.py (very loose atol=0.55)

Important Files Changed

Filename Overview
transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh Fixes MXFP8 CAST_DBIAS shared-memory race by adding cp_async_bulk_wait_group_read and __syncthreads barriers before parity flip
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Restores SM120-specific conditionals: use_ragged_stats excludes SM120, output_S shape guard adds sm_arch_ != 120 check for THD/cuDNN>=9.6 path
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds SM120 grouped NVFP4 fallback: copies quantizer without swizzled scales, splits input per first_dims, runs per-split quantization; disables stochastic rounding and fused RHT on SM120
transformer_engine/pytorch/csrc/quantizer.cpp Aligns grouped NVFP4 metadata: forces unswizzled layout on SM120 fallback path; disables SR and fused RHT cast on SM120 in quantize_impl
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Propagates actual grouped layout (self._with_gemm_swizzled_scales) into split tensor views instead of reading from quantizer.optimize_for_gemm
tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py Adds SM120 layout-consistency assertions and per-layout scale tolerances; tolerance constant name implies (rtol, atol) order but is unpacked as (atol, rtol)
tests/pytorch/test_fusible_ops.py Applies SM120-specific atol=0.55 for ffn1.bias.grad on NVFP4; tolerance is very loose and could mask future regressions on SM120

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[group_quantize call] --> B{is_sm120_device &&\nfirst_dims present?}
    B -->|Yes - SM120 fallback| C[Copy quantizer\noptimize_for_gemm=false]
    C --> D[get_split_sections D2H]
    D --> E[Build per-split input_list]
    E --> F[get_grouped_outputs\nsplit_into_quantized_tensors]
    F --> G[split_quantize_nvfp4_impl\nper-split, no SR, no fused RHT]
    G --> H[Output: unswizzled scale layout\nwith_gemm_swizzled_scales=false]
    B -->|No - normal path| I[group_quantize_nvfp4_impl\nfused kernel]
    I --> J{optimize_for_gemm?}
    J -->|Yes| K[Output: swizzled scale layout]
    J -->|No| L[Output: unswizzled scale layout]
    H --> M[grouped_tensor_storage.py\npropagate _with_gemm_swizzled_scales\nto split tensors]
    K --> M
    L --> M
Loading

Reviews (7): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/common/common.cu Outdated
Comment on lines +290 to +295
// KL: test function for CC 120
bool is_supported_by_CC_120() {
int deviceComputeCapability = cuda::sm_arch(cuda::current_device());

return deviceComputeCapability == 120;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Debug/WIP comment and misleading function name

The // KL: test function for CC 120 comment should be removed before merging — it reads as a personal debug note rather than production documentation.

More importantly, the name is_supported_by_CC_120() is semantically inconsistent with is_supported_by_CC_100(). is_supported_by_CC_100 returns >= 100 (meaning "supported by CC 100 or newer"), so by analogy is_supported_by_CC_120 would imply >= 120. However the implementation returns == 120 (exclusively SM120). Every call site uses this to disable a feature on SM120, not to enable something on SM120+. A name like is_exactly_CC_120() or is_CC_120_arch() would prevent future readers from misinterpreting the range semantics.

Comment on lines +49 to +52
//const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100();
// sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120
// KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated -
// are there any forward only tests we'd like to keep enabled on sm120?
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Leftover commented-out code and unresolved TODO

The original line //const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); is commented out rather than deleted, and the accompanying comment leaves an open investigation note ("KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated …"). Production code shouldn't carry stale commented-out expressions or unresolved TODO author notes. The same pattern appears in the backward helper at line ~535. Please remove the commented-out line and convert the open question to a tracked issue.

KshitijLakhani and others added 5 commits April 23, 2026 06:34
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch from 440ba8b to 4aed9e9 Compare April 23, 2026 18:45
Comment on lines +26 to +30
def setup_class(cls) -> None:
"""Set up test fixtures"""
# Configure RNG
seed = 42
torch.manual_seed(seed)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 @staticmethod on setup_class will break pytest collection

@staticmethod strips the implicit first-argument binding, so when pytest calls TestGroupedQuantizeFP8CurrentScaling.setup_class() it passes no arguments — but the function signature expects cls. This raises TypeError: setup_class() missing 1 required positional argument: 'cls', causing the entire test class to fail collection. The correct decoration is @classmethod.

Suggested change
def setup_class(cls) -> None:
"""Set up test fixtures"""
# Configure RNG
seed = 42
torch.manual_seed(seed)
@classmethod
def setup_class(cls) -> None:

Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Comment on lines +72 to +93
__device__ __forceinline__ __nv_fp8_e4m3 cast_to_fp8_e4m3_saturate(float val) {
// E4M3 range: [-448, 448]
constexpr float kFP8E4M3Max = 448.0f;

#if __CUDA_ARCH__ >= 890 // Hopper and newer have native FP8
// Use native FP8 conversion with saturation
__nv_fp8_e4m3 result;
asm("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
: "=r"(*reinterpret_cast<uint16_t*>(&result))
: "f"(val), "f"(0.0f));
return result;
#else
// Software path with explicit saturation
val = fmaxf(-kFP8E4M3Max, fminf(val, kFP8E4M3Max));
return __nv_fp8_e4m3(val);
#endif
}

/**
* @brief Fast saturate and cast to FP8 E5M2 using hardware intrinsics
*
* @param val Input float value (already scaled)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 PTX x2 instruction writes 16 bits to a 1-byte variable — stack UB

cvt.rn.satfinite.e4m3x2.f32 converts two float32 values into a packed pair of FP8 E4M3 values and writes the result as a 16-bit quantity. The target, however, is *reinterpret_cast<uint16_t*>(&result) where result is a single __nv_fp8_e4m3 (1 byte). This is a 2-byte write past a 1-byte allocation — undefined behaviour that silently corrupts an adjacent stack byte.

The same issue exists in cast_to_fp8_e5m2_saturate below.

Both functions appear unused today (kernels use static_cast<OutputType> via pack_4xfp8), but the UB should be fixed before these helpers are ever called. The minimal fix is to use a 2-element array as backing storage:

// fix: use a 2-element array so the 16-bit write is in-bounds
__nv_fp8_e4m3 cast_to_fp8_e4m3_saturate(float val) {
#if __CUDA_ARCH__ >= 890
  __nv_fp8_e4m3 result[2];
  asm("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
      : "=r"(*reinterpret_cast<uint16_t *>(result))
      : "f"(val), "f"(0.0f));
  return result[0];
#else
  ...

Comment on lines +107 to +178
void register_grouped_fp8_quantization_bindings(py::module &m) {
m.def("group_fp8_quantize_rowwise", &group_fp8_quantize_rowwise, py::arg("input"),
py::arg("output"),
R"pbdoc(
Perform grouped FP8 quantization with rowwise layout.

Quantizes multiple tensors from high precision to FP8 using pre-computed
scales. Processes all tensors in a single kernel launch for efficiency.

Args:
input: Input GroupedTensor (high precision: FP32/BF16/FP16)
output: Output GroupedTensor (FP8, must have scales pre-computed)

Returns:
Output GroupedTensor with quantized data

Example:
>>> # After computing scales
>>> output = tex.group_fp8_quantize_rowwise(input_grouped, output_grouped)

Note:
This is part of the three-step FP8 current scaling workflow:
1. Compute amax (tex.group_amax_graph_safe)
2. Compute scales (tex.multi_tensor_compute_scale_and_scale_inv)
3. Quantize (this function)
)pbdoc");

m.def("group_fp8_quantize_columnwise", &group_fp8_quantize_columnwise, py::arg("input"),
py::arg("output"),
R"pbdoc(
Perform grouped FP8 quantization with columnwise (transposed) layout.

Quantizes and transposes multiple tensors simultaneously. Output is in
columnwise format suitable for TN/NT GEMM layouts.

Args:
input: Input GroupedTensor (high precision, rowwise)
output: Output GroupedTensor (FP8, columnwise)

Returns:
Output GroupedTensor with quantized and transposed data

Example:
>>> # Quantize and transpose for columnwise GEMM
>>> output = tex.group_fp8_quantize_columnwise(input_grouped, output_grouped)

Note:
All tensors must be 2D for transpose operation.
)pbdoc");

m.def("group_fp8_quantize_both", &group_fp8_quantize_both, py::arg("input"), py::arg("output"),
R"pbdoc(
Perform grouped FP8 quantization producing both rowwise and columnwise outputs.

Quantizes multiple tensors and produces both layouts simultaneously.
Useful when both layouts are needed (e.g., forward and backward passes).

Args:
input: Input GroupedTensor (high precision)
output: Output GroupedTensor (FP8, must have both buffers allocated)

Returns:
Output GroupedTensor with both rowwise and columnwise data

Example:
>>> # Quantize to both layouts
>>> output = tex.group_fp8_quantize_both(input_grouped, output_grouped)
)pbdoc");
}

} // namespace pytorch
} // namespace transformer_engine
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 register_grouped_fp8_quantization_bindings is never called — bindings are silently absent

register_grouped_fp8_quantization_bindings is defined here and declared in pybind_grouped_fp8.h, but a search of all module-initialization files shows it is never invoked. As a result, tex.group_fp8_quantize_rowwise, tex.group_fp8_quantize_columnwise, and tex.group_fp8_quantize_both do not exist at runtime.

The tests that use these symbols are correctly marked @pytest.mark.xfail(reason="Grouped kernels not yet implemented"), so nothing breaks today. But a caller of grouped_quantize.py's grouped_quantize_current_scaling outside the test harness would receive an AttributeError with no clear explanation.

Please either:

  1. Add the call to the module init (once kernels are ready), or
  2. Add an explicit NotImplementedError guard in _grouped_compute_amax / _grouped_fp8_quantize_rowwise so callers get a clear error instead of an AttributeError.

@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-pyt-cpp-support branch from 0b00fef to a95ba1c Compare April 24, 2026 18:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant