Optimizations for MXFP8/NVFP4 dequantize kernels#2865
Optimizations for MXFP8/NVFP4 dequantize kernels#2865YigongQin wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
f5e7375 to
39c0fb1
Compare
|
The following relevant unit tests passed on SM100 (with the drop |
ddab15d to
3a4afdd
Compare
|
After this PR, fwd is around 3%-4% faster for DeepSeek shape MoE: |
Greptile SummaryThis PR adds GEMM-swizzled scale support to both the NVFP4 and MXFP8 dequantize kernels, allowing the dequantize path to work with the same swizzled layout used by GEMM operations. An early return on empty tensors (via Confidence Score: 5/5Safe to merge; all remaining findings are informational or pre-existing. No P0 or P1 bugs introduced by this PR. The swizzle index computation is mathematically correct for both MXFP8 (num_scale_tiles_X = DIVUP(cols,128)) and NVFP4 (num_scale_tiles_X = DIVUP(Mread,4)), consistently using TILE_DIM_X=4 from the shared gemm_swizzled_scale_idx helper. The empty-tensor early return, amax/scale GPU memory management, and test ordering fixes are all correct. The only flagged issue (duplicate MXFP8 branch in test_common.cu) is pre-existing dead code already noted in prior review rounds. tests/cpp/test_common.cu — pre-existing duplicate NVTE_MXFP8_1D_SCALING branch (lines ~196–222) is still dead code from a previous PR; not introduced here but worth a cleanup pass. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["nvte_dequantize"] --> B{"input.numel() == 0"}
B -->|Yes| C["Early return\nCUDA graph safe"]
B -->|No| D{"scaling_mode"}
D -->|MXFP8| E{"with_gemm_swizzled_scales"}
D -->|NVFP4| F{"with_gemm_swizzled_scales"}
D -->|DELAYED| G["fp8::dequantize"]
E -->|false| H["mxfp8_kernel false\nscale_idx = Y x stride + X"]
E -->|true| I["mxfp8_kernel true\ngemm_swizzled_scale_idx\nnum_tiles = DIVUP cols 128"]
F -->|false| J["fp4_kernel false\nscale_idx = x + y x stride"]
F -->|true| K["fp4_kernel true\ngemm_swizzled_scale_idx\nnum_tiles = DIVUP Mread 4"]
H --> L["High-precision output"]
I --> L
J --> L
K --> L
G --> L
Reviews (8): Last reviewed commit: "Apply suggestions from code review" | Re-trigger Greptile |
| } | ||
| } | ||
|
|
||
| std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = { |
There was a problem hiding this comment.
There is one edge case:
For MXFP8, When the input shape is like 64x64, it will produce scaling factor shape 64x2, but then zero padded to 128x4. We should be able to inject some very large random values in the padded region during malloc (because we don't use torch.zeros to malloc but torch.empty), and detect whether dequantize results is affected. If things work as expected, this line will be triggered
and the dequantize numerics won't be affected.For NVFP4, I think we optimize for GEMM (or swizzle fusion) is actually not enabled, same for the zero-out edge case handling logic?
So there shouldn't be any unswizzle logic needed here?There was a problem hiding this comment.
For NVFP4, I believe currently only device-init grouped quantize with RHT has the swizzle fusion feature, so the scaling factor zero-out is the job of the dedicated swizzle kernel. So if we dequantize + unswizzle for NVFP4, the unswizzle logic might not be correct.
There was a problem hiding this comment.
For both MXFP8 and NVFP4, the unit test logic is: 1. generate compact scales (or from quantization); 2. call nvte_swizzle_scaling_factors to swizzle compact scales; 3. compare results of nvte_dequantize with compact scales and swizzled scales. Quantize with swizzle fusion is never enabled for both MXFP8 and NVFP4
e6f2a6c to
0eccfb1
Compare
0eccfb1 to
2c479b0
Compare
| fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() | ||
| if fp8_recipe.backward_override == "dequantized" and ( | ||
| fp8_recipe.mxfp8() or fp8_recipe.nvfp4() | ||
| ): | ||
| input_quantizer.optimize_for_gemm = False | ||
| if grad_output_quantizer is not None: | ||
| grad_output_quantizer.optimize_for_gemm = False |
There was a problem hiding this comment.
I'm of two minds about this:
- Logically, GEMM-optimized data is not guaranteed to support anything except GEMMs. Even if MXFP8 and NVFP4 dequant happens to support them, these are custom optimizations. Future recipes can not be expected to support dequantizing GEMM-optimzied data by default.
- It's a little pedantic to have edge-case logic that won't be triggered by any of our existing use-cases. Given how subtle this is, I worry about it becoming stale and distracting.
I think for now, this change is fine. However, if we encounter problems in a future recipe, we should reimplement it properly:
# LOGICALLY WRONG!
# Fails if we add a new recipe
if recipe.backward_override == "dequantized" and recipe.future_recipe():
input_quantizer.optimize_for_gemm = False
# LOGICALLY RIGHT!
# Automatically handles new recipes
if recipe.backward_override == "dequantized" and not (
recipe.float8_per_tensor_scaling()
or recipe.float8_block_scaling()
or recipe.mxfp8()
or recipe.nvfp4()
):
input_quantizer.optimize_for_gemm = FalseSigned-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
0eda58a to
1bf24be
Compare
Signed-off-by: Ziang Li <ziangli@umich.edu> Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu> Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
666c496 to
80484a9
Compare
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
|
/te-ci core |
timmoon10
left a comment
There was a problem hiding this comment.
LGTM, pending CI. These kernels will be very useful.
Description
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: