Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions src/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import asyncio
import json
import os
from dataclasses import asdict
from typing import Dict

Expand All @@ -12,15 +13,29 @@
from src.core.dataclasses.config import ClientConfig


USER_ID_FILE = os.path.expanduser("~/.huri_user_id")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mettre dans un .env peut etre ? ou pas en vrai

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Je ne sais pas comment c'est fait dans l'industrie, faut se renseigner



def load_user_id() -> str | None:
if os.path.exists(USER_ID_FILE):
with open(USER_ID_FILE) as f:
return f.read().strip()
return None


def save_user_id(_user_id: str):
with open(USER_ID_FILE, "w") as f:
f.write(_user_id)

def load_client_config(path: str) -> ClientConfig:
with open(path) as f:
dict_config = OmegaConf.load(f)
raw_resolved = OmegaConf.to_container(dict_config, resolve=True)
raw_resolved = OmegaConf.to_container(dict_config, resolve=True)

if not isinstance(raw_resolved, Dict):
raise RuntimeError("error yaml does not output a dict")
if not isinstance(raw_resolved, Dict):
raise RuntimeError("error yaml does not output a dict")

return ClientConfig.from_dict(raw_resolved)
return ClientConfig.from_dict(raw_resolved)


async def stream_audio():
Expand All @@ -38,7 +53,19 @@ async def stream_audio():
async with websockets.connect(config.huri_url) as ws:
print("Connected to server")

await ws.send(json.dumps(asdict(config)))
payload = asdict(config)
_user_id = load_user_id()
if _user_id:
payload["_user_id"] = _user_id
print(f"Reconnecting with _user_id: {_user_id}")

await ws.send(json.dumps(payload))

init_msg = json.loads(await ws.recv())
if init_msg.get("type") == "session_init":
_user_id = init_msg["_user_id"]
save_user_id(_user_id)
print(f"Session started with _user_id: {_user_id}")

async def receive(ws: websockets.ClientConnection):
while True:
Expand Down
21 changes: 9 additions & 12 deletions src/core/huri.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Je pense qu'il faut initialiser la Session avec le user_id pour eviter le if "rag"
Quitte a ajouter un ModuleWithId qui s'initialise avec un user id, dans le module Factory

Copy link
Copy Markdown
Contributor Author

@MatthiasvonRakowski MatthiasvonRakowski May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On doit laisser la création du User ID dans huri.py car si on a plusieurs ModuleWithHandleAndID on peut se trouver avec plusieurs ID différent. Par contre je suis d'accord pour le "rag" Et je change ça

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from src.modules.factory import Module, ModuleFactory
from src.modules.utils.sender import Sender

from .app import app
from .dataclasses.config import ClientConfig
from .session import Session
Expand All @@ -24,29 +23,28 @@ def __init__(
self.factory = ModuleFactory(handles)
for name, module_cls in modules.items():
self.factory.register(name, module_cls)

self.clients: Dict[str, Session] = {}

@app.websocket("/session")
async def run_session(self, ws: WebSocket):
await ws.accept()

client_config_raw: Dict = await ws.receive_json()

client_config = ClientConfig.from_dict(client_config_raw)

_user_id = client_config_raw.get("_user_id") or str(uuid.uuid4())

senders: List[Module] = [
Sender(ws, topic) for topic in client_config.topic_list
]
modules: List[Module] = (
self.factory.create_from_config(client_config.modules) + senders
self.factory.create_from_config(_user_id, client_config.modules) + senders
)

session_id = str(uuid.uuid4())
await ws.send_json({"type": "session_init", "_user_id": _user_id})

session_id = str(uuid.uuid4())
self.clients[session_id] = Session(modules)

print("Client registered successfully with config:", client_config)
print(f"Client registered with _user_id={_user_id}, config: {client_config}")

async def receive_loop(session: Session, ws: WebSocket):
try:
Expand All @@ -55,9 +53,8 @@ async def receive_loop(session: Session, ws: WebSocket):
if "bytes" in msg:
chunk = msg["bytes"]
await session.publish("chunk", chunk)
# else:
# data = msg
# await session.publish(data["type"], data["data"])
except (WebSocketDisconnect, RuntimeError):
print(f"Client disconnected")
print(f"Client {_user_id} disconnected")

await receive_loop(self.clients[session_id], ws)
del self.clients[session_id]
15 changes: 12 additions & 3 deletions src/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ async def process(self, _) -> Optional[Any]:
class ModuleWithHandle(Module):
_handle_cls: Type[Any]

def __init__(self, handle: handle.DeploymentHandle):
super().__init__()
self.handle = handle
def __init__(self, _handle: handle.DeploymentHandle = None, **kwargs):
super().__init__(**kwargs)
self._handle = _handle

class ModuleWithId(Module):
def __init__(self, _user_id: str, **kwargs):
super().__init__(**kwargs)
self._user_id = _user_id

def get_user_context(self) -> dict:
"""Override in subclasses to provide user-specific context."""
return {"_user_id": self._user_id}
31 changes: 22 additions & 9 deletions src/modules/factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from os import name
from typing import Any, Dict, List, Mapping, Type

from src.core.dataclasses.config import ModuleConfig
from src.core.module import Module, ModuleWithHandle, handle
from src.core.module import Module, ModuleWithHandle, handle, ModuleWithId


class ModuleFactory:
def __init__(self, handles):
self._registry: Dict[str, Type[Module]] = {}
self._handles = handles


def register(self, name: str, module_cls: Type[Module]) -> None:
if not issubclass(module_cls, Module):
raise TypeError(f"{module_cls} must inherit from Module")
Expand All @@ -19,29 +21,40 @@ def register(self, name: str, module_cls: Type[Module]) -> None:
)
self._registry[name] = module_cls

def create(self, name: str, args: Mapping[str, Any] | None = None) -> Module:

def create(
self,
_user_id: str,
name: str,
args: Mapping[str, Any] | None = None
) -> Module:

if name not in self._registry:
raise ValueError(f"Unknown module '{name}'")

module_cls = self._registry[name]

if args is None:
args = {}
kwargs = dict(args or {})

if issubclass(module_cls, ModuleWithHandle):
if name not in self._handles:
raise RuntimeError(
f"Handles not bound for '{name}'. Check your config first."
)

return module_cls(handle=self._handles[name], **args)
return module_cls(**args)
kwargs["_handle"] = self._handles[name]

if issubclass(module_cls, ModuleWithId):
kwargs["_user_id"] = _user_id

return module_cls(**kwargs)

def create_from_config(
self, module_configs: Dict[str, ModuleConfig]
self, _user_id: str, module_configs: Dict[str, ModuleConfig]
) -> List[Module]:
modules: List[Module] = []
for _, module_config in module_configs.items():
modules.append(self.create(module_config.name, module_config.args))

modules.append(self.create(_user_id, module_config.name, module_config.args))
if modules == []:
raise Exception

Expand Down
3 changes: 2 additions & 1 deletion src/modules/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from src.modules.speech_to_text.record_speech import MIC
from src.modules.speech_to_text.speech_to_text import STT
from src.modules.speech_to_text.text_aggregator import TAG
from src.modules.rag.rag import RAG

from .factory import Module


def get_modules() -> Dict[str, Type[Module]]:
return {"mic": MIC, "stt": STT, "tag": TAG}
return {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG}
72 changes: 72 additions & 0 deletions src/modules/rag/ingestion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# ingestion.py
import argparse
import os
import uuid

from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
from sentence_transformers import SentenceTransformer

USER_ID_FILE = os.path.expanduser("~/.huri_user_id")


def get_user_id(provided_id: str = None) -> str:
"""Use provided ID, or load from file, or generate new one."""
if provided_id:
return provided_id
if os.path.exists(USER_ID_FILE):
with open(USER_ID_FILE) as f:
return f.read().strip()
new_id = str(uuid.uuid4())
with open(USER_ID_FILE, "w") as f:
f.write(new_id)
return new_id


def main():
parser = argparse.ArgumentParser(description="Ingest documents into Qdrant")
parser.add_argument("--user-id", type=str, default=None, help="User ID (reads from ~/.huri_user_id if not provided)")
parser.add_argument("--collection", type=str, default="documents")
parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333")
args = parser.parse_args()

user_id = get_user_id(args.user_id)
print(f"Ingesting for user_id: {user_id}")

client = QdrantClient(url=args.qdrant_url)
model = SentenceTransformer("BAAI/bge-large-en-v1.5")

collections = [c.name for c in client.get_collections().collections]
if args.collection not in collections:
client.create_collection(
collection_name=args.collection,
vectors_config=VectorParams(size=1024, distance=Distance.COSINE),
)
print(f"Created collection: {args.collection}")

docs = [
{"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"},
{"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"},
{"text": "The team consists of 5 developers and 2 designers.", "source": "team.pdf"},
{"text": "The main office is located in Paris, France.", "source": "info.pdf"},
]

points = []
for doc in docs:
vector = model.encode(doc["text"], normalize_embeddings=True).tolist()
points.append(PointStruct(
id=str(uuid.uuid4()),
vector=vector,
payload={
"text": doc["text"],
"source": doc["source"],
"user_id": user_id,
},
))

client.upsert(collection_name=args.collection, points=points)
print(f"Ingested {len(points)} documents for user {user_id}")


if __name__ == "__main__":
main()
Loading