Skip to content

DFlash: enable eval_from_cache + per-position (per-k) train/eval metrics#82

Merged
yubofredwang merged 2 commits intomainfrom
yubo/dflash-eval-per-pos
Apr 23, 2026
Merged

DFlash: enable eval_from_cache + per-position (per-k) train/eval metrics#82
yubofredwang merged 2 commits intomainfrom
yubo/dflash-eval-per-pos

Conversation

@yubofredwang
Copy link
Copy Markdown
Collaborator

Summary

  • Per-position train metricsDFlashModel.forward now returns loss_per_position and acc_per_position (shape [block_size]). DFlashTrainer._aggregate_metrics slices off the anchor slot and emits train/ploss_i, train/acc_i (for i = 0..B-2, so acc_0 is the first predicted token — matches Eagle3 semantics) plus train/simulated_acc_len (cumulative product of per-position accs).
  • Real eval path — replaces the stubbed eval_from_cache{} with eval_forward / eval_from_cache / _aggregate_eval_metrics that mirror Eagle3Trainer. Eval loss is decay-weighted with self.loss_decay_gamma so eval/avg_loss is directly comparable to train/avg_loss; also emits eval/avg_acc, eval/simulated_acc_len, and per-i eval/ploss_i, eval/acc_i. The stale "eval hangs in colocate/SGLang mode" comment is gone — the controller already calls eval_from_cache every eval_interval and all current DFlash training configs are non-colocated.
  • Hard-pin on DFlash configsconfigs/dflash_qwen3_8b_repro.yaml and configs/sglang_qwen3_8b_dflash.yaml set mooncake.enable_hard_pin: true. With force-delete (landed in Refactor Mooncake Store: force delete + hard pin #73), the TTL path is no longer our cleanup path; hard-pin makes `remove(force=True)` the only way an object leaves the store. Requires `mooncake-transfer-engine >= 0.3.10.post1`, already pinned in `pyproject.toml`.

Why

The linked wandb decagon runs log per-TTT-position metrics and the eval suite, but DFlash runs were showing only scalar `train/avg_loss` / `train/avg_acc` with no eval panel — making it impossible to diagnose which predicted position is failing when a run diverges (historically around steps 100–500 on m27). This brings DFlash to Eagle3 parity in wandb.

Test plan

  • `pytest tests/test_dflash.py` → 62 passed. Updated 8 call sites to unpack the new 4-tuple return from `DFlashModel.forward`; added a shape assertion on the new per-position tensors.
  • `MooncakeConfig.from_flat_args` loads both DFlash configs with `enable_hard_pin=True`.
  • Dry-run one training step + one eval step on a tiny DFlash config to confirm `train/acc_i`, `train/ploss_i`, `train/simulated_acc_len`, and the `eval/*` keys appear in wandb (in-progress).

Why: the DFlash trainer was logging only scalar train/avg_loss and
train/avg_acc, and eval_from_cache was stubbed to return {}, so DFlash
runs lacked the per-position breakdown, simulated_acc_len, and eval
metrics that Eagle3 runs provide and that we rely on in wandb for
diagnosing draft quality across positions in the proposed block.

Changes
- torchspec/models/dflash.py: DFlashModel.forward now additionally
  returns loss_per_position / acc_per_position (shape [block_size]),
  computed under no_grad using the existing binary_eval_mask (no decay
  bias) and the already-computed loss_per_token / correct. Index 0 is
  the anchor slot (always zero count) and is sliced off downstream.
- torchspec/training/dflash_trainer.py:
  - _forward / _train_step propagate the per-position tensors.
  - _aggregate_metrics emits train/ploss_i, train/acc_i (i=0..B-2,
    re-indexed so acc_0 = first predicted token, matching Eagle3)
    and train/simulated_acc_len = cumulative product of per-position
    accs.
  - Replaces the eval stub with real eval_forward, eval_from_cache,
    and _aggregate_eval_metrics mirroring Eagle3Trainer. Eval loss is
    decay-weighted using self.loss_decay_gamma so eval/avg_loss is
    directly comparable to train/avg_loss; eval also emits
    eval/avg_acc, eval/simulated_acc_len, eval/ploss_i, eval/acc_i.
  - Stale "eval hangs in colocate/SGLang mode" comment removed; the
    current controller/loop pipeline invokes eval_from_cache every
    eval_interval steps and is non-colocate in the DFlash training
    configs.
- tests/test_dflash.py: updated 8 call sites to unpack the new 4-tuple
  return from DFlashModel.forward; added a shape assertion on the new
  per-position tensors.
- configs/dflash_qwen3_8b_repro.yaml, configs/sglang_qwen3_8b_dflash.yaml:
  set mooncake.enable_hard_pin: true so batch_remove(force=True) is
  the sole deletion path and the master-side TTL is bypassed; requires
  mooncake-transfer-engine >= 0.3.10.post1 (already pinned).

Tests: pytest tests/test_dflash.py -> 62 passed.
Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
Copilot AI review requested due to automatic review settings April 23, 2026 02:07
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR brings DFlash training/eval logging up to parity with Eagle3 by adding per-position metrics to the DFlash model/trainer and implementing a real eval_from_cache path, plus enabling Mooncake hard-pin in DFlash-related configs.

Changes:

  • Extend DFlashModel.forward to also return per-position loss/accuracy tensors ([block_size]) and propagate them through DFlash training metric aggregation.
  • Implement DFlashTrainer.eval_forward, eval_from_cache, and _aggregate_eval_metrics to compute eval metrics from the CPU eval cache and emit per-position eval stats.
  • Enable mooncake.enable_hard_pin: true in DFlash YAML configs and document the dependency requirement.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
torchspec/training/dflash_trainer.py Adds per-position train metrics + a full eval-from-cache implementation and eval metric aggregation.
torchspec/models/dflash.py Returns per-position loss/accuracy tensors from the model forward pass.
tests/test_dflash.py Updates tests to unpack the expanded forward return tuple and asserts per-position tensor shapes.
configs/sglang_qwen3_8b_dflash.yaml Enables Mooncake hard-pin for this DFlash config.
configs/dflash_qwen3_8b_repro.yaml Enables Mooncake hard-pin for this DFlash repro config.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread torchspec/training/dflash_trainer.py Outdated
Comment on lines +349 to +353
# so weights become exp(-i/gamma).
weights = torch.exp(-k.float() / gamma)
else:
weights = torch.ones_like(pred_loss_pp)
weighted_avg_loss = (pred_loss_pp * weights).sum().item() / weights.sum().item()
Comment on lines +355 to +359
metrics: dict = {
"eval/avg_loss": weighted_avg_loss,
"eval/avg_acc": pred_acc_pp.mean().item(),
"eval/simulated_acc_len": simulated_acc_len,
}
Comment on lines +79 to +80
# batch_remove(force=True) (see mooncake/eagle_store.py). Requires
# mooncake-transfer-engine >= 0.3.10.post1.
global_segment_size: 16GB
local_buffer_size: 4GB
# Hard-pin: master-side TTL is disabled; we rely on our explicit
# batch_remove(force=True) (see mooncake/eagle_store.py). Requires
Preserve per-position counts so train and eval aggregate DFlash metrics with the correct token weighting. This keeps ploss/acc breakdowns and eval scalars aligned with the actual objective when masks or chunk lengths are sparse.

Signed-off-by: Yubo Wang <yubowang2019@gmail.com>
@yubofredwang yubofredwang merged commit 1aae91a into main Apr 23, 2026
2 checks passed
@yubofredwang yubofredwang deleted the yubo/dflash-eval-per-pos branch April 23, 2026 03:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants