From 25ac41a7cc6639ed96840d17a3539b2491416e30 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Fri, 19 Jun 2026 17:00:11 +0700 Subject: [PATCH 1/3] Workaround to support Qwen3-8B Signed-off-by: Zhengyuan Su --- examples/templates/inference_vllm_qwen3.yaml | 42 +++++ src/worker/executors/__init__.py | 105 +++++++----- src/worker/executors/vllm_executor.py | 31 +++- src/worker/main.py | 5 +- tests/worker/test_executor_registry.py | 3 +- tests/worker/test_vllm_scoped_plugins.py | 165 +++++++++++++++++++ 6 files changed, 308 insertions(+), 43 deletions(-) create mode 100644 examples/templates/inference_vllm_qwen3.yaml create mode 100644 tests/worker/test_vllm_scoped_plugins.py diff --git a/examples/templates/inference_vllm_qwen3.yaml b/examples/templates/inference_vllm_qwen3.yaml new file mode 100644 index 0000000..b6cc447 --- /dev/null +++ b/examples/templates/inference_vllm_qwen3.yaml @@ -0,0 +1,42 @@ +apiVersion: flowmesh/v1 +kind: InferenceTask +metadata: + name: inference-vllm-qwen3 + owner: timzsu + annotations: + description: > + Text-generation inference with Qwen3-8B. + +spec: + taskType: inference + resources: + hardware: + cpu: 8 + memory: 16Gi + gpu: + type: any + count: 1 + memory: 24Gi # Qwen3-8B BF16 weights + KV; 8Gi is too small + model: + source: + type: huggingface + identifier: Qwen/Qwen3-8B + revision: main + vllm: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.9 + trust_remote_code: true + data: + type: list + items: + - "Say hi in three words." + - "Name a primary color." + inference: + max_tokens: 16 + temperature: 0.0 + output: + destination: + type: local + artifacts: + - results.json + - logs diff --git a/src/worker/executors/__init__.py b/src/worker/executors/__init__.py index 6221865..5934952 100644 --- a/src/worker/executors/__init__.py +++ b/src/worker/executors/__init__.py @@ -1,4 +1,5 @@ import importlib +from collections.abc import Iterator, Mapping from .base_executor import Executor @@ -41,42 +42,71 @@ def _import_executor(name: str, module: str) -> type[Executor] | None: DiffusersExecutor = _import_executor("DiffusersExecutor", ".diffusers_executor") APIExecutor = _import_executor("APIExecutor", ".api_executor") SSHExecutor = _import_executor("SSHExecutor", ".ssh_executor") -OmniText2ImageExecutor = _import_executor( - "OmniText2ImageExecutor", ".omni_text2image_executor" -) -OmniText2SpeechExecutor = _import_executor( - "OmniText2SpeechExecutor", ".omni_text2speech_executor" -) -OmniText2AudioExecutor = _import_executor( - "OmniText2AudioExecutor", ".omni_text2audio_executor" -) -OmniText2GeneralExecutor = _import_executor( - "OmniText2GeneralExecutor", ".omni_text2general_executor" -) - -EXECUTOR_REGISTRY: dict[str, type[Executor] | None] = { - "vllm": VLLMExecutor, - "vllm_lora": VLLMLoRAExecutor, - "ppo": PPOExecutor, - "dpo": DPOExecutor, - "sft": SFTExecutor, - "lora_sft": LoRASFTExecutor, - "image_classification_training": ImageClassificationTrainingExecutor, - "default": HFTransformersExecutor, - "rag": RAGExecutor, - "agent": AgentExecutor, - "echo": EchoExecutor, - "data_profiling": DataProfilingExecutor, - "data_retrieval": DataRetrievalExecutor, - "diffusers": DiffusersExecutor, - "api": APIExecutor, - "ssh": SSHExecutor, - "omni_text2image": OmniText2ImageExecutor, - "omni_text2speech": OmniText2SpeechExecutor, - "omni_text2audio": OmniText2AudioExecutor, - "omni_text2general": OmniText2GeneralExecutor, +_OMNI_SPECS: dict[str, tuple[str, str]] = { + "omni_text2image": ("OmniText2ImageExecutor", ".omni_text2image_executor"), + "omni_text2speech": ("OmniText2SpeechExecutor", ".omni_text2speech_executor"), + "omni_text2audio": ("OmniText2AudioExecutor", ".omni_text2audio_executor"), + "omni_text2general": ("OmniText2GeneralExecutor", ".omni_text2general_executor"), } + +class _LazyRegistry(Mapping[str, type[Executor] | None]): + """Mapping that defers omni executor imports to first key lookup. + + Backed by a plain dict; omni entries are resolved (their module imported) + on the first __getitem__ / get() access via the Mapping mixin. + """ + + def __init__( + self, + eager: dict[str, type[Executor] | None], + lazy: dict[str, tuple[str, str]], + ) -> None: + self._data: dict[str, type[Executor] | None] = dict(eager) + self._lazy: dict[str, tuple[str, str]] = dict(lazy) + for key in lazy: + self._data[key] = None + + def _resolve(self, key: str) -> type[Executor] | None: + cls_name, module = self._lazy.pop(key) + cls = _import_executor(cls_name, module) + self._data[key] = cls + return cls + + def __getitem__(self, key: str) -> type[Executor] | None: + if key in self._lazy: + return self._resolve(key) + return self._data[key] + + def __iter__(self) -> Iterator[str]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + +EXECUTOR_REGISTRY: Mapping[str, type[Executor] | None] = _LazyRegistry( + { + "vllm": VLLMExecutor, + "vllm_lora": VLLMLoRAExecutor, + "ppo": PPOExecutor, + "dpo": DPOExecutor, + "sft": SFTExecutor, + "lora_sft": LoRASFTExecutor, + "image_classification_training": ImageClassificationTrainingExecutor, + "default": HFTransformersExecutor, + "rag": RAGExecutor, + "agent": AgentExecutor, + "echo": EchoExecutor, + "data_profiling": DataProfilingExecutor, + "data_retrieval": DataRetrievalExecutor, + "diffusers": DiffusersExecutor, + "api": APIExecutor, + "ssh": SSHExecutor, + }, + _OMNI_SPECS, +) + EXECUTOR_CLASS_NAMES: dict[str, str] = { "vllm": "VLLMExecutor", "vllm_lora": "VLLMLoRAExecutor", @@ -100,7 +130,8 @@ def _import_executor(name: str, module: str) -> type[Executor] | None: "omni_text2general": "OmniText2GeneralExecutor", } -IMPORT_ERRORS: dict[str, str] = dict(_IMPORT_ERRORS) +# Live reference so errors from lazy imports appear immediately +IMPORT_ERRORS: dict[str, str] = _IMPORT_ERRORS __all__ = [ name @@ -121,10 +152,6 @@ def _import_executor(name: str, module: str) -> type[Executor] | None: "DiffusersExecutor": DiffusersExecutor, "APIExecutor": APIExecutor, "SSHExecutor": SSHExecutor, - "OmniText2ImageExecutor": OmniText2ImageExecutor, - "OmniText2SpeechExecutor": OmniText2SpeechExecutor, - "OmniText2AudioExecutor": OmniText2AudioExecutor, - "OmniText2GeneralExecutor": OmniText2GeneralExecutor, }.items() if cls is not None ] + ["EXECUTOR_REGISTRY", "IMPORT_ERRORS", "EXECUTOR_CLASS_NAMES"] diff --git a/src/worker/executors/vllm_executor.py b/src/worker/executors/vllm_executor.py index e2a7cf5..f064000 100644 --- a/src/worker/executors/vllm_executor.py +++ b/src/worker/executors/vllm_executor.py @@ -11,6 +11,7 @@ import copy import datetime import gc +import importlib.metadata import json import logging import os @@ -157,6 +158,20 @@ def _detect_available_gpus() -> int: pass return 1 + @staticmethod + def _vllm_plugins_allowlist_excluding_omni() -> str: + """Every `vllm.general_plugins` entry point except vllm-omni's, comma-joined. + + Omni's plugin rebinds vllm.v1.request.Request and crashes plain LLM init. + """ + names: list[str] = [] + for ep in importlib.metadata.entry_points(group="vllm.general_plugins"): + module = ep.value.split(":", 1)[0].strip() + if module == "vllm_omni" or module.startswith("vllm_omni."): + continue + names.append(ep.name) + return ",".join(names) + @staticmethod def _compute_safe_utilization(requested: float) -> tuple[float, float | None]: if torch is None or not torch.cuda.is_available(): @@ -372,6 +387,20 @@ def _ensure_llm( "Ignored unrecognized vLLM config fields: %s", list(vllm_cfg.keys()) ) + # Required workaround: vLLM's load_general_plugins() auto-loads the + # vllm_omni_register_models entry point, which imports vllm_omni and + # rebinds vllm.v1.request.Request → OmniRequest, crashing plain-model + # warmup. The lazy __init__.py fix only prevents the direct-import path; + # this allowlist stops the entry-point path. Revert once vllm_omni ships + # an OmniRequest that is a drop-in for vllm.v1.request.Request. + if "VLLM_PLUGINS" not in os.environ: + os.environ["VLLM_PLUGINS"] = self._vllm_plugins_allowlist_excluding_omni() + else: + logger.debug( + "VLLM_PLUGINS already set (%r); skipping omni exclusion", + os.environ["VLLM_PLUGINS"], + ) + tp_candidates: list[int] = [] seen_tp: set[int] = set() initial_tp = max(1, tensor_parallel_size) @@ -731,7 +760,7 @@ def _flatten_image_embedding_chunks( ) -> torch.Tensor: if not embedding_chunks: raise ExecutionError( - "Loaded image embedding list must be non-empty " f"(task={task_id})." + f"Loaded image embedding list must be non-empty (task={task_id})." ) if group_sizes is not None: if len(embedding_chunks) != len(group_sizes): diff --git a/src/worker/main.py b/src/worker/main.py index 52d2a8b..21b389c 100644 --- a/src/worker/main.py +++ b/src/worker/main.py @@ -1,6 +1,7 @@ import argparse import logging import signal +from collections.abc import Mapping from shared.schemas.worker import WorkerCapabilities from shared.tasks.task_type import TaskType @@ -59,7 +60,7 @@ def initialize_executors( hardware: WorkerHardware, logger: logging.Logger, lifecycle: Lifecycle, - registry: dict[str, type[Executor] | None] | None = None, + registry: Mapping[str, type[Executor] | None] | None = None, import_errors: dict[str, str] | None = None, cuda_available: bool | None = None, enable_mp_executors: bool = True, @@ -167,7 +168,7 @@ def init_executor(key: str, *, gpu_required: bool = False): def build_capabilities( executors: dict[str, Executor], - registry: dict[str, type[Executor] | None] | None = None, + registry: Mapping[str, type[Executor] | None] | None = None, ) -> WorkerCapabilities: registry = registry or EXECUTOR_REGISTRY supported_task_types = frozenset[TaskType]().union( diff --git a/tests/worker/test_executor_registry.py b/tests/worker/test_executor_registry.py index 361d28b..145cf8f 100644 --- a/tests/worker/test_executor_registry.py +++ b/tests/worker/test_executor_registry.py @@ -1,5 +1,6 @@ """Tests for the executor registry and safe import mechanism.""" +from collections.abc import Mapping from pathlib import Path from types import SimpleNamespace @@ -65,7 +66,7 @@ def test_training_executors_are_wrapped_for_isolation(self) -> None: def test_import_executor_does_not_crash(self) -> None: """The registry should load without raising, even when deps are missing.""" # If we got here, the import at module level already succeeded - assert isinstance(EXECUTOR_REGISTRY, dict) + assert isinstance(EXECUTOR_REGISTRY, Mapping) assert isinstance(IMPORT_ERRORS, dict) def test_import_executor_rejects_non_executor( diff --git a/tests/worker/test_vllm_scoped_plugins.py b/tests/worker/test_vllm_scoped_plugins.py new file mode 100644 index 0000000..94ea70b --- /dev/null +++ b/tests/worker/test_vllm_scoped_plugins.py @@ -0,0 +1,165 @@ +"""Tests for VLLMExecutor's vllm-omni plugin exclusion via VLLM_PLUGINS.""" + +import importlib.metadata +import os +import subprocess +import sys + +import pytest + +pytest.importorskip("vllm", reason="vllm not installed (needs --extra inference-gpu)") + +from worker.executors.vllm_executor import VLLMExecutor # noqa: E402 + +_OMNI_EP_NAME = "vllm_omni_register_models" +_OMNI_EP_VALUE = "vllm_omni.engine.arg_utils:register_omni_models_to_vllm" + + +def _ep(name: str, value: str) -> importlib.metadata.EntryPoint: + return importlib.metadata.EntryPoint( + name=name, value=value, group="vllm.general_plugins" + ) + + +def _mock_eps( + monkeypatch: pytest.MonkeyPatch, eps: list[importlib.metadata.EntryPoint] +) -> None: + monkeypatch.setattr( + importlib.metadata, + "entry_points", + lambda group: eps if group == "vllm.general_plugins" else [], + ) + + +class TestPluginAllowlist: + def test_excludes_omni_keeps_others(self, monkeypatch: pytest.MonkeyPatch) -> None: + discovered = [ + _ep(_OMNI_EP_NAME, _OMNI_EP_VALUE), + _ep("some_other_plugin", "some_pkg.plugins:register"), + ] + _mock_eps(monkeypatch, discovered) + result = VLLMExecutor._vllm_plugins_allowlist_excluding_omni() + assert result == "some_other_plugin" + + def test_empty_when_only_omni(self, monkeypatch: pytest.MonkeyPatch) -> None: + _mock_eps(monkeypatch, [_ep(_OMNI_EP_NAME, _OMNI_EP_VALUE)]) + # "" is a valid allowlist (load none); NOT None which would mean load-all. + assert VLLMExecutor._vllm_plugins_allowlist_excluding_omni() == "" + + def test_filters_by_module_not_ep_name( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + discovered = [ + _ep("renamed_entry", "vllm_omni.something.else:fn"), + _ep("keep", "another_pkg:fn"), + ] + _mock_eps(monkeypatch, discovered) + assert VLLMExecutor._vllm_plugins_allowlist_excluding_omni() == "keep" + + def test_real_install_does_not_allowlist_omni(self) -> None: + eps = importlib.metadata.entry_points(group="vllm.general_plugins") + if not any( + ep.value.split(":", 1)[0].strip().startswith("vllm_omni") for ep in eps + ): + pytest.skip("vllm_omni not a registered vllm.general_plugins entry") + names = VLLMExecutor._vllm_plugins_allowlist_excluding_omni().split(",") + assert _OMNI_EP_NAME not in names + + +class TestAllowlistHelperIsPure: + def test_helper_does_not_mutate_env(self) -> None: + before = os.environ.get("VLLM_PLUGINS") + VLLMExecutor._vllm_plugins_allowlist_excluding_omni() + assert os.environ.get("VLLM_PLUGINS") == before + + +_OMNI_MODULE_PREFIX = "vllm_omni" + + +class TestPluginScopingPreventsRebind: + def test_allowlist_keeps_vllm_omni_unloaded(self) -> None: + """VLLM_PLUGINS fix allowlist keeps vllm_omni unloaded and Request unpatched.""" + + def _is_omni_ep(ep: importlib.metadata.EntryPoint) -> bool: + return ep.value.split(":", 1)[0].strip().startswith(_OMNI_MODULE_PREFIX) + + eps = importlib.metadata.entry_points(group="vllm.general_plugins") + if not any(_is_omni_ep(ep) for ep in eps): + pytest.skip( + "vllm_omni not a registered vllm.general_plugins entry" + ) # noqa: E501 + + allowlist = VLLMExecutor._vllm_plugins_allowlist_excluding_omni() + script = "\n".join( + [ + "import sys, os", + f"os.environ['VLLM_PLUGINS'] = {allowlist!r}", + "try:", + " from vllm.plugins import load_general_plugins", + " load_general_plugins()", + "except Exception:", + " pass", + "import vllm.v1.request as R", + "print('omni_loaded:', 'vllm_omni' in sys.modules)", + "print('request_module:', R.Request.__module__)", + ] + ) + result = subprocess.run( # nosec B603 — fixed argv list, no shell=True, sys.executable + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=30, + ) + assert result.returncode == 0, result.stderr + assert "omni_loaded: False" in result.stdout + assert "request_module: vllm.v1.request" in result.stdout + + +def _subprocess_env() -> dict[str, str]: + """Build an env with src/ on PYTHONPATH so subprocess can import worker.*.""" + import os + import pathlib + + src_dir = str(pathlib.Path(__file__).parents[2] / "src") + existing = os.environ.get("PYTHONPATH", "") + pythonpath = f"{src_dir}:{existing}" if existing else src_dir + return {**os.environ, "PYTHONPATH": pythonpath} + + +class TestLazyOmniRegistry: + """Regression tests: import worker.executors must not pull in vllm_omni.""" + + _OMNI_KEYS = ( + "omni_text2image", + "omni_text2speech", + "omni_text2audio", + "omni_text2general", + ) + + def test_executor_registry_has_omni_keys(self) -> None: + from worker.executors import EXECUTOR_REGISTRY + + for key in self._OMNI_KEYS: + assert key in EXECUTOR_REGISTRY + + def test_import_does_not_load_vllm_omni(self) -> None: + """import worker.executors alone must not load vllm_omni or patch Request.""" + script = "\n".join( + [ + "import sys", + "import worker.executors", + "import vllm.v1.request as R", + "print('omni_in_modules:', 'vllm_omni' in sys.modules)", + "print('request_class:', R.Request.__name__)", + ] + ) + result = subprocess.run( # nosec B603 — fixed argv list, no shell=True, sys.executable + [sys.executable, "-c", script], + capture_output=True, + text=True, + timeout=60, + env=_subprocess_env(), + ) + assert result.returncode == 0, result.stderr + assert "omni_in_modules: False" in result.stdout + assert "request_class: Request" in result.stdout From 53aed8f96d48831ec262dad960770454dfe821b2 Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Sat, 20 Jun 2026 16:59:08 +0700 Subject: [PATCH 2/3] Address PR comments Signed-off-by: Zhengyuan Su --- examples/templates/inference_vllm_qwen3.yaml | 6 +- src/worker/executors/__init__.py | 103 +++++++----------- src/worker/executors/omni_executor_base.py | 10 ++ .../executors/omni_text2audio_executor.py | 62 ++++------- .../executors/omni_text2general_executor.py | 19 ++-- .../executors/omni_text2image_executor.py | 17 +-- .../executors/omni_text2speech_executor.py | 16 +-- src/worker/executors/vllm_executor.py | 5 +- tests/worker/test_vllm_scoped_plugins.py | 10 +- 9 files changed, 91 insertions(+), 157 deletions(-) diff --git a/examples/templates/inference_vllm_qwen3.yaml b/examples/templates/inference_vllm_qwen3.yaml index b6cc447..1517732 100644 --- a/examples/templates/inference_vllm_qwen3.yaml +++ b/examples/templates/inference_vllm_qwen3.yaml @@ -2,10 +2,8 @@ apiVersion: flowmesh/v1 kind: InferenceTask metadata: name: inference-vllm-qwen3 - owner: timzsu annotations: - description: > - Text-generation inference with Qwen3-8B. + description: Text-generation inference with Qwen3-8B. spec: taskType: inference @@ -16,7 +14,7 @@ spec: gpu: type: any count: 1 - memory: 24Gi # Qwen3-8B BF16 weights + KV; 8Gi is too small + memory: 24Gi model: source: type: huggingface diff --git a/src/worker/executors/__init__.py b/src/worker/executors/__init__.py index 5934952..1621045 100644 --- a/src/worker/executors/__init__.py +++ b/src/worker/executors/__init__.py @@ -1,5 +1,4 @@ import importlib -from collections.abc import Iterator, Mapping from .base_executor import Executor @@ -42,70 +41,41 @@ def _import_executor(name: str, module: str) -> type[Executor] | None: DiffusersExecutor = _import_executor("DiffusersExecutor", ".diffusers_executor") APIExecutor = _import_executor("APIExecutor", ".api_executor") SSHExecutor = _import_executor("SSHExecutor", ".ssh_executor") -_OMNI_SPECS: dict[str, tuple[str, str]] = { - "omni_text2image": ("OmniText2ImageExecutor", ".omni_text2image_executor"), - "omni_text2speech": ("OmniText2SpeechExecutor", ".omni_text2speech_executor"), - "omni_text2audio": ("OmniText2AudioExecutor", ".omni_text2audio_executor"), - "omni_text2general": ("OmniText2GeneralExecutor", ".omni_text2general_executor"), -} - - -class _LazyRegistry(Mapping[str, type[Executor] | None]): - """Mapping that defers omni executor imports to first key lookup. - - Backed by a plain dict; omni entries are resolved (their module imported) - on the first __getitem__ / get() access via the Mapping mixin. - """ - - def __init__( - self, - eager: dict[str, type[Executor] | None], - lazy: dict[str, tuple[str, str]], - ) -> None: - self._data: dict[str, type[Executor] | None] = dict(eager) - self._lazy: dict[str, tuple[str, str]] = dict(lazy) - for key in lazy: - self._data[key] = None - - def _resolve(self, key: str) -> type[Executor] | None: - cls_name, module = self._lazy.pop(key) - cls = _import_executor(cls_name, module) - self._data[key] = cls - return cls - - def __getitem__(self, key: str) -> type[Executor] | None: - if key in self._lazy: - return self._resolve(key) - return self._data[key] - - def __iter__(self) -> Iterator[str]: - return iter(self._data) - - def __len__(self) -> int: - return len(self._data) - - -EXECUTOR_REGISTRY: Mapping[str, type[Executor] | None] = _LazyRegistry( - { - "vllm": VLLMExecutor, - "vllm_lora": VLLMLoRAExecutor, - "ppo": PPOExecutor, - "dpo": DPOExecutor, - "sft": SFTExecutor, - "lora_sft": LoRASFTExecutor, - "image_classification_training": ImageClassificationTrainingExecutor, - "default": HFTransformersExecutor, - "rag": RAGExecutor, - "agent": AgentExecutor, - "echo": EchoExecutor, - "data_profiling": DataProfilingExecutor, - "data_retrieval": DataRetrievalExecutor, - "diffusers": DiffusersExecutor, - "api": APIExecutor, - "ssh": SSHExecutor, - }, - _OMNI_SPECS, +OmniText2ImageExecutor = _import_executor( + "OmniText2ImageExecutor", ".omni_text2image_executor" +) +OmniText2SpeechExecutor = _import_executor( + "OmniText2SpeechExecutor", ".omni_text2speech_executor" ) +OmniText2AudioExecutor = _import_executor( + "OmniText2AudioExecutor", ".omni_text2audio_executor" +) +OmniText2GeneralExecutor = _import_executor( + "OmniText2GeneralExecutor", ".omni_text2general_executor" +) + +EXECUTOR_REGISTRY: dict[str, type[Executor] | None] = { + "vllm": VLLMExecutor, + "vllm_lora": VLLMLoRAExecutor, + "ppo": PPOExecutor, + "dpo": DPOExecutor, + "sft": SFTExecutor, + "lora_sft": LoRASFTExecutor, + "image_classification_training": ImageClassificationTrainingExecutor, + "default": HFTransformersExecutor, + "rag": RAGExecutor, + "agent": AgentExecutor, + "echo": EchoExecutor, + "data_profiling": DataProfilingExecutor, + "data_retrieval": DataRetrievalExecutor, + "diffusers": DiffusersExecutor, + "api": APIExecutor, + "ssh": SSHExecutor, + "omni_text2image": OmniText2ImageExecutor, + "omni_text2speech": OmniText2SpeechExecutor, + "omni_text2audio": OmniText2AudioExecutor, + "omni_text2general": OmniText2GeneralExecutor, +} EXECUTOR_CLASS_NAMES: dict[str, str] = { "vllm": "VLLMExecutor", @@ -130,7 +100,6 @@ def __len__(self) -> int: "omni_text2general": "OmniText2GeneralExecutor", } -# Live reference so errors from lazy imports appear immediately IMPORT_ERRORS: dict[str, str] = _IMPORT_ERRORS __all__ = [ @@ -152,6 +121,10 @@ def __len__(self) -> int: "DiffusersExecutor": DiffusersExecutor, "APIExecutor": APIExecutor, "SSHExecutor": SSHExecutor, + "OmniText2ImageExecutor": OmniText2ImageExecutor, + "OmniText2SpeechExecutor": OmniText2SpeechExecutor, + "OmniText2AudioExecutor": OmniText2AudioExecutor, + "OmniText2GeneralExecutor": OmniText2GeneralExecutor, }.items() if cls is not None ] + ["EXECUTOR_REGISTRY", "IMPORT_ERRORS", "EXECUTOR_CLASS_NAMES"] diff --git a/src/worker/executors/omni_executor_base.py b/src/worker/executors/omni_executor_base.py index 8fd9ec2..4074ebe 100644 --- a/src/worker/executors/omni_executor_base.py +++ b/src/worker/executors/omni_executor_base.py @@ -7,6 +7,7 @@ """ import gc +import importlib.util import json import logging import os @@ -22,6 +23,7 @@ from shared.schemas.result import BaseExecutorResult from shared.tasks.specs import TaskSpecStrictBase from shared.utils.parsing import to_bool, to_int +from worker.config import WorkerConfig from .base_executor import ExecutionError, Executor, ExecutorTask from .mixins.inference import InferenceMixin @@ -40,6 +42,10 @@ logger = logging.getLogger(__name__) +# find_spec, not import: importing vllm_omni rebinds vllm.v1.request.Request +# and breaks plain-model warmup. +_HAS_OMNI: bool = importlib.util.find_spec("vllm_omni") is not None + class OmniResult(BaseExecutorResult): executor: str @@ -66,6 +72,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._omni_spec: tuple[Any, ...] | None = None self._stage_configs_tmp: Path | None = None + @classmethod + def is_available(cls, config: WorkerConfig) -> bool: + return _HAS_OMNI + def run(self, task: ExecutorTask, out_dir: Path) -> OmniResult: spec = self.require_spec(task, self._TASK_SPEC_TYPE) spec_dict = spec.model_dump(by_alias=True) diff --git a/src/worker/executors/omni_text2audio_executor.py b/src/worker/executors/omni_text2audio_executor.py index d1e4d9d..c4d9cae 100644 --- a/src/worker/executors/omni_text2audio_executor.py +++ b/src/worker/executors/omni_text2audio_executor.py @@ -16,40 +16,6 @@ np = None torch = None -try: - from vllm_omni.entrypoints.omni import Omni - - _HAS_OMNI = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.entrypoints.omni import Omni - else: - Omni = None - _HAS_OMNI = False - -try: - from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt - - _HAS_OMNI_DIFFUSION = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt - else: - OmniDiffusionSamplingParams = None - OmniTextPrompt = None - _HAS_OMNI_DIFFUSION = False - -try: - from vllm_omni.platforms import current_omni_platform - - _HAS_OMNI_PLATFORM = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.platforms import current_omni_platform - else: - current_omni_platform = None - _HAS_OMNI_PLATFORM = False - from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType from shared.tasks.specs import TaskSpecStrictBase @@ -58,7 +24,12 @@ from shared.utils.parsing import to_float, to_int from .base_executor import ExecutionError, ExecutorTask -from .omni_executor_base import OmniExecutorBase, OmniResult, extract_multimodal_output +from .omni_executor_base import ( + _HAS_OMNI, + OmniExecutorBase, + OmniResult, + extract_multimodal_output, +) logger = logging.getLogger(__name__) EXECUTOR_NAME = "omni_text2audio" @@ -86,7 +57,7 @@ def prepare(self) -> None: raise ExecutionError("omni_text2audio requires torch.") if np is None: raise ExecutionError("omni_text2audio requires numpy.") - if not (_HAS_OMNI and _HAS_OMNI_DIFFUSION): + if not _HAS_OMNI: raise ExecutionError( "vllm_omni is not installed; cannot use omni_text2audio executor." ) @@ -99,6 +70,8 @@ def _run_inner( out_dir: Path, ) -> OmniText2AudioResult: assert isinstance(spec, OmniText2AudioSpecStrict) + from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt + prompts = self._collect_text_inputs(spec, task.task_id) cfg = _bgm_cfg(spec_dict) @@ -229,6 +202,8 @@ def _run_inner( # ── model ──────────────────────────────────────────────────────────── def _ensure_omni(self, spec_dict: dict[str, Any]) -> None: + from vllm_omni.entrypoints.omni import Omni + model_name = self.resolve_model_identifier( spec_dict, _bgm_cfg(spec_dict), @@ -261,12 +236,15 @@ def _bgm_cfg(spec_dict: dict[str, Any]) -> dict[str, Any]: def _resolve_generator_device() -> str: - if _HAS_OMNI_PLATFORM and current_omni_platform is not None: - device_type = ( - str(getattr(current_omni_platform, "device_type", "")).strip().lower() - ) - if device_type in {"cuda", "cpu", "mps", "xpu", "hpu", "npu"}: - return device_type + if _HAS_OMNI: + from vllm_omni.platforms import current_omni_platform + + if current_omni_platform is not None: + device_type = ( + str(getattr(current_omni_platform, "device_type", "")).strip().lower() + ) + if device_type in {"cuda", "cpu", "mps", "xpu", "hpu", "npu"}: + return device_type if torch is not None and torch.cuda.is_available(): return "cuda" return "cpu" diff --git a/src/worker/executors/omni_text2general_executor.py b/src/worker/executors/omni_text2general_executor.py index 04cde5c..9d09459 100644 --- a/src/worker/executors/omni_text2general_executor.py +++ b/src/worker/executors/omni_text2general_executor.py @@ -15,26 +15,17 @@ SamplingParams = None _HAS_VLLM = False -try: - from vllm_omni.entrypoints.omni import Omni - - _HAS_OMNI = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.entrypoints.omni import Omni - else: - Omni = None - _HAS_OMNI = False - from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType from shared.tasks.specs import TaskSpecStrictBase from shared.tasks.specs.omni import OmniText2GeneralSpecStrict from shared.tasks.task_type import TaskType from shared.utils.parsing import as_list, to_bool, to_float, to_int, to_int_list +from worker.config import WorkerConfig from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import ( + _HAS_OMNI, OmniExecutorBase, OmniResult, extract_audio_from_mm, @@ -66,6 +57,10 @@ class OmniText2GeneralExecutor(OmniExecutorBase): supported_task_types = frozenset({TaskType.OMNI_TEXT2GENERAL}) _TASK_SPEC_TYPE = OmniText2GeneralSpecStrict + @classmethod + def is_available(cls, config: WorkerConfig) -> bool: + return _HAS_OMNI and _HAS_VLLM + def prepare(self) -> None: if not _HAS_OMNI: raise ExecutionError( @@ -207,6 +202,8 @@ def _run_inner( # ── model ──────────────────────────────────────────────────────────── def _ensure_omni(self, spec_dict: dict[str, Any]) -> None: + from vllm_omni.entrypoints.omni import Omni + cfg = _narration_cfg(spec_dict) model_name = self.resolve_model_identifier( spec_dict, diff --git a/src/worker/executors/omni_text2image_executor.py b/src/worker/executors/omni_text2image_executor.py index 75f7de4..0265bac 100644 --- a/src/worker/executors/omni_text2image_executor.py +++ b/src/worker/executors/omni_text2image_executor.py @@ -2,18 +2,7 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, Any - -try: - from vllm_omni.entrypoints.omni import Omni - - _HAS_OMNI = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.entrypoints.omni import Omni - else: - Omni = None - _HAS_OMNI = False +from typing import Any from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType @@ -23,7 +12,7 @@ from shared.utils.parsing import as_list from .base_executor import ExecutionError, ExecutorTask -from .omni_executor_base import OmniExecutorBase, OmniResult +from .omni_executor_base import _HAS_OMNI, OmniExecutorBase, OmniResult logger = logging.getLogger(__name__) EXECUTOR_NAME = "omni_text2image" @@ -110,6 +99,8 @@ def _run_inner( # ── model ──────────────────────────────────────────────────────────── def _ensure_omni(self, spec_dict: dict[str, Any]) -> None: + from vllm_omni.entrypoints.omni import Omni + cfg = self.omni_cfg(spec_dict, "omni:image generation", "omni_text2image") model_name = self.resolve_model_identifier( spec_dict, diff --git a/src/worker/executors/omni_text2speech_executor.py b/src/worker/executors/omni_text2speech_executor.py index 911a115..966637b 100644 --- a/src/worker/executors/omni_text2speech_executor.py +++ b/src/worker/executors/omni_text2speech_executor.py @@ -3,18 +3,7 @@ import logging import os from pathlib import Path -from typing import TYPE_CHECKING, Any - -try: - from vllm_omni.entrypoints.omni import Omni - - _HAS_OMNI = True -except Exception: - if TYPE_CHECKING: - from vllm_omni.entrypoints.omni import Omni - else: - Omni = None - _HAS_OMNI = False +from typing import Any from shared.schemas.artifact import ArtifactRef from shared.schemas.governance import SpanType @@ -25,6 +14,7 @@ from .base_executor import ExecutionError, ExecutorTask from .omni_executor_base import ( + _HAS_OMNI, OmniExecutorBase, OmniResult, extract_audio_from_mm, @@ -124,6 +114,8 @@ def _run_inner( # ── model ──────────────────────────────────────────────────────────── def _ensure_omni(self, spec_dict: dict[str, Any]) -> None: + from vllm_omni.entrypoints.omni import Omni + cfg = self.omni_cfg(spec_dict, "omni:tts", "omni_text2speech") model_name = self.resolve_model_identifier( spec_dict, diff --git a/src/worker/executors/vllm_executor.py b/src/worker/executors/vllm_executor.py index f064000..a767dd8 100644 --- a/src/worker/executors/vllm_executor.py +++ b/src/worker/executors/vllm_executor.py @@ -390,9 +390,8 @@ def _ensure_llm( # Required workaround: vLLM's load_general_plugins() auto-loads the # vllm_omni_register_models entry point, which imports vllm_omni and # rebinds vllm.v1.request.Request → OmniRequest, crashing plain-model - # warmup. The lazy __init__.py fix only prevents the direct-import path; - # this allowlist stops the entry-point path. Revert once vllm_omni ships - # an OmniRequest that is a drop-in for vllm.v1.request.Request. + # warmup. Revert once vllm_omni ships an OmniRequest that is a drop-in + # for vllm.v1.request.Request. if "VLLM_PLUGINS" not in os.environ: os.environ["VLLM_PLUGINS"] = self._vllm_plugins_allowlist_excluding_omni() else: diff --git a/tests/worker/test_vllm_scoped_plugins.py b/tests/worker/test_vllm_scoped_plugins.py index 94ea70b..5657473 100644 --- a/tests/worker/test_vllm_scoped_plugins.py +++ b/tests/worker/test_vllm_scoped_plugins.py @@ -2,6 +2,7 @@ import importlib.metadata import os +import pathlib import subprocess import sys @@ -85,9 +86,7 @@ def _is_omni_ep(ep: importlib.metadata.EntryPoint) -> bool: eps = importlib.metadata.entry_points(group="vllm.general_plugins") if not any(_is_omni_ep(ep) for ep in eps): - pytest.skip( - "vllm_omni not a registered vllm.general_plugins entry" - ) # noqa: E501 + pytest.skip("vllm_omni not a registered vllm.general_plugins entry") allowlist = VLLMExecutor._vllm_plugins_allowlist_excluding_omni() script = "\n".join( @@ -117,10 +116,7 @@ def _is_omni_ep(ep: importlib.metadata.EntryPoint) -> bool: def _subprocess_env() -> dict[str, str]: """Build an env with src/ on PYTHONPATH so subprocess can import worker.*.""" - import os - import pathlib - - src_dir = str(pathlib.Path(__file__).parents[2] / "src") + src_dir = (pathlib.Path(__file__).parents[2] / "src").as_posix() existing = os.environ.get("PYTHONPATH", "") pythonpath = f"{src_dir}:{existing}" if existing else src_dir return {**os.environ, "PYTHONPATH": pythonpath} From c321535bbd3b06ff10c616b5765611a5115806fb Mon Sep 17 00:00:00 2001 From: Zhengyuan Su Date: Sat, 20 Jun 2026 17:06:04 +0700 Subject: [PATCH 3/3] Minor fix Signed-off-by: Zhengyuan Su --- src/worker/executors/vllm_executor.py | 2 +- tests/worker/test_vllm_scoped_plugins.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/worker/executors/vllm_executor.py b/src/worker/executors/vllm_executor.py index a767dd8..62563bb 100644 --- a/src/worker/executors/vllm_executor.py +++ b/src/worker/executors/vllm_executor.py @@ -390,7 +390,7 @@ def _ensure_llm( # Required workaround: vLLM's load_general_plugins() auto-loads the # vllm_omni_register_models entry point, which imports vllm_omni and # rebinds vllm.v1.request.Request → OmniRequest, crashing plain-model - # warmup. Revert once vllm_omni ships an OmniRequest that is a drop-in + # warmup. Revert once vllm_omni ships an OmniRequest that is a drop-in # for vllm.v1.request.Request. if "VLLM_PLUGINS" not in os.environ: os.environ["VLLM_PLUGINS"] = self._vllm_plugins_allowlist_excluding_omni() diff --git a/tests/worker/test_vllm_scoped_plugins.py b/tests/worker/test_vllm_scoped_plugins.py index 5657473..566a298 100644 --- a/tests/worker/test_vllm_scoped_plugins.py +++ b/tests/worker/test_vllm_scoped_plugins.py @@ -10,6 +10,7 @@ pytest.importorskip("vllm", reason="vllm not installed (needs --extra inference-gpu)") +from worker.executors import EXECUTOR_REGISTRY from worker.executors.vllm_executor import VLLMExecutor # noqa: E402 _OMNI_EP_NAME = "vllm_omni_register_models" @@ -133,8 +134,6 @@ class TestLazyOmniRegistry: ) def test_executor_registry_has_omni_keys(self) -> None: - from worker.executors import EXECUTOR_REGISTRY - for key in self._OMNI_KEYS: assert key in EXECUTOR_REGISTRY