Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/app.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tu pourrais possiblement faire la config de qdrant et OllamaService danss le config file huri.yaml je pense, ce srait plus clean

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T as raison je vais regarder pour le faire

Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@
from src.core.huri import HuRI
from src.modules.factory import bind_deployment_handles
from src.modules.modules import get_modules
from src.modules.rag.docker_services import OllamaService, QdrantService


def build_app() -> Application:
modules = get_modules()
handles = bind_deployment_handles(modules)

app: Application = HuRI.bind(modules, handles) # type: ignore[attr-defined]
qdrant = QdrantService.bind(port=6333)
ollama = OllamaService.options(num_replicas=1).bind(
model="mistral:7b",
image="ollama/ollama:rocm",
gpu_devices=True,
)

handles = bind_deployment_handles(modules, ollama=ollama, qdrant=qdrant)
app: Application = HuRI.bind(modules, handles)
return app


Expand Down
37 changes: 32 additions & 5 deletions src/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import asyncio
import json
import os
from dataclasses import asdict
from typing import Dict

Expand All @@ -12,15 +13,29 @@
from src.core.dataclasses.config import ClientConfig


USER_ID_FILE = os.path.expanduser("~/.huri_user_id")


def load_user_id() -> str | None:
if os.path.exists(USER_ID_FILE):
with open(USER_ID_FILE) as f:
return f.read().strip()
return None


def save_user_id(_user_id: str):
with open(USER_ID_FILE, "w") as f:
f.write(_user_id)

def load_client_config(path: str) -> ClientConfig:
with open(path) as f:
dict_config = OmegaConf.load(f)
raw_resolved = OmegaConf.to_container(dict_config, resolve=True)
raw_resolved = OmegaConf.to_container(dict_config, resolve=True)

if not isinstance(raw_resolved, Dict):
raise RuntimeError("error yaml does not output a dict")
if not isinstance(raw_resolved, Dict):
raise RuntimeError("error yaml does not output a dict")

return ClientConfig.from_dict(raw_resolved)
return ClientConfig.from_dict(raw_resolved)


async def stream_audio():
Expand All @@ -38,7 +53,19 @@ async def stream_audio():
async with websockets.connect(config.huri_url) as ws:
print("Connected to server")

await ws.send(json.dumps(asdict(config)))
payload = asdict(config)
_user_id = load_user_id()
if _user_id:
payload["_user_id"] = _user_id
print(f"Reconnecting with _user_id: {_user_id}")

await ws.send(json.dumps(payload))

init_msg = json.loads(await ws.recv())
if init_msg.get("type") == "session_init":
_user_id = init_msg["_user_id"]
save_user_id(_user_id)
print(f"Session started with _user_id: {_user_id}")

async def receive(ws: websockets.ClientConnection):
while True:
Expand Down
21 changes: 9 additions & 12 deletions src/core/huri.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from src.modules.factory import Module, ModuleFactory
from src.modules.utils.sender import Sender

from .app import app
from .dataclasses.config import ClientConfig
from .session import Session
Expand All @@ -24,29 +23,28 @@ def __init__(
self.factory = ModuleFactory(handles)
for name, module_cls in modules.items():
self.factory.register(name, module_cls)

self.clients: Dict[str, Session] = {}

@app.websocket("/session")
async def run_session(self, ws: WebSocket):
await ws.accept()

client_config_raw: Dict = await ws.receive_json()

client_config = ClientConfig.from_dict(client_config_raw)

_user_id = client_config_raw.get("_user_id") or str(uuid.uuid4())

senders: List[Module] = [
Sender(ws, topic) for topic in client_config.topic_list
]
modules: List[Module] = (
self.factory.create_from_config(client_config.modules) + senders
self.factory.create_from_config(_user_id, client_config.modules) + senders
)

session_id = str(uuid.uuid4())
await ws.send_json({"type": "session_init", "_user_id": _user_id})

session_id = str(uuid.uuid4())
self.clients[session_id] = Session(modules)

print("Client registered successfully with config:", client_config)
print(f"Client registered with _user_id={_user_id}, config: {client_config}")

async def receive_loop(session: Session, ws: WebSocket):
try:
Expand All @@ -55,9 +53,8 @@ async def receive_loop(session: Session, ws: WebSocket):
if "bytes" in msg:
chunk = msg["bytes"]
await session.publish("chunk", chunk)
# else:
# data = msg
# await session.publish(data["type"], data["data"])
except (WebSocketDisconnect, RuntimeError):
print(f"Client disconnected")
print(f"Client {_user_id} disconnected")

await receive_loop(self.clients[session_id], ws)
del self.clients[session_id]
15 changes: 12 additions & 3 deletions src/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ async def process(self, _) -> Optional[Any]:
class ModuleWithHandle(Module):
_handle_cls: Type[Any]

def __init__(self, handle: handle.DeploymentHandle):
super().__init__()
self.handle = handle
def __init__(self, _handle: handle.DeploymentHandle = None, **kwargs):
super().__init__(**kwargs)
self._handle = _handle

class ModuleWithId(Module):
def __init__(self, _user_id: str, **kwargs):
super().__init__(**kwargs)
self._user_id = _user_id

def get_user_context(self) -> dict:
"""Override in subclasses to provide user-specific context."""
return {"_user_id": self._user_id}
44 changes: 33 additions & 11 deletions src/modules/factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from os import name
from typing import Any, Dict, List, Mapping, Type

from src.core.dataclasses.config import ModuleConfig
from src.core.module import Module, ModuleWithHandle, handle
from src.core.module import Module, ModuleWithHandle, handle, ModuleWithId


class ModuleFactory:
def __init__(self, handles):
self._registry: Dict[str, Type[Module]] = {}
self._handles = handles


def register(self, name: str, module_cls: Type[Module]) -> None:
if not issubclass(module_cls, Module):
raise TypeError(f"{module_cls} must inherit from Module")
Expand All @@ -19,29 +21,40 @@ def register(self, name: str, module_cls: Type[Module]) -> None:
)
self._registry[name] = module_cls

def create(self, name: str, args: Mapping[str, Any] | None = None) -> Module:

def create(
self,
_user_id: str,
name: str,
args: Mapping[str, Any] | None = None
) -> Module:

if name not in self._registry:
raise ValueError(f"Unknown module '{name}'")

module_cls = self._registry[name]

if args is None:
args = {}
kwargs = dict(args or {})

if issubclass(module_cls, ModuleWithHandle):
if name not in self._handles:
raise RuntimeError(
f"Handles not bound for '{name}'. Check your config first."
)

return module_cls(handle=self._handles[name], **args)
return module_cls(**args)
kwargs["_handle"] = self._handles[name]

if issubclass(module_cls, ModuleWithId):
kwargs["_user_id"] = _user_id

return module_cls(**kwargs)

def create_from_config(
self, module_configs: Dict[str, ModuleConfig]
self, _user_id: str, module_configs: Dict[str, ModuleConfig]
) -> List[Module]:
modules: List[Module] = []
for _, module_config in module_configs.items():
modules.append(self.create(module_config.name, module_config.args))

modules.append(self.create(_user_id, module_config.name, module_config.args))
if modules == []:
raise Exception

Expand All @@ -50,15 +63,24 @@ def create_from_config(

def bind_deployment_handles(
modules: Dict[str, Type[Module]],
**service_handles,
) -> Dict[str, handle.DeploymentHandle]:
handles: Dict[str, handle.DeploymentHandle] = {}
for name, module_cls in modules.items():
if not issubclass(module_cls, ModuleWithHandle):
continue

if not hasattr(module_cls, "_handle_cls"):
raise TypeError(f"{module_cls.__name__} must define _handle_cls")

handle_cls = module_cls._handle_cls
handles[name] = handle_cls.bind()

if name == "rag" and service_handles:
handles[name] = handle_cls.bind(
ollama_handle=service_handles.get("ollama"),
qdrant_handle=service_handles.get("qdrant"),
)
else:
handles[name] = handle_cls.bind()

return handles
3 changes: 2 additions & 1 deletion src/modules/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from src.modules.speech_to_text.record_speech import MIC
from src.modules.speech_to_text.speech_to_text import STT
from src.modules.speech_to_text.text_aggregator import TAG
from src.modules.rag.rag import RAG

from .factory import Module


def get_modules() -> Dict[str, Type[Module]]:
return {"mic": MIC, "stt": STT, "tag": TAG}
return {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG}
Loading