diff --git a/src/winml/modelkit/analyze/analyzer.py b/src/winml/modelkit/analyze/analyzer.py index 60d8dac01..779b4a52e 100644 --- a/src/winml/modelkit/analyze/analyzer.py +++ b/src/winml/modelkit/analyze/analyzer.py @@ -140,10 +140,7 @@ def _build_runtime_debug_details_summary( level_bucket[node_stable_key] = candidate_entry continue - if ( - existing_entry.case_indices is None - and candidate_entry.case_indices is not None - ): + if existing_entry.case_indices is None and candidate_entry.case_indices is not None: existing_entry.case_indices = candidate_entry.case_indices if existing_entry.table_path is None and candidate_entry.table_path is not None: @@ -798,7 +795,7 @@ def analyze_from_proto( if device is not None and device.lower() == "auto": from ..sysinfo import resolve_device - resolved, _ = resolve_device("auto") + resolved, _ = resolve_device("auto", ep=None) device_to_use = resolved.upper() logger.info("Device 'auto' resolved to: %s", device_to_use) else: diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index c3ffc660d..186d9fdb6 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -554,7 +554,7 @@ def build( from ..sysinfo import resolve_eps as _resolve_eps try: - resolved_device, _ = _resolve_device(device=device) + resolved_device, _ = _resolve_device(device=device, ep=None) except ValueError as e: raise click.UsageError(str(e)) from e device = resolved_device @@ -579,6 +579,7 @@ def build( trust_remote_code=trust_remote_code, device=device, precision=precision, + ep=ep, ) if not quant: config_or_configs.quant = None diff --git a/src/winml/modelkit/commands/eval.py b/src/winml/modelkit/commands/eval.py index 414153b09..d7c6aa448 100644 --- a/src/winml/modelkit/commands/eval.py +++ b/src/winml/modelkit/commands/eval.py @@ -364,7 +364,7 @@ def _resolve_device(cfg: WinMLEvaluationConfig) -> None: console = Console(stderr=True) console.print("[bold]Detecting available devices...[/bold]") - resolved, _ = resolve_device(cfg.device) + resolved, _ = resolve_device(cfg.device, ep=None) cfg.device = resolved console.print(f"[dim]Using device:[/dim] {resolved}") diff --git a/src/winml/modelkit/compiler/stages/compile.py b/src/winml/modelkit/compiler/stages/compile.py index 4bc1c28c4..da761882f 100644 --- a/src/winml/modelkit/compiler/stages/compile.py +++ b/src/winml/modelkit/compiler/stages/compile.py @@ -176,7 +176,7 @@ def _compile_multiple(self, context: CompileContext) -> None: sess_options = context.shared_session_options if sess_options is None: register_execution_providers(ort=True) - resolved_device, _ = resolve_device(context.config.get("device", "auto")) + resolved_device, _ = resolve_device(context.config.get("device", "auto"), ep=None) ep = normalize_ep_name(ep_config.provider) or resolve_eps(resolved_device)[0] device_type = DEVICE_TO_DEVICE_TYPE.get(resolved_device.upper()) diff --git a/src/winml/modelkit/sysinfo/device.py b/src/winml/modelkit/sysinfo/device.py index 6f6fdb5b9..05e714102 100644 --- a/src/winml/modelkit/sysinfo/device.py +++ b/src/winml/modelkit/sysinfo/device.py @@ -145,9 +145,9 @@ def _get_available_eps() -> frozenset[EPName]: def resolve_device( - device: str = "auto", + device: str, *, - ep: EPNameOrAlias | None = None, + ep: EPNameOrAlias | None, ) -> tuple[str, list[str]]: """Resolve target device with EP availability cross-check. @@ -233,7 +233,7 @@ def resolve_eps(resolved_device: str) -> list[EPName]: def resolve_check_device_ep( - *, device: str = "auto", ep: EPNameOrAlias | None = None + *, device: str, ep: EPNameOrAlias | None ) -> tuple[str, list[str], list[EPName]]: """Resolve or check that the requested device and/or EP combination is valid, raising if not. diff --git a/tests/unit/commands/test_build.py b/tests/unit/commands/test_build.py index 00f54fc23..ffa29bb7c 100644 --- a/tests/unit/commands/test_build.py +++ b/tests/unit/commands/test_build.py @@ -1711,3 +1711,34 @@ def test_returns_compiled_path_when_file_exists( # current_path should be updated to compiled_path assert result == compiled_path + + +class TestBuildEpResolution: + """--ep forwarding into config generation + the compile EP-availability gate.""" + + def _base_args(self, cfg: str, tmp_path: Path) -> list[str]: + return ["-c", cfg, "-m", "microsoft/resnet-50", "-o", str(tmp_path / "out")] + + def test_ep_forwarded_to_generate_build_config( + self, tmp_path: Path, mock_run_single_build: MagicMock + ): + """On the auto-config path (-m, no -c), --ep reaches generate_build_config. + + Regression: the build command dropped --ep when auto-generating a config, + so the requested EP never influenced the generated config (it failed or + analyzed/compiled for the wrong EP). + """ + fake_cfg = MagicMock() + fake_cfg.compile = None # no compile -> EP-availability gate is skipped + with ( + patch("winml.modelkit.config.generate_build_config", return_value=fake_cfg) as mock_gen, + patch( + "winml.modelkit.commands.build._validate_loader_tasks_for_model", + return_value=None, + ), + ): + result = _invoke( + ["-m", "microsoft/resnet-50", "--ep", "openvino", "-o", str(tmp_path / "out")] + ) + assert result.exit_code == 0, result.output + assert mock_gen.call_args.kwargs["ep"] == "openvino" diff --git a/tests/unit/sysinfo/test_device.py b/tests/unit/sysinfo/test_device.py index aa730ea75..5062abd63 100644 --- a/tests/unit/sysinfo/test_device.py +++ b/tests/unit/sysinfo/test_device.py @@ -83,7 +83,7 @@ def test_no_npu_no_gpu(self) -> None: def test_returns_empty_when_enumeration_fails(self) -> None: """If EP enumeration raises, return empty tuple (no devices visible). - ``resolve_device("auto")`` is responsible for the CPU fallback when no + ``resolve_device("auto", ep=None)`` is responsible for the CPU fallback when no devices are reachable; ``_get_available_devices`` only reports what is actually registered. """ @@ -214,7 +214,7 @@ def test_resolve_device_auto_npu_with_ep(self) -> None: "cpu": ("CPUExecutionProvider",), } ): - device, available = resolve_device("auto") + device, available = resolve_device("auto", ep=None) assert device == "npu" assert available == ["npu", "gpu", "cpu"] @@ -227,7 +227,7 @@ def test_resolve_device_auto_npu_without_ep(self) -> None: "cpu": ("CPUExecutionProvider",), } ): - device, available = resolve_device("auto") + device, available = resolve_device("auto", ep=None) assert device == "gpu" assert available == ["gpu", "cpu"] @@ -235,7 +235,7 @@ def test_resolve_device_auto_npu_without_ep(self) -> None: def test_resolve_device_auto_cpu_fallback(self) -> None: """Auto mode: only CPU EP registered -> returns "cpu".""" with _patch_device_ep_map({"cpu": ("CPUExecutionProvider",)}): - device, available = resolve_device("auto") + device, available = resolve_device("auto", ep=None) assert device == "cpu" assert available == ["cpu"] @@ -248,7 +248,7 @@ def test_resolve_device_explicit_valid(self) -> None: "cpu": ("CPUExecutionProvider",), } ): - device, available = resolve_device("gpu") + device, available = resolve_device("gpu", ep=None) assert device == "gpu" assert available == ["gpu", "cpu"] @@ -256,7 +256,7 @@ def test_resolve_device_explicit_valid(self) -> None: def test_resolve_device_explicit_invalid(self) -> None: """Unrecognized device "tpu" -> raises ValueError.""" with pytest.raises(ValueError, match="Unknown device 'tpu'"): - resolve_device("tpu") + resolve_device("tpu", ep=None) def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None: """Error message must name the compatible EPs so users know what to install.""" @@ -268,7 +268,7 @@ def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None: ), pytest.raises(ValueError) as exc_info, ): - resolve_device("npu") + resolve_device("npu", ep=None) message = str(exc_info.value) assert "no compatible EP" in message @@ -278,7 +278,7 @@ def test_resolve_device_explicit_no_ep_error_names_missing_eps(self) -> None: def test_resolve_device_case_insensitive(self) -> None: """Device argument should be case-insensitive.""" with _patch_device_ep_map({"cpu": ("CPUExecutionProvider",)}): - device, _ = resolve_device("CPU") + device, _ = resolve_device("CPU", ep=None) assert device == "cpu" @@ -293,7 +293,7 @@ def test_resolve_device_no_eps_raises(self) -> None: _patch_device_ep_map({}), pytest.raises(RuntimeError, match="No execution providers detected"), ): - resolve_device("auto") + resolve_device("auto", ep=None) class TestResolveDeviceWithEp: