Skip to content

[Ascend] support qwen35 mtp on Ascend-A3#337

Open
wanfengcxz wants to merge 31 commits into
DeepLink-org:mainfrom
wanfengcxz:wq/support_qwen35_mtp
Open

[Ascend] support qwen35 mtp on Ascend-A3#337
wanfengcxz wants to merge 31 commits into
DeepLink-org:mainfrom
wanfengcxz:wq/support_qwen35_mtp

Conversation

@wanfengcxz

Copy link
Copy Markdown
Collaborator

Support qwen35 mtp on Ascend-A3

tangzhiyi11 and others added 30 commits July 3, 2026 06:04
- ascend_cudagraph.py: multi-token decode graph mode support
  (4-tuple graph key with query_len, actual_seq_lengths_q buffers)
- device/__init__.py: add patch_attention_is_tp (draft model TP),
  patch_ray_init (NPU Ray resource), MTP multi-token paths in
  GatedDelta conv1d and sigmoid_gating update kernels

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Move the Ascend-specific graph alignment, state replay, and sampling fallback into dlinfer so multi-token speculative decode stays stable without expanding lmdeploy core runtime changes.

Made-with: Cursor
Snapshot only the active state-cache rows during speculative replay so Ascend no longer clones the full state pool for rejection recovery.

Made-with: Cursor

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Ascend-A3 support for Qwen3.5 MTP (speculative decoding) by introducing Ascend-specific rejection sampling and ring-buffer state handling, and by extending the gated-delta / conv-state execution paths (including cudagraph buffer management) to support multi-token decode.

Changes:

  • Add Ascend rejection sampling implementation (Triton + PyTorch fallback) and patch lmdeploy to use it.
  • Add fused recurrent gated-delta-rule kernel and integrate ring-buffer state/conv updates for multi-token decoding.
  • Update Ascend cudagraph buffer logic and MoE comm buffers to handle MTP’s increased per-step token counts.

Reviewed changes

Copilot reviewed 3 out of 10 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
dlinfer/vendor/ascend/triton_ops/reject_sample.py New Ascend rejection sampling implementation with Triton backend and PyTorch reference fallback.
dlinfer/vendor/ascend/triton_ops/fla/fused_recurrent.py New fused recurrent gated-delta-rule Triton implementation for decode/prefill state updates.
dlinfer/vendor/ascend/triton_ops/fla/init.py Expose fused_recurrent_gated_delta_rule from the FLA submodule.
dlinfer/vendor/ascend/triton_ops/causal_conv1d.py Add ring-buffer mode support for conv state read/write and kernel update path.
dlinfer/vendor/ascend/triton_ops/init.py Export new Ascend triton ops (rejection_sample, fused_recurrent_gated_delta_rule).
dlinfer/vendor/ascend/torch_npu_ops.py Remove an unused import.
dlinfer/vendor/ascend/moe.py Fix topk padding behavior to align with padded hidden_states handling.
dlinfer/framework/lmdeploy_ext/device/ascend.py Adjust bad-words processing to avoid negative indices on Ascend gather/scatter.
dlinfer/framework/lmdeploy_ext/device/init.py Patch lmdeploy rejection sampler and extend gated-delta net / Qwen3.5 builder for speculative decoding.
dlinfer/framework/lmdeploy_ext/cudagraph/ascend_cudagraph.py Extend cudagraph buffers and keying to support multi-token decode and DP-global gating.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

max_batches, dtype=torch.int32, device=device
)

input_buffers["kv_seqlens"] = torch.ones(max_batches, dtype=torch.int32)
Comment on lines +111 to +113
input_buffers["q_seqlens"] = (
torch.arange(1, max_batches + 1, dtype=torch.int32) * max_q_seq_len
)
)

else:
input_buffers["q_seqlens"] = torch.arange(1, max_batches + 1, dtype=torch.int32)
Comment on lines +60 to +64
# (which are negative padding values) with 0 before gather/scatter.
valid_bad_words = bad_words.where(mask, 0)
filtered_scores = scores.gather(1, valid_bad_words)
filtered_scores = mask.to(filtered_scores.dtype) * filter_value + filtered_scores
scores.scatter_(1, bad_words, filtered_scores)
scores.scatter_(1, valid_bad_words, filtered_scores)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants