[Draft] [PyTorch] Add distributed Muon optimizer#2920
[Draft] [PyTorch] Add distributed Muon optimizer#2920vcherepanov-nv wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a Confidence Score: 5/5Safe to merge; all new findings are P2 suggestions and the core distributed math is correct. The optimizer's distributed normalization, transpose handling, Nesterov/HeavyBall update, and weight-decay branches are all correct and consistent with the reference implementation in the test. Previously flagged P1s are either fixed (closure/enable_grad) or noted in prior threads. The only new findings are P2: a documentation gap about rank-symmetric gradient availability and incomplete scale-mode test coverage. Neither blocks correctness in the intended tensor-parallel use case. transformer_engine/pytorch/optimizers/muon.py — collective-deadlock documentation; tests/pytorch/distributed/run_muon_optimizer.py — scale_mode coverage gap Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant MuonOptimizer
participant _orthogonalize
participant _distributed_normalize_p2_
participant newton_schulz
Caller->>MuonOptimizer: step()
loop for each param with grad
MuonOptimizer->>MuonOptimizer: apply weight decay (decoupled or L2)
MuonOptimizer->>MuonOptimizer: momentum_buffer.lerp_(grad, 1-β)
MuonOptimizer->>MuonOptimizer: compute nesterov/non-nesterov update
MuonOptimizer->>_orthogonalize: update, partition_dim, ...
_orthogonalize->>_orthogonalize: clone + optional transpose
_orthogonalize->>_distributed_normalize_p2_: orth_grad
_distributed_normalize_p2_-->>_distributed_normalize_p2_: dist.all_reduce(norm_sq)
_distributed_normalize_p2_->>_orthogonalize: x /= global_norm
_orthogonalize->>newton_schulz: orth_grad, CusolverMpCtx
newton_schulz-->>newton_schulz: distributed NS iterations
newton_schulz->>_orthogonalize: orth_grad (orthogonalized)
_orthogonalize->>_orthogonalize: optional un-transpose + scale
_orthogonalize->>MuonOptimizer: orth_update
MuonOptimizer->>MuonOptimizer: p.add_(orth_update, alpha=-lr)
end
MuonOptimizer->>Caller: loss
Reviews (2): Last reviewed commit: "Fix Muon closure and reference test" | Re-trigger Greptile |
| def step(self, closure=None): | ||
| """Perform a single optimization step.""" | ||
| loss = None | ||
| if closure is not None: | ||
| loss = closure() | ||
|
|
There was a problem hiding this comment.
Closure called inside
@torch.no_grad(), preventing gradient computation
closure() is invoked while torch.no_grad() is active. Any loss.backward() call inside the closure will silently produce zero/no gradients. The standard PyTorch pattern (used in SGD, Adam, etc.) is to wrap the closure in with torch.enable_grad():.
| def step(self, closure=None): | |
| """Perform a single optimization step.""" | |
| loss = None | |
| if closure is not None: | |
| loss = closure() | |
| @torch.no_grad() | |
| def step(self, closure=None): | |
| """Perform a single optimization step.""" | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() |
| scale_mode: str, | ||
| extra_scale_factor: float, | ||
| eps: float, | ||
| ) -> torch.Tensor: | ||
| global_shape = [grad.size(0), grad.size(1)] | ||
| global_shape[partition_dim] *= world_size |
There was a problem hiding this comment.
Reference
global_shape incorrectly scales an already-full tensor
_reference_orthogonalize receives the full matrix (shape full_shape) but then multiplies global_shape[partition_dim] by world_size a second time. For partition_dim=1 with world_size=2 and full_shape=(96, 128) this gives global_shape=[96, 256], so get_muon_scale_factor returns max(96,256)^0.5 = 16. The optimizer, operating on the shard (96, 64), correctly reconstructs global_shape=[96, 128] and computes max(96,128)^0.5 ≈ 11.3. This √2 discrepancy means the reference cannot correctly validate the optimizer's output.
The global_shape[partition_dim] *= world_size line should be removed since the input is already the full matrix.
| if mode == "unit_rms_norm": | ||
| return (size_out / size_in) ** 0.5 |
There was a problem hiding this comment.
unit_rms_norm mode can divide by zero when size_in == 0
(size_out / size_in) ** 0.5 raises ZeroDivisionError when size_in is 0. While the optimizer validates that the partition dimension is non-empty, it doesn't ensure the other dimension is non-zero. Consider adding a guard or documenting that both dimensions must be strictly positive.
| if group["nesterov"]: | ||
| update = grad.lerp(momentum_buffer, group["momentum"]) | ||
| else: | ||
| update = momentum_buffer |
There was a problem hiding this comment.
Non-Nesterov
update is an alias to momentum_buffer, not a copy
update = momentum_buffer holds a reference. If _orthogonalize ever modifies its input in-place in a future refactor, the momentum buffer will be silently corrupted. _orthogonalize currently clones the input immediately so this is safe today, but a defensive .clone() or comment would make the intent explicit.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Description
Add a distributed Muon optimizer, based on newton_schulz orthogonalization
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: