Skip to content

Optimizations for MXFP8/NVFP4 dequantize kernels#2865

Open
YigongQin wants to merge 10 commits intoNVIDIA:mainfrom
YigongQin:yigongq/bwd-dequantize-optim
Open

Optimizations for MXFP8/NVFP4 dequantize kernels#2865
YigongQin wants to merge 10 commits intoNVIDIA:mainfrom
YigongQin:yigongq/bwd-dequantize-optim

Conversation

@YigongQin
Copy link
Copy Markdown

@YigongQin YigongQin commented Apr 10, 2026

Description

  • Handle empty tensors in dequantize for CUDA graph compatibility
  • Add swizzled scale support to the NVFP4 dequantize kernel, reusing the existing MXFP8 swizzle index computation
  • Add C++ unit tests for both NVFP4 and MXFP8 dequantization (including swizzled scale variants)
  • Fix to_cpu() and set_scale() in test infrastructure to correctly sync amax/scale for NVTE_NVFP4_1D_SCALING mode

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Handle empty tensors in dequantize for CUDA graph compatibility — Early return when input has zero elements, avoiding kernel launches on empty tensors.
  • Add GEMM-swizzled scale support to NVFP4 dequantize kernel — Template the kernel with WITH_GEMM_SWIZZLED_SCALES to support reading scales from swizzled layout, reusing the MXFP8 swizzle index computation.
  • Add GEMM-swizzled scale support to MXFP8 dequantize kernel — Extend the MXFP8 dequantize kernel to handle swizzled scale inputs.
  • Add C++ unit tests for NVFP4 dequantization — 21 tests for compact scales + 21 tests for swizzled scales, covering multiple sizes and output dtypes (fp32, bf16, fp16).
  • Add C++ unit tests for MXFP8 dequantization with swizzled scales — New swizzled test suite for MXFP8.
  • Fix to_cpu() to sync amax/scale for NVFP4 tensors — Previously only synced for NVTE_DELAYED_TENSOR_SCALING, causing the CPU reference to use stale amax=0.
  • Fix set_scale() to work for NVFP4 tensors — Same condition fix, enabling the scale to be properly uploaded to GPU before quantization.
  • Fix swizzled test ordering — Move from_cpu() before the FP4 data copy to prevent from_cpu() from overwriting the copied data with zeros.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@YigongQin YigongQin force-pushed the yigongq/bwd-dequantize-optim branch from f5e7375 to 39c0fb1 Compare April 10, 2026 22:04
@zianglih
Copy link
Copy Markdown
Contributor

zianglih commented Apr 14, 2026

The following relevant unit tests passed on SM100 (with the drop optimize_for_gemm = False changes):

python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

@zianglih zianglih force-pushed the yigongq/bwd-dequantize-optim branch from ddab15d to 3a4afdd Compare April 14, 2026 18:46
@zianglih
Copy link
Copy Markdown
Contributor

After this PR, fwd is around 3%-4% faster for DeepSeek shape MoE:

# With the optimization
NVTE_BACKWARD_OVERRIDE=dequantized python benchmarks/linear/benchmark_grouped_linear.py --recipe mxfp8 --fwd-only
       m     k     n recipe  num_gemms  grouped_fwd_time_ms
0  16384  7168  2048  mxfp8          4             0.272829
1  32768  7168  2048  mxfp8          4             0.509788
2  65536  7168  2048  mxfp8          4             0.948633
3  98304  7168  2048  mxfp8          4             1.391146
0  16384  7168  2048  mxfp8          8             0.303238
1  32768  7168  2048  mxfp8          8             0.533896
2  65536  7168  2048  mxfp8          8             1.003446
3  98304  7168  2048  mxfp8          8             1.470030

# Without the optimization
git restore --source 77b8681de5cf -- transformer_engine/pytorch/module
NVTE_BACKWARD_OVERRIDE=dequantized python benchmarks/linear/benchmark_grouped_linear.py --recipe mxfp8 --fwd-only
       m     k     n recipe  num_gemms  grouped_fwd_time_ms
0  16384  7168  2048  mxfp8          4             0.282720
1  32768  7168  2048  mxfp8          4             0.526736
2  65536  7168  2048  mxfp8          4             0.982166
3  98304  7168  2048  mxfp8          4             1.451485
0  16384  7168  2048  mxfp8          8             0.313753
1  32768  7168  2048  mxfp8          8             0.551043
2  65536  7168  2048  mxfp8          8             1.040773
3  98304  7168  2048  mxfp8          8             1.527951

@YigongQin YigongQin marked this pull request as ready for review April 15, 2026 16:49
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 15, 2026

Greptile Summary

This PR adds GEMM-swizzled scale support to both the NVFP4 and MXFP8 dequantize kernels, allowing the dequantize path to work with the same swizzled layout used by GEMM operations. An early return on empty tensors (via input.numel() == 0) in dequantize_helper enables CUDA graph compatibility and lets the Python modules drop the workaround that forced compact format for dequantized backward mode. The test infrastructure is also fixed to properly manage amax/scale GPU memory for NVTE_NVFP4_1D_SCALING tensors, and comprehensive C++ unit tests are added for both quantization formats.

Confidence Score: 5/5

Safe to merge; all remaining findings are informational or pre-existing.

No P0 or P1 bugs introduced by this PR. The swizzle index computation is mathematically correct for both MXFP8 (num_scale_tiles_X = DIVUP(cols,128)) and NVFP4 (num_scale_tiles_X = DIVUP(Mread,4)), consistently using TILE_DIM_X=4 from the shared gemm_swizzled_scale_idx helper. The empty-tensor early return, amax/scale GPU memory management, and test ordering fixes are all correct. The only flagged issue (duplicate MXFP8 branch in test_common.cu) is pre-existing dead code already noted in prior review rounds.

tests/cpp/test_common.cu — pre-existing duplicate NVTE_MXFP8_1D_SCALING branch (lines ~196–222) is still dead code from a previous PR; not introduced here but worth a cleanup pass.

Important Files Changed

Filename Overview
transformer_engine/common/cast/dispatch/dequantize.cuh Adds early return for zero-element inputs before the scaling-mode switch; clean and correct.
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh Templates kernel with WITH_GEMM_SWIZZLED_SCALES; num_scale_tiles_X = DIVUP(Mread, 4) correctly mirrors the MXFP8 tile decomposition (TILE_DIM_X=4).
transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh Removes the hard assertion against swizzled scales and templates the kernel; num_scale_tiles_X = DIVUP(cols, 128) for rowwise / DIVUP(rows, 128) for colwise is consistent with the swizzle layout.
tests/cpp/test_common.cu Fixes to_cpu/set_scale/set_amax for NVFP4 mode; a duplicate NVTE_MXFP8_1D_SCALING branch at line 196 (pre-existing dead code) remains unaddressed.
tests/cpp/test_common.h Destructor now frees amax/scale GPU buffers; move semantics are handled correctly because the underlying TE tensor zeros moved-from pointers (same pattern as data/scale_inv pointers already freed here).
tests/cpp/operator/test_dequantize_nvfp4.cu New test file covering compact and swizzled NVFP4 dequantization; empty-tensor (rows=0) cases correctly guarded, swizzle ordering (from_cpu before data copy) is correct.
tests/cpp/operator/test_dequantize_mxfp8.cu Adds swizzled-scale test suite for MXFP8 and inserts zero-row dimension pairs into tensor_dims for empty tensor coverage.
transformer_engine/pytorch/module/grouped_linear.py Removes the m_split==0 workaround and the optimize_for_gemm=False override for dequantized backward; both are now handled by the kernel-level empty-tensor guard and swizzled-scale support.
transformer_engine/pytorch/module/linear.py Removes optimize_for_gemm=False override for dequantized-backward MXFP8/NVFP4; safe now that dequantize supports swizzled scales.
transformer_engine/pytorch/module/layernorm_linear.py Same optimize_for_gemm cleanup as linear.py; correct.
transformer_engine/pytorch/ops/basic/basic_linear.py Same optimize_for_gemm cleanup as linear.py; correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["nvte_dequantize"] --> B{"input.numel() == 0"}
    B -->|Yes| C["Early return\nCUDA graph safe"]
    B -->|No| D{"scaling_mode"}

    D -->|MXFP8| E{"with_gemm_swizzled_scales"}
    D -->|NVFP4| F{"with_gemm_swizzled_scales"}
    D -->|DELAYED| G["fp8::dequantize"]

    E -->|false| H["mxfp8_kernel false\nscale_idx = Y x stride + X"]
    E -->|true| I["mxfp8_kernel true\ngemm_swizzled_scale_idx\nnum_tiles = DIVUP cols 128"]

    F -->|false| J["fp4_kernel false\nscale_idx = x + y x stride"]
    F -->|true| K["fp4_kernel true\ngemm_swizzled_scale_idx\nnum_tiles = DIVUP Mread 4"]

    H --> L["High-precision output"]
    I --> L
    J --> L
    K --> L
    G --> L
Loading

Reviews (8): Last reviewed commit: "Apply suggestions from code review" | Re-trigger Greptile

Comment thread tests/cpp/operator/test_dequantize_nvfp4.cu Outdated
}
}

std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = {
Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one edge case:

For MXFP8, When the input shape is like 64x64, it will produce scaling factor shape 64x2, but then zero padded to 128x4. We should be able to inject some very large random values in the padded region during malloc (because we don't use torch.zeros to malloc but torch.empty), and detect whether dequantize results is affected. If things work as expected, this line will be triggered

// Zero out swizzled scales if padding is needed
and the dequantize numerics won't be affected.

For NVFP4, I think we optimize for GEMM (or swizzle fusion) is actually not enabled, same for the zero-out edge case handling logic?

NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format.");
So there shouldn't be any unswizzle logic needed here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For NVFP4, I believe currently only device-init grouped quantize with RHT has the swizzle fusion feature, so the scaling factor zero-out is the job of the dedicated swizzle kernel. So if we dequantize + unswizzle for NVFP4, the unswizzle logic might not be correct.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For both MXFP8 and NVFP4, the unit test logic is: 1. generate compact scales (or from quantization); 2. call nvte_swizzle_scaling_factors to swizzle compact scales; 3. compare results of nvte_dequantize with compact scales and swizzled scales. Quantize with swizzle fusion is never enabled for both MXFP8 and NVFP4

Comment on lines -1713 to -1719
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_recipe.backward_override == "dequantized" and (
fp8_recipe.mxfp8() or fp8_recipe.nvfp4()
):
input_quantizer.optimize_for_gemm = False
if grad_output_quantizer is not None:
grad_output_quantizer.optimize_for_gemm = False
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm of two minds about this:

  • Logically, GEMM-optimized data is not guaranteed to support anything except GEMMs. Even if MXFP8 and NVFP4 dequant happens to support them, these are custom optimizations. Future recipes can not be expected to support dequantizing GEMM-optimzied data by default.
  • It's a little pedantic to have edge-case logic that won't be triggered by any of our existing use-cases. Given how subtle this is, I worry about it becoming stale and distracting.

I think for now, this change is fine. However, if we encounter problems in a future recipe, we should reimplement it properly:

# LOGICALLY WRONG!
# Fails if we add a new recipe
if recipe.backward_override == "dequantized" and recipe.future_recipe():
    input_quantizer.optimize_for_gemm = False

# LOGICALLY RIGHT!
# Automatically handles new recipes
if recipe.backward_override == "dequantized" and not (
    recipe.float8_per_tensor_scaling()
    or recipe.float8_block_scaling()
    or recipe.mxfp8()
    or recipe.nvfp4()
):
    input_quantizer.optimize_for_gemm = False

CC @ptrendx @ksivaman @zhongbozhu

Comment thread tests/cpp/test_common.h Outdated
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
@YigongQin YigongQin force-pushed the yigongq/bwd-dequantize-optim branch from 0eda58a to 1bf24be Compare April 23, 2026 21:47
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
@YigongQin YigongQin force-pushed the yigongq/bwd-dequantize-optim branch from 666c496 to 80484a9 Compare April 23, 2026 21:53
@zhongbozhu
Copy link
Copy Markdown
Collaborator

/te-ci pytorch L1

Comment thread tests/cpp/test_common.h
Comment thread tests/cpp/test_common.cu Outdated
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci core

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI. These kernels will be very useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants