Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/templates/inference_vllm_qwen3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
apiVersion: flowmesh/v1
kind: InferenceTask
metadata:
name: inference-vllm-qwen3
annotations:
description: Text-generation inference with Qwen3-8B.

spec:
taskType: inference
resources:
hardware:
cpu: 8
memory: 16Gi
gpu:
type: any
count: 1
memory: 24Gi
model:
source:
type: huggingface
identifier: Qwen/Qwen3-8B
revision: main
vllm:
tensor_parallel_size: 1
gpu_memory_utilization: 0.9
trust_remote_code: true
data:
type: list
items:
- "Say hi in three words."
- "Name a primary color."
inference:
max_tokens: 16
temperature: 0.0
output:
destination:
type: local
artifacts:
- results.json
- logs
2 changes: 1 addition & 1 deletion src/worker/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _import_executor(name: str, module: str) -> type[Executor] | None:
"omni_text2general": "OmniText2GeneralExecutor",
}

IMPORT_ERRORS: dict[str, str] = dict(_IMPORT_ERRORS)
IMPORT_ERRORS: dict[str, str] = _IMPORT_ERRORS

__all__ = [
name
Expand Down
10 changes: 10 additions & 0 deletions src/worker/executors/omni_executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import gc
import importlib.util
import json
import logging
import os
Expand All @@ -22,6 +23,7 @@
from shared.schemas.result import BaseExecutorResult
from shared.tasks.specs import TaskSpecStrictBase
from shared.utils.parsing import to_bool, to_int
from worker.config import WorkerConfig

from .base_executor import ExecutionError, Executor, ExecutorTask
from .mixins.inference import InferenceMixin
Expand All @@ -40,6 +42,10 @@

logger = logging.getLogger(__name__)

# find_spec, not import: importing vllm_omni rebinds vllm.v1.request.Request
# and breaks plain-model warmup.
_HAS_OMNI: bool = importlib.util.find_spec("vllm_omni") is not None


class OmniResult(BaseExecutorResult):
executor: str
Expand All @@ -66,6 +72,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._omni_spec: tuple[Any, ...] | None = None
self._stage_configs_tmp: Path | None = None

@classmethod
def is_available(cls, config: WorkerConfig) -> bool:
return _HAS_OMNI

def run(self, task: ExecutorTask, out_dir: Path) -> OmniResult:
spec = self.require_spec(task, self._TASK_SPEC_TYPE)
spec_dict = spec.model_dump(by_alias=True)
Expand Down
62 changes: 20 additions & 42 deletions src/worker/executors/omni_text2audio_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,6 @@
np = None
torch = None

try:
from vllm_omni.entrypoints.omni import Omni

_HAS_OMNI = True
except Exception:
if TYPE_CHECKING:
from vllm_omni.entrypoints.omni import Omni
else:
Omni = None
_HAS_OMNI = False

try:
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt

_HAS_OMNI_DIFFUSION = True
except Exception:
if TYPE_CHECKING:
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
else:
OmniDiffusionSamplingParams = None
OmniTextPrompt = None
_HAS_OMNI_DIFFUSION = False

try:
from vllm_omni.platforms import current_omni_platform

_HAS_OMNI_PLATFORM = True
except Exception:
if TYPE_CHECKING:
from vllm_omni.platforms import current_omni_platform
else:
current_omni_platform = None
_HAS_OMNI_PLATFORM = False

from shared.schemas.artifact import ArtifactRef
from shared.schemas.governance import SpanType
from shared.tasks.specs import TaskSpecStrictBase
Expand All @@ -58,7 +24,12 @@
from shared.utils.parsing import to_float, to_int

from .base_executor import ExecutionError, ExecutorTask
from .omni_executor_base import OmniExecutorBase, OmniResult, extract_multimodal_output
from .omni_executor_base import (
_HAS_OMNI,
OmniExecutorBase,
OmniResult,
extract_multimodal_output,
)

logger = logging.getLogger(__name__)
EXECUTOR_NAME = "omni_text2audio"
Expand Down Expand Up @@ -86,7 +57,7 @@ def prepare(self) -> None:
raise ExecutionError("omni_text2audio requires torch.")
if np is None:
raise ExecutionError("omni_text2audio requires numpy.")
if not (_HAS_OMNI and _HAS_OMNI_DIFFUSION):
if not _HAS_OMNI:
raise ExecutionError(
"vllm_omni is not installed; cannot use omni_text2audio executor."
)
Expand All @@ -99,6 +70,8 @@ def _run_inner(
out_dir: Path,
) -> OmniText2AudioResult:
assert isinstance(spec, OmniText2AudioSpecStrict)
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt

prompts = self._collect_text_inputs(spec, task.task_id)

cfg = _bgm_cfg(spec_dict)
Expand Down Expand Up @@ -229,6 +202,8 @@ def _run_inner(
# ── model ────────────────────────────────────────────────────────────

def _ensure_omni(self, spec_dict: dict[str, Any]) -> None:
from vllm_omni.entrypoints.omni import Omni

model_name = self.resolve_model_identifier(
spec_dict,
_bgm_cfg(spec_dict),
Expand Down Expand Up @@ -261,12 +236,15 @@ def _bgm_cfg(spec_dict: dict[str, Any]) -> dict[str, Any]:


def _resolve_generator_device() -> str:
if _HAS_OMNI_PLATFORM and current_omni_platform is not None:
device_type = (
str(getattr(current_omni_platform, "device_type", "")).strip().lower()
)
if device_type in {"cuda", "cpu", "mps", "xpu", "hpu", "npu"}:
return device_type
if _HAS_OMNI:
from vllm_omni.platforms import current_omni_platform

if current_omni_platform is not None:
device_type = (
str(getattr(current_omni_platform, "device_type", "")).strip().lower()
)
if device_type in {"cuda", "cpu", "mps", "xpu", "hpu", "npu"}:
return device_type
if torch is not None and torch.cuda.is_available():
return "cuda"
return "cpu"
Expand Down
19 changes: 8 additions & 11 deletions src/worker/executors/omni_text2general_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,17 @@
SamplingParams = None
_HAS_VLLM = False

try:
from vllm_omni.entrypoints.omni import Omni

_HAS_OMNI = True
except Exception:
if TYPE_CHECKING:
from vllm_omni.entrypoints.omni import Omni
else:
Omni = None
_HAS_OMNI = False

from shared.schemas.artifact import ArtifactRef
from shared.schemas.governance import SpanType
from shared.tasks.specs import TaskSpecStrictBase
from shared.tasks.specs.omni import OmniText2GeneralSpecStrict
from shared.tasks.task_type import TaskType
from shared.utils.parsing import as_list, to_bool, to_float, to_int, to_int_list
from worker.config import WorkerConfig

from .base_executor import ExecutionError, ExecutorTask
from .omni_executor_base import (
_HAS_OMNI,
OmniExecutorBase,
OmniResult,
extract_audio_from_mm,
Expand Down Expand Up @@ -66,6 +57,10 @@ class OmniText2GeneralExecutor(OmniExecutorBase):
supported_task_types = frozenset({TaskType.OMNI_TEXT2GENERAL})
_TASK_SPEC_TYPE = OmniText2GeneralSpecStrict

@classmethod
def is_available(cls, config: WorkerConfig) -> bool:
return _HAS_OMNI and _HAS_VLLM

def prepare(self) -> None:
if not _HAS_OMNI:
raise ExecutionError(
Expand Down Expand Up @@ -207,6 +202,8 @@ def _run_inner(
# ── model ────────────────────────────────────────────────────────────

def _ensure_omni(self, spec_dict: dict[str, Any]) -> None:
from vllm_omni.entrypoints.omni import Omni

cfg = _narration_cfg(spec_dict)
model_name = self.resolve_model_identifier(
spec_dict,
Expand Down
17 changes: 4 additions & 13 deletions src/worker/executors/omni_text2image_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,7 @@

import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any

try:
from vllm_omni.entrypoints.omni import Omni

_HAS_OMNI = True
except Exception:
if TYPE_CHECKING:
from vllm_omni.entrypoints.omni import Omni
else:
Omni = None
_HAS_OMNI = False
from typing import Any

from shared.schemas.artifact import ArtifactRef
from shared.schemas.governance import SpanType
Expand All @@ -23,7 +12,7 @@
from shared.utils.parsing import as_list

from .base_executor import ExecutionError, ExecutorTask
from .omni_executor_base import OmniExecutorBase, OmniResult
from .omni_executor_base import _HAS_OMNI, OmniExecutorBase, OmniResult

logger = logging.getLogger(__name__)
EXECUTOR_NAME = "omni_text2image"
Expand Down Expand Up @@ -110,6 +99,8 @@ def _run_inner(
# ── model ────────────────────────────────────────────────────────────

def _ensure_omni(self, spec_dict: dict[str, Any]) -> None:
from vllm_omni.entrypoints.omni import Omni

cfg = self.omni_cfg(spec_dict, "omni:image generation", "omni_text2image")
model_name = self.resolve_model_identifier(
spec_dict,
Expand Down
16 changes: 4 additions & 12 deletions src/worker/executors/omni_text2speech_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,7 @@
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any

try:
from vllm_omni.entrypoints.omni import Omni

_HAS_OMNI = True
except Exception:
if TYPE_CHECKING:
from vllm_omni.entrypoints.omni import Omni
else:
Omni = None
_HAS_OMNI = False
from typing import Any

from shared.schemas.artifact import ArtifactRef
from shared.schemas.governance import SpanType
Expand All @@ -25,6 +14,7 @@

from .base_executor import ExecutionError, ExecutorTask
from .omni_executor_base import (
_HAS_OMNI,
OmniExecutorBase,
OmniResult,
extract_audio_from_mm,
Expand Down Expand Up @@ -124,6 +114,8 @@ def _run_inner(
# ── model ────────────────────────────────────────────────────────────

def _ensure_omni(self, spec_dict: dict[str, Any]) -> None:
from vllm_omni.entrypoints.omni import Omni

cfg = self.omni_cfg(spec_dict, "omni:tts", "omni_text2speech")
model_name = self.resolve_model_identifier(
spec_dict,
Expand Down
30 changes: 29 additions & 1 deletion src/worker/executors/vllm_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import copy
import datetime
import gc
import importlib.metadata
import json
import logging
import os
Expand Down Expand Up @@ -157,6 +158,20 @@ def _detect_available_gpus() -> int:
pass
return 1

@staticmethod
def _vllm_plugins_allowlist_excluding_omni() -> str:
"""Every `vllm.general_plugins` entry point except vllm-omni's, comma-joined.

Omni's plugin rebinds vllm.v1.request.Request and crashes plain LLM init.
"""
names: list[str] = []
for ep in importlib.metadata.entry_points(group="vllm.general_plugins"):
module = ep.value.split(":", 1)[0].strip()
if module == "vllm_omni" or module.startswith("vllm_omni."):
continue
names.append(ep.name)
return ",".join(names)

@staticmethod
def _compute_safe_utilization(requested: float) -> tuple[float, float | None]:
if torch is None or not torch.cuda.is_available():
Expand Down Expand Up @@ -372,6 +387,19 @@ def _ensure_llm(
"Ignored unrecognized vLLM config fields: %s", list(vllm_cfg.keys())
)

# Required workaround: vLLM's load_general_plugins() auto-loads the
# vllm_omni_register_models entry point, which imports vllm_omni and
# rebinds vllm.v1.request.Request → OmniRequest, crashing plain-model
# warmup. Revert once vllm_omni ships an OmniRequest that is a drop-in
# for vllm.v1.request.Request.
if "VLLM_PLUGINS" not in os.environ:
os.environ["VLLM_PLUGINS"] = self._vllm_plugins_allowlist_excluding_omni()
else:
logger.debug(
"VLLM_PLUGINS already set (%r); skipping omni exclusion",
os.environ["VLLM_PLUGINS"],
)

tp_candidates: list[int] = []
seen_tp: set[int] = set()
initial_tp = max(1, tensor_parallel_size)
Expand Down Expand Up @@ -731,7 +759,7 @@ def _flatten_image_embedding_chunks(
) -> torch.Tensor:
if not embedding_chunks:
raise ExecutionError(
"Loaded image embedding list must be non-empty " f"(task={task_id})."
f"Loaded image embedding list must be non-empty (task={task_id})."
)
if group_sizes is not None:
if len(embedding_chunks) != len(group_sizes):
Expand Down
5 changes: 3 additions & 2 deletions src/worker/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
import signal
from collections.abc import Mapping

from shared.schemas.worker import WorkerCapabilities
from shared.tasks.task_type import TaskType
Expand Down Expand Up @@ -59,7 +60,7 @@ def initialize_executors(
hardware: WorkerHardware,
logger: logging.Logger,
lifecycle: Lifecycle,
registry: dict[str, type[Executor] | None] | None = None,
registry: Mapping[str, type[Executor] | None] | None = None,
import_errors: dict[str, str] | None = None,
cuda_available: bool | None = None,
enable_mp_executors: bool = True,
Expand Down Expand Up @@ -167,7 +168,7 @@ def init_executor(key: str, *, gpu_required: bool = False):

def build_capabilities(
executors: dict[str, Executor],
registry: dict[str, type[Executor] | None] | None = None,
registry: Mapping[str, type[Executor] | None] | None = None,
) -> WorkerCapabilities:
registry = registry or EXECUTOR_REGISTRY
supported_task_types = frozenset[TaskType]().union(
Expand Down
Loading