[Pyt][Common] Enabling/Guarding sm120 support (non - attention)#2833
[Pyt][Common] Enabling/Guarding sm120 support (non - attention)#2833KshitijLakhani wants to merge 29 commits intoNVIDIA:mainfrom
Conversation
59ab765 to
5cbb074
Compare
…p8::cast_gated_bwd kernel as sm120 shmem requirements are insufficient Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…rted Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…s Flash and not Fused Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…e different Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…MM lda constraints Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…debug test activation comparisons Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
dd7b903 to
3b84fda
Compare
- Route grouped NVFP4 with first_dims through SM120 fallback split quantize path. - Ensure grouped tensor swizzle metadata reflects actual runtime layout - Propagate grouped layout metadata to split tensor views instead of re-deriving from quantizer flags. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
- Select expected scale reference layout from backend-reported _with_gemm_swizzled_scales. - Assert grouped/split metadata consistency before validating scales. - Apply SM120-only tolerance relaxation for scale comparisons and skip unsupported SM120 paged-stashing cas Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
c9fa7e1 to
a0afae8
Compare
- SM120 backend currently disables NVFP4 stochastic rounding, so SR no longer outperforms RN. - Update SR assertions to use close-equality on SM120 and keep strict SR<RN checks for sm!=120. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…shape that was lost in an earlier PR's merge conflict Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
979178e to
cb9e0d3
Compare
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
b01b227 to
ccf0da4
Compare
for more information, see https://pre-commit.ci
…tn backend Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Greptile SummaryThis PR enables and guards SM120 (Blackwell GB10x) support across non-attention paths in TransformerEngine: disabling gated TMA backward kernels and fused NVFP4 RHT paths that exceed SM120 shared-memory limits, adding a per-split grouped NVFP4 quantization fallback in Confidence Score: 5/5Safe to merge; all findings are P2 style/naming issues with no functional impact on correctness. No P0 or P1 defects found. The atol/rtol naming swap in the test tolerance helper is harmless today (symmetric values). The typo in the warning string and the broad tolerance for bias-grad on SM120 are both P2 quality concerns. Core correctness guards and the MXFP8 race-condition fix look correct. tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py (atol/rtol swap), tests/pytorch/test_fusible_ops.py (very loose atol=0.55) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[group_quantize call] --> B{is_sm120_device &&\nfirst_dims present?}
B -->|Yes - SM120 fallback| C[Copy quantizer\noptimize_for_gemm=false]
C --> D[get_split_sections D2H]
D --> E[Build per-split input_list]
E --> F[get_grouped_outputs\nsplit_into_quantized_tensors]
F --> G[split_quantize_nvfp4_impl\nper-split, no SR, no fused RHT]
G --> H[Output: unswizzled scale layout\nwith_gemm_swizzled_scales=false]
B -->|No - normal path| I[group_quantize_nvfp4_impl\nfused kernel]
I --> J{optimize_for_gemm?}
J -->|Yes| K[Output: swizzled scale layout]
J -->|No| L[Output: unswizzled scale layout]
H --> M[grouped_tensor_storage.py\npropagate _with_gemm_swizzled_scales\nto split tensors]
K --> M
L --> M
Reviews (7): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| // KL: test function for CC 120 | ||
| bool is_supported_by_CC_120() { | ||
| int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); | ||
|
|
||
| return deviceComputeCapability == 120; | ||
| } |
There was a problem hiding this comment.
Debug/WIP comment and misleading function name
The // KL: test function for CC 120 comment should be removed before merging — it reads as a personal debug note rather than production documentation.
More importantly, the name is_supported_by_CC_120() is semantically inconsistent with is_supported_by_CC_100(). is_supported_by_CC_100 returns >= 100 (meaning "supported by CC 100 or newer"), so by analogy is_supported_by_CC_120 would imply >= 120. However the implementation returns == 120 (exclusively SM120). Every call site uses this to disable a feature on SM120, not to enable something on SM120+. A name like is_exactly_CC_120() or is_CC_120_arch() would prevent future readers from misinterpreting the range semantics.
| //const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); | ||
| // sm120 shared memory capapbilities are much smaller than sm100, so we disable TMA kernels on sm120 | ||
| // KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated - | ||
| // are there any forward only tests we'd like to keep enabled on sm120? |
There was a problem hiding this comment.
Leftover commented-out code and unresolved TODO
The original line //const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); is commented out rather than deleted, and the accompanying comment leaves an open investigation note ("KL: It is possible that for fwd, the limits are not exceeded for sm120. To be investigated …"). Production code shouldn't carry stale commented-out expressions or unresolved TODO author notes. The same pattern appears in the backward helper at line ~535. Please remove the commented-out line and convert the open question to a tracked issue.
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
440ba8b to
4aed9e9
Compare
for more information, see https://pre-commit.ci
| def setup_class(cls) -> None: | ||
| """Set up test fixtures""" | ||
| # Configure RNG | ||
| seed = 42 | ||
| torch.manual_seed(seed) |
There was a problem hiding this comment.
@staticmethod on setup_class will break pytest collection
@staticmethod strips the implicit first-argument binding, so when pytest calls TestGroupedQuantizeFP8CurrentScaling.setup_class() it passes no arguments — but the function signature expects cls. This raises TypeError: setup_class() missing 1 required positional argument: 'cls', causing the entire test class to fail collection. The correct decoration is @classmethod.
| def setup_class(cls) -> None: | |
| """Set up test fixtures""" | |
| # Configure RNG | |
| seed = 42 | |
| torch.manual_seed(seed) | |
| @classmethod | |
| def setup_class(cls) -> None: |
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
| __device__ __forceinline__ __nv_fp8_e4m3 cast_to_fp8_e4m3_saturate(float val) { | ||
| // E4M3 range: [-448, 448] | ||
| constexpr float kFP8E4M3Max = 448.0f; | ||
|
|
||
| #if __CUDA_ARCH__ >= 890 // Hopper and newer have native FP8 | ||
| // Use native FP8 conversion with saturation | ||
| __nv_fp8_e4m3 result; | ||
| asm("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;" | ||
| : "=r"(*reinterpret_cast<uint16_t*>(&result)) | ||
| : "f"(val), "f"(0.0f)); | ||
| return result; | ||
| #else | ||
| // Software path with explicit saturation | ||
| val = fmaxf(-kFP8E4M3Max, fminf(val, kFP8E4M3Max)); | ||
| return __nv_fp8_e4m3(val); | ||
| #endif | ||
| } | ||
|
|
||
| /** | ||
| * @brief Fast saturate and cast to FP8 E5M2 using hardware intrinsics | ||
| * | ||
| * @param val Input float value (already scaled) |
There was a problem hiding this comment.
PTX x2 instruction writes 16 bits to a 1-byte variable — stack UB
cvt.rn.satfinite.e4m3x2.f32 converts two float32 values into a packed pair of FP8 E4M3 values and writes the result as a 16-bit quantity. The target, however, is *reinterpret_cast<uint16_t*>(&result) where result is a single __nv_fp8_e4m3 (1 byte). This is a 2-byte write past a 1-byte allocation — undefined behaviour that silently corrupts an adjacent stack byte.
The same issue exists in cast_to_fp8_e5m2_saturate below.
Both functions appear unused today (kernels use static_cast<OutputType> via pack_4xfp8), but the UB should be fixed before these helpers are ever called. The minimal fix is to use a 2-element array as backing storage:
// fix: use a 2-element array so the 16-bit write is in-bounds
__nv_fp8_e4m3 cast_to_fp8_e4m3_saturate(float val) {
#if __CUDA_ARCH__ >= 890
__nv_fp8_e4m3 result[2];
asm("cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;"
: "=r"(*reinterpret_cast<uint16_t *>(result))
: "f"(val), "f"(0.0f));
return result[0];
#else
...| void register_grouped_fp8_quantization_bindings(py::module &m) { | ||
| m.def("group_fp8_quantize_rowwise", &group_fp8_quantize_rowwise, py::arg("input"), | ||
| py::arg("output"), | ||
| R"pbdoc( | ||
| Perform grouped FP8 quantization with rowwise layout. | ||
|
|
||
| Quantizes multiple tensors from high precision to FP8 using pre-computed | ||
| scales. Processes all tensors in a single kernel launch for efficiency. | ||
|
|
||
| Args: | ||
| input: Input GroupedTensor (high precision: FP32/BF16/FP16) | ||
| output: Output GroupedTensor (FP8, must have scales pre-computed) | ||
|
|
||
| Returns: | ||
| Output GroupedTensor with quantized data | ||
|
|
||
| Example: | ||
| >>> # After computing scales | ||
| >>> output = tex.group_fp8_quantize_rowwise(input_grouped, output_grouped) | ||
|
|
||
| Note: | ||
| This is part of the three-step FP8 current scaling workflow: | ||
| 1. Compute amax (tex.group_amax_graph_safe) | ||
| 2. Compute scales (tex.multi_tensor_compute_scale_and_scale_inv) | ||
| 3. Quantize (this function) | ||
| )pbdoc"); | ||
|
|
||
| m.def("group_fp8_quantize_columnwise", &group_fp8_quantize_columnwise, py::arg("input"), | ||
| py::arg("output"), | ||
| R"pbdoc( | ||
| Perform grouped FP8 quantization with columnwise (transposed) layout. | ||
|
|
||
| Quantizes and transposes multiple tensors simultaneously. Output is in | ||
| columnwise format suitable for TN/NT GEMM layouts. | ||
|
|
||
| Args: | ||
| input: Input GroupedTensor (high precision, rowwise) | ||
| output: Output GroupedTensor (FP8, columnwise) | ||
|
|
||
| Returns: | ||
| Output GroupedTensor with quantized and transposed data | ||
|
|
||
| Example: | ||
| >>> # Quantize and transpose for columnwise GEMM | ||
| >>> output = tex.group_fp8_quantize_columnwise(input_grouped, output_grouped) | ||
|
|
||
| Note: | ||
| All tensors must be 2D for transpose operation. | ||
| )pbdoc"); | ||
|
|
||
| m.def("group_fp8_quantize_both", &group_fp8_quantize_both, py::arg("input"), py::arg("output"), | ||
| R"pbdoc( | ||
| Perform grouped FP8 quantization producing both rowwise and columnwise outputs. | ||
|
|
||
| Quantizes multiple tensors and produces both layouts simultaneously. | ||
| Useful when both layouts are needed (e.g., forward and backward passes). | ||
|
|
||
| Args: | ||
| input: Input GroupedTensor (high precision) | ||
| output: Output GroupedTensor (FP8, must have both buffers allocated) | ||
|
|
||
| Returns: | ||
| Output GroupedTensor with both rowwise and columnwise data | ||
|
|
||
| Example: | ||
| >>> # Quantize to both layouts | ||
| >>> output = tex.group_fp8_quantize_both(input_grouped, output_grouped) | ||
| )pbdoc"); | ||
| } | ||
|
|
||
| } // namespace pytorch | ||
| } // namespace transformer_engine |
There was a problem hiding this comment.
register_grouped_fp8_quantization_bindings is never called — bindings are silently absent
register_grouped_fp8_quantization_bindings is defined here and declared in pybind_grouped_fp8.h, but a search of all module-initialization files shows it is never invoked. As a result, tex.group_fp8_quantize_rowwise, tex.group_fp8_quantize_columnwise, and tex.group_fp8_quantize_both do not exist at runtime.
The tests that use these symbols are correctly marked @pytest.mark.xfail(reason="Grouped kernels not yet implemented"), so nothing breaks today. But a caller of grouped_quantize.py's grouped_quantize_current_scaling outside the test harness would receive an AttributeError with no clear explanation.
Please either:
- Add the call to the module init (once kernels are ready), or
- Add an explicit
NotImplementedErrorguard in_grouped_compute_amax/_grouped_fp8_quantize_rowwiseso callers get a clear error instead of anAttributeError.
0b00fef to
a95ba1c
Compare
for more information, see https://pre-commit.ci
Description
This PR is a follow up to : #2693.
PR #2693 aimed to enable/guard PyT attention for sm120
This PR aims to enable/guard non-attention for sm120 (and a small attn related regression fix)
Fixes # (issue)
Type of change
Changes
Runtime/backend guards for SM120 correctness
csrc/quantizer.cppdue to unsupported.rsPTX.csrc/extensions/cast.cpp) to use safer per-split processing.gemm/cublaslt_grouped_gemm.cubecause cuBLASLt grouped GEMM heuristic returns unsupported (for affected BF16/FP8 cases).General Bug fix (not SM120 specific)
I stumbled upon this bug specifically when I was testing on SM120, but it is an arch agnostic fix.
NVFP4 grouped quantization layout consistency for SM120
csrc/quantizer.cpp:grouped_tensor_storage.pyso split tensors inherit true grouped layout state.test_nvfp4_group_quantize_graph_safe.pyto compare against metadata-selected reference layout and use scoped SM120 tolerance behavior.** Test changes (SM120 specific)**
test_nvfp4_sr_quantize.py, changed SM120 expectation from SR < RN to numerical equivalence (assert_close) because SR is disabled on SM120 backend.run_layer_with_overlap.py, added SM120-only looser tolerance for fp8_current_scaling (rtol=0.4, atol=0.25) in deterministic fallback backend scenarios (I borrowed these tolerances from the corresponding distributed test filerun_numerics.py)test_numerics.py, C++ grouped GEMM operator tests, and PyTorch grouped GEMM numerics to match explicit SM120 unsupported/runtime-guarded paths.SM120 coverage/test harness updates
lda % 16 == 0) in backward.run_distributed.pyand related tests for observed SM120 outlier behavior.ffn1.bias.gradexceeded prior absolute tolerances) for SM120 onlyFused attention SM120 regression fix
Reinstated lost SM120 conditionals in
fused_attn_f16_arbitrary_seqlen.cu(This was lost when 2677 was merged and conflict resolution.):Checklist: