Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/winml/modelkit/analyze/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,7 @@ def _build_runtime_debug_details_summary(
level_bucket[node_stable_key] = candidate_entry
continue

if (
existing_entry.case_indices is None
and candidate_entry.case_indices is not None
):
if existing_entry.case_indices is None and candidate_entry.case_indices is not None:
existing_entry.case_indices = candidate_entry.case_indices

if existing_entry.table_path is None and candidate_entry.table_path is not None:
Expand Down Expand Up @@ -798,7 +795,7 @@ def analyze_from_proto(
if device is not None and device.lower() == "auto":
from ..sysinfo import resolve_device

resolved, _ = resolve_device("auto")
resolved, _ = resolve_device("auto", ep=None)
device_to_use = resolved.upper()
logger.info("Device 'auto' resolved to: %s", device_to_use)
else:
Expand Down
3 changes: 2 additions & 1 deletion src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def build(
from ..sysinfo import resolve_eps as _resolve_eps

try:
resolved_device, _ = _resolve_device(device=device)
resolved_device, _ = _resolve_device(device=device, ep=None)
except ValueError as e:
raise click.UsageError(str(e)) from e
device = resolved_device
Expand All @@ -579,6 +579,7 @@ def build(
trust_remote_code=trust_remote_code,
device=device,
precision=precision,
ep=ep,
)
if not quant:
config_or_configs.quant = None
Expand Down
2 changes: 1 addition & 1 deletion src/winml/modelkit/commands/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _resolve_device(cfg: WinMLEvaluationConfig) -> None:

console = Console(stderr=True)
console.print("[bold]Detecting available devices...[/bold]")
resolved, _ = resolve_device(cfg.device)
resolved, _ = resolve_device(cfg.device, ep=None)
cfg.device = resolved
console.print(f"[dim]Using device:[/dim] {resolved}")

Expand Down
2 changes: 1 addition & 1 deletion src/winml/modelkit/compiler/stages/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _compile_multiple(self, context: CompileContext) -> None:
sess_options = context.shared_session_options
if sess_options is None:
register_execution_providers(ort=True)
resolved_device, _ = resolve_device(context.config.get("device", "auto"))
resolved_device, _ = resolve_device(context.config.get("device", "auto"), ep=None)
ep = normalize_ep_name(ep_config.provider) or resolve_eps(resolved_device)[0]
device_type = DEVICE_TO_DEVICE_TYPE.get(resolved_device.upper())

Expand Down
6 changes: 3 additions & 3 deletions src/winml/modelkit/sysinfo/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def _get_available_eps() -> frozenset[EPName]:


def resolve_device(
device: str = "auto",
device: str,
*,
ep: EPNameOrAlias | None = None,
ep: EPNameOrAlias | None,
) -> tuple[str, list[str]]:
"""Resolve target device with EP availability cross-check.

Expand Down Expand Up @@ -233,7 +233,7 @@ def resolve_eps(resolved_device: str) -> list[EPName]:


def resolve_check_device_ep(
*, device: str = "auto", ep: EPNameOrAlias | None = None
*, device: str, ep: EPNameOrAlias | None
) -> tuple[str, list[str], list[EPName]]:
"""Resolve or check that the requested device and/or EP combination is valid, raising if not.

Expand Down
31 changes: 31 additions & 0 deletions tests/unit/commands/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,3 +1711,34 @@ def test_returns_compiled_path_when_file_exists(

# current_path should be updated to compiled_path
assert result == compiled_path


class TestBuildEpResolution:
"""--ep forwarding into config generation + the compile EP-availability gate."""

def _base_args(self, cfg: str, tmp_path: Path) -> list[str]:
return ["-c", cfg, "-m", "microsoft/resnet-50", "-o", str(tmp_path / "out")]

def test_ep_forwarded_to_generate_build_config(
self, tmp_path: Path, mock_run_single_build: MagicMock
):
"""On the auto-config path (-m, no -c), --ep reaches generate_build_config.

Regression: the build command dropped --ep when auto-generating a config,
so the requested EP never influenced the generated config (it failed or
analyzed/compiled for the wrong EP).
"""
fake_cfg = MagicMock()
fake_cfg.compile = None # no compile -> EP-availability gate is skipped
with (
patch("winml.modelkit.config.generate_build_config", return_value=fake_cfg) as mock_gen,
patch(
"winml.modelkit.commands.build._validate_loader_tasks_for_model",
return_value=None,
),
):
result = _invoke(
["-m", "microsoft/resnet-50", "--ep", "openvino", "-o", str(tmp_path / "out")]
)
assert result.exit_code == 0, result.output
assert mock_gen.call_args.kwargs["ep"] == "openvino"
18 changes: 9 additions & 9 deletions tests/unit/sysinfo/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_no_npu_no_gpu(self) -> None:
def test_returns_empty_when_enumeration_fails(self) -> None:
"""If EP enumeration raises, return empty tuple (no devices visible).

``resolve_device("auto")`` is responsible for the CPU fallback when no
``resolve_device("auto", ep=None)`` is responsible for the CPU fallback when no
devices are reachable; ``_get_available_devices`` only reports what is
actually registered.
"""
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_resolve_device_auto_npu_with_ep(self) -> None:
"cpu": ("CPUExecutionProvider",),
}
):
device, available = resolve_device("auto")
device, available = resolve_device("auto", ep=None)

assert device == "npu"
assert available == ["npu", "gpu", "cpu"]
Expand All @@ -227,15 +227,15 @@ def test_resolve_device_auto_npu_without_ep(self) -> None:
"cpu": ("CPUExecutionProvider",),
}
):
device, available = resolve_device("auto")
device, available = resolve_device("auto", ep=None)

assert device == "gpu"
assert available == ["gpu", "cpu"]

def test_resolve_device_auto_cpu_fallback(self) -> None:
"""Auto mode: only CPU EP registered -> returns "cpu"."""
with _patch_device_ep_map({"cpu": ("CPUExecutionProvider",)}):
device, available = resolve_device("auto")
device, available = resolve_device("auto", ep=None)

assert device == "cpu"
assert available == ["cpu"]
Expand All @@ -248,15 +248,15 @@ def test_resolve_device_explicit_valid(self) -> None:
"cpu": ("CPUExecutionProvider",),
}
):
device, available = resolve_device("gpu")
device, available = resolve_device("gpu", ep=None)

assert device == "gpu"
assert available == ["gpu", "cpu"]

def test_resolve_device_explicit_invalid(self) -> None:
"""Unrecognized device "tpu" -> raises ValueError."""
with pytest.raises(ValueError, match="Unknown device 'tpu'"):
resolve_device("tpu")
resolve_device("tpu", ep=None)

def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None:
"""Error message must name the compatible EPs so users know what to install."""
Expand All @@ -268,7 +268,7 @@ def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None:
),
pytest.raises(ValueError) as exc_info,
):
resolve_device("npu")
resolve_device("npu", ep=None)

message = str(exc_info.value)
assert "no compatible EP" in message
Expand All @@ -278,7 +278,7 @@ def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None:
def test_resolve_device_case_insensitive(self) -> None:
"""Device argument should be case-insensitive."""
with _patch_device_ep_map({"cpu": ("CPUExecutionProvider",)}):
device, _ = resolve_device("CPU")
device, _ = resolve_device("CPU", ep=None)

assert device == "cpu"

Expand All @@ -293,7 +293,7 @@ def test_resolve_device_no_eps_raises(self) -> None:
_patch_device_ep_map({}),
pytest.raises(RuntimeError, match="No execution providers detected"),
):
resolve_device("auto")
resolve_device("auto", ep=None)


class TestResolveDeviceWithEp:
Expand Down
Loading