RMS Norm Optimization#583
Conversation
… missing configs for layer norm
| prop.multiProcessorCount, zero_centered_gamma, stream); | ||
| } | ||
|
|
||
| HIP_CHECK(hipStreamSynchronize(stream)); |
There was a problem hiding this comment.
Is synchronization needed before warmup?
There was a problem hiding this comment.
Good point. These are in fact redundant since the warmup already calls a device-wide sync anyway. Removed in 4256e3c
| #include <typeindex> | ||
| #include <unordered_map> | ||
| #include <vector> | ||
| #include <unordered_set> |
There was a problem hiding this comment.
nit: move it after unordered_map
| bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, | ||
| bool training = true, bool gamma_in_weight_dtype = false); | ||
|
|
||
| inline DType decode_itype(uint64_t general_key) { |
There was a problem hiding this comment.
This code is fragile because encoding could change. At least put comments here and at encoding block that they should match
There was a problem hiding this comment.
Good point. I updated this in d548d54 to make the coupling between encoding/decoding explicit by introducing shared norm_key bit-layout constants and using them in both get_key() and the decode helpers. I also added comments documenting that the layouts must remain in sync, so future changes to the packed key format are less likely to silently diverge.
| HIP_CHECK(hipEventDestroy(stop)); | ||
| HIP_CHECK(hipStreamDestroy(stream)); | ||
|
|
||
| size_t bytes_read = |
There was a problem hiding this comment.
nit: line splits aren't needed here
| static void BM_NormBackward(benchmark::State& state) { | ||
| const size_t N = state.range(0); | ||
| const size_t H = state.range(1); | ||
| const float epsilon = 1e-5f; |
There was a problem hiding this comment.
epsilon is the same between forward and backward, can probably make a const global
| REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); | ||
| REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); | ||
|
|
||
| REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 8, 4); |
There was a problem hiding this comment.
Does BYTES_PER_LDG=8 outperform 16 for this config? If so, I wonder if the configs around it would perform better that way too.
| REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); | ||
| REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); | ||
|
|
||
| REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16); |
There was a problem hiding this comment.
BWD you have 7 warps set, but here you have 4. Is this optimal?
| REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); | ||
| REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); | ||
|
|
||
| REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); |
There was a problem hiding this comment.
Same here, is 4 better than 7?
| (uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) | | ||
| (uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) | | ||
| (uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38); | ||
| uint64_t general_key = |
There was a problem hiding this comment.
I get the motivation behind this change, but this affects upstream code. I feel like we're more likely to miss a key change from upstream if we have diverged here.
Description
Fixes # (16527)
RMSNorm falls back to general kernel implementation on several DeepSeek and Qwen shapes, causing poor performance. These shapes have been registered with the tuned kernel cache, and a performance benchmark for RMSNorm has been added.
Additionally, a fallback warning is printed the first time at which a tuned config is not found for a requested kernel. For example:
E2E TFLOPS/s/GPU for proxy models (Previous -> Current with RMSNorm tuning) :
Qwen:
bf16: 369.4 -> 374.7
fp8: 352.1 ->358.2
Deepseek:
bf16: 501.4 -> 529.4
fp8: 463.9 -> 511.4
Also added matching tuned configs for LayerNorm.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: