Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
434 changes: 434 additions & 0 deletions qwen3_transformer_only_quantize.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/winml/modelkit/build/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def build_hf_model(
cache_key: str | None = None,
ep: EPNameOrAlias | None = None,
device: str | None = None,
model_type: str | None = None,
**kwargs: Any,
) -> BuildResult:
"""Build an ONNX model from a HuggingFace model architecture.
Expand Down Expand Up @@ -208,6 +209,7 @@ def _name(base: str) -> str:
model_id,
trust_remote_code,
random_init=random_init,
model_type=model_type,
)

# =========================================================================
Expand Down Expand Up @@ -436,6 +438,7 @@ def _load_model(
trust_remote_code: bool,
random_init: bool = False,
hf_config: Any | None = None,
model_type: str | None = None,
) -> Any:
"""Load PyTorch model — pretrained or random weights.

Expand Down Expand Up @@ -511,6 +514,7 @@ def _load_model(
task=task,
trust_remote_code=effective_trust,
hf_config=hf_config,
model_type=model_type,
)
return pytorch_model

Expand Down
13 changes: 13 additions & 0 deletions src/winml/modelkit/loader/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,19 @@ def resolve_loader_config(
f"attribute. Cannot proceed with config generation."
)

# Explicit model_type override alongside a model_id: honor the requested
# type so downstream class / build-config / export resolution selects the
# variant (e.g. "qwen3_transformer_only") rather than the architecture's
# native type. The model_type-only path above (AutoConfig.for_model) is
# unaffected because it only runs when model_id is None.
if model_id is not None and model_type is not None and hf_config.model_type != model_type:
logger.info(
"Overriding resolved model_type '%s' -> '%s' (explicit request)",
hf_config.model_type,
model_type,
)
hf_config.model_type = model_type

# 2. Infer task (depends on: model_type param or hf_config.architectures)
if task is None and model_type is not None:
supported = get_supported_tasks(model_type, library_name=library_name)
Expand Down
13 changes: 13 additions & 0 deletions src/winml/modelkit/loader/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def load_hf_model(
user_script: str | None = None,
trust_remote_code: bool = False,
hf_config: PretrainedConfig | None = None,
model_type: str | None = None,
) -> tuple[nn.Module, PretrainedConfig, str]:
"""Load, detect task, and prepare HuggingFace model.

Expand Down Expand Up @@ -224,6 +225,18 @@ def load_hf_model(
trust_remote_code=trust_remote_code,
)

# Explicit model_type override: select a registered build variant (e.g.
# "qwen3_transformer_only") rather than the architecture's native type.
# Mutates the freshly-loaded config only; gated on an explicit request so
# normal loading is unaffected.
if model_type is not None and getattr(hf_config, "model_type", None) != model_type:
logger.info(
"Overriding model_type '%s' -> '%s' (explicit request)",
getattr(hf_config, "model_type", None),
model_type,
)
hf_config.model_type = model_type

# [2] Task & Model Class Resolution
if user_script is not None:
resolved_class = _load_class_from_script(user_script, model_class)
Expand Down
16 changes: 15 additions & 1 deletion src/winml/modelkit/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def from_pretrained(
trust_remote_code: bool = False,
shape_config: dict | None = None,
no_compile: bool = False,
model_type: str | None = None,
**kwargs: Any,
) -> WinMLPreTrainedModel:
"""Load appropriate WinML model based on task detection.
Expand Down Expand Up @@ -278,6 +279,10 @@ def from_pretrained(
shape_config: Shape overrides passed to generate_build_config().
Valid keys -- text: sequence_length; vision: height, width;
audio: feature_size, nb_max_frames, audio_sequence_length.
model_type: Explicit model_type override. When provided alongside a
HF model_id, selects a registered build variant (e.g.
``"qwen3_transformer_only"``) instead of the architecture's
native model_type. Leave ``None`` for normal auto-detection.
**kwargs: Additional arguments

Returns:
Expand Down Expand Up @@ -334,6 +339,11 @@ def from_pretrained(
else:
_model_type = None

# Explicit override wins so a variant composite (e.g.
# "qwen3_transformer_only") can be selected over the native type.
if model_type is not None:
_model_type = model_type

if _model_type is not None and (_model_type, task) in COMPOSITE_MODEL_REGISTRY:
from .winml.composite_model import WinMLCompositeModel

Expand Down Expand Up @@ -368,6 +378,7 @@ def from_pretrained(
trust_remote_code=trust_remote_code,
ep=kwargs.get("ep"),
no_compile=no_compile,
model_type=model_type,
)

resolved_task = build_config.loader.task
Expand Down Expand Up @@ -402,7 +413,9 @@ def from_pretrained(
from transformers import AutoConfig

hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=effective_trust)
model_type = getattr(hf_config, "model_type", "unknown")
# Honor an explicit model_type override; otherwise probe from the config.
if model_type is None:
model_type = getattr(hf_config, "model_type", "unknown")
logger.debug("Model type: %s, task: %s", model_type, resolved_task)

# =====================================================================
Expand Down Expand Up @@ -431,6 +444,7 @@ def from_pretrained(
cache_key=cache_key,
ep=resolved_ep,
device=device,
model_type=model_type,
)
onnx_path = result.final_onnx_path

Expand Down
10 changes: 10 additions & 0 deletions src/winml/modelkit/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@
from .qwen import QWEN_CONFIG
from .qwen import QwenGenIOConfig as _QwenGenIOConfig
from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig
from .qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING
from .qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG
from .qwen_transformer_only import (
QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration
)
from .qwen_transformer_only import (
QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, # triggers registration
)
from .roberta import ROBERTA_FAMILY_CONFIG
from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration
from .sam import MODEL_CLASS_MAPPING as _SAM2_CLASS_MAPPING
Expand Down Expand Up @@ -92,6 +100,7 @@
**_MARIAN_CLASS_MAPPING,
**_MU2_CLASS_MAPPING,
**_QWEN_CLASS_MAPPING,
**_QWEN_TO_CLASS_MAPPING,
**_SAM2_CLASS_MAPPING,
**_SEGFORMER_CLASS_MAPPING,
**_SIGLIP_CLASS_MAPPING,
Expand All @@ -115,6 +124,7 @@
"roberta": ROBERTA_FAMILY_CONFIG,
"mu2": MU2_CONFIG,
"qwen3": QWEN_CONFIG,
"qwen3-transformer-only": QWEN_TRANSFORMER_ONLY_CONFIG,
"siglip": SIGLIP_CONFIG,
"siglip-text-model": SIGLIP_CONFIG,
"siglip-vision-model": SIGLIP_CONFIG,
Expand Down
154 changes: 154 additions & 0 deletions src/winml/modelkit/models/hf/qwen3_export_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Custom ONNX export ops + the entry point that reshapes HF's Qwen3 modules
for the transformer-only export.

These reshape the standard HF Qwen3 modules so winml-cli can produce a
QNN-friendly, transformer-only graph:

- ``LpNormalization`` replaces the eager RMSNorm Mul/Pow/ReduceMean chain.
- ``com.microsoft::GroupQueryAttention`` replaces the eager QKV MatMul +
Softmax + KV-update path (with built-in rotary).
- 1x1 ``Conv`` (NHWC<->NCHW) replaces ``nn.Linear`` for QNN-friendly
projections.

Everything here operates only on the standard ``transformers.models.qwen3``
module attributes.
"""

from __future__ import annotations

import torch
import torch.nn as nn
from torch.onnx import symbolic_helper


# =============================================================================
# Custom ONNX symbolic functions
# =============================================================================


class LpNormOnnxExport(torch.autograd.Function):
"""RMSNorm body → ONNX ``LpNormalization`` (p=2 along last dim)."""

@staticmethod
def symbolic(g, input, axis, p): # noqa: D401
output_type = input.type().with_sizes(symbolic_helper._get_tensor_sizes(input))
output = g.op(
"onnx::LpNormalization",
input,
axis_i=int(axis),
p_i=int(p),
)
return output.setType(output_type)

@staticmethod
def forward(ctx, input, axis, p): # noqa: ARG004
# Shape-only tracing placeholder. The real op is emitted by
# ``symbolic`` during ONNX export; ``forward`` exists solely so the
# TorchScript exporter (and Optimum's pre-export dry run) can trace
# output shapes. It returns ``input`` unchanged on purpose and is NOT a
# correct eager RMSNorm — do not call this module for real inference.
return input


class GroupQueryAttentionOnnxExport(torch.autograd.Function):
"""Fused Q/K/V + KV-cache + rotary → ``com.microsoft::GroupQueryAttention``."""

@staticmethod
def symbolic(
g,
query,
key,
value,
past_key,
past_value,
seqlens_k,
total_sequence_length,
cos_cache,
sin_cache,
do_rotary,
kv_num_heads,
num_heads,
):
args = [query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache]
attention_output, present_keys, present_values = g.op(
"com.microsoft::GroupQueryAttention",
*args,
do_rotary_i=int(do_rotary),
kv_num_heads_i=int(kv_num_heads),
num_heads_i=int(num_heads),
outputs=3,
)

query_sizes = symbolic_helper._get_tensor_sizes(query)
attention_output.setType(query.type().with_sizes(query_sizes))
present_keys.setType(past_key.type().with_sizes(symbolic_helper._get_tensor_sizes(past_key)))
present_values.setType(past_value.type().with_sizes(symbolic_helper._get_tensor_sizes(past_value)))
return attention_output, present_keys, present_values

@staticmethod
def forward(
ctx,
query,
key,
value,
past_key,
past_value,
seqlens_k,
total_sequence_length,
cos_cache,
sin_cache,
do_rotary,
kv_num_heads,
num_heads,
): # noqa: ARG004
# Shape-only tracing placeholder. The real op is emitted by
# ``symbolic`` during ONNX export; ``forward`` exists solely so the
# TorchScript exporter (and Optimum's pre-export dry run) can trace
# output shapes. It returns the inputs as stand-in present-KV on
# purpose and is NOT correct attention — do not call this module for
# real inference.
return query, past_key, past_value # placeholder shapes

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning: Stale KV cache in eager mode

GroupQueryAttentionOnnxExport.forward returns (query, past_key, past_value) — the present_keys/present_values are the old un-updated tensors. Eager execution silently produces a KV cache that never advances. A NotImplementedError here would be safer than a silently-wrong placeholder.



# =============================================================================
# 1x1 Conv replacement for nn.Linear
# =============================================================================


class TransposeConv2d1x1Transpose(nn.Module):
"""``nn.Linear`` → 1x1 ``Conv2d`` with NHWC<->NCHW permutes."""

def __init__(
self,
in_channels: int,
out_channels: int,
weight: torch.nn.Parameter,
bias: torch.nn.Parameter | None = None,
) -> None:
super().__init__()
# Linear weight is (out, in); Conv2d weight is (out, in, 1, 1).
self.weight = nn.Parameter(weight.data.view(out_channels, in_channels, 1, 1))
self.bias = bias

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 3, 1, 2) # NHWC -> NCHW
x = torch.nn.functional.conv2d(x, self.weight)
x = x.permute(0, 2, 3, 1) # NCHW -> NHWC
if self.bias is not None:
x = x + self.bias
return x

@classmethod
def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose:
return cls(linear.in_features, linear.out_features, linear.weight, linear.bias)


__all__ = [
"GroupQueryAttentionOnnxExport",
"LpNormOnnxExport",
"TransposeConv2d1x1Transpose",
]
Loading