diff --git a/examples/templates/inference_vllm_qwen3.yaml b/examples/templates/inference_vllm_qwen3.yaml new file mode 100644 index 0000000..1517732 --- /dev/null +++ b/examples/templates/inference_vllm_qwen3.yaml @@ -0,0 +1,40 @@ +apiVersion: flowmesh/v1 +kind: InferenceTask +metadata: + name: inference-vllm-qwen3 + annotations: + description: Text-generation inference with Qwen3-8B. + +spec: + taskType: inference + resources: + hardware: + cpu: 8 + memory: 16Gi + gpu: + type: any + count: 1 + memory: 24Gi + model: + source: + type: huggingface + identifier: Qwen/Qwen3-8B + revision: main + vllm: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.9 + trust_remote_code: true + data: + type: list + items: + - "Say hi in three words." + - "Name a primary color." + inference: + max_tokens: 16 + temperature: 0.0 + output: + destination: + type: local + artifacts: + - results.json + - logs diff --git a/src/worker/executors/__init__.py b/src/worker/executors/__init__.py index 6221865..1621045 100644 --- a/src/worker/executors/__init__.py +++ b/src/worker/executors/__init__.py @@ -100,7 +100,7 @@ def _import_executor(name: str, module: str) -> type[Executor] | None: "omni_text2general": "OmniText2GeneralExecutor", } -IMPORT_ERRORS: dict[str, str] = dict(_IMPORT_ERRORS) +IMPORT_ERRORS: dict[str, str] = _IMPORT_ERRORS __all__ = [ name diff --git a/src/worker/executors/omni_executor_base.py b/src/worker/executors/omni_executor_base.py index 8fd9ec2..4074ebe 100644 --- a/src/worker/executors/omni_executor_base.py +++ b/src/worker/executors/omni_executor_base.py @@ -7,6 +7,7 @@ """ import gc +import importlib.util import json import logging import os @@ -22,6 +23,7 @@ from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import TaskSpecStrictBase from shared.utils.parsing import to_bool, to_int +from worker.config import WorkerConfig from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.inference import InferenceMixin @@ -40,6 +42,10 @@ logger = logging.getLogger(__name__) +# find_spec, not import: importing vllm_omni rebinds vllm.v1.request.Request +# and breaks plain-model warmup. +_HAS_OMNI: bool = importlib.util.find_spec("vllm_omni") is not None + class OmniResult(BaseExecutorResult): executor: str @@ -66,6 +72,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._omni_spec: tuple[Any, ...] | None = None self._stage_configs_tmp: Path | None = None + @classmethod + def is_available(cls, config: WorkerConfig) -> bool: + return _HAS_OMNI + def run(self, task: ExecutorTask, out_dir: Path) -> OmniResult: spec = self.require_spec(task, self._TASK_SPEC_TYPE) spec_dict = spec.model_dump(by_alias=True) diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py index d1e4d9d..c4d9cae 100644 --- a/src/worker/executors/omni_text2audio_executor.py +++ b/src/worker/executors/omni_text2audio_executor.py @@ -16,40 +16,6 @@ np = None torch = None -try: - from vllm_omni.entrypoints.omni import Omni - - _HAS_OMNI = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.entrypoints.omni import Omni - else: - Omni = None - _HAS_OMNI = False - -try: - from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt - - _HAS_OMNI_DIFFUSION = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt - else: - OmniDiffusionSamplingParams = None - OmniTextPrompt = None - _HAS_OMNI_DIFFUSION = False - -try: - from vllm_omni.platforms import current_omni_platform - - _HAS_OMNI_PLATFORM = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.platforms import current_omni_platform - else: - current_omni_platform = None - _HAS_OMNI_PLATFORM = False - from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType from shared.tasks.specs import TaskSpecStrictBase @@ -58,7 +24,12 @@ from shared.utils.parsing import to_float, to_int from .base_executor import ExecutionError, ExecutorTask -from .omni_executor_base import OmniExecutorBase, OmniResult, extract_multimodal_output +from .omni_executor_base import ( + _HAS_OMNI, + OmniExecutorBase, + OmniResult, + extract_multimodal_output, +) logger = logging.getLogger(__name__) EXECUTOR_NAME = "omni_text2audio" @@ -86,7 +57,7 @@ def prepare(self) -> None: raise ExecutionError("omni_text2audio requires torch.") if np is None: raise ExecutionError("omni_text2audio requires numpy.") - if not (_HAS_OMNI and _HAS_OMNI_DIFFUSION): + if not _HAS_OMNI: raise ExecutionError( "vllm_omni is not installed; cannot use omni_text2audio executor." ) @@ -99,6 +70,8 @@ def _run_inner( out_dir: Path, ) -> OmniText2AudioResult: assert isinstance(spec, OmniText2AudioSpecStrict) + from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt + prompts = self._collect_text_inputs(spec, task.task_id) cfg = _bgm_cfg(spec_dict) @@ -229,6 +202,8 @@ def _run_inner( # ── model ──────────────────────────────────────────────────────────── def _ensure_omni(self, spec_dict: dict[str, Any]) -> None: + from vllm_omni.entrypoints.omni import Omni + model_name = self.resolve_model_identifier( spec_dict, _bgm_cfg(spec_dict), @@ -261,12 +236,15 @@ def _bgm_cfg(spec_dict: dict[str, Any]) -> dict[str, Any]: def _resolve_generator_device() -> str: - if _HAS_OMNI_PLATFORM and current_omni_platform is not None: - device_type = ( - str(getattr(current_omni_platform, "device_type", "")).strip().lower() - ) - if device_type in {"cuda", "cpu", "mps", "xpu", "hpu", "npu"}: - return device_type + if _HAS_OMNI: + from vllm_omni.platforms import current_omni_platform + + if current_omni_platform is not None: + device_type = ( + str(getattr(current_omni_platform, "device_type", "")).strip().lower() + ) + if device_type in {"cuda", "cpu", "mps", "xpu", "hpu", "npu"}: + return device_type if torch is not None and torch.cuda.is_available(): return "cuda" return "cpu" diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py index 04cde5c..9d09459 100644 --- a/src/worker/executors/omni_text2general_executor.py +++ b/src/worker/executors/omni_text2general_executor.py @@ -15,26 +15,17 @@ SamplingParams = None _HAS_VLLM = False -try: - from vllm_omni.entrypoints.omni import Omni - - _HAS_OMNI = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.entrypoints.omni import Omni - else: - Omni = None - _HAS_OMNI = False - from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2GeneralSpecStrict from shared.tasks.task_type import TaskType from shared.utils.parsing import as_list, to_bool, to_float, to_int, to_int_list +from worker.config import WorkerConfig from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import ( + _HAS_OMNI, OmniExecutorBase, OmniResult, extract_audio_from_mm, @@ -66,6 +57,10 @@ class OmniText2GeneralExecutor(OmniExecutorBase): supported_task_types = frozenset({TaskType.OMNI_TEXT2GENERAL}) _TASK_SPEC_TYPE = OmniText2GeneralSpecStrict + @classmethod + def is_available(cls, config: WorkerConfig) -> bool: + return _HAS_OMNI and _HAS_VLLM + def prepare(self) -> None: if not _HAS_OMNI: raise ExecutionError( @@ -207,6 +202,8 @@ def _run_inner( # ── model ──────────────────────────────────────────────────────────── def _ensure_omni(self, spec_dict: dict[str, Any]) -> None: + from vllm_omni.entrypoints.omni import Omni + cfg = _narration_cfg(spec_dict) model_name = self.resolve_model_identifier( spec_dict, diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py index 75f7de4..0265bac 100644 --- a/src/worker/executors/omni_text2image_executor.py +++ b/src/worker/executors/omni_text2image_executor.py @@ -2,18 +2,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Any - -try: - from vllm_omni.entrypoints.omni import Omni - - _HAS_OMNI = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.entrypoints.omni import Omni - else: - Omni = None - _HAS_OMNI = False +from typing import Any from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType @@ -23,7 +12,7 @@ from shared.utils.parsing import as_list from .base_executor import ExecutionError, ExecutorTask -from .omni_executor_base import OmniExecutorBase, OmniResult +from .omni_executor_base import _HAS_OMNI, OmniExecutorBase, OmniResult logger = logging.getLogger(__name__) EXECUTOR_NAME = "omni_text2image" @@ -110,6 +99,8 @@ def _run_inner( # ── model ──────────────────────────────────────────────────────────── def _ensure_omni(self, spec_dict: dict[str, Any]) -> None: + from vllm_omni.entrypoints.omni import Omni + cfg = self.omni_cfg(spec_dict, "omni:image generation", "omni_text2image") model_name = self.resolve_model_identifier( spec_dict, diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py index 911a115..966637b 100644 --- a/src/worker/executors/omni_text2speech_executor.py +++ b/src/worker/executors/omni_text2speech_executor.py @@ -3,18 +3,7 @@ import logging import os from pathlib import Path -from typing import TYPE_CHECKING, Any - -try: - from vllm_omni.entrypoints.omni import Omni - - _HAS_OMNI = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.entrypoints.omni import Omni - else: - Omni = None - _HAS_OMNI = False +from typing import Any from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType @@ -25,6 +14,7 @@ from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import ( + _HAS_OMNI, OmniExecutorBase, OmniResult, extract_audio_from_mm, @@ -124,6 +114,8 @@ def _run_inner( # ── model ──────────────────────────────────────────────────────────── def _ensure_omni(self, spec_dict: dict[str, Any]) -> None: + from vllm_omni.entrypoints.omni import Omni + cfg = self.omni_cfg(spec_dict, "omni:tts", "omni_text2speech") model_name = self.resolve_model_identifier( spec_dict, diff --git a/src/worker/executors/vllm_executor.py b/src/worker/executors/vllm_executor.py index e2a7cf5..62563bb 100644 --- a/src/worker/executors/vllm_executor.py +++ b/src/worker/executors/vllm_executor.py @@ -11,6 +11,7 @@ import copy import datetime import gc +import importlib.metadata import json import logging import os @@ -157,6 +158,20 @@ def _detect_available_gpus() -> int: pass return 1 + @staticmethod + def _vllm_plugins_allowlist_excluding_omni() -> str: + """Every `vllm.general_plugins` entry point except vllm-omni's, comma-joined. + + Omni's plugin rebinds vllm.v1.request.Request and crashes plain LLM init. + """ + names: list[str] = [] + for ep in importlib.metadata.entry_points(group="vllm.general_plugins"): + module = ep.value.split(":", 1)[0].strip() + if module == "vllm_omni" or module.startswith("vllm_omni."): + continue + names.append(ep.name) + return ",".join(names) + @staticmethod def _compute_safe_utilization(requested: float) -> tuple[float, float | None]: if torch is None or not torch.cuda.is_available(): @@ -372,6 +387,19 @@ def _ensure_llm( "Ignored unrecognized vLLM config fields: %s", list(vllm_cfg.keys()) ) + # Required workaround: vLLM's load_general_plugins() auto-loads the + # vllm_omni_register_models entry point, which imports vllm_omni and + # rebinds vllm.v1.request.Request → OmniRequest, crashing plain-model + # warmup. Revert once vllm_omni ships an OmniRequest that is a drop-in + # for vllm.v1.request.Request. + if "VLLM_PLUGINS" not in os.environ: + os.environ["VLLM_PLUGINS"] = self._vllm_plugins_allowlist_excluding_omni() + else: + logger.debug( + "VLLM_PLUGINS already set (%r); skipping omni exclusion", + os.environ["VLLM_PLUGINS"], + ) + tp_candidates: list[int] = [] seen_tp: set[int] = set() initial_tp = max(1, tensor_parallel_size) @@ -731,7 +759,7 @@ def _flatten_image_embedding_chunks( ) -> torch.Tensor: if not embedding_chunks: raise ExecutionError( - "Loaded image embedding list must be non-empty " f"(task={task_id})." + f"Loaded image embedding list must be non-empty (task={task_id})." ) if group_sizes is not None: if len(embedding_chunks) != len(group_sizes): diff --git a/src/worker/main.py b/src/worker/main.py index 52d2a8b..21b389c 100644 --- a/src/worker/main.py +++ b/src/worker/main.py @@ -1,6 +1,7 @@ import argparse import logging import signal +from collections.abc import Mapping from shared.schemas.worker import WorkerCapabilities from shared.tasks.task_type import TaskType @@ -59,7 +60,7 @@ def initialize_executors( hardware: WorkerHardware, logger: logging.Logger, lifecycle: Lifecycle, - registry: dict[str, type[Executor] | None] | None = None, + registry: Mapping[str, type[Executor] | None] | None = None, import_errors: dict[str, str] | None = None, cuda_available: bool | None = None, enable_mp_executors: bool = True, @@ -167,7 +168,7 @@ def init_executor(key: str, *, gpu_required: bool = False): def build_capabilities( executors: dict[str, Executor], - registry: dict[str, type[Executor] | None] | None = None, + registry: Mapping[str, type[Executor] | None] | None = None, ) -> WorkerCapabilities: registry = registry or EXECUTOR_REGISTRY supported_task_types = frozenset[TaskType]().union( diff --git a/tests/worker/test_executor_registry.py b/tests/worker/test_executor_registry.py index 361d28b..145cf8f 100644 --- a/tests/worker/test_executor_registry.py +++ b/tests/worker/test_executor_registry.py @@ -1,5 +1,6 @@ """Tests for the executor registry and safe import mechanism.""" +from collections.abc import Mapping from pathlib import Path from types import SimpleNamespace @@ -65,7 +66,7 @@ def test_training_executors_are_wrapped_for_isolation(self) -> None: def test_import_executor_does_not_crash(self) -> None: """The registry should load without raising, even when deps are missing.""" # If we got here, the import at module level already succeeded - assert isinstance(EXECUTOR_REGISTRY, dict) + assert isinstance(EXECUTOR_REGISTRY, Mapping) assert isinstance(IMPORT_ERRORS, dict) def test_import_executor_rejects_non_executor( diff --git a/tests/worker/test_vllm_scoped_plugins.py b/tests/worker/test_vllm_scoped_plugins.py new file mode 100644 index 0000000..566a298 --- /dev/null +++ b/tests/worker/test_vllm_scoped_plugins.py @@ -0,0 +1,160 @@ +"""Tests for VLLMExecutor's vllm-omni plugin exclusion via VLLM_PLUGINS.""" + +import importlib.metadata +import os +import pathlib +import subprocess +import sys + +import pytest + +pytest.importorskip("vllm", reason="vllm not installed (needs --extra inference-gpu)") + +from worker.executors import EXECUTOR_REGISTRY +from worker.executors.vllm_executor import VLLMExecutor # noqa: E402 + +_OMNI_EP_NAME = "vllm_omni_register_models" +_OMNI_EP_VALUE = "vllm_omni.engine.arg_utils:register_omni_models_to_vllm" + + +def _ep(name: str, value: str) -> importlib.metadata.EntryPoint: + return importlib.metadata.EntryPoint( + name=name, value=value, group="vllm.general_plugins" + ) + + +def _mock_eps( + monkeypatch: pytest.MonkeyPatch, eps: list[importlib.metadata.EntryPoint] +) -> None: + monkeypatch.setattr( + importlib.metadata, + "entry_points", + lambda group: eps if group == "vllm.general_plugins" else [], + ) + + +class TestPluginAllowlist: + def test_excludes_omni_keeps_others(self, monkeypatch: pytest.MonkeyPatch) -> None: + discovered = [ + _ep(_OMNI_EP_NAME, _OMNI_EP_VALUE), + _ep("some_other_plugin", "some_pkg.plugins:register"), + ] + _mock_eps(monkeypatch, discovered) + result = VLLMExecutor._vllm_plugins_allowlist_excluding_omni() + assert result == "some_other_plugin" + + def test_empty_when_only_omni(self, monkeypatch: pytest.MonkeyPatch) -> None: + _mock_eps(monkeypatch, [_ep(_OMNI_EP_NAME, _OMNI_EP_VALUE)]) + # "" is a valid allowlist (load none); NOT None which would mean load-all. + assert VLLMExecutor._vllm_plugins_allowlist_excluding_omni() == "" + + def test_filters_by_module_not_ep_name( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + discovered = [ + _ep("renamed_entry", "vllm_omni.something.else:fn"), + _ep("keep", "another_pkg:fn"), + ] + _mock_eps(monkeypatch, discovered) + assert VLLMExecutor._vllm_plugins_allowlist_excluding_omni() == "keep" + + def test_real_install_does_not_allowlist_omni(self) -> None: + eps = importlib.metadata.entry_points(group="vllm.general_plugins") + if not any( + ep.value.split(":", 1)[0].strip().startswith("vllm_omni") for ep in eps + ): + pytest.skip("vllm_omni not a registered vllm.general_plugins entry") + names = VLLMExecutor._vllm_plugins_allowlist_excluding_omni().split(",") + assert _OMNI_EP_NAME not in names + + +class TestAllowlistHelperIsPure: + def test_helper_does_not_mutate_env(self) -> None: + before = os.environ.get("VLLM_PLUGINS") + VLLMExecutor._vllm_plugins_allowlist_excluding_omni() + assert os.environ.get("VLLM_PLUGINS") == before + + +_OMNI_MODULE_PREFIX = "vllm_omni" + + +class TestPluginScopingPreventsRebind: + def test_allowlist_keeps_vllm_omni_unloaded(self) -> None: + """VLLM_PLUGINS fix allowlist keeps vllm_omni unloaded and Request unpatched.""" + + def _is_omni_ep(ep: importlib.metadata.EntryPoint) -> bool: + return ep.value.split(":", 1)[0].strip().startswith(_OMNI_MODULE_PREFIX) + + eps = importlib.metadata.entry_points(group="vllm.general_plugins") + if not any(_is_omni_ep(ep) for ep in eps): + pytest.skip("vllm_omni not a registered vllm.general_plugins entry") + + allowlist = VLLMExecutor._vllm_plugins_allowlist_excluding_omni() + script = "\n".join( + [ + "import sys, os", + f"os.environ['VLLM_PLUGINS'] = {allowlist!r}", + "try:", + " from vllm.plugins import load_general_plugins", + " load_general_plugins()", + "except Exception:", + " pass", + "import vllm.v1.request as R", + "print('omni_loaded:', 'vllm_omni' in sys.modules)", + "print('request_module:', R.Request.__module__)", + ] + ) + result = subprocess.run( # nosec B603 — fixed argv list, no shell=True, sys.executable + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=30, + ) + assert result.returncode == 0, result.stderr + assert "omni_loaded: False" in result.stdout + assert "request_module: vllm.v1.request" in result.stdout + + +def _subprocess_env() -> dict[str, str]: + """Build an env with src/ on PYTHONPATH so subprocess can import worker.*.""" + src_dir = (pathlib.Path(__file__).parents[2] / "src").as_posix() + existing = os.environ.get("PYTHONPATH", "") + pythonpath = f"{src_dir}:{existing}" if existing else src_dir + return {**os.environ, "PYTHONPATH": pythonpath} + + +class TestLazyOmniRegistry: + """Regression tests: import worker.executors must not pull in vllm_omni.""" + + _OMNI_KEYS = ( + "omni_text2image", + "omni_text2speech", + "omni_text2audio", + "omni_text2general", + ) + + def test_executor_registry_has_omni_keys(self) -> None: + for key in self._OMNI_KEYS: + assert key in EXECUTOR_REGISTRY + + def test_import_does_not_load_vllm_omni(self) -> None: + """import worker.executors alone must not load vllm_omni or patch Request.""" + script = "\n".join( + [ + "import sys", + "import worker.executors", + "import vllm.v1.request as R", + "print('omni_in_modules:', 'vllm_omni' in sys.modules)", + "print('request_class:', R.Request.__name__)", + ] + ) + result = subprocess.run( # nosec B603 — fixed argv list, no shell=True, sys.executable + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=60, + env=_subprocess_env(), + ) + assert result.returncode == 0, result.stderr + assert "omni_in_modules: False" in result.stdout + assert "request_class: Request" in result.stdout