Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ concurrency:
jobs:
lint:
runs-on: windows-latest
# Bumped from 5: combined mypy on 23 packages cold-starts at ~3-4 min on
# Bumped from 5: combined mypy on 24 packages cold-starts at ~3-4 min on
# Windows runners; the original 5-min ceiling cancelled mid-run.
timeout-minutes: 10

Expand Down Expand Up @@ -65,6 +65,7 @@ jobs:
-p winml.modelkit.onnx
-p winml.modelkit.optim
-p winml.modelkit.optracing
-p winml.modelkit.pattern
-p winml.modelkit.quant
-p winml.modelkit.serve
-p winml.modelkit.session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _get_node_actual_dtype(self, node: onnx.NodeProto) -> str | None:
return dtype.upper()
return None

def _get_node_shape(self, tensor_name: str) -> list[int] | None:
def _get_node_shape(self, tensor_name: str) -> tuple[int | str | None, ...] | None:
"""Get shape of a tensor.

Args:
Expand Down
2 changes: 1 addition & 1 deletion src/winml/modelkit/analyze/core/runtime_checker_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ def _tensor_to_array_with_fallback(tensor: onnx.TensorProto) -> np.ndarray:
type_vars[type_annotation] = dtype
else:
vi = valueinfo.get(inp_name)
shape_seq: list | tuple[int, ...] | None = None
shape_seq: tuple[int | str | None, ...] | None = None
dtype = None
if vi is not None:
shape_seq, dtype = shape_and_dtype_from_valueinfo(vi)
Expand Down
7 changes: 3 additions & 4 deletions src/winml/modelkit/pattern/attention_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,11 @@
from .base import (
Pattern,
PatternInputGenerator,
PatternMatchResult,
PatternSchema,
Skeleton,
SkeletonMatchResult,
register_pattern_input_generator,
)
from .match import PatternMatchResult, SkeletonMatchResult
from .op_input_gen import InputShapeConstraint


Expand Down Expand Up @@ -539,7 +538,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]:

def get_input_and_infinite_attribute_combinations(
self,
) -> list[dict[str, InputShapeConstraint]]:
) -> list[dict[str, object]]:
"""Returns input combinations for expanded attention with mask pattern testing.

Provides various 4D input shapes for Q, K, V, and attn_mask tensors.
Expand Down Expand Up @@ -596,7 +595,7 @@ def get_finite_attribute_sets(self) -> dict[str, list]:

def get_input_and_infinite_attribute_combinations(
self,
) -> list[dict[str, InputShapeConstraint]]:
) -> list[dict[str, object]]:
"""Returns input combinations for Transpose+Attention pattern testing.

Provides various 4D input shapes for Q, K, V, and attn_mask tensors.
Expand Down
34 changes: 21 additions & 13 deletions src/winml/modelkit/pattern/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_schema(self) -> PatternSchema:
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from typing import Any, cast

import numpy as np
import onnx
Expand Down Expand Up @@ -599,7 +599,10 @@ def _check_skeleton_result_impl(
if type_str:
# Use InputShapeConstraint to create dummy value
type_annotation = SupportedONNXType.from_onnx_type(type_str).annotation
inputs[name] = InputShapeConstraint(info.shape).get_value(type_annotation)
# Matched-tensor shapes are concrete here (dynamic dims resolved).
inputs[name] = InputShapeConstraint(
cast("tuple[int, ...]", info.shape)
).get_value(type_annotation)

# Build is_constant_map from input_infos
is_constant_map = {name: info.is_constant for name, info in input_infos.items()}
Expand Down Expand Up @@ -649,7 +652,7 @@ def _check_skeleton_result_impl(
node_domain = skeleton.node_domains[node_idx]
op_type = skeleton.node_op_types[node_idx]
opset_versions = ONNXDomain.get_model_domain_opset_versions(
skeleton_match_result.model
skeleton_match_result.matcher.model
)
opset_version = opset_versions[node_domain]
op_schema = node_domain.get_op_schema(op_type, opset_version)
Expand Down Expand Up @@ -845,7 +848,7 @@ def get_onnx_model(

# Create nodes
nodes = []
node_output_names = {} # node_idx -> output_name
node_output_names: dict[int, str] = {} # node_idx -> output_name

for node_idx in range(skeleton.n_nodes):
op_type = skeleton.node_op_types[node_idx]
Expand Down Expand Up @@ -1037,7 +1040,7 @@ def _infer_schema_attributes(
class PatternInputGenerator(OpInputGenerator):
"""Input generator that wraps a Pattern for runtime checking."""

pattern: Pattern = None
pattern: Pattern | None = None # subclasses set a real Pattern (asserted in __init__)
registration_name: str

def __init__(
Expand All @@ -1056,7 +1059,9 @@ def __init__(
self.domain_versions = domain_versions
schema = self.pattern.get_schema()
self.op_name = schema.name # compatibility with OpInputGenerator
super().__init__(schema, onnx_types_to_check)
# OpInputGenerator duck-types the schema (OpSchema or PatternSchema); it
# guards OpSchema-specific access with isinstance internally.
super().__init__(cast("OpSchema", schema), onnx_types_to_check)

def _create_model(
self,
Expand Down Expand Up @@ -1086,8 +1091,8 @@ def _create_model(
SupportedONNXType.from_annotation(dtype).onnx_type for dtype in output_dtypes
]

# Use the pattern's get_onnx_model method
return self.pattern.get_onnx_model(
# Use the pattern's get_onnx_model method (pattern is set by the subclass).
return cast("Pattern", self.pattern).get_onnx_model(
inputs=input_kwargs,
attributes=attr_kwargs,
is_constant_map=is_constant_map,
Expand Down Expand Up @@ -1538,7 +1543,7 @@ def _get_registered_edge_info(self, tensor_name: str, consumer_name: str) -> Edg

def _check_constant_constraints(
self,
matched_nodes: list[str],
matched_nodes: list[onnx.NodeProto],
constant_constraints: list[tuple[int, int, np.ndarray]],
) -> bool:
"""Check constant value constraints for a skeleton match.
Expand Down Expand Up @@ -1649,7 +1654,9 @@ def match(self) -> list[PatternMatchResult]:
# Validate each result using pattern's check_skeleton_result
validated_results = []
for result in skeleton_results:
pattern_match_result = result.pattern.check_skeleton_result(result)
# match_skeleton() yields results whose .pattern is a registered ABC
# Pattern (the pydantic PatternModel is only used for serialization).
pattern_match_result = cast("Pattern", result.pattern).check_skeleton_result(result)
if pattern_match_result is not None:
validated_results.append(pattern_match_result)

Expand Down Expand Up @@ -1773,7 +1780,7 @@ def _match_single_skeleton(
# check 3: the mappings must be compatible
valid_merged_mappings = []
for mapping_combination in it.product(*dst_slot_partial_mappings):
merged_mapping = _merge_mappings(mapping_combination)
merged_mapping = _merge_mappings(list(mapping_combination))
if merged_mapping is not None:
# valid mapping
merged_mapping[subgraph_node] = node_name
Expand Down Expand Up @@ -1955,7 +1962,7 @@ def _allocate_graph_node_key(node: Any) -> str:
nonlocal generated_node_key_counter

if node.name and node.name not in used_graph_node_keys:
key = node.name
key: str = node.name
elif node.name:
suffix = 1
key = f"{node.name}__{suffix}"
Expand Down Expand Up @@ -2018,7 +2025,8 @@ def _allocate_graph_node_key(node: Any) -> str:

# Create the new pattern instance
new_pattern = new_pattern_class()
assert skeleton_match.pattern.get_schema() == new_pattern.get_schema(), (
matched_pattern = cast("Pattern", skeleton_match.pattern)
assert matched_pattern.get_schema() == new_pattern.get_schema(), (
f"New pattern {new_pattern_class.__name__} schema does not match "
f"the matched pattern {skeleton_match.pattern.__class__.__name__} schema."
)
Expand Down
4 changes: 2 additions & 2 deletions src/winml/modelkit/pattern/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast


if TYPE_CHECKING:
Expand Down Expand Up @@ -95,7 +95,7 @@ def load_pattern(self) -> Pattern:
except (ImportError, AttributeError):
continue
# Instantiation errors should propagate, not be silently caught
return pattern_cls()
return cast("Pattern", pattern_cls())

msg = f"Failed to load pattern {self.pattern_class} from {self.module}"
logger.error(msg)
Expand Down
8 changes: 4 additions & 4 deletions src/winml/modelkit/pattern/gelu_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_schema(self) -> PatternSchema:


@register_pattern_input_generator
class Gelu1PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")):
class Gelu1PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): # type: ignore[misc] # dynamic base class (runtime-checker op)
"""Input generator for GELU activation pattern variant 1."""

pattern = Gelu1Pattern()
Expand Down Expand Up @@ -224,7 +224,7 @@ def get_schema(self) -> PatternSchema:


@register_pattern_input_generator
class Gelu2PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")):
class Gelu2PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): # type: ignore[misc] # dynamic base class (runtime-checker op)
"""Input generator for GELU activation pattern variant 2."""

pattern = Gelu2Pattern()
Expand Down Expand Up @@ -331,7 +331,7 @@ def get_schema(self) -> PatternSchema:


@register_pattern_input_generator
class Gelu3PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")):
class Gelu3PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): # type: ignore[misc] # dynamic base class (runtime-checker op)
"""Input generator for GELU activation pattern variant 3."""

pattern = Gelu3Pattern()
Expand Down Expand Up @@ -439,7 +439,7 @@ def get_schema(self) -> PatternSchema:


@register_pattern_input_generator
class Gelu4PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")):
class Gelu4PatternInputGenerator(PatternInputGenerator, get_runtime_checker_op("Gelu")): # type: ignore[misc] # dynamic base class (runtime-checker op)
"""Input generator for GELU activation pattern variant 4."""

pattern = Gelu4Pattern()
Expand Down
5 changes: 3 additions & 2 deletions src/winml/modelkit/pattern/gemm_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ def get_schema(self) -> PatternSchema:
class GemmPatternInputGenerator(PatternInputGenerator):
"""PatternInputGenerator for Gemm patterns."""

pattern = ReshapeGemmReshapePattern()
# Typed as the base Pattern so subclasses can set a different concrete pattern.
pattern: Pattern = ReshapeGemmReshapePattern()

def get_finite_attribute_sets(self) -> dict[str, list[Any]]:
"""Return finite attribute sets for ReshapeGemmReshape (none)."""
Expand Down Expand Up @@ -418,7 +419,7 @@ def get_input_and_infinite_attribute_combinations(
for a_shape in a_shapes:
for b_shape in b_shapes:
for c_option in c_options:
combination = {
combination: dict[str, object] = {
"A": InputShapeConstraint(a_shape),
"B": InputShapeConstraint(b_shape),
}
Expand Down
24 changes: 13 additions & 11 deletions src/winml/modelkit/pattern/layernorm_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

from abc import abstractmethod
from typing import Any
from typing import Any, cast

import numpy as np
from onnx.defs import OpSchema
Expand All @@ -23,13 +23,12 @@
from .base import (
Pattern,
PatternInputGenerator,
PatternMatchResult,
PatternMismatchedError,
PatternSchema,
Skeleton,
SkeletonMatchResult,
register_pattern_input_generator,
)
from .match import PatternMatchResult, SkeletonMatchResult
from .op_input_gen import get_runtime_checker_op
from .utils import (
get_attribute_proto_value,
Expand Down Expand Up @@ -253,6 +252,8 @@ def _infer_schema_attributes(
if axes_value is None:
raise PatternMismatchedError("ReduceMean missing axes attribute")

if axes_value is None:
raise PatternMismatchedError("ReduceMean axes tensor value is None")
if len(axes_value) != 1:
raise PatternMismatchedError(
f"Only single-axis normalization supported, got axes={axes_value}"
Expand Down Expand Up @@ -495,7 +496,7 @@ def _get_normalized_dim(self, inputs: dict[str, np.ndarray], attributes: dict[st
axis = attributes["axis"]
rank = len(x_shape)
normalized_axis = axis if axis >= 0 else rank + axis
return x_shape[normalized_axis]
return int(x_shape[normalized_axis])

def get_internal_constants_and_attributes(
self,
Expand Down Expand Up @@ -534,7 +535,7 @@ def get_internal_constants_and_attributes(


class LayerNormalizationPatternInputGenerator(
PatternInputGenerator, get_runtime_checker_op("LayerNormalization")
PatternInputGenerator, get_runtime_checker_op("LayerNormalization") # type: ignore[misc] # dynamic base class (runtime-checker op)
):
"""Base PatternInputGenerator for LayerNormalization patterns.

Expand All @@ -547,14 +548,15 @@ def get_finite_attribute_sets(self) -> dict[str, list[Any]]:

def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, Any]]:
"""Return input combinations with broadcast-compatible Scale/B shapes."""
from .op_input_gen import InputValueConstraint
from .op_input_gen import InputShapeConstraint, InputValueConstraint

combinations = super().get_input_and_infinite_attribute_combinations()
# Dynamic base provides the real combinations method at runtime.
combinations = super().get_input_and_infinite_attribute_combinations() # type: ignore[safe-super]

adapted = []
for combo in combinations:
axis = combo["axis"]
x_shape = combo["X"].shape
axis = cast("int", combo["axis"])
x_shape = cast("InputShapeConstraint", combo["X"]).shape
rank = len(x_shape)
normalized_axis = axis if axis >= 0 else rank + axis
normalized_dim = x_shape[normalized_axis]
Expand All @@ -569,8 +571,8 @@ def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, Any]]:
broadcast_shape = [1] * rank
broadcast_shape[normalized_axis] = normalized_dim

scale_value = combo["Scale"].value
bias_value = combo["B"].value
scale_value = cast("InputValueConstraint", combo["Scale"]).value
bias_value = cast("InputValueConstraint", combo["B"]).value
new_scale = np.ones((normalized_dim,), dtype=scale_value.dtype).reshape(
broadcast_shape
)
Expand Down
Loading
Loading