Skip to content

[Draft] [PyTorch] Add distributed Muon optimizer#2920

Open
vcherepanov-nv wants to merge 3 commits intoNVIDIA:mainfrom
vcherepanov-nv:muon
Open

[Draft] [PyTorch] Add distributed Muon optimizer#2920
vcherepanov-nv wants to merge 3 commits intoNVIDIA:mainfrom
vcherepanov-nv:muon

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

Description

Add a distributed Muon optimizer, based on newton_schulz orthogonalization

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add an optimizer class and tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vcherepanov-nv and others added 2 commits April 23, 2026 18:50
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 23, 2026

Greptile Summary

This PR adds a MuonOptimizer for tensor-parallel CUDA parameters, applying SGD-momentum followed by distributed Newton-Schulz orthogonalization over an NCCL process group. The implementation correctly handles distributed normalization (all-reduce of the global L2 norm), partition-dim transposition, Nesterov momentum, and decoupled/L2 weight decay; the closure is now properly wrapped in torch.enable_grad().

Confidence Score: 5/5

Safe to merge; all new findings are P2 suggestions and the core distributed math is correct.

The optimizer's distributed normalization, transpose handling, Nesterov/HeavyBall update, and weight-decay branches are all correct and consistent with the reference implementation in the test. Previously flagged P1s are either fixed (closure/enable_grad) or noted in prior threads. The only new findings are P2: a documentation gap about rank-symmetric gradient availability and incomplete scale-mode test coverage. Neither blocks correctness in the intended tensor-parallel use case.

transformer_engine/pytorch/optimizers/muon.py — collective-deadlock documentation; tests/pytorch/distributed/run_muon_optimizer.py — scale_mode coverage gap

Important Files Changed

Filename Overview
transformer_engine/pytorch/optimizers/muon.py New MuonOptimizer class implementing distributed SGD-momentum + Newton-Schulz orthogonalization. Core logic (distributed normalization, transpose/contiguous handling, Nesterov update, scale factor) looks correct. Closure now properly wrapped in torch.enable_grad(). Previously-flagged concerns remain open (unit_rms_norm ZeroDivision, non-Nesterov momentum alias); new P2 risk: deadlock if p.grad availability is uneven across ranks at step time.
tests/pytorch/distributed/run_muon_optimizer.py Distributed test worker that validates optimizer output against a single-process float32 reference. Reference implementation is consistent with the optimizer (no world_size double-multiplication for global_shape). Only spectral scale mode is tested; shape_scaling and unit_rms_norm are not exercised.
tests/pytorch/distributed/test_muon_optimizer.py pytest harness that launches the worker via torchrun. Correctly parametrizes over dtype, partition_dim, and weight_decay_mode; parses stdout/stderr for pass/fail markers. Clean structure.
transformer_engine/pytorch/optimizers/init.py Adds MuonOptimizer and get_muon_scale_factor to the public optimizers namespace. Trivial one-line addition.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant MuonOptimizer
    participant _orthogonalize
    participant _distributed_normalize_p2_
    participant newton_schulz

    Caller->>MuonOptimizer: step()
    loop for each param with grad
        MuonOptimizer->>MuonOptimizer: apply weight decay (decoupled or L2)
        MuonOptimizer->>MuonOptimizer: momentum_buffer.lerp_(grad, 1-β)
        MuonOptimizer->>MuonOptimizer: compute nesterov/non-nesterov update
        MuonOptimizer->>_orthogonalize: update, partition_dim, ...
        _orthogonalize->>_orthogonalize: clone + optional transpose
        _orthogonalize->>_distributed_normalize_p2_: orth_grad
        _distributed_normalize_p2_-->>_distributed_normalize_p2_: dist.all_reduce(norm_sq)
        _distributed_normalize_p2_->>_orthogonalize: x /= global_norm
        _orthogonalize->>newton_schulz: orth_grad, CusolverMpCtx
        newton_schulz-->>newton_schulz: distributed NS iterations
        newton_schulz->>_orthogonalize: orth_grad (orthogonalized)
        _orthogonalize->>_orthogonalize: optional un-transpose + scale
        _orthogonalize->>MuonOptimizer: orth_update
        MuonOptimizer->>MuonOptimizer: p.add_(orth_update, alpha=-lr)
    end
    MuonOptimizer->>Caller: loss
Loading

Reviews (2): Last reviewed commit: "Fix Muon closure and reference test" | Re-trigger Greptile

Comment on lines +186 to +191
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
loss = closure()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Closure called inside @torch.no_grad(), preventing gradient computation

closure() is invoked while torch.no_grad() is active. Any loss.backward() call inside the closure will silently produce zero/no gradients. The standard PyTorch pattern (used in SGD, Adam, etc.) is to wrap the closure in with torch.enable_grad():.

Suggested change
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
loss = closure()
@torch.no_grad()
def step(self, closure=None):
"""Perform a single optimization step."""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

Comment on lines +28 to +33
scale_mode: str,
extra_scale_factor: float,
eps: float,
) -> torch.Tensor:
global_shape = [grad.size(0), grad.size(1)]
global_shape[partition_dim] *= world_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Reference global_shape incorrectly scales an already-full tensor

_reference_orthogonalize receives the full matrix (shape full_shape) but then multiplies global_shape[partition_dim] by world_size a second time. For partition_dim=1 with world_size=2 and full_shape=(96, 128) this gives global_shape=[96, 256], so get_muon_scale_factor returns max(96,256)^0.5 = 16. The optimizer, operating on the shard (96, 64), correctly reconstructs global_shape=[96, 128] and computes max(96,128)^0.5 ≈ 11.3. This √2 discrepancy means the reference cannot correctly validate the optimizer's output.

The global_shape[partition_dim] *= world_size line should be removed since the input is already the full matrix.

Comment on lines +33 to +34
if mode == "unit_rms_norm":
return (size_out / size_in) ** 0.5
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 unit_rms_norm mode can divide by zero when size_in == 0

(size_out / size_in) ** 0.5 raises ZeroDivisionError when size_in is 0. While the optimizer validates that the partition dimension is non-empty, it doesn't ensure the other dimension is non-zero. Consider adding a guard or documenting that both dimensions must be strictly positive.

Comment on lines +218 to +221
if group["nesterov"]:
update = grad.lerp(momentum_buffer, group["momentum"])
else:
update = momentum_buffer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Non-Nesterov update is an alias to momentum_buffer, not a copy

update = momentum_buffer holds a reference. If _orthogonalize ever modifies its input in-place in a future refactor, the momentum buffer will be silently corrupted. _orthogonalize currently clones the input immediately so this is safe today, but a defensive .clone() or comment would make the intent explicit.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant