From 3796b7eabe620c9712fb65941ac5d7f441b6e79f Mon Sep 17 00:00:00 2001 From: spalne Date: Mon, 8 Jun 2026 11:17:08 -0700 Subject: [PATCH 1/6] Add qauntization for transformers for qwen0.6B --- qwen3_quantize.py | 256 ++++++++++++++++++++++++ src/winml/modelkit/onnx/__init__.py | 2 + src/winml/modelkit/onnx/qwen_surgery.py | 186 +++++++++++++++++ test_qwen 2.py | 70 +++++++ 4 files changed, 514 insertions(+) create mode 100644 qwen3_quantize.py create mode 100644 src/winml/modelkit/onnx/qwen_surgery.py create mode 100644 test_qwen 2.py diff --git a/qwen3_quantize.py b/qwen3_quantize.py new file mode 100644 index 000000000..655c65e6a --- /dev/null +++ b/qwen3_quantize.py @@ -0,0 +1,256 @@ +"""Qwen3 transformer-only quantization. + +Must be called after the composite Qwen3 model has been built (e.g. by +``test_qwen 2.py``) so that ``decoder_prefill`` / ``decoder_gen`` ONNX files +exist in the winml cache. + +Pipeline: + + 1. Apply ``make_transformer_only`` surgery to each sub-model, producing + ``*_transformer.onnx`` with ``inputs_embeds`` input and + ``output_hidden_states`` output — embeddings and lm_head are stripped + out (ignored, not quantized). + 2. Quantize those transformer-only files via winml-cli's ``quantize_onnx`` + using a calibration reader that runs ``embed_tokens`` in PyTorch on + real text samples. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Iterator + +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from winml.modelkit.models.winml.composite_model import WinMLCompositeModel +from winml.modelkit.onnx import make_transformer_only +from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx + + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL_ID = "Qwen/Qwen3-0.6B" +DEFAULT_MAX_CACHE = 256 +DEFAULT_PREFILL_SEQ = 64 +DEFAULT_GEN_SEQ = 1 +DEFAULT_NUM_SAMPLES = 16 +DEFAULT_PROMPTS = [ + "Solve: 8 * 7 = ?", + "Translate to French: The weather is nice today.", + "Write a short poem about the ocean.", + "Explain gradient descent in one paragraph.", + "What is the capital of Japan?", + "List three uses of magnesium.", + "Summarize the plot of Hamlet in two sentences.", + "Give a Python one-liner to reverse a string.", +] + + +# --------------------------------------------------------------------------- +# Calibration data reader +# --------------------------------------------------------------------------- + + +class Qwen3TransformerCalibReader: + """Yields calibration feeds for the transformer-only Qwen3 ONNX. + + Runs HF ``embed_tokens`` in PyTorch to produce ``inputs_embeds`` since the + embedding layer was stripped from the ONNX graph. All other inputs + (attention_mask, position_ids, past_{i}_key/value) follow the conventions + used by winml-cli's ``WinMLQwen3Model`` runtime. + """ + + def __init__( + self, + embed_tokens: torch.nn.Module, + config: Any, + token_ids_list: list[torch.Tensor], + *, + seq_len: int, + max_cache_len: int, + ) -> None: + self.embed = embed_tokens + self.cfg = config + self.seq_len = seq_len + self.max_cache_len = max_cache_len + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self._samples = list(self._build_samples(token_ids_list)) + self._iter: Iterator[dict[str, np.ndarray]] | None = None + self.rewind() + + def _build_samples( + self, token_ids_list: list[torch.Tensor] + ) -> Iterator[dict[str, np.ndarray]]: + for ids in token_ids_list: + # Right-truncate / pad to seq_len so we feed the static graph shape. + ids = ids[:, : self.seq_len] + real_len = ids.shape[1] + if real_len < self.seq_len: + pad = torch.zeros( + (1, self.seq_len - real_len), dtype=ids.dtype, device=ids.device + ) + ids = torch.cat([ids, pad], dim=1) + + with torch.no_grad(): + embeds = self.embed(ids).to(torch.float32).cpu().numpy() + + # attention_mask: ones for real prompt positions placed at the + # END of the max_cache buffer (sliding-window cache convention), + # zeros elsewhere. + attn_mask = np.zeros((1, self.max_cache_len), dtype=np.int64) + attn_mask[0, -real_len:] = 1 + + # position_ids: 0..seq_len-1 (clamped for padding). + position_ids = np.arange(self.seq_len, dtype=np.int64)[None, :] + + feed: dict[str, np.ndarray] = { + "inputs_embeds": embeds.astype(np.float32), + "attention_mask": attn_mask, + "position_ids": position_ids, + } + kv_shape = (1, self.num_kv_heads, self.max_cache_len, self.head_dim) + zeros = np.zeros(kv_shape, dtype=np.float32) + for i in range(self.num_layers): + feed[f"past_{i}_key"] = zeros + feed[f"past_{i}_value"] = zeros + yield feed + + def get_next(self) -> dict[str, np.ndarray] | None: + try: + return next(self._iter) if self._iter is not None else None + except StopIteration: + return None + + def rewind(self) -> None: + self._iter = iter(self._samples) + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +def _tokenize_prompts( + tokenizer: Any, prompts: list[str], num_samples: int +) -> list[torch.Tensor]: + # Cycle through prompts up to num_samples; apply chat template like the + # runtime so calibration distribution matches inference inputs. + out: list[torch.Tensor] = [] + for i in range(num_samples): + prompt = prompts[i % len(prompts)] + text = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ids = tokenizer([text], return_tensors="pt").input_ids + out.append(ids) + return out + + +def quantize_built_model( + model: WinMLCompositeModel, + *, + model_id: str = DEFAULT_MODEL_ID, + max_cache_len: int = DEFAULT_MAX_CACHE, + prefill_seq: int = DEFAULT_PREFILL_SEQ, + num_samples: int = DEFAULT_NUM_SAMPLES, + weight_type: str = "uint8", + activation_type: str = "uint16", +) -> dict[str, Path]: + """Run surgery + transformer-only quantization on an already-built composite. + + Reuses the ONNX files produced by ``WinMLCompositeModel.from_pretrained`` + so this can be called after a build step without re-exporting. + + Returns: mapping of sub-model name → quantized ONNX path. + """ + sub_paths: dict[str, Path] = {} + for name, sub in model.sub_models.items(): + final_path = Path(sub._onnx_path) + # ``_model.onnx`` is the *compiled* QNN EPContext blob — surgery needs + # the uncompiled fp16 graph. ``build.hf`` emits ``{cache_key}_optimized.onnx`` + # alongside it in the same artifacts directory. + if final_path.name.endswith("_model.onnx"): + stem = final_path.name[: -len("_model.onnx")] + optimized = final_path.with_name(f"{stem}_optimized.onnx") + if optimized.exists(): + sub_paths[name] = optimized + continue + print( + f"WARNING: {optimized.name} not found next to {final_path.name}; " + "falling back to the compiled model (surgery will likely fail)." + ) + sub_paths[name] = final_path + + for name, p in sub_paths.items(): + print(f" {name}: {p}") + + print("\n=== Loading HF embed_tokens for calibration ===") + hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) + hf_model.eval() + embed_tokens = hf_model.get_input_embeddings() + tokenizer = AutoTokenizer.from_pretrained(model_id) + token_ids_list = _tokenize_prompts(tokenizer, DEFAULT_PROMPTS, num_samples) + + seq_by_sub = { + "decoder_prefill": prefill_seq, + "decoder_gen": DEFAULT_GEN_SEQ, + } + + quant_paths: dict[str, Path] = {} + for sub_name, fused_path in sub_paths.items(): + if sub_name not in seq_by_sub: + print(f"\n--- Skipping unknown sub-model {sub_name!r} ---") + continue + + seq_len = seq_by_sub[sub_name] + transformer_path = fused_path.with_name(fused_path.stem + "_transformer.onnx") + quant_path = transformer_path.with_name( + transformer_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" + ) + + print(f"\n=== Surgery: {sub_name} (seq_len={seq_len}) ===") + print(f" in : {fused_path}") + print(f" out: {transformer_path}") + make_transformer_only(fused_path, transformer_path) + + print(f"\n=== Quantize (transformer only): {sub_name} ===") + print(f" out: {quant_path}") + reader = Qwen3TransformerCalibReader( + embed_tokens, + hf_model.config, + token_ids_list, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) + cfg = WinMLQuantizationConfig( + samples=num_samples, + weight_type=weight_type, # type: ignore[arg-type] + activation_type=activation_type, # type: ignore[arg-type] + calibration_method="minmax", + calibration_data=reader, + ) + result = quantize_onnx(transformer_path, output_path=quant_path, config=cfg) + if not result.success: + print(" FAILED:") + for err in result.errors: + print(f" {err}") + raise SystemExit(1) + print( + f" ok — {result.nodes_quantized} QDQ nodes inserted in " + f"{result.total_time_seconds:.1f}s" + ) + quant_paths[sub_name] = quant_path + + print("\n=== Done ===") + return quant_paths + diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index a3bc49d51..0287a2ff7 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -19,6 +19,7 @@ from .io import InputTensorSpec, OutputTensorSpec, generate_inputs_from_onnx, get_io_config from .metadata import capture_metadata, restore_metadata from .persistence import cleanup_onnx, load_onnx, save_onnx +from .qwen_surgery import make_transformer_only from .shape import infer_onnx_shapes, infer_shapes from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size @@ -41,6 +42,7 @@ "is_compiled_onnx", "is_quantized_onnx", "load_onnx", + "make_transformer_only", "remove_optional_from_type_annotation", "restore_metadata", "save_onnx", diff --git a/src/winml/modelkit/onnx/qwen_surgery.py b/src/winml/modelkit/onnx/qwen_surgery.py new file mode 100644 index 000000000..cd49ee5ec --- /dev/null +++ b/src/winml/modelkit/onnx/qwen_surgery.py @@ -0,0 +1,186 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Ad-hoc ONNX surgery to turn a Qwen3 decoder ONNX into a transformer-only graph. + +Applied as a post-export surgery on the fused decoder ONNX produced by +``WinMLQwen3Model`` (``decoder_prefill.onnx`` / ``decoder_gen.onnx``). + +The resulting transformer-only ONNX has: + - ``input_ids`` graph input replaced by ``inputs_embeds`` (FLOAT, + ``[batch, seq, hidden_size]``) — the upstream embedding Gather is + removed. + - ``logits`` graph output replaced by ``output_hidden_states`` + (FLOAT, ``[batch, seq, hidden_size]``) — the final ``lm_head`` MatMul + is removed. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import onnx +from onnx import TensorProto, helper + +from .persistence import load_onnx, save_onnx + + +logger = logging.getLogger(__name__) + + +def _dim(d: onnx.TensorShapeProto.Dimension) -> int | str: + if d.HasField("dim_value"): + return d.dim_value + return d.dim_param or "?" + + +def make_transformer_only( + model_path: str | Path, + output_path: str | Path, + *, + input_ids_name: str = "input_ids", + logits_name: str = "logits", + inputs_embeds_name: str = "inputs_embeds", + output_hidden_states_name: str = "output_hidden_states", +) -> Path: + """Strip the embedding Gather and the lm_head MatMul from a Qwen3 ONNX. + + Args: + model_path: Path to the fused decoder ONNX (logits output, input_ids input). + output_path: Destination for the transformer-only ONNX. + input_ids_name: Name of the input_ids graph input to drop. + logits_name: Name of the logits graph output to drop. + inputs_embeds_name: Display name for the new embeddings input + (used only for logging; the actual tensor keeps its existing + internal name so downstream nodes need no rewiring). + output_hidden_states_name: Display name for the new hidden-state output. + + Returns: + The output path. + """ + model_path = Path(model_path) + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + model = load_onnx(model_path, load_weights=True, validate=False) + graph = model.graph + init_by_name = {init.name: init for init in graph.initializer} + + # -------------------- Embedding removal -------------------- + embed_idx = next( + (i for i, n in enumerate(graph.node) if input_ids_name in n.input), + None, + ) + if embed_idx is None: + msg = f"No node consumes graph input {input_ids_name!r}" + raise RuntimeError(msg) + + embed_node = graph.node[embed_idx] + embed_out_name = embed_node.output[0] + + embed_weight = None + for ipt in embed_node.input: + init = init_by_name.get(ipt) + if init is not None and len(init.dims) == 2: + embed_weight = init + break + if embed_weight is None: + msg = f"Could not find 2-D embedding weight initializer on node {embed_node.name!r}" + raise RuntimeError(msg) + hidden_size = int(embed_weight.dims[1]) + + ids_input = next(i for i in graph.input if i.name == input_ids_name) + batch_dim = _dim(ids_input.type.tensor_type.shape.dim[0]) + seq_dim = _dim(ids_input.type.tensor_type.shape.dim[1]) + + logger.info( + "Removing embedding node %r (%s) — exposing %r as new input %r [%s, %s, %d]", + embed_node.name, + embed_node.op_type, + embed_out_name, + inputs_embeds_name, + batch_dim, + seq_dim, + hidden_size, + ) + + new_embed_input = helper.make_tensor_value_info( + inputs_embeds_name, + TensorProto.FLOAT, + [batch_dim, seq_dim, hidden_size], + ) + + del graph.node[embed_idx] + graph.input.remove(ids_input) + graph.input.append(new_embed_input) + graph.initializer.remove(embed_weight) + + # Rewire any consumer of the removed embedding output to the new input. + for n in graph.node: + for i, name in enumerate(n.input): + if name == embed_out_name: + n.input[i] = inputs_embeds_name + + # -------------------- lm_head removal -------------------- + lmh_idx = next( + (i for i, n in enumerate(graph.node) if logits_name in n.output), + None, + ) + if lmh_idx is None: + msg = f"No node produces graph output {logits_name!r}" + raise RuntimeError(msg) + + lmh_node = graph.node[lmh_idx] + init_names = {init.name for init in graph.initializer} + hidden_in: str | None = None + weight_in: str | None = None + for ipt in lmh_node.input: + if ipt in init_names: + weight_in = ipt + else: + hidden_in = ipt + if hidden_in is None: + msg = f"lm_head node {lmh_node.name!r} has no non-initializer input ({list(lmh_node.input)})" + raise RuntimeError(msg) + + logger.info( + "Removing lm_head node %r (%s) — exposing %r as new output %r", + lmh_node.name, + lmh_node.op_type, + hidden_in, + output_hidden_states_name, + ) + + logits_output = next(o for o in graph.output if o.name == logits_name) + new_hidden_output = helper.make_tensor_value_info( + output_hidden_states_name, + TensorProto.FLOAT, + [batch_dim, seq_dim, hidden_size], + ) + + del graph.node[lmh_idx] + graph.output.remove(logits_output) + # Put hidden states first so it mirrors the original logits position. + graph.output.insert(0, new_hidden_output) + + # Rename the producer of ``hidden_in`` to emit the new graph output name. + for n in graph.node: + for i, name in enumerate(n.output): + if name == hidden_in: + n.output[i] = output_hidden_states_name + for i, name in enumerate(n.input): + if name == hidden_in: + n.input[i] = output_hidden_states_name + + if weight_in is not None and not any(weight_in in n.input for n in graph.node): + wi = next(init for init in graph.initializer if init.name == weight_in) + graph.initializer.remove(wi) + + save_onnx(model, output_path) + logger.info("Wrote transformer-only ONNX → %s", output_path) + return output_path + + +__all__ = ["make_transformer_only"] diff --git a/test_qwen 2.py b/test_qwen 2.py new file mode 100644 index 000000000..6a52dee72 --- /dev/null +++ b/test_qwen 2.py @@ -0,0 +1,70 @@ +"""E2E test for Qwen3 decoder-only pipeline. + +Uses sub_model_kwargs to set per-component shape_config: + - decoder_prefill: max_cache_len=256, seq_len=64 + - decoder_gen: max_cache_len=256, seq_len=1 + +Set env var ``QUANTIZE=1`` to also run the MOPS-style Step 3: +transformer-only surgery + winml quantize on both sub-models +(embeddings and lm_head are stripped and not quantized). +""" + +import os + +from transformers import AutoTokenizer + +from winml.modelkit.config import WinMLBuildConfig +from winml.modelkit.models.winml.composite_model import WinMLCompositeModel + +model_id = "Qwen/Qwen3-0.6B" + +model = WinMLCompositeModel.from_pretrained( + model_id, + task="text-generation", + # config=WinMLBuildConfig(quant=None, compile=None), + config=WinMLBuildConfig(quant=None), + precision="fp16", + device="npu", + ep="qnn", + force_rebuild=False, + sub_model_kwargs={ + "decoder_prefill": {"shape_config": {"max_cache_len": 256, "seq_len": 64}}, + "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, + }, +) + +# Verify ONNX I/O shapes +for name, sub in model.sub_models.items(): + io = sub.io_config + shapes = dict(zip(io["input_names"], io["input_shapes"])) + print(f"\n=== {name} ===") + for k, v in shapes.items(): + print(f" {k}: {v}") + +tokenizer = AutoTokenizer.from_pretrained(model_id) + +prompt = "8 * 7 = ?" +messages = [{"role": "user", "content": prompt}] +text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, +) +model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + +generated_ids = model.generate(**model_inputs) + +output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() +content = tokenizer.decode(output_ids, skip_special_tokens=True) +print("\nAnswer:", content) + +if os.environ.get("QUANTIZE") == "1": + # Reuse the already-built decoder_prefill/decoder_gen ONNX files: + # surgery (strip embed + lm_head) + transformer-only quantize. + print("\n=== QUANTIZE=1 — running transformer-only quantization ===") + from qwen3_quantize import quantize_built_model + + quantize_built_model( + model, + model_id=model_id, + max_cache_len=256, + prefill_seq=64, + ) From 1ee316c8350d9a904e5e06a51dacb1a7186658d0 Mon Sep 17 00:00:00 2001 From: spalne Date: Tue, 16 Jun 2026 15:07:46 -0700 Subject: [PATCH 2/6] Quantize transformer-only with fused GQA + GSM8k calibration --- ...e.py => qwen3_transformer_only_quantize.py | 152 ++++---- .../modelkit/models/hf/qwen3_export_ops.py | 211 +++++++++++ .../modelkit/models/hf/qwen3_modeling.py | 237 ++++++++++++ .../models/hf/qwen_transformer_only.py | 354 ++++++++++++++++++ src/winml/modelkit/onnx/__init__.py | 2 - src/winml/modelkit/onnx/qwen_surgery.py | 186 --------- test_qwen 2.py | 70 ---- test_qwen.py | 235 ++++++++++++ 8 files changed, 1100 insertions(+), 347 deletions(-) rename qwen3_quantize.py => qwen3_transformer_only_quantize.py (54%) create mode 100644 src/winml/modelkit/models/hf/qwen3_export_ops.py create mode 100644 src/winml/modelkit/models/hf/qwen3_modeling.py create mode 100644 src/winml/modelkit/models/hf/qwen_transformer_only.py delete mode 100644 src/winml/modelkit/onnx/qwen_surgery.py delete mode 100644 test_qwen 2.py create mode 100644 test_qwen.py diff --git a/qwen3_quantize.py b/qwen3_transformer_only_quantize.py similarity index 54% rename from qwen3_quantize.py rename to qwen3_transformer_only_quantize.py index 655c65e6a..8b4efa9b7 100644 --- a/qwen3_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -1,18 +1,15 @@ -"""Qwen3 transformer-only quantization. +"""Transformer-only w8a16 quantization for Qwen3. -Must be called after the composite Qwen3 model has been built (e.g. by -``test_qwen 2.py``) so that ``decoder_prefill`` / ``decoder_gen`` ONNX files -exist in the winml cache. +Targets the transformer-only ONNX produced by +``qwen_transformer_only.install() + test_qwen.py``: -Pipeline: + - **No embedding/lm_head surgery.** The export already excludes both, + so we feed ``WinMLQuantization`` the file directly. + - **Transformer-shaped calibration feeds.** ``input_hidden_states`` (FP32), + ``past_seq_len`` / ``total_seq_len`` (INT32), ``past_keys_{i}`` / + ``past_values_{i}`` (FP16) — names + dtypes match the exported graph. - 1. Apply ``make_transformer_only`` surgery to each sub-model, producing - ``*_transformer.onnx`` with ``inputs_embeds`` input and - ``output_hidden_states`` output — embeddings and lm_head are stripped - out (ignored, not quantized). - 2. Quantize those transformer-only files via winml-cli's ``quantize_onnx`` - using a calibration reader that runs ``embed_tokens`` in PyTorch on - real text samples. +Run via ``test_qwen.py``. """ from __future__ import annotations @@ -26,7 +23,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from winml.modelkit.models.winml.composite_model import WinMLCompositeModel -from winml.modelkit.onnx import make_transformer_only from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx @@ -36,31 +32,28 @@ DEFAULT_MAX_CACHE = 256 DEFAULT_PREFILL_SEQ = 64 DEFAULT_GEN_SEQ = 1 -DEFAULT_NUM_SAMPLES = 16 -DEFAULT_PROMPTS = [ - "Solve: 8 * 7 = ?", - "Translate to French: The weather is nice today.", - "Write a short poem about the ocean.", - "Explain gradient descent in one paragraph.", - "What is the capital of Japan?", - "List three uses of magnesium.", - "Summarize the plot of Hamlet in two sentences.", - "Give a Python one-liner to reverse a string.", -] - - -# --------------------------------------------------------------------------- -# Calibration data reader -# --------------------------------------------------------------------------- - - -class Qwen3TransformerCalibReader: - """Yields calibration feeds for the transformer-only Qwen3 ONNX. - - Runs HF ``embed_tokens`` in PyTorch to produce ``inputs_embeds`` since the - embedding layer was stripped from the ONNX graph. All other inputs - (attention_mask, position_ids, past_{i}_key/value) follow the conventions - used by winml-cli's ``WinMLQwen3Model`` runtime. +DEFAULT_NUM_SAMPLES = 30 +DEFAULT_CALIB_DATASET = "openai/gsm8k" +DEFAULT_CALIB_DATASET_CONFIG = "main" +DEFAULT_CALIB_SPLIT = "train" +DEFAULT_CALIB_SEED = 42 + + +def _load_gsm8k_prompts(num_samples: int) -> list[str]: + """GSM8K train split, shuffled seed=42 for reproducible calibration.""" + from datasets import load_dataset + + ds = load_dataset(DEFAULT_CALIB_DATASET, DEFAULT_CALIB_DATASET_CONFIG) + split = ds[DEFAULT_CALIB_SPLIT].shuffle(seed=DEFAULT_CALIB_SEED) + return [row["question"] for row in split.select(range(num_samples))] + + +class Qwen3TransformerOnlyCalibReader: + """Yields calibration feeds for the transformer-only ONNX. + + Feeds match the exported graph exactly: ``input_hidden_states`` (FP32), + ``past_seq_len`` (INT32 ``[1,1]``), ``total_seq_len`` (INT32 ``[1]``), + and ``past_keys_{i}`` / ``past_values_{i}`` (FP16, full cache buffer). """ def __init__( @@ -73,7 +66,6 @@ def __init__( max_cache_len: int, ) -> None: self.embed = embed_tokens - self.cfg = config self.seq_len = seq_len self.max_cache_len = max_cache_len self.num_layers = config.num_hidden_layers @@ -85,11 +77,8 @@ def __init__( self._iter: Iterator[dict[str, np.ndarray]] | None = None self.rewind() - def _build_samples( - self, token_ids_list: list[torch.Tensor] - ) -> Iterator[dict[str, np.ndarray]]: + def _build_samples(self, token_ids_list: list[torch.Tensor]) -> Iterator[dict[str, np.ndarray]]: for ids in token_ids_list: - # Right-truncate / pad to seq_len so we feed the static graph shape. ids = ids[:, : self.seq_len] real_len = ids.shape[1] if real_len < self.seq_len: @@ -101,25 +90,22 @@ def _build_samples( with torch.no_grad(): embeds = self.embed(ids).to(torch.float32).cpu().numpy() - # attention_mask: ones for real prompt positions placed at the - # END of the max_cache buffer (sliding-window cache convention), - # zeros elsewhere. - attn_mask = np.zeros((1, self.max_cache_len), dtype=np.int64) - attn_mask[0, -real_len:] = 1 - - # position_ids: 0..seq_len-1 (clamped for padding). - position_ids = np.arange(self.seq_len, dtype=np.int64)[None, :] - feed: dict[str, np.ndarray] = { - "inputs_embeds": embeds.astype(np.float32), - "attention_mask": attn_mask, - "position_ids": position_ids, + "input_hidden_states": embeds.astype(np.float32), + # seqlens_k for GQA = (valid context length - 1), i.e. + # ``embeddings.shape[1] - 1``. We pad to seq_len, so the query + # has seq_len valid positions → past_seq_len = seq_len - 1. + # (Using 0 here declares only 1 valid token while feeding a + # seq_len-token query, which makes the GQA prefill kernel read + # out of bounds → native access violation.) + "past_seq_len": np.array([[self.seq_len - 1]], dtype=np.int32), + "total_seq_len": np.array([self.max_cache_len], dtype=np.int32), } kv_shape = (1, self.num_kv_heads, self.max_cache_len, self.head_dim) - zeros = np.zeros(kv_shape, dtype=np.float32) + zeros = np.zeros(kv_shape, dtype=np.float16) for i in range(self.num_layers): - feed[f"past_{i}_key"] = zeros - feed[f"past_{i}_value"] = zeros + feed[f"past_keys_{i}"] = zeros + feed[f"past_values_{i}"] = zeros yield feed def get_next(self) -> dict[str, np.ndarray] | None: @@ -132,16 +118,7 @@ def rewind(self) -> None: self._iter = iter(self._samples) -# --------------------------------------------------------------------------- -# Pipeline -# --------------------------------------------------------------------------- - - -def _tokenize_prompts( - tokenizer: Any, prompts: list[str], num_samples: int -) -> list[torch.Tensor]: - # Cycle through prompts up to num_samples; apply chat template like the - # runtime so calibration distribution matches inference inputs. +def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: out: list[torch.Tensor] = [] for i in range(num_samples): prompt = prompts[i % len(prompts)] @@ -166,19 +143,15 @@ def quantize_built_model( weight_type: str = "uint8", activation_type: str = "uint16", ) -> dict[str, Path]: - """Run surgery + transformer-only quantization on an already-built composite. - - Reuses the ONNX files produced by ``WinMLCompositeModel.from_pretrained`` - so this can be called after a build step without re-exporting. + """Quantize the transformer-only ONNX files in-place. - Returns: mapping of sub-model name → quantized ONNX path. + Returns ``{sub_model_name: quantized_path}``. """ + # Locate the un-compiled ONNX for each sub-model (no surgery — file is + # already transformer-only). sub_paths: dict[str, Path] = {} for name, sub in model.sub_models.items(): final_path = Path(sub._onnx_path) - # ``_model.onnx`` is the *compiled* QNN EPContext blob — surgery needs - # the uncompiled fp16 graph. ``build.hf`` emits ``{cache_key}_optimized.onnx`` - # alongside it in the same artifacts directory. if final_path.name.endswith("_model.onnx"): stem = final_path.name[: -len("_model.onnx")] optimized = final_path.with_name(f"{stem}_optimized.onnx") @@ -187,7 +160,7 @@ def quantize_built_model( continue print( f"WARNING: {optimized.name} not found next to {final_path.name}; " - "falling back to the compiled model (surgery will likely fail)." + "falling back to the compiled model." ) sub_paths[name] = final_path @@ -199,7 +172,14 @@ def quantize_built_model( hf_model.eval() embed_tokens = hf_model.get_input_embeddings() tokenizer = AutoTokenizer.from_pretrained(model_id) - token_ids_list = _tokenize_prompts(tokenizer, DEFAULT_PROMPTS, num_samples) + + print( + f"=== Loading {num_samples} GSM8K calibration prompts " + f"({DEFAULT_CALIB_DATASET}/{DEFAULT_CALIB_DATASET_CONFIG}, " + f"split={DEFAULT_CALIB_SPLIT}, seed={DEFAULT_CALIB_SEED}) ===" + ) + prompts = _load_gsm8k_prompts(num_samples) + token_ids_list = _tokenize_prompts(tokenizer, prompts, num_samples) seq_by_sub = { "decoder_prefill": prefill_seq, @@ -213,19 +193,14 @@ def quantize_built_model( continue seq_len = seq_by_sub[sub_name] - transformer_path = fused_path.with_name(fused_path.stem + "_transformer.onnx") - quant_path = transformer_path.with_name( - transformer_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" + quant_path = fused_path.with_name( + fused_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" ) - print(f"\n=== Surgery: {sub_name} (seq_len={seq_len}) ===") + print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") print(f" in : {fused_path}") - print(f" out: {transformer_path}") - make_transformer_only(fused_path, transformer_path) - - print(f"\n=== Quantize (transformer only): {sub_name} ===") print(f" out: {quant_path}") - reader = Qwen3TransformerCalibReader( + reader = Qwen3TransformerOnlyCalibReader( embed_tokens, hf_model.config, token_ids_list, @@ -239,7 +214,7 @@ def quantize_built_model( calibration_method="minmax", calibration_data=reader, ) - result = quantize_onnx(transformer_path, output_path=quant_path, config=cfg) + result = quantize_onnx(fused_path, output_path=quant_path, config=cfg) if not result.success: print(" FAILED:") for err in result.errors: @@ -253,4 +228,3 @@ def quantize_built_model( print("\n=== Done ===") return quant_paths - diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py new file mode 100644 index 000000000..61d45f0ef --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -0,0 +1,211 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Custom ONNX export ops + the entry point that reshapes HF's Qwen3 modules +for the transformer-only export. + +These reshape the standard HF Qwen3 modules so winml-cli can produce a +QNN-friendly, transformer-only graph: + +- ``LpNormalization`` replaces the eager RMSNorm Mul/Pow/ReduceMean chain. +- ``com.microsoft::GroupQueryAttention`` replaces the eager QKV MatMul + + Softmax + KV-update path (with built-in rotary). +- 1x1 ``Conv`` (NHWC<->NCHW) replaces ``nn.Linear`` for QNN-friendly + projections. + +Everything here operates only on the standard ``transformers.models.qwen3`` +module attributes. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.onnx import symbolic_helper + + +# ============================================================================= +# Custom ONNX symbolic functions +# ============================================================================= + + +class LpNormOnnxExport(torch.autograd.Function): + """RMSNorm body → ONNX ``LpNormalization`` (p=2 along last dim).""" + + @staticmethod + def symbolic(g, input, axis, p): # noqa: D401 + output_type = input.type().with_sizes(symbolic_helper._get_tensor_sizes(input)) + output = g.op( + "onnx::LpNormalization", + input, + axis_i=int(axis), + p_i=int(p), + ) + return output.setType(output_type) + + @staticmethod + def forward(ctx, input, axis, p): # noqa: ARG004 + return input # placeholder — real compute happens in symbolic + + +class GroupQueryAttentionOnnxExport(torch.autograd.Function): + """Fused Q/K/V + KV-cache + rotary → ``com.microsoft::GroupQueryAttention``.""" + + @staticmethod + def symbolic( + g, + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + do_rotary, + kv_num_heads, + num_heads, + ): + args = [query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache] + attention_output, present_keys, present_values = g.op( + "com.microsoft::GroupQueryAttention", + *args, + do_rotary_i=int(do_rotary), + kv_num_heads_i=int(kv_num_heads), + num_heads_i=int(num_heads), + outputs=3, + ) + + query_sizes = symbolic_helper._get_tensor_sizes(query) + attention_output.setType(query.type().with_sizes(query_sizes)) + present_keys.setType(past_key.type().with_sizes(symbolic_helper._get_tensor_sizes(past_key))) + present_values.setType(past_value.type().with_sizes(symbolic_helper._get_tensor_sizes(past_value))) + return attention_output, present_keys, present_values + + @staticmethod + def forward( + ctx, + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + do_rotary, + kv_num_heads, + num_heads, + ): # noqa: ARG004 + return query, past_key, past_value # placeholder shapes + + +# ============================================================================= +# 1x1 Conv replacement for nn.Linear +# ============================================================================= + + +class TransposeConv2d1x1Transpose(nn.Module): + """``nn.Linear`` → 1x1 ``Conv2d`` with NHWC<->NCHW permutes.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + weight: torch.nn.Parameter, + bias: torch.nn.Parameter | None = None, + ) -> None: + super().__init__() + # Linear weight is (out, in); Conv2d weight is (out, in, 1, 1). + self.weight = nn.Parameter(weight.data.view(out_channels, in_channels, 1, 1)) + self.bias = bias + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + x = torch.nn.functional.conv2d(x, self.weight) + x = x.permute(0, 2, 3, 1) # NCHW -> NHWC + if self.bias is not None: + x = x + self.bias + return x + + @classmethod + def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose: + return cls(linear.in_features, linear.out_features, linear.weight, linear.bias) + + +# ============================================================================= +# Apply export prep: bind winml Qwen3 export methods onto a loaded model +# ============================================================================= + + +def apply_transformer_only_export_prep(causal_lm: nn.Module, *, matmul_to_conv: bool = True) -> None: + """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. + + Binds the winml-owned export behaviour from :mod:`.qwen3_modeling` onto each + Qwen3 submodule (runs ``prepare_for_onnx_export`` and rebinds ``forward``). + After this call, ``causal_lm.model(inputs_embeds, past_key_values, + past_seq_len, total_seq_len)`` runs the transformer-only forward. + + Args: + causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. + matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so + QNN sees them as Conv. + """ + from .qwen3_modeling import ( + WinMLQwen3Attention, + WinMLQwen3DecoderLayer, + WinMLQwen3MLP, + WinMLQwen3Model, + WinMLQwen3RMSNorm, + ) + + def _bind(module: nn.Module, owner: type) -> None: + module.forward = owner.forward.__get__(module, type(module)) + + # Identify Qwen3 submodules by their (stock HF) class name so we don't + # depend on importing ``transformers.models.qwen3`` here. + def _is(module: nn.Module, name: str) -> bool: + return type(module).__name__ == name + + # Patch every RMSNorm first (Qwen3RMSNorm appears at top, in q_norm/k_norm, + # in input/post_attention layernorms). + for mod in causal_lm.modules(): + if _is(mod, "Qwen3RMSNorm"): + WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) + _bind(mod, WinMLQwen3RMSNorm) + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Attention"): + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Attention) + elif _is(mod, "Qwen3MLP"): + # MLP forward is unchanged; only the projections are swapped to Conv. + WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + + # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; + # the export forward invokes ``self.rotary_emb`` on the attention module, + # so re-attach a reference from the parent model. + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): + for layer in mod.layers: + layer.self_attn.rotary_emb = mod.rotary_emb + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3DecoderLayer"): + _bind(mod, WinMLQwen3DecoderLayer) + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model"): + WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Model) + + +__all__ = [ + "GroupQueryAttentionOnnxExport", + "LpNormOnnxExport", + "TransposeConv2d1x1Transpose", + "apply_transformer_only_export_prep", +] diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3_modeling.py new file mode 100644 index 000000000..05a70adfe --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3_modeling.py @@ -0,0 +1,237 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""winml-owned Qwen3 model definitions for the transformer-only ONNX export. + +Each class is a plain ``nn.Module`` that carries the export-time behaviour +directly (``prepare_for_onnx_export`` + ``forward``). The export entry point +binds these ``forward`` methods onto the corresponding live Qwen3 submodules, +so the stock eager model is left untouched. + +What each class emits: + +- ``WinMLQwen3RMSNorm`` -> ``onnx::LpNormalization`` body. +- ``WinMLQwen3Attention`` -> ``com.microsoft::GroupQueryAttention`` (built-in + rotary) with optional 1x1 ``Conv`` projections. +- ``WinMLQwen3MLP`` -> 1x1 ``Conv`` projections (NHWC). +- ``WinMLQwen3DecoderLayer`` / ``WinMLQwen3Model`` -> transformer-only forward + that threads the KV cache + seq-len tensors and omits embeddings / lm_head. + +``apply_transformer_only_export_prep`` (in ``qwen3_export_ops``) walks a loaded +``Qwen3ForCausalLM``, calls ``prepare_for_onnx_export`` on each submodule, and +binds the matching ``forward`` from these classes onto it. +""" + +from __future__ import annotations + +from typing import Any + +import torch +import torch.nn as nn + +from .qwen3_export_ops import ( + GroupQueryAttentionOnnxExport, + LpNormOnnxExport, + TransposeConv2d1x1Transpose, +) + + +class WinMLQwen3RMSNorm(nn.Module): + """RMSNorm export variant — ``onnx::LpNormalization`` body.""" + + def prepare_for_onnx_export(self) -> None: + # Pre-multiply the gain into the weight (LpNorm has unit gain). + n = self.weight.numel() + scale = torch.sqrt( + torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype) + ) + if torch.any(self.weight.data != torch.ones_like(self.weight)).item(): + new_w = scale * self.weight + else: + new_w = scale + self.weight = nn.Parameter(new_w) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + out = LpNormOnnxExport.apply(hidden_states, -1, 2) + return self.weight * out + + +class WinMLQwen3MLP(nn.Module): + """MLP export variant — 1x1 Conv projections (forward unchanged).""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + if not matmul_to_conv: + return + self.gate_proj = TransposeConv2d1x1Transpose.from_linear_module(self.gate_proj) + self.up_proj = TransposeConv2d1x1Transpose.from_linear_module(self.up_proj) + self.down_proj = TransposeConv2d1x1Transpose.from_linear_module(self.down_proj) + + +class WinMLQwen3Attention(nn.Module): + """Attention export variant — fused ``GroupQueryAttention`` op.""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + if matmul_to_conv: + self.q_proj = TransposeConv2d1x1Transpose.from_linear_module(self.q_proj) + self.k_proj = TransposeConv2d1x1Transpose.from_linear_module(self.k_proj) + self.v_proj = TransposeConv2d1x1Transpose.from_linear_module(self.v_proj) + self.o_proj = TransposeConv2d1x1Transpose.from_linear_module(self.o_proj) + self._matmul_to_conv = matmul_to_conv # noqa: SLF001 + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, + past_seq_len: torch.Tensor | None = None, + total_seq_len: torch.Tensor | None = None, + **kwargs: Any, # noqa: ARG002 + ) -> tuple[torch.Tensor, None, tuple[torch.Tensor, torch.Tensor]]: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + input_shape = hidden_states.shape[1:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_norm(query_states.view(hidden_shape)) + key_states = self.k_norm(key_states.view(hidden_shape)) + + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads + query_dim = num_heads * self.head_dim + key_dim = num_kv_heads * self.head_dim + query_states = query_states.reshape(1, -1, query_dim) + key_states = key_states.reshape(1, -1, key_dim) + + if self._matmul_to_conv: + value_states = value_states.squeeze(0) + + past_keys, past_values = past_key_value + + # GroupQueryAttention requires Q/K/V/past_K/past_V to share dtype. + # The KV cache is FP16, so cast Q/K/V to the same dtype; otherwise ORT + # type inference rejects the node. + kv_dtype = past_keys.dtype + if query_states.dtype != kv_dtype: + query_states = query_states.to(kv_dtype) + key_states = key_states.to(kv_dtype) + value_states = value_states.to(kv_dtype) + + cos, sin = self.rotary_emb( + value_states, + torch.arange(self.config.max_position_embeddings).unsqueeze(0), + ) + cos = cos.squeeze(0)[:, : cos.shape[-1] // 2] + sin = sin.squeeze(0)[:, : sin.shape[-1] // 2] + if cos.dtype != kv_dtype: + cos = cos.to(kv_dtype) + sin = sin.to(kv_dtype) + + if isinstance(past_seq_len, int): + past_seq_len = torch.tensor(past_seq_len) + past_seq_len = torch.atleast_2d(past_seq_len) + + attention_output, present_keys, present_values = GroupQueryAttentionOnnxExport.apply( + query_states, + key_states, + value_states, + past_keys, + past_values, + past_seq_len, + total_seq_len, + cos, + sin, + 1, # do_rotary + num_kv_heads, + num_heads, + ) + + # Cast back to the residual-stream dtype so the downstream Conv + # (o_proj) sees its expected weight dtype. + if attention_output.dtype != hidden_states.dtype: + attention_output = attention_output.to(hidden_states.dtype) + + if self._matmul_to_conv: + attention_output = attention_output.unsqueeze(0) + + attention_output = self.o_proj(attention_output) + return attention_output, None, (present_keys, present_values) + + +class WinMLQwen3DecoderLayer(nn.Module): + """Decoder-layer export variant — threads KV cache + seq-len kwargs.""" + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, + past_seq_len: torch.Tensor | None = None, + total_seq_len: torch.Tensor | None = None, + use_cache: bool = True, + **kwargs: Any, # noqa: ARG002 + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_out, _, present_kv = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + ) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if use_cache: + outputs += (present_kv,) + return outputs + + +class WinMLQwen3Model(nn.Module): + """Model export variant — transformer-only body (no embeddings / lm_head).""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + self._matmul_to_conv = matmul_to_conv # noqa: SLF001 + + def forward( + self, + inputs_embeds: torch.Tensor, + past_key_values: list[tuple[torch.Tensor, torch.Tensor]], + past_seq_len: torch.Tensor, + total_seq_len: torch.Tensor, + use_cache: bool = True, + ) -> tuple[torch.Tensor, tuple[tuple[torch.Tensor, torch.Tensor], ...]]: + hidden_states = inputs_embeds + if self._matmul_to_conv: + hidden_states = hidden_states.unsqueeze(0) # NHWC for Conv path + + present_kvs: tuple[tuple[torch.Tensor, torch.Tensor], ...] = () + for idx, layer in enumerate(self.layers): + out = layer( + hidden_states, + past_key_value=past_key_values[idx], + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + use_cache=use_cache, + ) + hidden_states = out[0] + if use_cache: + present_kvs += (out[1],) + + hidden_states = self.norm(hidden_states) + if self._matmul_to_conv: + hidden_states = hidden_states.squeeze(0) + return hidden_states, present_kvs + + +__all__ = [ + "WinMLQwen3Attention", + "WinMLQwen3DecoderLayer", + "WinMLQwen3MLP", + "WinMLQwen3Model", + "WinMLQwen3RMSNorm", +] diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py new file mode 100644 index 000000000..8e30b1fb6 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -0,0 +1,354 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Parallel ``qwen3`` build path that produces a transformer-only ONNX. + +Opt-in via ``install()`` — calling it hot-patches the WinML registries so +that the next ``WinMLAutoModel.from_pretrained("Qwen/Qwen3-*", task="text-generation")`` +exports two transformer-only ONNX files (a prefill/context graph and an +iteration/decode graph) with this I/O: + + Inputs : past_keys_{i}, past_values_{i} (FP16, ``[1, kv_heads, max_cache, head_dim]``), + input_hidden_states (FP32, ``[1, seq_len, hidden]``), + past_seq_len (INT32, ``[1, 1]``), total_seq_len (INT32, ``[1]``) + Outputs: output_hidden_states (FP32), present_keys_{i}, present_values_{i} (FP16) + Ops : ``com.microsoft::GroupQueryAttention`` (do_rotary=1), + ``onnx::LpNormalization`` (RMSNorm), 1x1 ``Conv`` projections. + +The original eager-export path in ``qwen.py`` is left intact — only the +qwen3 entries in the registries are replaced. ``install()`` is idempotent. +""" + +from __future__ import annotations + +import logging +from typing import Any, ClassVar + +import torch +import torch.nn as nn +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedConfig +from optimum.utils.input_generators import DummyInputGenerator +from transformers import AutoModelForCausalLM + +from ...config import WinMLBuildConfig +from ...export import register_onnx_overwrite +from ...export.config import WinMLExportConfig +from ..winml import register_specialization +from ..winml.decoder_only import WinMLDecoderOnlyModel +from ..winml.kv_cache import WinMLSlidingWindowCache +from .qwen3_export_ops import apply_transformer_only_export_prep + + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Wrapper module +# ============================================================================= + + +class QwenTransformerOnlyDecoderWrapper(nn.Module): + """Wraps ``Qwen3ForCausalLM`` for transformer-only export. + + The wrapper applies the export prep (LpNorm RMSNorm, GQA op, 1x1 + Conv projections) in ``__init__`` and exposes a positional ``forward`` + whose argument order matches :class:`QwenTransformerOnlyPrefillIOConfig.inputs`. + Only ``self.model.model`` (the inner ``Qwen3Model``) is invoked at + export time — embedding lookup and ``lm_head`` stay out of the graph. + """ + + def __init__(self, model: nn.Module, num_layers: int) -> None: + super().__init__() + self.model = model + self.num_layers = num_layers + self.config = model.config + apply_transformer_only_export_prep(model, matmul_to_conv=True) + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenTransformerOnlyDecoderWrapper: + kwargs.setdefault("torch_dtype", torch.float32) + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **kwargs) + model.config._attn_implementation = "eager" + wrapper = cls(model, model.config.num_hidden_layers) + wrapper.eval() + return wrapper + + def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + return tuple(inputs.values()) + + def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Positional inputs (order matches OnnxConfig.inputs): + + past_keys_0, past_values_0, ..., past_keys_{L-1}, past_values_{L-1}, + input_hidden_states, past_seq_len, total_seq_len + + Returns ``(output_hidden_states, present_keys_0, present_values_0, ...)``. + """ + kv_args = args[: 2 * self.num_layers] + input_hidden_states = args[2 * self.num_layers] + past_seq_len = args[2 * self.num_layers + 1] + total_seq_len = args[2 * self.num_layers + 2] + + past_key_values = [ + (kv_args[2 * i], kv_args[2 * i + 1]) for i in range(self.num_layers) + ] + + hidden_states, present_kvs = self.model.model( + inputs_embeds=input_hidden_states, + past_key_values=past_key_values, + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + use_cache=True, + ) + + out: list[torch.Tensor] = [hidden_states] + for k, v in present_kvs: + out.extend([k, v]) + return tuple(out) + + +# ============================================================================= +# Dummy input generators (transformer-only I/O) +# ============================================================================= + + +class _TransformerOnlyHiddenStateGenerator(DummyInputGenerator): + """Generates ``input_hidden_states`` (FP32, ``[1, seq_len, hidden]``).""" + + SUPPORTED_INPUT_NAMES = ("input_hidden_states",) + + _default_seq_len: ClassVar[int] = 1 + + def __init__( + self, + task: str, + normalized_config: Any, + batch_size: int = 1, + seq_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.hidden_size = normalized_config.hidden_size + self.seq_len = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + if input_name == "input_hidden_states": + return torch.randn(self.batch_size, self.seq_len, self.hidden_size, dtype=torch.float32) + raise ValueError(f"Unknown input: {input_name}") + + +class _TransformerOnlyHiddenStatePrefillGenerator(_TransformerOnlyHiddenStateGenerator): + _default_seq_len = 64 + + +class _TransformerOnlySeqLenGenerator(DummyInputGenerator): + """Generates ``past_seq_len`` (INT32 ``[1,1]``) and ``total_seq_len`` (INT32 ``[1]``).""" + + SUPPORTED_INPUT_NAMES = ("past_seq_len", "total_seq_len") + + def __init__(self, task: str, normalized_config: Any, **kwargs: Any) -> None: # noqa: ARG002 + self.max_cache_len = normalized_config.max_cache_len + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + if input_name == "past_seq_len": + return torch.zeros((1, 1), dtype=torch.int32) + if input_name == "total_seq_len": + return torch.tensor([self.max_cache_len], dtype=torch.int32) + raise ValueError(f"Unknown input: {input_name}") + + +class _TransformerOnlyKvCacheGenerator(DummyInputGenerator): + """Generates ``past_keys_{i}`` / ``past_values_{i}`` (FP16).""" + + SUPPORTED_INPUT_NAMES = () # built dynamically in __init__ + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = 1, + max_cache_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.num_layers: int = normalized_config.num_layers + self.num_heads: int = normalized_config.num_attention_heads # KV heads (NormalizedConfig maps it) + self.head_dim: int = normalized_config.head_dim + self.max_cache_len: int = max_cache_len or normalized_config.max_cache_len + self.SUPPORTED_INPUT_NAMES = tuple( + name for i in range(self.num_layers) for name in (f"past_keys_{i}", f"past_values_{i}") + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + shape = (self.batch_size, self.num_heads, self.max_cache_len, self.head_dim) + return torch.zeros(shape, dtype=torch.float16) + + +# ============================================================================= +# OnnxConfigs — transformer-only I/O layout +# ============================================================================= + + +_QWEN_TRANSFORMER_ONLY_NORMALIZED = NormalizedConfig.with_args( + hidden_size="hidden_size", + num_layers="num_hidden_layers", + num_attention_heads="num_key_value_heads", # KV heads (GQA) + head_dim="head_dim", + max_cache_len="max_position_embeddings", + vocab_size="vocab_size", + allow_new=True, +) + + +def _transformer_only_inputs(num_layers: int, kv_seq_axis: str = "max_seq_len") -> dict[str, dict[int, str]]: + """Input ordering: past KV pairs, then hidden states, then seq lens.""" + result: dict[str, dict[int, str]] = {} + for i in range(num_layers): + result[f"past_keys_{i}"] = {2: kv_seq_axis} + result[f"past_values_{i}"] = {2: kv_seq_axis} + result["input_hidden_states"] = {1: "seq_len"} + result["past_seq_len"] = {} + result["total_seq_len"] = {} + return result + + +def _transformer_only_outputs(num_layers: int, kv_seq_axis: str = "max_seq_len") -> dict[str, dict[int, str]]: + result: dict[str, dict[int, str]] = {"output_hidden_states": {1: "seq_len"}} + for i in range(num_layers): + result[f"present_keys_{i}"] = {2: kv_seq_axis} + result[f"present_values_{i}"] = {2: kv_seq_axis} + return result + + +class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): + """Prefill (seq=64) — transformer-only I/O.""" + + NORMALIZED_CONFIG_CLASS = _QWEN_TRANSFORMER_ONLY_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = ( + _TransformerOnlyKvCacheGenerator, + _TransformerOnlyHiddenStatePrefillGenerator, + _TransformerOnlySeqLenGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_outputs(self._normalized_config.num_layers) + + +class QwenTransformerOnlyGenIOConfig(OnnxConfig): + """Generation (seq=1) — transformer-only I/O.""" + + NORMALIZED_CONFIG_CLASS = _QWEN_TRANSFORMER_ONLY_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = ( + _TransformerOnlyKvCacheGenerator, + _TransformerOnlyHiddenStateGenerator, + _TransformerOnlySeqLenGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_outputs(self._normalized_config.num_layers) + + +# ============================================================================= +# Build config — TorchScript exporter required for the custom autograd ops +# ============================================================================= + + +QWEN_TRANSFORMER_ONLY_CONFIG = WinMLBuildConfig( + export=WinMLExportConfig(dynamo=False, opset_version=18), + # Pure graph (no post-export RMSNorm fusion / matmul-add fusion). + optim=None, +) + + +# ============================================================================= +# Composite inference wrapper (placeholder so the build pipeline finds a +# composite class — generation isn't yet wired for the transformer-only +# I/O signature). +# ============================================================================= + + +class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): + """Composite handle for the transformer-only Qwen3 build (export only). + + ``generate()`` is **not** functional with this build path — the inference + feeds and KV update logic still target the eager I/O signature. Use the + eager :class:`WinMLQwen3Model` for generation; use this class to produce + the transformer-only ONNX for downstream quantization. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "decoder_prefill": "feature-extraction", + "decoder_gen": "text2text-generation", + } + + @classmethod + def get_cache_class(cls) -> type: + return WinMLSlidingWindowCache + + +# ============================================================================= +# install() — hot-patch the registries +# ============================================================================= + + +_INSTALLED = False + + +def install() -> None: + """Replace the qwen3 entries in WinML registries with the transformer-only variants. + + Idempotent. After this call, building any qwen3 model via + :class:`~winml.modelkit.models.winml.composite_model.WinMLCompositeModel` + or :class:`~winml.modelkit.models.auto.WinMLAutoModel` produces + transformer-only ONNX files. + """ + global _INSTALLED + if _INSTALLED: + return + + # 1) Per-model build config + wrapper-class lookup live on the parent + # ``models.hf`` package as module-level dicts; mutating them is the + # documented hook for adding/overriding a model_type. + from .. import hf as _hf_pkg # noqa: PLC0415 + + _hf_pkg.MODEL_BUILD_CONFIGS["qwen3"] = QWEN_TRANSFORMER_ONLY_CONFIG + _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "feature-extraction")] = QwenTransformerOnlyDecoderWrapper + _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "text2text-generation")] = QwenTransformerOnlyDecoderWrapper + + # 2) Optimum OnnxConfig (overwrites existing registration). + register_onnx_overwrite("qwen3", "feature-extraction", library_name="transformers")(QwenTransformerOnlyPrefillIOConfig) + register_onnx_overwrite("qwen3", "text2text-generation", library_name="transformers")(QwenTransformerOnlyGenIOConfig) + + # 3) Inference specialization (still GenericTask — wrapper returns raw KV). + register_specialization("qwen3", "feature-extraction", "WinMLModelForGenericTask") + register_specialization("qwen3", "text2text-generation", "WinMLModelForGenericTask") + + # 4) Composite registry — swap to the transformer-only handle. + from ..winml.composite_model import COMPOSITE_MODEL_REGISTRY + + COMPOSITE_MODEL_REGISTRY[("qwen3", "text-generation")] = WinMLQwen3TransformerOnlyModel + + _INSTALLED = True + logger.info("qwen_transformer_only: transformer-only export path installed for qwen3.") + + +__all__ = [ + "QWEN_TRANSFORMER_ONLY_CONFIG", + "QwenTransformerOnlyDecoderWrapper", + "QwenTransformerOnlyGenIOConfig", + "QwenTransformerOnlyPrefillIOConfig", + "WinMLQwen3TransformerOnlyModel", + "install", +] diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index 0287a2ff7..a3bc49d51 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -19,7 +19,6 @@ from .io import InputTensorSpec, OutputTensorSpec, generate_inputs_from_onnx, get_io_config from .metadata import capture_metadata, restore_metadata from .persistence import cleanup_onnx, load_onnx, save_onnx -from .qwen_surgery import make_transformer_only from .shape import infer_onnx_shapes, infer_shapes from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size @@ -42,7 +41,6 @@ "is_compiled_onnx", "is_quantized_onnx", "load_onnx", - "make_transformer_only", "remove_optional_from_type_annotation", "restore_metadata", "save_onnx", diff --git a/src/winml/modelkit/onnx/qwen_surgery.py b/src/winml/modelkit/onnx/qwen_surgery.py deleted file mode 100644 index cd49ee5ec..000000000 --- a/src/winml/modelkit/onnx/qwen_surgery.py +++ /dev/null @@ -1,186 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Ad-hoc ONNX surgery to turn a Qwen3 decoder ONNX into a transformer-only graph. - -Applied as a post-export surgery on the fused decoder ONNX produced by -``WinMLQwen3Model`` (``decoder_prefill.onnx`` / ``decoder_gen.onnx``). - -The resulting transformer-only ONNX has: - - ``input_ids`` graph input replaced by ``inputs_embeds`` (FLOAT, - ``[batch, seq, hidden_size]``) — the upstream embedding Gather is - removed. - - ``logits`` graph output replaced by ``output_hidden_states`` - (FLOAT, ``[batch, seq, hidden_size]``) — the final ``lm_head`` MatMul - is removed. -""" - -from __future__ import annotations - -import logging -from pathlib import Path - -import onnx -from onnx import TensorProto, helper - -from .persistence import load_onnx, save_onnx - - -logger = logging.getLogger(__name__) - - -def _dim(d: onnx.TensorShapeProto.Dimension) -> int | str: - if d.HasField("dim_value"): - return d.dim_value - return d.dim_param or "?" - - -def make_transformer_only( - model_path: str | Path, - output_path: str | Path, - *, - input_ids_name: str = "input_ids", - logits_name: str = "logits", - inputs_embeds_name: str = "inputs_embeds", - output_hidden_states_name: str = "output_hidden_states", -) -> Path: - """Strip the embedding Gather and the lm_head MatMul from a Qwen3 ONNX. - - Args: - model_path: Path to the fused decoder ONNX (logits output, input_ids input). - output_path: Destination for the transformer-only ONNX. - input_ids_name: Name of the input_ids graph input to drop. - logits_name: Name of the logits graph output to drop. - inputs_embeds_name: Display name for the new embeddings input - (used only for logging; the actual tensor keeps its existing - internal name so downstream nodes need no rewiring). - output_hidden_states_name: Display name for the new hidden-state output. - - Returns: - The output path. - """ - model_path = Path(model_path) - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - model = load_onnx(model_path, load_weights=True, validate=False) - graph = model.graph - init_by_name = {init.name: init for init in graph.initializer} - - # -------------------- Embedding removal -------------------- - embed_idx = next( - (i for i, n in enumerate(graph.node) if input_ids_name in n.input), - None, - ) - if embed_idx is None: - msg = f"No node consumes graph input {input_ids_name!r}" - raise RuntimeError(msg) - - embed_node = graph.node[embed_idx] - embed_out_name = embed_node.output[0] - - embed_weight = None - for ipt in embed_node.input: - init = init_by_name.get(ipt) - if init is not None and len(init.dims) == 2: - embed_weight = init - break - if embed_weight is None: - msg = f"Could not find 2-D embedding weight initializer on node {embed_node.name!r}" - raise RuntimeError(msg) - hidden_size = int(embed_weight.dims[1]) - - ids_input = next(i for i in graph.input if i.name == input_ids_name) - batch_dim = _dim(ids_input.type.tensor_type.shape.dim[0]) - seq_dim = _dim(ids_input.type.tensor_type.shape.dim[1]) - - logger.info( - "Removing embedding node %r (%s) — exposing %r as new input %r [%s, %s, %d]", - embed_node.name, - embed_node.op_type, - embed_out_name, - inputs_embeds_name, - batch_dim, - seq_dim, - hidden_size, - ) - - new_embed_input = helper.make_tensor_value_info( - inputs_embeds_name, - TensorProto.FLOAT, - [batch_dim, seq_dim, hidden_size], - ) - - del graph.node[embed_idx] - graph.input.remove(ids_input) - graph.input.append(new_embed_input) - graph.initializer.remove(embed_weight) - - # Rewire any consumer of the removed embedding output to the new input. - for n in graph.node: - for i, name in enumerate(n.input): - if name == embed_out_name: - n.input[i] = inputs_embeds_name - - # -------------------- lm_head removal -------------------- - lmh_idx = next( - (i for i, n in enumerate(graph.node) if logits_name in n.output), - None, - ) - if lmh_idx is None: - msg = f"No node produces graph output {logits_name!r}" - raise RuntimeError(msg) - - lmh_node = graph.node[lmh_idx] - init_names = {init.name for init in graph.initializer} - hidden_in: str | None = None - weight_in: str | None = None - for ipt in lmh_node.input: - if ipt in init_names: - weight_in = ipt - else: - hidden_in = ipt - if hidden_in is None: - msg = f"lm_head node {lmh_node.name!r} has no non-initializer input ({list(lmh_node.input)})" - raise RuntimeError(msg) - - logger.info( - "Removing lm_head node %r (%s) — exposing %r as new output %r", - lmh_node.name, - lmh_node.op_type, - hidden_in, - output_hidden_states_name, - ) - - logits_output = next(o for o in graph.output if o.name == logits_name) - new_hidden_output = helper.make_tensor_value_info( - output_hidden_states_name, - TensorProto.FLOAT, - [batch_dim, seq_dim, hidden_size], - ) - - del graph.node[lmh_idx] - graph.output.remove(logits_output) - # Put hidden states first so it mirrors the original logits position. - graph.output.insert(0, new_hidden_output) - - # Rename the producer of ``hidden_in`` to emit the new graph output name. - for n in graph.node: - for i, name in enumerate(n.output): - if name == hidden_in: - n.output[i] = output_hidden_states_name - for i, name in enumerate(n.input): - if name == hidden_in: - n.input[i] = output_hidden_states_name - - if weight_in is not None and not any(weight_in in n.input for n in graph.node): - wi = next(init for init in graph.initializer if init.name == weight_in) - graph.initializer.remove(wi) - - save_onnx(model, output_path) - logger.info("Wrote transformer-only ONNX → %s", output_path) - return output_path - - -__all__ = ["make_transformer_only"] diff --git a/test_qwen 2.py b/test_qwen 2.py deleted file mode 100644 index 6a52dee72..000000000 --- a/test_qwen 2.py +++ /dev/null @@ -1,70 +0,0 @@ -"""E2E test for Qwen3 decoder-only pipeline. - -Uses sub_model_kwargs to set per-component shape_config: - - decoder_prefill: max_cache_len=256, seq_len=64 - - decoder_gen: max_cache_len=256, seq_len=1 - -Set env var ``QUANTIZE=1`` to also run the MOPS-style Step 3: -transformer-only surgery + winml quantize on both sub-models -(embeddings and lm_head are stripped and not quantized). -""" - -import os - -from transformers import AutoTokenizer - -from winml.modelkit.config import WinMLBuildConfig -from winml.modelkit.models.winml.composite_model import WinMLCompositeModel - -model_id = "Qwen/Qwen3-0.6B" - -model = WinMLCompositeModel.from_pretrained( - model_id, - task="text-generation", - # config=WinMLBuildConfig(quant=None, compile=None), - config=WinMLBuildConfig(quant=None), - precision="fp16", - device="npu", - ep="qnn", - force_rebuild=False, - sub_model_kwargs={ - "decoder_prefill": {"shape_config": {"max_cache_len": 256, "seq_len": 64}}, - "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, - }, -) - -# Verify ONNX I/O shapes -for name, sub in model.sub_models.items(): - io = sub.io_config - shapes = dict(zip(io["input_names"], io["input_shapes"])) - print(f"\n=== {name} ===") - for k, v in shapes.items(): - print(f" {k}: {v}") - -tokenizer = AutoTokenizer.from_pretrained(model_id) - -prompt = "8 * 7 = ?" -messages = [{"role": "user", "content": prompt}] -text = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, -) -model_inputs = tokenizer([text], return_tensors="pt").to(model.device) - -generated_ids = model.generate(**model_inputs) - -output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() -content = tokenizer.decode(output_ids, skip_special_tokens=True) -print("\nAnswer:", content) - -if os.environ.get("QUANTIZE") == "1": - # Reuse the already-built decoder_prefill/decoder_gen ONNX files: - # surgery (strip embed + lm_head) + transformer-only quantize. - print("\n=== QUANTIZE=1 — running transformer-only quantization ===") - from qwen3_quantize import quantize_built_model - - quantize_built_model( - model, - model_id=model_id, - max_cache_len=256, - prefill_seq=64, - ) diff --git a/test_qwen.py b/test_qwen.py new file mode 100644 index 000000000..f958c2932 --- /dev/null +++ b/test_qwen.py @@ -0,0 +1,235 @@ +"""E2E test for the transformer-only Qwen3 export path. + +Produces two transformer-only ONNX files whose I/O matches +``qwen3_gqa_fp16_ctx.onnx`` / ``qwen3_gqa_fp16_iter.onnx``: + + decoder_prefill: input_hidden_states [1, 64, 1024] → output_hidden_states + KV + decoder_gen : input_hidden_states [1, 1, 1024] → output_hidden_states + KV + +with FP16 past/present KV named ``past_keys_{i}`` / ``past_values_{i}``, +``com.microsoft::GroupQueryAttention``, ``LpNormalization``, and 1x1 Conv +projections. + +Important: ``install()`` MUST be called before importing the composite model +machinery so the registry hot-patches take effect. + +Generation (``model.generate(...)``) is NOT supported by this build path — +the inference feeds in ``WinMLDecoderOnlyModel`` still target the eager +I/O signature. Use the eager ``WinMLQwen3Model`` build path for end-to-end +generation. + +Run:: + + python test_qwen_transformer_only.py + +This builds each transformer sub-model and then runs the w8a16 +quantization on the exported transformer ONNX files (no surgery needed — +files are already transformer-only). +""" + +import os +import sys +import pathlib +import subprocess + +# Put the in-repo `src/` ahead of site-packages so `import winml` always +# resolves to the editable source tree — no manual copy-to-venv needed. +_repo_root = pathlib.Path(__file__).resolve().parent +sys.path.insert(0, str(_repo_root / "src")) +sys.path.insert(0, str(_repo_root)) + +model_id = "Qwen/Qwen3-0.6B" +MAX_CACHE = 256 + +# component name -> (HF task, seq_len, artifact prefix). Order matters +# (prefill first). The prefix is how the built npu_ctx file is named so the +# parent can verify success by artifact appearance (the build segfaults on +# native QNN/ORT teardown AFTER writing the file, so exit codes are unreliable). +SUB_MODELS = { + "decoder_prefill": ("feature-extraction", 64, "feat_"), + "decoder_gen": ("text2text-generation", 1, "txt2txt_"), +} + +ARTIFACTS_DIR = ( + pathlib.Path.home() / ".cache" / "winml" / "artifacts" / model_id.replace("/", "_") +) + + +def _latest_ctx_mtime(prefix: str) -> float: + """Newest mtime of a ``{prefix}*_optimized_npu_ctx.onnx`` artifact, or 0.""" + files = list(ARTIFACTS_DIR.glob(f"{prefix}*_optimized_npu_ctx.onnx")) + return max((f.stat().st_mtime for f in files), default=0.0) + + +def _build_one(task: str, seq_len: int) -> None: + """Build a SINGLE transformer sub-model in this (fresh) process. + + Invoked as a subprocess by ``main()`` so each sub-model exports in a + clean interpreter — building both in one process leaves PyTorch/ORT + state from the first build that corrupts/kills the second. + """ + from winml.modelkit.models.hf.qwen_transformer_only import install as install_qwen_transformer_only + + install_qwen_transformer_only() + + from winml.modelkit.config import WinMLBuildConfig + from winml.modelkit.models.auto import WinMLAutoModel + + WinMLAutoModel.from_pretrained( + model_id, + task=task, + config=WinMLBuildConfig(quant=None, compile=None), + precision="fp16", + device="npu", + ep="qnn", + force_rebuild=True, + shape_config={"max_cache_len": MAX_CACHE, "seq_len": seq_len}, + ) + # The QNN/ORT teardown segfaults (0xC0000005) on interpreter shutdown + # AFTER the artifact is fully written. Skip the buggy cleanup with a hard + # exit so the parent sees a clean exit code 0. + print(f"BUILD COMPLETE: task={task} seq_len={seq_len}", flush=True) + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + + +def _find_optimized(prefix: str) -> pathlib.Path: + """Locate the cached transformer-only ``{prefix}*_optimized.onnx`` file.""" + cands = [ + p for p in ARTIFACTS_DIR.glob(f"{prefix}*_optimized.onnx") + if not p.name.endswith("_optimized_npu_ctx.onnx") + ] + if not cands: + raise FileNotFoundError( + f"No {prefix}*_optimized.onnx in {ARTIFACTS_DIR} — build the sub-model first." + ) + return max(cands, key=lambda p: p.stat().st_mtime) + + +class _SubShim: + """Minimal stand-in exposing the ``_onnx_path`` quant needs.""" + + def __init__(self, onnx_path: pathlib.Path): + self._onnx_path = str(onnx_path) + + +class _ModelShim: + """Minimal stand-in exposing ``sub_models`` for ``quantize_built_model``.""" + + def __init__(self, sub_models: dict): + self.sub_models = sub_models + + +def _run_quant() -> None: + """Quantize the cached transformer ONNX files (no composite/QNN load). + + Runs as its own subprocess so any ORT teardown crash can't poison the + parent. Builds a shim ``model`` whose ``sub_models[name]._onnx_path`` + point straight at the cached ``*_optimized.onnx`` files. + """ + # Dump a native C-stack if the calibration InferenceSession segfaults + # (otherwise the crash is silent — no Python traceback). + import faulthandler + faulthandler.enable() + + from qwen3_transformer_only_quantize import quantize_built_model + + sub_models = { + name: _SubShim(_find_optimized(prefix)) + for name, (_task, _seq, prefix) in SUB_MODELS.items() + } + model = _ModelShim(sub_models) + print("=== Running transformer w8a16 quantization ===", flush=True) + for name, sub in sub_models.items(): + print(f" {name}: {sub._onnx_path}", flush=True) + + try: + quantize_built_model( + model, + model_id=model_id, + max_cache_len=MAX_CACHE, + prefill_seq=64, + ) + except BaseException: + import traceback + print("QUANT FAILED with exception:", flush=True) + traceback.print_exc() + sys.stdout.flush() + sys.stderr.flush() + raise + print("QUANT COMPLETE", flush=True) + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + + +def main() -> None: + # 1) Build each sub-model in its OWN subprocess (fresh state each time). + # Judge success by whether a FRESH npu_ctx artifact appeared, NOT by the + # subprocess exit code: the native QNN/ORT layer segfaults (0xC0000005) + # on teardown AFTER the artifact is fully written to disk. + import time as _time + + for name, (task, seq_len, prefix) in SUB_MODELS.items(): + print(f"\n########## BUILD {name} (task={task}, seq_len={seq_len}) ##########", flush=True) + before = _latest_ctx_mtime(prefix) + start = _time.time() + rc = subprocess.run( + [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), + "--build-sub", task, str(seq_len)], + cwd=str(_repo_root), + ).returncode + + after = _latest_ctx_mtime(prefix) + if after > before and after >= start - 1: + status = "OK" if rc == 0 else f"OK (ignored teardown exit {rc})" + print(f"########## {name} {status}: fresh {prefix}*_optimized_npu_ctx.onnx ##########", flush=True) + else: + raise SystemExit( + f"Sub-model build failed for {name} (exit {rc}) — " + f"no fresh {prefix}*_optimized_npu_ctx.onnx in {ARTIFACTS_DIR}" + ) + + # 2) Report the built transformer-only ONNX files (no composite/QNN load — + # that creates QNN EP sessions that segfault the parent on teardown). + for name, (_task, _seq, prefix) in SUB_MODELS.items(): + print(f"\n=== {name} ===") + print(f" optimized : {_find_optimized(prefix).name}") + ctx = sorted(ARTIFACTS_DIR.glob(f"{prefix}*_optimized_npu_ctx.onnx")) + if ctx: + print(f" npu_ctx : {ctx[-1].name}") + + # 3) Quantization — run in its OWN subprocess for the same teardown-crash + # isolation. Judge by whether quant files appeared. + print("\n########## QUANTIZE ##########", flush=True) + before = max( + (p.stat().st_mtime for p in ARTIFACTS_DIR.glob("*quant.onnx")), + default=0.0, + ) + qstart = _time.time() + rc = subprocess.run( + [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--quant"], + cwd=str(_repo_root), + ).returncode + after_files = list(ARTIFACTS_DIR.glob("*quant.onnx")) + after = max((p.stat().st_mtime for p in after_files), default=0.0) + if after > before and after >= qstart - 1: + status = "OK" if rc == 0 else f"OK (ignored teardown exit {rc})" + print(f"########## QUANTIZE {status} ##########", flush=True) + for p in sorted(after_files, key=lambda x: x.stat().st_mtime)[-len(SUB_MODELS):]: + print(f" {p.name}", flush=True) + else: + raise SystemExit( + f"Quantization failed (exit {rc}) — no fresh *quant.onnx in {ARTIFACTS_DIR}" + ) + + +if __name__ == "__main__": + if len(sys.argv) >= 4 and sys.argv[1] == "--build-sub": + _build_one(sys.argv[2], int(sys.argv[3])) + elif len(sys.argv) >= 2 and sys.argv[1] == "--quant": + _run_quant() + else: + main() + From 78815fd97d7458edc745185d65c8aefdb7b82d67 Mon Sep 17 00:00:00 2001 From: spalne Date: Mon, 22 Jun 2026 10:44:41 -0700 Subject: [PATCH 3/6] Fix Qwen3 w8a16 quant: symmetric int8 weights + exclude GQA from QDQ --- qwen3_transformer_only_quantize.py | 33 ++++++++++++++++++++++++++- src/winml/modelkit/quant/config.py | 9 ++++++++ src/winml/modelkit/quant/quantizer.py | 19 ++++++++++++--- 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 8b4efa9b7..3ae895ae2 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -133,6 +133,23 @@ def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> l return out +def _gqa_node_names(onnx_path: Path) -> list[str]: + """Return the names of every GroupQueryAttention node in ``onnx_path``. + + These nodes are excluded from quantization so ORT leaves both their + inputs and output in float (``... -> Cast -> GQA -> Cast``), matching + the reference graph which keeps attention entirely out of QDQ. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + return [ + n.name + for n in model.graph.node + if n.op_type == "GroupQueryAttention" and n.name + ] + + def quantize_built_model( model: WinMLCompositeModel, *, @@ -140,7 +157,7 @@ def quantize_built_model( max_cache_len: int = DEFAULT_MAX_CACHE, prefill_seq: int = DEFAULT_PREFILL_SEQ, num_samples: int = DEFAULT_NUM_SAMPLES, - weight_type: str = "uint8", + weight_type: str = "int8", activation_type: str = "uint16", ) -> dict[str, Path]: """Quantize the transformer-only ONNX files in-place. @@ -200,6 +217,11 @@ def quantize_built_model( print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") print(f" in : {fused_path}") print(f" out: {quant_path}") + gqa_nodes = _gqa_node_names(fused_path) + print( + f" excluding {len(gqa_nodes)} GroupQueryAttention nodes from " + "quantization (inputs + output stay float, Cast -> GQA -> Cast)" + ) reader = Qwen3TransformerOnlyCalibReader( embed_tokens, hf_model.config, @@ -213,6 +235,15 @@ def quantize_built_model( activation_type=activation_type, # type: ignore[arg-type] calibration_method="minmax", calibration_data=reader, + # w8a16: symmetric int8 weights (zp=0) + asymmetric uint16 + # activations, matching the reference quantization. + weight_symmetric=True, + activation_symmetric=False, + # ORT treats GroupQueryAttention as quantizable and wraps both its + # inputs and output in QDQ. The reference keeps attention entirely + # in float (Cast -> GQA -> Cast), so exclude the GQA nodes from + # quantization so no QDQ is inserted around them. + nodes_to_exclude=gqa_nodes, ) result = quantize_onnx(fused_path, output_path=quant_path, config=cfg) if not result.success: diff --git a/src/winml/modelkit/quant/config.py b/src/winml/modelkit/quant/config.py index b9709cc0e..6132be599 100644 --- a/src/winml/modelkit/quant/config.py +++ b/src/winml/modelkit/quant/config.py @@ -68,6 +68,11 @@ class WinMLQuantizationConfig: # Quantization options per_channel: bool = False symmetric: bool = False + # Optional per-target symmetry overrides. When None, fall back to + # ``symmetric``. Lets w8a16 use symmetric weights (int8, zp=0) together + # with asymmetric activations (uint16). + weight_symmetric: bool | None = None + activation_symmetric: bool | None = None # Output settings save_calibration: bool = False @@ -98,6 +103,8 @@ def to_dict(self) -> dict: "activation_type": self.activation_type, "per_channel": self.per_channel, "symmetric": self.symmetric, + "weight_symmetric": self.weight_symmetric, + "activation_symmetric": self.activation_symmetric, "save_calibration": self.save_calibration, "distribution": self.distribution, "seed": self.seed, @@ -139,6 +146,8 @@ def from_dict(cls, data: dict) -> WinMLQuantizationConfig: activation_type=data.get("activation_type", "uint8"), per_channel=data.get("per_channel", False), symmetric=data.get("symmetric", False), + weight_symmetric=data.get("weight_symmetric"), + activation_symmetric=data.get("activation_symmetric"), save_calibration=data.get("save_calibration", False), distribution=data.get("distribution", "uniform"), seed=data.get("seed"), diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index c562599de..e5fd30df3 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -132,10 +132,23 @@ def quantize_onnx( activation_type = activation_type_map[config.activation_type] calibrate_method = calibration_method_map[config.calibration_method] - # Build extra options + # Build extra options. Weight/activation symmetry can be controlled + # independently (e.g. w8a16 = symmetric int8 weights + asymmetric + # uint16 activations); fall back to the single ``symmetric`` flag when + # the per-target override is unset. + weight_symmetric = ( + config.weight_symmetric + if config.weight_symmetric is not None + else config.symmetric + ) + activation_symmetric = ( + config.activation_symmetric + if config.activation_symmetric is not None + else config.symmetric + ) extra_options = { - "ActivationSymmetric": config.symmetric, - "WeightSymmetric": config.symmetric, + "ActivationSymmetric": activation_symmetric, + "WeightSymmetric": weight_symmetric, } # Step 1: Generate QDQ config From 95d45d9ad9a9baab2576e2b88d7c3999a60ca3f4 Mon Sep 17 00:00:00 2001 From: spalne Date: Mon, 22 Jun 2026 14:51:23 -0700 Subject: [PATCH 4/6] refactor(qwen): register transformer-only path as a declarative model_type variant --- qwen3_transformer_only_quantize.py | 7 +- src/winml/modelkit/build/hf.py | 4 + src/winml/modelkit/loader/config.py | 13 +++ src/winml/modelkit/loader/hf.py | 13 +++ src/winml/modelkit/models/auto.py | 16 +++- src/winml/modelkit/models/hf/__init__.py | 10 ++ .../models/hf/qwen_transformer_only.py | 93 +++++++++---------- test_qwen.py | 8 +- 8 files changed, 105 insertions(+), 59 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 3ae895ae2..0b90c8bd0 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -1,7 +1,7 @@ """Transformer-only w8a16 quantization for Qwen3. -Targets the transformer-only ONNX produced by -``qwen_transformer_only.install() + test_qwen.py``: +Targets the transformer-only ONNX produced by the +``qwen3_transformer_only`` build variant (see ``test_qwen.py``): - **No embedding/lm_head surgery.** The export already excludes both, so we feed ``WinMLQuantization`` the file directly. @@ -24,6 +24,7 @@ from winml.modelkit.models.winml.composite_model import WinMLCompositeModel from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx +from winml.modelkit.quant.config import CalibrationDataReader logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def _load_gsm8k_prompts(num_samples: int) -> list[str]: return [row["question"] for row in split.select(range(num_samples))] -class Qwen3TransformerOnlyCalibReader: +class Qwen3TransformerOnlyCalibReader(CalibrationDataReader): """Yields calibration feeds for the transformer-only ONNX. Feeds match the exported graph exactly: ``input_hidden_states`` (FP32), diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 26356a6eb..dc2661afa 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -91,6 +91,7 @@ def build_hf_model( cache_key: str | None = None, ep: EPNameOrAlias | None = None, device: str | None = None, + model_type: str | None = None, **kwargs: Any, ) -> BuildResult: """Build an ONNX model from a HuggingFace model architecture. @@ -208,6 +209,7 @@ def _name(base: str) -> str: model_id, trust_remote_code, random_init=random_init, + model_type=model_type, ) # ========================================================================= @@ -436,6 +438,7 @@ def _load_model( trust_remote_code: bool, random_init: bool = False, hf_config: Any | None = None, + model_type: str | None = None, ) -> Any: """Load PyTorch model — pretrained or random weights. @@ -511,6 +514,7 @@ def _load_model( task=task, trust_remote_code=effective_trust, hf_config=hf_config, + model_type=model_type, ) return pytorch_model diff --git a/src/winml/modelkit/loader/config.py b/src/winml/modelkit/loader/config.py index cb6cb9af1..b533c1636 100644 --- a/src/winml/modelkit/loader/config.py +++ b/src/winml/modelkit/loader/config.py @@ -218,6 +218,19 @@ def resolve_loader_config( f"attribute. Cannot proceed with config generation." ) + # Explicit model_type override alongside a model_id: honor the requested + # type so downstream class / build-config / export resolution selects the + # variant (e.g. "qwen3_transformer_only") rather than the architecture's + # native type. The model_type-only path above (AutoConfig.for_model) is + # unaffected because it only runs when model_id is None. + if model_id is not None and model_type is not None and hf_config.model_type != model_type: + logger.info( + "Overriding resolved model_type '%s' -> '%s' (explicit request)", + hf_config.model_type, + model_type, + ) + hf_config.model_type = model_type + # 2. Infer task (depends on: model_type param or hf_config.architectures) if task is None and model_type is not None: supported = get_supported_tasks(model_type, library_name=library_name) diff --git a/src/winml/modelkit/loader/hf.py b/src/winml/modelkit/loader/hf.py index 5a90b5828..7c40c5fee 100644 --- a/src/winml/modelkit/loader/hf.py +++ b/src/winml/modelkit/loader/hf.py @@ -150,6 +150,7 @@ def load_hf_model( user_script: str | None = None, trust_remote_code: bool = False, hf_config: PretrainedConfig | None = None, + model_type: str | None = None, ) -> tuple[nn.Module, PretrainedConfig, str]: """Load, detect task, and prepare HuggingFace model. @@ -224,6 +225,18 @@ def load_hf_model( trust_remote_code=trust_remote_code, ) + # Explicit model_type override: select a registered build variant (e.g. + # "qwen3_transformer_only") rather than the architecture's native type. + # Mutates the freshly-loaded config only; gated on an explicit request so + # normal loading is unaffected. + if model_type is not None and getattr(hf_config, "model_type", None) != model_type: + logger.info( + "Overriding model_type '%s' -> '%s' (explicit request)", + getattr(hf_config, "model_type", None), + model_type, + ) + hf_config.model_type = model_type + # [2] Task & Model Class Resolution if user_script is not None: resolved_class = _load_class_from_script(user_script, model_class) diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 78f944b36..4767b97db 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -247,6 +247,7 @@ def from_pretrained( trust_remote_code: bool = False, shape_config: dict | None = None, no_compile: bool = False, + model_type: str | None = None, **kwargs: Any, ) -> WinMLPreTrainedModel: """Load appropriate WinML model based on task detection. @@ -278,6 +279,10 @@ def from_pretrained( shape_config: Shape overrides passed to generate_build_config(). Valid keys -- text: sequence_length; vision: height, width; audio: feature_size, nb_max_frames, audio_sequence_length. + model_type: Explicit model_type override. When provided alongside a + HF model_id, selects a registered build variant (e.g. + ``"qwen3_transformer_only"``) instead of the architecture's + native model_type. Leave ``None`` for normal auto-detection. **kwargs: Additional arguments Returns: @@ -334,6 +339,11 @@ def from_pretrained( else: _model_type = None + # Explicit override wins so a variant composite (e.g. + # "qwen3_transformer_only") can be selected over the native type. + if model_type is not None: + _model_type = model_type + if _model_type is not None and (_model_type, task) in COMPOSITE_MODEL_REGISTRY: from .winml.composite_model import WinMLCompositeModel @@ -368,6 +378,7 @@ def from_pretrained( trust_remote_code=trust_remote_code, ep=kwargs.get("ep"), no_compile=no_compile, + model_type=model_type, ) resolved_task = build_config.loader.task @@ -402,7 +413,9 @@ def from_pretrained( from transformers import AutoConfig hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=effective_trust) - model_type = getattr(hf_config, "model_type", "unknown") + # Honor an explicit model_type override; otherwise probe from the config. + if model_type is None: + model_type = getattr(hf_config, "model_type", "unknown") logger.debug("Model type: %s, task: %s", model_type, resolved_task) # ===================================================================== @@ -431,6 +444,7 @@ def from_pretrained( cache_key=cache_key, ep=resolved_ep, device=device, + model_type=model_type, ) onnx_path = result.final_onnx_path diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index c6f4c9520..0d2e538a3 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -56,6 +56,14 @@ from .qwen import QWEN_CONFIG from .qwen import QwenGenIOConfig as _QwenGenIOConfig from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig +from .qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING +from .qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG +from .qwen_transformer_only import ( + QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration +) +from .qwen_transformer_only import ( + QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, # triggers registration +) from .roberta import ROBERTA_FAMILY_CONFIG from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration from .sam import MODEL_CLASS_MAPPING as _SAM2_CLASS_MAPPING @@ -92,6 +100,7 @@ **_MARIAN_CLASS_MAPPING, **_MU2_CLASS_MAPPING, **_QWEN_CLASS_MAPPING, + **_QWEN_TO_CLASS_MAPPING, **_SAM2_CLASS_MAPPING, **_SEGFORMER_CLASS_MAPPING, **_SIGLIP_CLASS_MAPPING, @@ -115,6 +124,7 @@ "roberta": ROBERTA_FAMILY_CONFIG, "mu2": MU2_CONFIG, "qwen3": QWEN_CONFIG, + "qwen3-transformer-only": QWEN_TRANSFORMER_ONLY_CONFIG, "siglip": SIGLIP_CONFIG, "siglip-text-model": SIGLIP_CONFIG, "siglip-vision-model": SIGLIP_CONFIG, diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index 8e30b1fb6..614267df4 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -2,12 +2,17 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Parallel ``qwen3`` build path that produces a transformer-only ONNX. +"""Transformer-only ``qwen3`` build variant, registered as a distinct model_type. -Opt-in via ``install()`` — calling it hot-patches the WinML registries so -that the next ``WinMLAutoModel.from_pretrained("Qwen/Qwen3-*", task="text-generation")`` -exports two transformer-only ONNX files (a prefill/context graph and an -iteration/decode graph) with this I/O: +This module registers a self-contained build path under the model_type +``"qwen3_transformer_only"`` (distinct from the stock ``"qwen3"`` path in +``qwen.py``). Selecting it is explicit — pass ``model_type="qwen3_transformer_only"`` +to ``WinMLAutoModel.from_pretrained(...)`` (or the underlying +``generate_hf_build_config(...)``). Both paths coexist; neither overrides the +other, and there is no import-ordering requirement. + +The variant exports two transformer-only ONNX files (a prefill/context graph +and an iteration/decode graph) with this I/O: Inputs : past_keys_{i}, past_values_{i} (FP16, ``[1, kv_heads, max_cache, head_dim]``), input_hidden_states (FP32, ``[1, seq_len, hidden]``), @@ -16,8 +21,9 @@ Ops : ``com.microsoft::GroupQueryAttention`` (do_rotary=1), ``onnx::LpNormalization`` (RMSNorm), 1x1 ``Conv`` projections. -The original eager-export path in ``qwen.py`` is left intact — only the -qwen3 entries in the registries are replaced. ``install()`` is idempotent. +Registration happens at import time via decorators and module-level mappings, +mirroring ``qwen.py``. The aggregating ``models.hf`` package imports this +module so the entries land in ``MODEL_CLASS_MAPPING`` / ``MODEL_BUILD_CONFIGS``. """ from __future__ import annotations @@ -36,6 +42,7 @@ from ...export import register_onnx_overwrite from ...export.config import WinMLExportConfig from ..winml import register_specialization +from ..winml.composite_model import register_composite_model from ..winml.decoder_only import WinMLDecoderOnlyModel from ..winml.kv_cache import WinMLSlidingWindowCache from .qwen3_export_ops import apply_transformer_only_export_prep @@ -43,6 +50,13 @@ logger = logging.getLogger(__name__) +# Distinct model_type for this variant. The underscore form is what the +# exporter sees on ``model.config.model_type`` and what Optimum's TasksManager +# and ``register_specialization`` are keyed on; the hyphenated form is used for +# the ``MODEL_CLASS_MAPPING`` / ``MODEL_BUILD_CONFIGS`` lookups (those callers +# normalize ``_`` -> ``-``). +TRANSFORMER_ONLY_MODEL_TYPE = "qwen3_transformer_only" + # ============================================================================= # Wrapper module @@ -65,6 +79,10 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.num_layers = num_layers self.config = model.config apply_transformer_only_export_prep(model, matmul_to_conv=True) + # Tag the config so the exporter resolves this variant's OnnxConfig + # (registered under ``TRANSFORMER_ONLY_MODEL_TYPE``) rather than the + # stock qwen3 one. Mirrors the CLIP/zoedepth sub-model precedent. + self.config.model_type = TRANSFORMER_ONLY_MODEL_TYPE @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenTransformerOnlyDecoderWrapper: @@ -222,6 +240,9 @@ def _transformer_only_outputs(num_layers: int, kv_seq_axis: str = "max_seq_len") return result +@register_onnx_overwrite( + TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", library_name="transformers" +) class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): """Prefill (seq=64) — transformer-only I/O.""" @@ -241,6 +262,9 @@ def outputs(self) -> dict[str, dict[int, str]]: return _transformer_only_outputs(self._normalized_config.num_layers) +@register_onnx_overwrite( + TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", library_name="transformers" +) class QwenTransformerOnlyGenIOConfig(OnnxConfig): """Generation (seq=1) — transformer-only I/O.""" @@ -279,6 +303,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # ============================================================================= +@register_composite_model(TRANSFORMER_ONLY_MODEL_TYPE, "text-generation") class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): """Composite handle for the transformer-only Qwen3 build (export only). @@ -299,56 +324,28 @@ def get_cache_class(cls) -> type: # ============================================================================= -# install() — hot-patch the registries +# Declarative registration (import-time) # ============================================================================= +# Wrapper-class lookup keyed by (model_type, task). Keys use the hyphenated +# model_type form because ``_get_custom_model_class`` normalizes ``_`` -> ``-`` +# before lookup. Merged into the aggregate mapping by ``models.hf.__init__``. +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("qwen3-transformer-only", "feature-extraction"): QwenTransformerOnlyDecoderWrapper, + ("qwen3-transformer-only", "text2text-generation"): QwenTransformerOnlyDecoderWrapper, +} -_INSTALLED = False - - -def install() -> None: - """Replace the qwen3 entries in WinML registries with the transformer-only variants. - - Idempotent. After this call, building any qwen3 model via - :class:`~winml.modelkit.models.winml.composite_model.WinMLCompositeModel` - or :class:`~winml.modelkit.models.auto.WinMLAutoModel` produces - transformer-only ONNX files. - """ - global _INSTALLED - if _INSTALLED: - return - - # 1) Per-model build config + wrapper-class lookup live on the parent - # ``models.hf`` package as module-level dicts; mutating them is the - # documented hook for adding/overriding a model_type. - from .. import hf as _hf_pkg # noqa: PLC0415 - - _hf_pkg.MODEL_BUILD_CONFIGS["qwen3"] = QWEN_TRANSFORMER_ONLY_CONFIG - _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "feature-extraction")] = QwenTransformerOnlyDecoderWrapper - _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "text2text-generation")] = QwenTransformerOnlyDecoderWrapper - - # 2) Optimum OnnxConfig (overwrites existing registration). - register_onnx_overwrite("qwen3", "feature-extraction", library_name="transformers")(QwenTransformerOnlyPrefillIOConfig) - register_onnx_overwrite("qwen3", "text2text-generation", library_name="transformers")(QwenTransformerOnlyGenIOConfig) - - # 3) Inference specialization (still GenericTask — wrapper returns raw KV). - register_specialization("qwen3", "feature-extraction", "WinMLModelForGenericTask") - register_specialization("qwen3", "text2text-generation", "WinMLModelForGenericTask") - - # 4) Composite registry — swap to the transformer-only handle. - from ..winml.composite_model import COMPOSITE_MODEL_REGISTRY - - COMPOSITE_MODEL_REGISTRY[("qwen3", "text-generation")] = WinMLQwen3TransformerOnlyModel - - _INSTALLED = True - logger.info("qwen_transformer_only: transformer-only export path installed for qwen3.") +# Inference specialization (GenericTask — the wrapper returns raw hidden states / KV). +register_specialization(TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", "WinMLModelForGenericTask") +register_specialization(TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", "WinMLModelForGenericTask") __all__ = [ + "MODEL_CLASS_MAPPING", "QWEN_TRANSFORMER_ONLY_CONFIG", + "TRANSFORMER_ONLY_MODEL_TYPE", "QwenTransformerOnlyDecoderWrapper", "QwenTransformerOnlyGenIOConfig", "QwenTransformerOnlyPrefillIOConfig", "WinMLQwen3TransformerOnlyModel", - "install", ] diff --git a/test_qwen.py b/test_qwen.py index f958c2932..da23f4481 100644 --- a/test_qwen.py +++ b/test_qwen.py @@ -10,9 +10,6 @@ ``com.microsoft::GroupQueryAttention``, ``LpNormalization``, and 1x1 Conv projections. -Important: ``install()`` MUST be called before importing the composite model -machinery so the registry hot-patches take effect. - Generation (``model.generate(...)``) is NOT supported by this build path — the inference feeds in ``WinMLDecoderOnlyModel`` still target the eager I/O signature. Use the eager ``WinMLQwen3Model`` build path for end-to-end @@ -68,16 +65,13 @@ def _build_one(task: str, seq_len: int) -> None: clean interpreter — building both in one process leaves PyTorch/ORT state from the first build that corrupts/kills the second. """ - from winml.modelkit.models.hf.qwen_transformer_only import install as install_qwen_transformer_only - - install_qwen_transformer_only() - from winml.modelkit.config import WinMLBuildConfig from winml.modelkit.models.auto import WinMLAutoModel WinMLAutoModel.from_pretrained( model_id, task=task, + model_type="qwen3_transformer_only", config=WinMLBuildConfig(quant=None, compile=None), precision="fp16", device="npu", From 9cecb03913d9d19a10e1d2c27934414bcdd5ee3f Mon Sep 17 00:00:00 2001 From: spalne Date: Tue, 23 Jun 2026 11:52:52 -0700 Subject: [PATCH 5/6] fix(qwen): calibrate transformer-only decode model on real trajectory --- qwen3_transformer_only_quantize.py | 170 +++++++++++++++++++++++++++-- 1 file changed, 163 insertions(+), 7 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 0b90c8bd0..81bcb780f 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -34,6 +34,7 @@ DEFAULT_PREFILL_SEQ = 64 DEFAULT_GEN_SEQ = 1 DEFAULT_NUM_SAMPLES = 30 +DEFAULT_DECODE_STEPS = 16 DEFAULT_CALIB_DATASET = "openai/gsm8k" DEFAULT_CALIB_DATASET_CONFIG = "main" DEFAULT_CALIB_SPLIT = "train" @@ -119,6 +120,140 @@ def rewind(self) -> None: self._iter = iter(self._samples) +def _layer_kv(past: Any, i: int) -> tuple[torch.Tensor, torch.Tensor]: + """Extract layer ``i``'s (key, value) from an HF cache, version-agnostic. + + Handles the legacy tuple-of-tuples cache, the older ``DynamicCache`` + (``.key_cache`` / ``.value_cache``), and the newer per-layer + ``DynamicCache`` (``.layers[i].keys`` / ``.values``). + """ + if hasattr(past, "key_cache") and hasattr(past, "value_cache"): + return past.key_cache[i], past.value_cache[i] + if hasattr(past, "layers"): + layer = past.layers[i] + return layer.keys, layer.values + return past[i][0], past[i][1] + + +class Qwen3DecodeTrajectoryCalibReader(CalibrationDataReader): + """Calibrate the iter (seq_len=1) model on REAL decode-step states. + + The naive reader feeds one (repeated) token with a zeroed KV cache and + ``past_seq_len=0`` — a state the model never sees during generation. With + MinMax calibration this collapses the observed activation ranges far below + the real decode distribution, so the resulting w8a16 model degenerates + (e.g. ``Paris -> Parisammedammed...``). + + Instead, drive the HF FP reference model through a real prefill + decode + trajectory and capture, at each decode step, the exact feed the iter ONNX + would receive: the embedding of the *actually generated* token, the real + accumulated KV cache (copied into the fixed ``[1, kv_heads, max_cache, + head_dim]`` FP16 buffer), and the growing ``past_seq_len``. Token + selection uses the HF model's true logits, so the trajectory matches + greedy generation. The QDQ scheme is unchanged — only the calibration + statistics become representative. + """ + + def __init__( + self, + hf_model: torch.nn.Module, + embed_tokens: torch.nn.Module, + config: Any, + token_ids_list: list[torch.Tensor], + *, + prefill_seq: int, + max_cache_len: int, + decode_steps: int = 16, + ) -> None: + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = max_cache_len + self._samples = list( + self._build_samples( + hf_model, + embed_tokens, + token_ids_list, + prefill_seq=prefill_seq, + decode_steps=decode_steps, + ) + ) + self._iter: Iterator[dict[str, np.ndarray]] | None = None + self.rewind() + + def _kv_buffers(self, past: Any, cur_len: int) -> dict[str, np.ndarray]: + """Copy the ``cur_len`` valid KV positions into fixed FP16 buffers.""" + feed: dict[str, np.ndarray] = {} + for i in range(self.num_layers): + k, v = _layer_kv(past, i) + kbuf = np.zeros( + (1, self.num_kv_heads, self.max_cache_len, self.head_dim), np.float16 + ) + vbuf = np.zeros_like(kbuf) + kbuf[:, :, :cur_len, :] = k[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + vbuf[:, :, :cur_len, :] = v[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + feed[f"past_keys_{i}"] = kbuf + feed[f"past_values_{i}"] = vbuf + return feed + + def _build_samples( + self, + hf_model: torch.nn.Module, + embed_tokens: torch.nn.Module, + token_ids_list: list[torch.Tensor], + *, + prefill_seq: int, + decode_steps: int, + ) -> Iterator[dict[str, np.ndarray]]: + for ids in token_ids_list: + ids = ids[:, :prefill_seq] # real prompt prefix (no pad-token KV) + cur_len = ids.shape[1] + + # FP prefill once to seed a realistic KV cache + first token. + with torch.no_grad(): + out = hf_model(input_ids=ids, use_cache=True) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + + for _ in range(decode_steps): + if cur_len >= self.max_cache_len: + break + # The feed the iter model sees for THIS token: embedding of the + # token to process, the KV of the `cur_len` preceding tokens, + # and seqlens_k = (cur_len + 1) - 1 = cur_len. + with torch.no_grad(): + emb = embed_tokens(torch.tensor([[tok]])).to(torch.float32).cpu().numpy() + feed: dict[str, np.ndarray] = { + "input_hidden_states": emb.astype(np.float32), + "past_seq_len": np.array([[cur_len]], dtype=np.int32), + "total_seq_len": np.array([self.max_cache_len], dtype=np.int32), + } + feed.update(self._kv_buffers(past, cur_len)) + yield feed + + # Advance the reference model one real decode step. + with torch.no_grad(): + out = hf_model( + input_ids=torch.tensor([[tok]]), + past_key_values=past, + use_cache=True, + ) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + cur_len += 1 + + def get_next(self) -> dict[str, np.ndarray] | None: + try: + return next(self._iter) if self._iter is not None else None + except StopIteration: + return None + + def rewind(self) -> None: + self._iter = iter(self._samples) + + def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: out: list[torch.Tensor] = [] for i in range(num_samples): @@ -160,6 +295,7 @@ def quantize_built_model( num_samples: int = DEFAULT_NUM_SAMPLES, weight_type: str = "int8", activation_type: str = "uint16", + decode_steps: int = DEFAULT_DECODE_STEPS, ) -> dict[str, Path]: """Quantize the transformer-only ONNX files in-place. @@ -223,13 +359,33 @@ def quantize_built_model( f" excluding {len(gqa_nodes)} GroupQueryAttention nodes from " "quantization (inputs + output stay float, Cast -> GQA -> Cast)" ) - reader = Qwen3TransformerOnlyCalibReader( - embed_tokens, - hf_model.config, - token_ids_list, - seq_len=seq_len, - max_cache_len=max_cache_len, - ) + if sub_name == "decoder_gen": + # The iter model only sees mid-generation states. Calibrate it on a + # real prefill+decode trajectory (true tokens, accumulated KV, + # growing past_seq_len) instead of one token + zeroed KV, which + # would under-range the MinMax activation scales and collapse + # generation. + print( + f" calibrating on decode trajectory ({decode_steps} steps/prompt, " + f"prefill_seq={prefill_seq})" + ) + reader: CalibrationDataReader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed_tokens, + hf_model.config, + token_ids_list, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, + ) + else: + reader = Qwen3TransformerOnlyCalibReader( + embed_tokens, + hf_model.config, + token_ids_list, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) cfg = WinMLQuantizationConfig( samples=num_samples, weight_type=weight_type, # type: ignore[arg-type] From 08f05d7c399dafe2f60dfccf3cbd3348355ab721 Mon Sep 17 00:00:00 2001 From: spalne Date: Tue, 23 Jun 2026 13:18:25 -0700 Subject: [PATCH 6/6] Fixed small bugs --- qwen3_transformer_only_quantize.py | 18 +++- .../modelkit/models/hf/qwen3_export_ops.py | 81 +++----------- .../modelkit/models/hf/qwen3_modeling.py | 101 ++++++++++++++++-- .../models/hf/qwen_transformer_only.py | 2 +- test_qwen.py | 8 +- 5 files changed, 132 insertions(+), 78 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 81bcb780f..559620973 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +import gc from pathlib import Path from typing import Any, Iterator @@ -40,6 +41,16 @@ DEFAULT_CALIB_SPLIT = "train" DEFAULT_CALIB_SEED = 42 +# Map an ONNX quantization dtype to the bit-width suffix used in artifact +# filenames (e.g. int8 -> "8", uint16 -> "16"), instead of brittle string +# slicing of the dtype name. +_DTYPE_BITS = { + "int8": "8", + "uint8": "8", + "int16": "16", + "uint16": "16", +} + def _load_gsm8k_prompts(num_samples: int) -> list[str]: """GSM8K train split, shuffled seed=42 for reproducible calibration.""" @@ -348,7 +359,8 @@ def quantize_built_model( seq_len = seq_by_sub[sub_name] quant_path = fused_path.with_name( - fused_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" + fused_path.stem + + f"_w{_DTYPE_BITS[weight_type]}a{_DTYPE_BITS[activation_type]}.quant.onnx" ) print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") @@ -414,5 +426,9 @@ def quantize_built_model( ) quant_paths[sub_name] = quant_path + # Free the FP reference model now that calibration is done. + del hf_model, embed_tokens + gc.collect() + print("\n=== Done ===") return quant_paths diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py index 61d45f0ef..5fd3edb68 100644 --- a/src/winml/modelkit/models/hf/qwen3_export_ops.py +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -46,7 +46,12 @@ def symbolic(g, input, axis, p): # noqa: D401 @staticmethod def forward(ctx, input, axis, p): # noqa: ARG004 - return input # placeholder — real compute happens in symbolic + # Shape-only tracing placeholder. The real op is emitted by + # ``symbolic`` during ONNX export; ``forward`` exists solely so the + # TorchScript exporter (and Optimum's pre-export dry run) can trace + # output shapes. It returns ``input`` unchanged on purpose and is NOT a + # correct eager RMSNorm — do not call this module for real inference. + return input class GroupQueryAttentionOnnxExport(torch.autograd.Function): @@ -100,6 +105,12 @@ def forward( kv_num_heads, num_heads, ): # noqa: ARG004 + # Shape-only tracing placeholder. The real op is emitted by + # ``symbolic`` during ONNX export; ``forward`` exists solely so the + # TorchScript exporter (and Optimum's pre-export dry run) can trace + # output shapes. It returns the inputs as stand-in present-KV on + # purpose and is NOT correct attention — do not call this module for + # real inference. return query, past_key, past_value # placeholder shapes @@ -136,76 +147,8 @@ def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose: return cls(linear.in_features, linear.out_features, linear.weight, linear.bias) -# ============================================================================= -# Apply export prep: bind winml Qwen3 export methods onto a loaded model -# ============================================================================= - - -def apply_transformer_only_export_prep(causal_lm: nn.Module, *, matmul_to_conv: bool = True) -> None: - """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. - - Binds the winml-owned export behaviour from :mod:`.qwen3_modeling` onto each - Qwen3 submodule (runs ``prepare_for_onnx_export`` and rebinds ``forward``). - After this call, ``causal_lm.model(inputs_embeds, past_key_values, - past_seq_len, total_seq_len)`` runs the transformer-only forward. - - Args: - causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. - matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so - QNN sees them as Conv. - """ - from .qwen3_modeling import ( - WinMLQwen3Attention, - WinMLQwen3DecoderLayer, - WinMLQwen3MLP, - WinMLQwen3Model, - WinMLQwen3RMSNorm, - ) - - def _bind(module: nn.Module, owner: type) -> None: - module.forward = owner.forward.__get__(module, type(module)) - - # Identify Qwen3 submodules by their (stock HF) class name so we don't - # depend on importing ``transformers.models.qwen3`` here. - def _is(module: nn.Module, name: str) -> bool: - return type(module).__name__ == name - - # Patch every RMSNorm first (Qwen3RMSNorm appears at top, in q_norm/k_norm, - # in input/post_attention layernorms). - for mod in causal_lm.modules(): - if _is(mod, "Qwen3RMSNorm"): - WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) - _bind(mod, WinMLQwen3RMSNorm) - - for mod in causal_lm.modules(): - if _is(mod, "Qwen3Attention"): - WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) - _bind(mod, WinMLQwen3Attention) - elif _is(mod, "Qwen3MLP"): - # MLP forward is unchanged; only the projections are swapped to Conv. - WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) - - # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; - # the export forward invokes ``self.rotary_emb`` on the attention module, - # so re-attach a reference from the parent model. - for mod in causal_lm.modules(): - if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): - for layer in mod.layers: - layer.self_attn.rotary_emb = mod.rotary_emb - - for mod in causal_lm.modules(): - if _is(mod, "Qwen3DecoderLayer"): - _bind(mod, WinMLQwen3DecoderLayer) - - for mod in causal_lm.modules(): - if _is(mod, "Qwen3Model"): - WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) - _bind(mod, WinMLQwen3Model) - - __all__ = [ "GroupQueryAttentionOnnxExport", "LpNormOnnxExport", "TransposeConv2d1x1Transpose", - "apply_transformer_only_export_prep", ] diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3_modeling.py index 05a70adfe..d3c538df5 100644 --- a/src/winml/modelkit/models/hf/qwen3_modeling.py +++ b/src/winml/modelkit/models/hf/qwen3_modeling.py @@ -18,7 +18,7 @@ - ``WinMLQwen3DecoderLayer`` / ``WinMLQwen3Model`` -> transformer-only forward that threads the KV cache + seq-len tensors and omits embeddings / lm_head. -``apply_transformer_only_export_prep`` (in ``qwen3_export_ops``) walks a loaded +``apply_transformer_only_export_prep`` (defined below) walks a loaded ``Qwen3ForCausalLM``, calls ``prepare_for_onnx_export`` on each submodule, and binds the matching ``forward`` from these classes onto it. """ @@ -42,15 +42,14 @@ class WinMLQwen3RMSNorm(nn.Module): def prepare_for_onnx_export(self) -> None: # Pre-multiply the gain into the weight (LpNorm has unit gain). + # ``scale`` is shape ``[1]`` and broadcasts over ``self.weight`` + # (shape ``[hidden_size]``), so the result keeps the per-channel + # shape even when the original weights are all ones. n = self.weight.numel() scale = torch.sqrt( torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype) ) - if torch.any(self.weight.data != torch.ones_like(self.weight)).item(): - new_w = scale * self.weight - else: - new_w = scale - self.weight = nn.Parameter(new_w) + self.weight = nn.Parameter(scale * self.weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: out = LpNormOnnxExport.apply(hidden_states, -1, 2) @@ -228,10 +227,100 @@ def forward( return hidden_states, present_kvs +# ============================================================================= +# Apply export prep: bind winml Qwen3 export methods onto a loaded model +# ============================================================================= + + +def apply_transformer_only_export_prep( + causal_lm: nn.Module, *, matmul_to_conv: bool = True +) -> None: + """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. + + Binds the winml-owned export behaviour (the ``WinMLQwen3*`` classes in this + module) onto each Qwen3 submodule (runs ``prepare_for_onnx_export`` and + rebinds ``forward``). After this call, ``causal_lm.model(inputs_embeds, + past_key_values, past_seq_len, total_seq_len)`` runs the transformer-only + forward. + + Args: + causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. + matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so + QNN sees them as Conv. + + Raises: + RuntimeError: If any expected Qwen3 submodule class is not found, + meaning the loaded model does not match the expected topology + (e.g. the stock HF class names changed). + """ + + def _bind(module: nn.Module, owner: type) -> None: + module.forward = owner.forward.__get__(module, type(module)) + + # Identify Qwen3 submodules by their (stock HF) class name so we don't + # depend on importing ``transformers.models.qwen3`` here. + def _is(module: nn.Module, name: str) -> bool: + return type(module).__name__ == name + + patched = { + "Qwen3RMSNorm": 0, + "Qwen3Attention": 0, + "Qwen3MLP": 0, + "Qwen3DecoderLayer": 0, + "Qwen3Model": 0, + } + + # Patch every RMSNorm first (Qwen3RMSNorm appears at top, in q_norm/k_norm, + # in input/post_attention layernorms). + for mod in causal_lm.modules(): + if _is(mod, "Qwen3RMSNorm"): + WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) + _bind(mod, WinMLQwen3RMSNorm) + patched["Qwen3RMSNorm"] += 1 + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Attention"): + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Attention) + patched["Qwen3Attention"] += 1 + elif _is(mod, "Qwen3MLP"): + # MLP forward is unchanged; only the projections are swapped to Conv. + WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + patched["Qwen3MLP"] += 1 + + # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; + # the export forward invokes ``self.rotary_emb`` on the attention module, + # so re-attach a reference from the parent model. + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): + for layer in mod.layers: + layer.self_attn.rotary_emb = mod.rotary_emb + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3DecoderLayer"): + _bind(mod, WinMLQwen3DecoderLayer) + patched["Qwen3DecoderLayer"] += 1 + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model"): + WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Model) + patched["Qwen3Model"] += 1 + + missing = [name for name, count in patched.items() if count == 0] + if missing: + raise RuntimeError( + "transformer-only export prep found no " + f"{missing} submodule(s) to patch; the loaded model does not match " + "the expected Qwen3 topology (stock HF class names may have changed)." + ) + + __all__ = [ "WinMLQwen3Attention", "WinMLQwen3DecoderLayer", "WinMLQwen3MLP", "WinMLQwen3Model", "WinMLQwen3RMSNorm", + "apply_transformer_only_export_prep", ] diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index 614267df4..6ac9d0852 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -45,7 +45,7 @@ from ..winml.composite_model import register_composite_model from ..winml.decoder_only import WinMLDecoderOnlyModel from ..winml.kv_cache import WinMLSlidingWindowCache -from .qwen3_export_ops import apply_transformer_only_export_prep +from .qwen3_modeling import apply_transformer_only_export_prep logger = logging.getLogger(__name__) diff --git a/test_qwen.py b/test_qwen.py index da23f4481..14cf4656d 100644 --- a/test_qwen.py +++ b/test_qwen.py @@ -17,7 +17,7 @@ Run:: - python test_qwen_transformer_only.py + python test_qwen.py This builds each transformer sub-model and then runs the w8a16 quantization on the exported transformer ONNX files (no surgery needed — @@ -85,6 +85,8 @@ def _build_one(task: str, seq_len: int) -> None: print(f"BUILD COMPLETE: task={task} seq_len={seq_len}", flush=True) sys.stdout.flush() sys.stderr.flush() + # TODO(winml-cli#836): replace the hard exit once the native QNN/ORT + # teardown segfault (0xC0000005) on interpreter shutdown is fixed upstream. os._exit(0) @@ -155,6 +157,8 @@ def _run_quant() -> None: print("QUANT COMPLETE", flush=True) sys.stdout.flush() sys.stderr.flush() + # TODO(winml-cli#836): replace the hard exit once the native QNN/ORT + # teardown segfault (0xC0000005) on interpreter shutdown is fixed upstream. os._exit(0) @@ -173,6 +177,7 @@ def main() -> None: [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--build-sub", task, str(seq_len)], cwd=str(_repo_root), + timeout=1800, ).returncode after = _latest_ctx_mtime(prefix) @@ -205,6 +210,7 @@ def main() -> None: rc = subprocess.run( [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--quant"], cwd=str(_repo_root), + timeout=1800, ).returncode after_files = list(ARTIFACTS_DIR.glob("*quant.onnx")) after = max((p.stat().st_mtime for p in after_files), default=0.0)