[Ascend] support qwen35 mtp on Ascend-A3#337
Open
wanfengcxz wants to merge 31 commits into
Open
Conversation
- ascend_cudagraph.py: multi-token decode graph mode support (4-tuple graph key with query_len, actual_seq_lengths_q buffers) - device/__init__.py: add patch_attention_is_tp (draft model TP), patch_ray_init (NPU Ray resource), MTP multi-token paths in GatedDelta conv1d and sigmoid_gating update kernels Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Move the Ascend-specific graph alignment, state replay, and sampling fallback into dlinfer so multi-token speculative decode stays stable without expanding lmdeploy core runtime changes. Made-with: Cursor
Snapshot only the active state-cache rows during speculative replay so Ascend no longer clones the full state pool for rejection recovery. Made-with: Cursor
There was a problem hiding this comment.
Pull request overview
This PR adds Ascend-A3 support for Qwen3.5 MTP (speculative decoding) by introducing Ascend-specific rejection sampling and ring-buffer state handling, and by extending the gated-delta / conv-state execution paths (including cudagraph buffer management) to support multi-token decode.
Changes:
- Add Ascend rejection sampling implementation (Triton + PyTorch fallback) and patch lmdeploy to use it.
- Add fused recurrent gated-delta-rule kernel and integrate ring-buffer state/conv updates for multi-token decoding.
- Update Ascend cudagraph buffer logic and MoE comm buffers to handle MTP’s increased per-step token counts.
Reviewed changes
Copilot reviewed 3 out of 10 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| dlinfer/vendor/ascend/triton_ops/reject_sample.py | New Ascend rejection sampling implementation with Triton backend and PyTorch reference fallback. |
| dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py | New fused recurrent gated-delta-rule Triton implementation for decode/prefill state updates. |
| dlinfer/vendor/ascend/triton_ops/fla/init.py | Expose fused_recurrent_gated_delta_rule from the FLA submodule. |
| dlinfer/vendor/ascend/triton_ops/causal_conv1d.py | Add ring-buffer mode support for conv state read/write and kernel update path. |
| dlinfer/vendor/ascend/triton_ops/init.py | Export new Ascend triton ops (rejection_sample, fused_recurrent_gated_delta_rule). |
| dlinfer/vendor/ascend/torch_npu_ops.py | Remove an unused import. |
| dlinfer/vendor/ascend/moe.py | Fix topk padding behavior to align with padded hidden_states handling. |
| dlinfer/framework/lmdeploy_ext/device/ascend.py | Adjust bad-words processing to avoid negative indices on Ascend gather/scatter. |
| dlinfer/framework/lmdeploy_ext/device/init.py | Patch lmdeploy rejection sampler and extend gated-delta net / Qwen3.5 builder for speculative decoding. |
| dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py | Extend cudagraph buffers and keying to support multi-token decode and DP-global gating. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| max_batches, dtype=torch.int32, device=device | ||
| ) | ||
|
|
||
| input_buffers["kv_seqlens"] = torch.ones(max_batches, dtype=torch.int32) |
Comment on lines
+111
to
+113
| input_buffers["q_seqlens"] = ( | ||
| torch.arange(1, max_batches + 1, dtype=torch.int32) * max_q_seq_len | ||
| ) |
| ) | ||
|
|
||
| else: | ||
| input_buffers["q_seqlens"] = torch.arange(1, max_batches + 1, dtype=torch.int32) |
Comment on lines
+60
to
+64
| # (which are negative padding values) with 0 before gather/scatter. | ||
| valid_bad_words = bad_words.where(mask, 0) | ||
| filtered_scores = scores.gather(1, valid_bad_words) | ||
| filtered_scores = mask.to(filtered_scores.dtype) * filter_value + filtered_scores | ||
| scores.scatter_(1, bad_words, filtered_scores) | ||
| scores.scatter_(1, valid_bad_words, filtered_scores) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Support qwen35 mtp on Ascend-A3