Skip to content

add MXFP8 pre-swizzling for gfx1250 GEMM#568

Open
matthiasdiener wants to merge 24 commits into
devfrom
mdiener/mxfp8-swizzle
Open

add MXFP8 pre-swizzling for gfx1250 GEMM#568
matthiasdiener wants to merge 24 commits into
devfrom
mdiener/mxfp8-swizzle

Conversation

@matthiasdiener
Copy link
Copy Markdown
Contributor

@matthiasdiener matthiasdiener commented Apr 29, 2026

Description

Fixes https://github.com/ROCm/frameworks-internal/issues/16428

This was lightly tested on gfx1250.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@matthiasdiener matthiasdiener self-assigned this Apr 29, 2026
@matthiasdiener matthiasdiener added the ci-level 1 CI test level 1 label Apr 29, 2026
@matthiasdiener matthiasdiener force-pushed the mdiener/mxfp8-swizzle branch from ddf19da to 313a6b7 Compare May 3, 2026 22:06
@matthiasdiener matthiasdiener requested a review from alextmagro May 4, 2026 16:33
Copy link
Copy Markdown
Contributor

@alextmagro alextmagro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Matthias, a few comments. I also assume you are still planning on adding in the hooks to scale swizzle when we're on gfx1250? I believe there were hooks in all of common, pytorch and jax. These PRs removed them, so would be a partial revert.

#420
#424
#442

asm volatile("ds_swizzle_b32 %0, %1 offset:0x041F\n\t"
"s_waitcnt lgkmcnt(0)" : "=v"(r) : "v"(v));
return r;
return __shfl_xor(v, 1);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need these helper functions now that we're just doing a __shfl_xor?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is only inadvertently part of this PR, it is already part of #571. Will revert here.

Comment thread transformer_engine/common/swizzle/swizzle.cu
const int k = idx % K_scale;

uint8_t val = 127;
if (m < original_M && k < original_K) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move this check to the hostside, or remove it completely?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved hostside in b55a538

Comment thread tests/cpp/operator/test_swizzle.cu
#include <cstdint>

#include "../common.h"
#include "../util/cuda_runtime.h"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this include needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in b55a538

" (got shape=", shape, ")");
#ifdef USE_ROCM
// gfx1250 MX pre-swizzle (Tensile 3D) layout requires M padded to multiple of 4.
// Other ROCm architectures use 128x4 tiles but currently skip padding
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is true regarding us using 128x4 tiles. 128x4 scaling is an upstream requirement. We also have padding expectations in pytorch, jax, and all 3 test dirs have padding that will probably need fixing.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the comment in b55a538

@matthiasdiener matthiasdiener changed the title [proof-of-concept] add MXFP8 pre-swizzling for gfx1250 add MXFP8 pre-swizzling for gfx1250 GEMM May 13, 2026
@matthiasdiener
Copy link
Copy Markdown
Contributor Author

I also assume you are still planning on adding in the hooks to scale swizzle when we're on gfx1250? I believe there were hooks in all of common, pytorch and jax. These PRs removed them, so would be a partial revert.

#420 #424 #442

The hooks should be re-added in 384d590.

@matthiasdiener matthiasdiener requested a review from alextmagro May 14, 2026 20:20
@matthiasdiener matthiasdiener marked this pull request as ready for review May 14, 2026 20:21
// Simple GPU reference kernel for MXFP8 GEMM: D = A * B^T (TN layout)
// A is [M, K] row-major, B is [N, K] row-major, D is [M, N] column-major
// Scales are E8M0, one per group of 32 elements along K.
__global__ void mxfp8_gemm_ref_kernel(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a second mxfp8 reference kernel?

class MxGemmSwizzleGfx1250TestSuite
: public ::testing::TestWithParam<MxGemmParams> {};

TEST_P(MxGemmSwizzleGfx1250TestSuite, TestMxfp8GemmE2E) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is we must swizzle scales for gfx1250. I think ideally we would fuse this with the existing mxfp8 GEMM tests -- pre-1250 we don't swizzle, 1250+ we do.


#ifdef USE_ROCM
// On ROCm, only MXFP8 on gfx1250 needs scale pre-swizzling
if (scaling_mode != NVTE_MXFP8_1D_SCALING || transformer_engine::cuda::sm_arch() != 125) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes we use == 125, sometimes >= 125. Should probably be consistent one or the other.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants