Skip to content

RMS Norm Optimization#583

Open
aris134 wants to merge 12 commits into
devfrom
amartin/rmsnorm
Open

RMS Norm Optimization#583
aris134 wants to merge 12 commits into
devfrom
amartin/rmsnorm

Conversation

@aris134
Copy link
Copy Markdown
Contributor

@aris134 aris134 commented May 12, 2026

Description

Fixes # (16527)

RMSNorm falls back to general kernel implementation on several DeepSeek and Qwen shapes, causing poor performance. These shapes have been registered with the tuned kernel cache, and a performance benchmark for RMSNorm has been added.

Additionally, a fallback warning is printed the first time at which a tuned config is not found for a requested kernel. For example:

in function getKernel: Falling back to general normalization kernel because no tuned kernel is available for this config. hidden_size=128, wtype=bf16, itype=bf16, otype=bf16, ctype=fp32

E2E TFLOPS/s/GPU for proxy models (Previous -> Current with RMSNorm tuning) :

Qwen:
bf16: 369.4 -> 374.7
fp8: 352.1 ->358.2

Deepseek:
bf16: 501.4 -> 529.4
fp8: 463.9 -> 511.4

Also added matching tuned configs for LayerNorm.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@aris134 aris134 requested a review from alextmagro May 12, 2026 12:13
@aris134 aris134 self-assigned this May 12, 2026
@aris134 aris134 marked this pull request as ready for review May 12, 2026 19:15
prop.multiProcessorCount, zero_centered_gamma, stream);
}

HIP_CHECK(hipStreamSynchronize(stream));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is synchronization needed before warmup?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. These are in fact redundant since the warmup already calls a device-wide sync anyway. Removed in 4256e3c

#include <typeindex>
#include <unordered_map>
#include <vector>
#include <unordered_set>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move it after unordered_map

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 2f9ff47

bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING,
bool training = true, bool gamma_in_weight_dtype = false);

inline DType decode_itype(uint64_t general_key) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is fragile because encoding could change. At least put comments here and at encoding block that they should match

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I updated this in d548d54 to make the coupling between encoding/decoding explicit by introducing shared norm_key bit-layout constants and using them in both get_key() and the decode helpers. I also added comments documenting that the layouts must remain in sync, so future changes to the packed key format are less likely to silently diverge.

@aris134 aris134 requested a review from ipanfilo May 15, 2026 16:40
HIP_CHECK(hipEventDestroy(stop));
HIP_CHECK(hipStreamDestroy(stream));

size_t bytes_read =
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: line splits aren't needed here

static void BM_NormBackward(benchmark::State& state) {
const size_t N = state.range(0);
const size_t H = state.range(1);
const float epsilon = 1e-5f;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

epsilon is the same between forward and backward, can probably make a const global

REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4);
REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4);

REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 7, 8, 4);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does BYTES_PER_LDG=8 outperform 16 for this config? If so, I wonder if the configs around it would perform better that way too.

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16);
REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16);

REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BWD you have 7 warps set, but here you have 4. Is this optimal?

REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4);
REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4);

REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 7168, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, is 4 better than 7?

(uint64_t(NormStage)) << 22 | (uint64_t(NormBackend) << 24) |
(uint64_t(zero_centered_gamma) << 26) | (uint64_t(mode) << 27) |
(uint64_t(training) << 37) | (uint64_t(gamma_in_weight_dtype) << 38);
uint64_t general_key =
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get the motivation behind this change, but this affects upstream code. I feel like we're more likely to miss a key change from upstream if we have diverged here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants