Skip to content

feat(megatron): Add support for Gated Delta Net (GDN) & Kimi Delta Attention (KDA)#676

Draft
clairesonglee wants to merge 30 commits into
mainfrom
clairlee/kda-optimized-training-patch
Draft

feat(megatron): Add support for Gated Delta Net (GDN) & Kimi Delta Attention (KDA)#676
clairesonglee wants to merge 30 commits into
mainfrom
clairlee/kda-optimized-training-patch

Conversation

@clairesonglee

Copy link
Copy Markdown
Collaborator

No description provided.

Comment thread tools/chat_zebra_llama.py Fixed
Comment thread tools/chat_zebra_llama.py Fixed
Comment thread tools/convert_zebra_llama_to_hf.py Fixed
Comment thread primus/backends/megatron/core/models/hybrid/hybrid_block.py Fixed
Comment thread tools/modeling_zebra_llama.py Fixed
Comment thread tools/modeling_zebra_llama.py Fixed
Comment thread tools/modeling_zebra_llama.py Fixed
Comment thread tools/convert_zebra_llama_to_hf.py Fixed
Comment thread primus/backends/megatron/core/models/hybrid/hybrid_block.py Fixed
vanshbhatia-amd and others added 8 commits April 24, 2026 19:22
- 300M pure GDN pretrain config aligned with FLA (micro_batch=128, global=1024)
- 300M model architecture config (hidden=1024, layers=24, ffn=4096)
- use_short_conv in language_model.yaml
- Updated KDA pretrain config

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>

# Conflicts:
#	examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml
#	examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml
#	examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml
#	examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml
#	primus/backends/megatron/core/models/hybrid/__init__.py
#	primus/backends/megatron/core/models/hybrid/hybrid_block.py
#	primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py
#	primus/configs/models/megatron/language_model.yaml
#	primus/configs/models/megatron/mamba_370M.yaml
#	primus/configs/models/megatron/mamba_base.yaml
#	primus/configs/models/megatron/zebra_llama_1B.yaml
#	primus/configs/models/megatron/zebra_llama_3B.yaml
#	primus/configs/models/megatron/zebra_llama_8B.yaml
#	primus/core/utils/import_utils.py
#	primus/modules/trainer/lightmegatron/pre_trainer.py
#	primus/modules/trainer/megatron/pre_trainer.py
#	primus/modules/trainer/megatron/trainer.py
- Introduced `bash-docker.sh` for running the Primus environment with necessary device and network configurations.
- Added `GDN_FLA_PARITY.md` to document changes for achieving GDN training parity with Flash Linear Attention (FLA).
- Created `README_GDN.md` for a comprehensive guide on running the 300M pure GDN pretraining recipe in Primus, including setup and performance metrics.
- Updated `README.md` to include details about Zebra-Llama models and their configurations.
- Enhanced Megatron patch script to apply necessary modifications for GDN training.
- Adjusted training configurations to ensure alignment with FLA performance metrics.
- Introduced `KDA_FLA_PARITY.md` to document changes for achieving Kimi Delta Attention (KDA) training parity with Flash Linear Attention (FLA).
- Created `README_KDA.md` for a comprehensive guide on running the 300M pure KDA pretraining recipe in Primus, including setup and performance metrics.
- Added new training configuration file `zebra_llama_300M_kda_pure-pretrain.yaml` aligned with FLA specifications.
- Enhanced Megatron patch script to apply necessary modifications for KDA training.
- Updated model code and runtime configurations to ensure alignment with FLA performance metrics, including fused operations and normalization techniques.
- Implemented diagnostic features for capturing activations and iter-1 batch tokens for comparison with FLA outputs.
- Introduced `KDA_FLA_PARITY.md` to document changes for achieving Kimi Delta Attention (KDA) training parity with Flash Linear Attention (FLA).
- Created `README_KDA.md` for a comprehensive guide on running the 300M pure KDA pretraining recipe in Primus, including setup and performance metrics.
- Added new training configuration file `zebra_llama_300M_kda_pure-pretrain.yaml` aligned with FLA specifications.
- Enhanced Megatron patch script to apply necessary modifications for KDA training.
- Updated model code and runtime configurations to ensure alignment with FLA performance metrics, including optimizations for fused operations and memory management.
@clairesonglee clairesonglee force-pushed the clairlee/kda-optimized-training-patch branch from 774acb4 to c301288 Compare June 3, 2026 12:04


if __name__ == "__main__":
sys.exit(main())


if __name__ == "__main__":
sys.exit(main())
with open(marker, "w") as fh:
fh.write(_msg + "\n")
fh.write(f"pid={os.getpid()} ts={time.time()}\n")
except Exception:
eps=self.config.layernorm_epsilon,
)
self._use_fla_fused_gated_norm = True
except ImportError:
# `return_dict` for any code paths that consult config defaults.
try:
hf_model.config.return_dict = True
except Exception:
for h in mg_hooks:
try:
h.remove()
except Exception:
for h in hf_hooks:
try:
h.remove()
except Exception:
for h in mg_io_hooks:
try:
h.remove()
except Exception:
for h in hf_io_hooks:
try:
h.remove()
except Exception:
The upstream Megatron-LM MambaStackSubmodules dataclass does not have
a moe_layer field. Only HybridStackSubmodules (Primus) supports it.
This caused a RuntimeError at import time when the spec module was
loaded for any hybrid model training.

Co-authored-by: Cursor <cursoragent@cursor.com>
…paths

Set mock_data: true, train_data_path: null, and load: null across all
zebra_llama pretrain configs so they can run without depending on
/home/vanbhati@amd.com data or checkpoint paths.

Co-authored-by: Cursor <cursoragent@cursor.com>
Hybrid models have parameters that don't all participate in backward,
causing Megatron's DDP buffer per_param_grad_ready_counts assertion
to fail when overlap_grad_reduce is enabled. Set overlap_grad_reduce,
overlap_param_gather, and gradient_accumulation_fusion to false.

Co-authored-by: Cursor <cursoragent@cursor.com>
@clairesonglee clairesonglee force-pushed the clairlee/kda-optimized-training-patch branch from ddc56fb to a6cc909 Compare June 8, 2026 08:51
Comment thread tools/megatron_forward_zebra_llama.py Fixed
… overlap flags

- Set train_iters to 50 across all 12 zebra_llama configs for quick perf benchmarking
- 300M configs: reduce batch sizes (mbs=8, gbs=64) for 8-GPU runs
- 1B GDN Pure / KDA Pure: restore original mbs=16, gbs=128
- 1B GDN Pure 100B: restore original mbs=64, gbs=512
- 300M configs: disable overlap_grad_reduce, overlap_param_gather, gradient_accumulation_fusion

Co-authored-by: Cursor <cursoragent@cursor.com>
@clairesonglee clairesonglee force-pushed the clairlee/kda-optimized-training-patch branch from a6cc909 to c791b85 Compare June 8, 2026 14:03
fh.write("sys.path[:6]:\n")
for p in sys.path[:6]:
fh.write(f" {p}\n")
except Exception:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants