feat(megatron): Add support for Gated Delta Net (GDN) & Kimi Delta Attention (KDA)#676
Draft
clairesonglee wants to merge 30 commits into
Draft
feat(megatron): Add support for Gated Delta Net (GDN) & Kimi Delta Attention (KDA)#676clairesonglee wants to merge 30 commits into
clairesonglee wants to merge 30 commits into
Conversation
- 300M pure GDN pretrain config aligned with FLA (micro_batch=128, global=1024) - 300M model architecture config (hidden=1024, layers=24, ffn=4096) - use_short_conv in language_model.yaml - Updated KDA pretrain config Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com> # Conflicts: # examples/megatron/configs/MI300X/mamba_370M-pretrain.yaml # examples/megatron/configs/MI300X/zebra_llama_1B-pretrain.yaml # examples/megatron/configs/MI300X/zebra_llama_3B-pretrain.yaml # examples/megatron/configs/MI300X/zebra_llama_8B-pretrain.yaml # primus/backends/megatron/core/models/hybrid/__init__.py # primus/backends/megatron/core/models/hybrid/hybrid_block.py # primus/backends/megatron/core/models/hybrid/hybrid_mamba_mla_layer_specs.py # primus/configs/models/megatron/language_model.yaml # primus/configs/models/megatron/mamba_370M.yaml # primus/configs/models/megatron/mamba_base.yaml # primus/configs/models/megatron/zebra_llama_1B.yaml # primus/configs/models/megatron/zebra_llama_3B.yaml # primus/configs/models/megatron/zebra_llama_8B.yaml # primus/core/utils/import_utils.py # primus/modules/trainer/lightmegatron/pre_trainer.py # primus/modules/trainer/megatron/pre_trainer.py # primus/modules/trainer/megatron/trainer.py
- Introduced `bash-docker.sh` for running the Primus environment with necessary device and network configurations. - Added `GDN_FLA_PARITY.md` to document changes for achieving GDN training parity with Flash Linear Attention (FLA). - Created `README_GDN.md` for a comprehensive guide on running the 300M pure GDN pretraining recipe in Primus, including setup and performance metrics. - Updated `README.md` to include details about Zebra-Llama models and their configurations. - Enhanced Megatron patch script to apply necessary modifications for GDN training. - Adjusted training configurations to ensure alignment with FLA performance metrics.
- Introduced `KDA_FLA_PARITY.md` to document changes for achieving Kimi Delta Attention (KDA) training parity with Flash Linear Attention (FLA). - Created `README_KDA.md` for a comprehensive guide on running the 300M pure KDA pretraining recipe in Primus, including setup and performance metrics. - Added new training configuration file `zebra_llama_300M_kda_pure-pretrain.yaml` aligned with FLA specifications. - Enhanced Megatron patch script to apply necessary modifications for KDA training. - Updated model code and runtime configurations to ensure alignment with FLA performance metrics, including fused operations and normalization techniques. - Implemented diagnostic features for capturing activations and iter-1 batch tokens for comparison with FLA outputs.
- Introduced `KDA_FLA_PARITY.md` to document changes for achieving Kimi Delta Attention (KDA) training parity with Flash Linear Attention (FLA). - Created `README_KDA.md` for a comprehensive guide on running the 300M pure KDA pretraining recipe in Primus, including setup and performance metrics. - Added new training configuration file `zebra_llama_300M_kda_pure-pretrain.yaml` aligned with FLA specifications. - Enhanced Megatron patch script to apply necessary modifications for KDA training. - Updated model code and runtime configurations to ensure alignment with FLA performance metrics, including optimizations for fused operations and memory management.
774acb4 to
c301288
Compare
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| sys.exit(main()) |
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| sys.exit(main()) |
| with open(marker, "w") as fh: | ||
| fh.write(_msg + "\n") | ||
| fh.write(f"pid={os.getpid()} ts={time.time()}\n") | ||
| except Exception: |
| eps=self.config.layernorm_epsilon, | ||
| ) | ||
| self._use_fla_fused_gated_norm = True | ||
| except ImportError: |
| # `return_dict` for any code paths that consult config defaults. | ||
| try: | ||
| hf_model.config.return_dict = True | ||
| except Exception: |
| for h in mg_hooks: | ||
| try: | ||
| h.remove() | ||
| except Exception: |
| for h in hf_hooks: | ||
| try: | ||
| h.remove() | ||
| except Exception: |
| for h in mg_io_hooks: | ||
| try: | ||
| h.remove() | ||
| except Exception: |
| for h in hf_io_hooks: | ||
| try: | ||
| h.remove() | ||
| except Exception: |
The upstream Megatron-LM MambaStackSubmodules dataclass does not have a moe_layer field. Only HybridStackSubmodules (Primus) supports it. This caused a RuntimeError at import time when the spec module was loaded for any hybrid model training. Co-authored-by: Cursor <cursoragent@cursor.com>
…paths Set mock_data: true, train_data_path: null, and load: null across all zebra_llama pretrain configs so they can run without depending on /home/vanbhati@amd.com data or checkpoint paths. Co-authored-by: Cursor <cursoragent@cursor.com>
Hybrid models have parameters that don't all participate in backward, causing Megatron's DDP buffer per_param_grad_ready_counts assertion to fail when overlap_grad_reduce is enabled. Set overlap_grad_reduce, overlap_param_gather, and gradient_accumulation_fusion to false. Co-authored-by: Cursor <cursoragent@cursor.com>
ddc56fb to
a6cc909
Compare
… overlap flags - Set train_iters to 50 across all 12 zebra_llama configs for quick perf benchmarking - 300M configs: reduce batch sizes (mbs=8, gbs=64) for 8-GPU runs - 1B GDN Pure / KDA Pure: restore original mbs=16, gbs=128 - 1B GDN Pure 100B: restore original mbs=64, gbs=512 - 300M configs: disable overlap_grad_reduce, overlap_param_gather, gradient_accumulation_fusion Co-authored-by: Cursor <cursoragent@cursor.com>
a6cc909 to
c791b85
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.