Skip to content

Multi-GPU training does not data-parallelize: the DDP wrapper is bypassed#49

Open
TonyChen06 wants to merge 1 commit into
SchmiedmayerLab:mainfrom
TonyChen06:fix/ddp-grad-sync
Open

Multi-GPU training does not data-parallelize: the DDP wrapper is bypassed#49
TonyChen06 wants to merge 1 commit into
SchmiedmayerLab:mainfrom
TonyChen06:fix/ddp-grad-sync

Conversation

@TonyChen06

@TonyChen06 TonyChen06 commented Jun 27, 2026

Copy link
Copy Markdown

cc @RealLast

♻️ Current situation & Problem

curriculum_learning.py wraps the model in DDP(model), but every training step computes the loss on the unwrapped module:

loss = self._get_model().compute_loss(batch)   # _get_model() returns self.model.module
loss.backward()

Calling the module instead of the wrapper means DDP's reducer never fires, so gradients are never synchronized across ranks. With DistributedSampler sharding the data, each rank trains an independent replica on its own 1/N of the data, _save_checkpoint keeps only rank 0, and ranks 1…N−1's training is discarded.

⚙️ Release Notes

  • Multi-GPU (DDP) training now synchronizes gradients across ranks (manual all-reduce after backward()); previously each rank trained an independent replica and only rank 0 was saved.

📚 Documentation

A one-line inline comment at the fix site explains why the manual all-reduce is needed (the _get_model() unwrap bypasses DDP). No public-interface change.

✅ Testing

The toy below replicates the trainer's exact pattern — DDP(model) wrapped (as in _initialize_model), loss computed on model.module (as in _get_model().compute_loss(...)), and a different data shard per rank (as DistributedSampler gives) — then logs the gradient each rank holds and how far the two ranks' weights have drifted apart.

import torch, torch.distributed as dist, torch.nn as nn
dist.init_process_group("nccl"); r, W = dist.get_rank(), dist.get_world_size(); torch.cuda.set_device(r)

def trajectory(fix, steps=5):
    torch.manual_seed(0); m = nn.Linear(4, 4).cuda()                 # identical init on both ranks
    ddp = nn.parallel.DistributedDataParallel(m, device_ids=[r]); gm = lambda: ddp.module
    opt = torch.optim.SGD(m.parameters(), lr=0.05)
    out = []
    for s in range(steps):
        g = torch.Generator(device="cuda").manual_seed(1000 + s)
        x = torch.randn(4, 4, generator=g, device="cuda") * 0.5 + r  # a different shard per rank
        y = torch.randn(4, 4, generator=torch.Generator(device="cuda").manual_seed(s), device="cuda")
        opt.zero_grad(); ((gm()(x) - y) ** 2).mean().backward()      # loss on module == the trainer's path
        if fix:                                                       # the proposed fix
            for p in gm().parameters():
                if p.grad is not None: dist.all_reduce(p.grad, op=dist.ReduceOp.AVG)
        gw = gm().weight.grad.flatten()[0].clone(); opt.step()
        w = m.weight.flatten()
        gws = [torch.zeros_like(gw) for _ in range(W)]; dist.all_gather(gws, gw)
        ws  = [torch.zeros_like(w)  for _ in range(W)]; dist.all_gather(ws, w)
        out.append((gws[0].item(), gws[1].item(), (ws[0]-ws[1]).abs().max().item()))
    return out

B, F = trajectory(False), trajectory(True)
if r == 0:
    print("Table 1 - gradient on weight[0] after backward, per rank:")
    print(f"{'step':>4} | {'BUGGY g_r0':>11} {'g_r1':>11} {'synced':>7} | {'FIXED g_r0':>11} {'g_r1':>11} {'synced':>7}")
    for i,((b0,b1,_),(f0,f1,_)) in enumerate(zip(B,F),1):
        print(f"{i:>4} | {b0:>11.5f} {b1:>11.5f} {str(abs(b0-b1)<1e-9):>7} | {f0:>11.5f} {f1:>11.5f} {str(abs(f0-f1)<1e-9):>7}")
    print("\nTable 2 - cross-rank model divergence max|W_rank0 - W_rank1| after each step:")
    print(f"{'step':>4} | {'BUGGY':>12} | {'FIXED':>12}")
    for i,((_,_,bd),(_,_,fd)) in enumerate(zip(B,F),1):
        print(f"{i:>4} | {bd:>12.6f} | {fd:>12.6f}")
dist.destroy_process_group()

Output:

Table 1 - gradient on weight[0] after backward, per rank:
step |  BUGGY g_r0        g_r1  synced |  FIXED g_r0        g_r1  synced
   1 |     0.04043     0.19431   False |     0.11737     0.11737    True
   2 |    -0.13683    -0.05685   False |    -0.09659    -0.09659    True
   3 |     0.05426    -0.25685   False |    -0.10196    -0.10196    True
   4 |     0.02004    -0.01721   False |    -0.00963    -0.00963    True
   5 |    -0.21217    -0.45084   False |    -0.34623    -0.34623    True

Table 2 - cross-rank model divergence max|W_rank0 - W_rank1| after each step:
step |        BUGGY |        FIXED
   1 |     0.050678 |     0.000000
   2 |     0.073750 |     0.000000
   3 |     0.093219 |     0.000000
   4 |     0.118547 |     0.000000
   5 |     0.132573 |     0.000000
  • Buggy (loss on model.module): each rank holds different gradients (synced=False), takes a different step, and the replicas drift apart — divergence grows every step and keeps climbing.
  • Fixed (+ all-reduce): gradients are identical across ranks and equal the average of the per-rank grads (step 1: (0.04043 + 0.19431) / 2 = 0.11737), so the replicas stay bit-identical (0.000000).

The same divergence also reproduces on the actual pipeline; the two ranks diverge after the first step under the current loop and stay identical with the fix:

BUGGY (current code: loss on _get_model(), no all-reduce)
step | per-rank grad norm     | grad_synced | weight drift across ranks
  0  | [29.180, 29.762]       |   False     |  0.000e+00
  1  | [30.274, 29.429]       |   False     |  0.000e+00
  2  | [30.679, 29.582]       |   False     |  1.769e-04
  3  | [28.949, 29.850]       |   False     |  3.611e-04

FIXED (+ all-reduce grads after backward)
step | per-rank grad norm     | grad_synced | weight drift across ranks
  0  | [28.969, 28.969]       |   True      |  0.000e+00
  1  | [29.689, 29.689]       |   True      |  0.000e+00
  2  | [29.913, 29.913]       |   True      |  0.000e+00
  3  | [29.223, 29.223]       |   True      |  0.000e+00

Code of Conduct & Contributing Guidelines

By creating and submitting this pull request, you agree to follow our Code of Conduct and Contributing Guidelines:


View with Codesmith Autofix with Codesmith
Need help on this PR? Tag /codesmith with what you need. Autofix is disabled.

Summary by CodeRabbit

  • Bug Fixes
    • Improved multi-device training stability by explicitly synchronizing gradients and averaging them across processes immediately after backpropagation during stage training.
    • This helps ensure consistent parameter updates across multiple devices, leading to more reliable and reproducible results in distributed runs.

@coderabbitai

coderabbitai Bot commented Jun 27, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro Plus

Run ID: 92417b2f-48f9-46e3-9a3e-6cf8c89d22c6

📥 Commits

Reviewing files that changed from the base of the PR and between 26dcf3a and 4a4e983.

📒 Files selected for processing (1)
  • curriculum_learning.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • curriculum_learning.py

📝 Walkthrough

Walkthrough

In _train_stage, after loss.backward(), an explicit gradient averaging step is added: when dist.is_initialized(), the code iterates over the unwrapped model's parameters and calls dist.all_reduce with ReduceOp.AVG on each non-None gradient tensor.

Changes

Distributed Gradient Averaging

Layer / File(s) Summary
Explicit gradient all-reduce after backward
curriculum_learning.py
After loss.backward(), iterates over DDP-unwrapped model parameters and performs dist.all_reduce(..., op=AVG) on each non-None gradient when dist.is_initialized().

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the core fix: DDP gradient synchronization was bypassed during multi-GPU training.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
curriculum_learning.py (1)

1118-1123: 🚀 Performance & Scalability | 🔵 Trivial | 🏗️ Heavy lift

Per-parameter all-reduce is communication-inefficient.

Issuing a separate all_reduce for every parameter creates many small collectives with no compute/communication overlap, which scales poorly on large models. Two options:

  • Coalesce gradients into flat buckets and reduce in fewer calls (e.g., torch._utils._flatten_dense_tensors / torch.distributed.algorithms).
  • Preferably, drive the backward through the DDP wrapper so DDP's bucketed, overlapped reducer is used instead of this manual path. The root cause is that compute_loss runs on the unwrapped _get_model(), which bypasses DDP's autograd hooks; routing the loss computation through self.model (the DDP module) would remove the need for this manual averaging entirely.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@curriculum_learning.py` around lines 1118 - 1123, The manual gradient
averaging in `curriculum_learning.py` is doing a separate `dist.all_reduce` for
each parameter, which is inefficient. Update the training flow around
`compute_loss` and `_get_model()` so the backward pass goes through `self.model`
(the DDP wrapper) instead of the unwrapped model, allowing DDP’s bucketed
reducer to handle synchronization automatically. If that isn’t possible, at
least coalesce parameter grads into fewer buckets before reducing rather than
iterating over `self._get_model().parameters()` one by one.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@curriculum_learning.py`:
- Around line 1119-1122: The gradient synchronization loop in self._get_model()
can hang because dist.all_reduce is called only when p.grad is not None, making
the collective sequence differ across ranks. Update the reduction logic so every
rank iterates over the same deterministic set of trainable parameters and always
participates in the same all_reduce calls, materializing a zero gradient for
missing grads before calling dist.all_reduce.

---

Nitpick comments:
In `@curriculum_learning.py`:
- Around line 1118-1123: The manual gradient averaging in
`curriculum_learning.py` is doing a separate `dist.all_reduce` for each
parameter, which is inefficient. Update the training flow around `compute_loss`
and `_get_model()` so the backward pass goes through `self.model` (the DDP
wrapper) instead of the unwrapped model, allowing DDP’s bucketed reducer to
handle synchronization automatically. If that isn’t possible, at least coalesce
parameter grads into fewer buckets before reducing rather than iterating over
`self._get_model().parameters()` one by one.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro Plus

Run ID: 9ced9a8f-b67e-4c3c-a6c2-8b68a4378f75

📥 Commits

Reviewing files that changed from the base of the PR and between 104013b and 65dd2cb.

📒 Files selected for processing (1)
  • curriculum_learning.py

Comment thread curriculum_learning.py Outdated
The model is DDP-wrapped, but the training loop computes loss on the unwrapped module
(`_get_model().compute_loss()`), which bypasses DDP's reducer -- so gradients are never
synced across ranks. Each rank trains an independent replica on its own data shard and
only rank 0's is saved, so multi-GPU runs do not data-parallelize.

Fix: average gradients across ranks after backward().
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant