From 81826b7b4701b1ba0f91fdc0700e369713191db0 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Thu, 7 May 2026 12:06:18 +0100 Subject: [PATCH 1/5] wip(rag): V1 of a rag working with an ollama llm working with the current pipeline with a RAG + Embedding + LLM (ollama). Should work with vLLM but not tested --- src/modules/modules.py | 3 +- src/modules/rag/rag.py | 310 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 src/modules/rag/rag.py diff --git a/src/modules/modules.py b/src/modules/modules.py index 69ebb45..7283983 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -3,9 +3,10 @@ from src.modules.speech_to_text.record_speech import MIC from src.modules.speech_to_text.speech_to_text import STT from src.modules.speech_to_text.text_aggregator import TAG +from src.modules.rag.rag import RAG from .factory import Module def get_modules() -> Dict[str, Type[Module]]: - return {"mic": MIC, "stt": STT, "tag": TAG} + return {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG} diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py new file mode 100644 index 0000000..0884c76 --- /dev/null +++ b/src/modules/rag/rag.py @@ -0,0 +1,310 @@ +from typing import Any, Optional +from dataclasses import dataclass, field + +from ray import data, serve +from ray.serve import handle +from src.core.module import ModuleWithHandle +from qdrant_client.models import Filter, FieldCondition, MatchValue +from sentence_transformers import SentenceTransformer +from qdrant_client import QdrantClient + + +import httpx + + +@dataclass +class RAGQuery: + """What flows from RAG module to RAGHandle.""" + user_id: str + question: str + preferences: dict = field(default_factory=dict) + # preferences can include: language, tone, response_format, max_length, system_prompt, extra_instructions, etc. + + +@dataclass +class RAGResult: + """What RAGHandle returns.""" + answer: str + sources: list[dict] = field(default_factory=list) + + +@serve.deployment( + num_replicas=2, + ray_actor_options={"num_cpus": 1}, +) +class RAGHandle: + """ + Stateless RAG processor. Knows nothing about sessions. + Receives a user_id + question, uses user_id to find the right + collection/data in the vector DB, runs embed -> search -> LLM. + """ + + def __init__( + self, + qdrant_url: str = "http://localhost:6333", + default_collection: str = "documents", + embedding_model: str = "BAAI/bge-large-en-v1.5", + llm_provider: str = "ollama", # "vllm", "ollama", "api" + llm_url: str = "http://localhost:11434", + llm_model: str = "mistral:7b", + llm_api_key: str = "", + top_k: int = 5, + score_threshold: float = 0.5, + ): + self.embed_model = SentenceTransformer(embedding_model) + self.qdrant = QdrantClient(url=qdrant_url) + self.default_collection = default_collection + self.top_k = top_k + self.score_threshold = score_threshold + + self.llm_provider = llm_provider + self.llm_url = llm_url + self.llm_model = llm_model + self.llm_api_key = llm_api_key + + def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: + """ + Given a user_id, decide which collection to search + and which filters to apply. + + Options (pick what fits your data model): + A) One collection per user: collection = f"user_{user_id}" + B) Shared collection, filter by user_id in payload + C) Lookup in a DB to find the user's config + """ + + # Option A: separate collection per user + # collection = f"user_{user_id}" + # filters = None + + # Option B: shared collection with user_id filter (recommended) + collection = self.default_collection + filters = {"user_id": user_id} + + return collection, filters + + + def _embed(self, text) -> list[float]: + return self.embed_model.encode(str(text), normalize_embeddings=True).tolist() + + + + def _search( + self, + query_vector: list[float], + collection: str, + filters: dict | None = None, + ) -> list[dict]: + + # Build qdrant filter from user context + qdrant_filter = None + if filters: + conditions = [ + FieldCondition(key=k, match=MatchValue(value=v)) + for k, v in filters.items() + ] + qdrant_filter = Filter(must=conditions) + + results = self.qdrant.query_points( + collection_name=collection, + query=query_vector, + query_filter=qdrant_filter, + limit=self.top_k, + score_threshold=self.score_threshold, + ).points + + return [ + { + "text": point.payload.get("text", ""), + "score": point.score, + "metadata": {k: v for k, v in point.payload.items() if k != "text"}, + } + for point in results + ] + + + def _build_prompt( + self, + question: str, + chunks: list[dict], + preferences: dict, + ) -> tuple[str, str]: + + parts = [ + "You are a helpful assistant. Answer based on the provided context.", + "If the context is insufficient, say so clearly.", + ] + if preferences.get("language"): + parts.append(f"Always respond in {preferences['language']}.") + if preferences.get("tone"): + parts.append(f"Use a {preferences['tone']} tone.") + if preferences.get("response_format") == "bullet_points": + parts.append("Format your answer as bullet points.") + elif preferences.get("response_format") == "short": + parts.append("Keep your answer to 2-3 sentences maximum.") + if preferences.get("extra_instructions"): + parts.append(preferences["extra_instructions"]) + system_prompt = " ".join(parts) + + if not chunks: + user_prompt = ( + "No relevant context was found.\n\n" + f"Question: {question}\n\n" + "Answer based on general knowledge and mention no documents were found." + ) + else: + context_parts = [] + for i, chunk in enumerate(chunks, 1): + source = chunk["metadata"].get("source", "unknown") + context_parts.append( + f"[{i}] (source: {source}, score: {chunk['score']:.2f})\n{chunk['text']}" + ) + context_block = "\n\n".join(context_parts) + user_prompt = ( + f"Context:\n{context_block}\n\n" + f"Question: {question}\n\n" + "Answer based on the context above. Cite sources by number." + ) + + return system_prompt, user_prompt + + + async def _llm_generate( + self, + system_prompt: str, + user_prompt: str, + preferences: dict, + ) -> str: + max_tokens = preferences.get("max_length", 1024) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + if self.llm_provider == "vllm": + return await self._call_openai_compatible( + f"{self.llm_url}/v1/chat/completions", messages, max_tokens + ) + elif self.llm_provider == "ollama": + return await self._call_ollama(messages, max_tokens) + elif self.llm_provider == "api": + return await self._call_openai_compatible( + f"{self.llm_url}/v1/chat/completions", messages, max_tokens, self.llm_api_key + ) + else: + raise ValueError(f"Unknown llm_provider: {self.llm_provider}") + + + async def _call_openai_compatible( + self, url: str, messages: list, max_tokens: int, api_key: str = "" + ) -> str: + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(url, headers=headers, json={ + "model": self.llm_model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0.1, + }) + resp.raise_for_status() + return resp.json()["choices"][0]["message"]["content"] + + + async def _call_ollama(self, messages: list, max_tokens: int) -> str: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{self.llm_url}/api/chat", json={ + "model": self.llm_model, + "messages": messages, + "stream": False, + "options": {"num_predict": max_tokens, "temperature": 0.1}, + }) + resp.raise_for_status() + return resp.json()["message"]["content"] + + + async def process(self, query: RAGQuery) -> RAGResult: + """ + Main entry point. Called by the RAG module. + Uses user_id to determine which collection / filters to use. + """ + + print(f"[RAG] Question: {query.question}") + collection, filters = self._resolve_user_context(query.user_id) + query_vector = self._embed(query.question) + chunks = self._search(query_vector, collection, filters) + + + print(f"[RAG] Found {len(chunks)} chunks") + for c in chunks: + print(f" - score: {c['score']:.2f} | {c['text'][:100]}...") + + system_prompt, user_prompt = self._build_prompt( + query.question, chunks, query.preferences + ) + print(f"[RAG] System prompt: {system_prompt[:200]}...") + answer = await self._llm_generate(system_prompt, user_prompt, query.preferences) + print(f"[RAG] Answer: {answer}") + + return RAGResult( + answer=answer, + sources=[ + {"text": c["text"], "score": c["score"], "metadata": c["metadata"]} + for c in chunks + ], + ) + + +class RAG(ModuleWithHandle): + """ + Session-bound module. HuRI instantiates this when a client connects, + passing the user_id from the WebSocket config. + + Listens to "question" events. + Forwards question + user_id to the detached RAGHandle. + Emits "rag_response" event with the answer. + """ + _handle_cls = RAGHandle + input_type = "question" + output_type = "rag_response" + + def __init__( + self, + handle: handle.DeploymentHandle[RAGHandle], + user_id: str = "", + language: str = "en", + tone: str = "formal", + response_format: str = "paragraph", + max_length: int = 1024, + extra_instructions: str = "", + ): + super().__init__(handle) + self.user_id = user_id + self.preferences = { + "language": language, + "tone": tone, + "response_format": response_format, + "max_length": max_length, + "extra_instructions": extra_instructions, + } + + async def process(self, data) -> Optional[Any]: + """ + Called when a "question" event arrives through the event bus. + Packages user_id + question, sends to the stateless RAGHandle. + """ + question_text = data.text if hasattr(data, 'text') else str(data) + + query = RAGQuery( + user_id=self.user_id if self.user_id else "anonymous", + question=question_text, + preferences=self.preferences, + ) + + result: RAGResult = await self.handle.process.remote(query) + return result + + def update_preferences(self, new_preferences: dict): + """Client can update preferences mid-session via the event bus.""" + self.preferences.update(new_preferences) From 87588fbd4f3093afe2556b4abc26fee2fffb1ef2 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Thu, 7 May 2026 12:17:18 +0100 Subject: [PATCH 2/5] wip(rag): set the filter at None to be able to restrieve collections without a user_id --- src/modules/rag/rag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 0884c76..07a3244 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -79,7 +79,7 @@ def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: # Option B: shared collection with user_id filter (recommended) collection = self.default_collection - filters = {"user_id": user_id} + filters = None #{"user_id": user_id} return collection, filters From 9b7dadea6f372792d867071451a63807d9513bb6 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Mon, 11 May 2026 15:39:53 +0100 Subject: [PATCH 3/5] feat(id): add ids to make it work with the rag system + an ingestion system --- src/client.py | 37 +++++++++++++++--- src/core/huri.py | 12 ++++-- src/modules/rag/ingestion.py | 75 ++++++++++++++++++++++++++++++++++++ src/modules/rag/rag.py | 8 ++-- 4 files changed, 119 insertions(+), 13 deletions(-) create mode 100644 src/modules/rag/ingestion.py diff --git a/src/client.py b/src/client.py index ca29146..2240709 100644 --- a/src/client.py +++ b/src/client.py @@ -1,6 +1,7 @@ import argparse import asyncio import json +import os from dataclasses import asdict from typing import Dict @@ -12,15 +13,29 @@ from src.core.dataclasses.config import ClientConfig +USER_ID_FILE = os.path.expanduser("~/.huri_user_id") + + +def load_user_id() -> str | None: + if os.path.exists(USER_ID_FILE): + with open(USER_ID_FILE) as f: + return f.read().strip() + return None + + +def save_user_id(user_id: str): + with open(USER_ID_FILE, "w") as f: + f.write(user_id) + def load_client_config(path: str) -> ClientConfig: with open(path) as f: dict_config = OmegaConf.load(f) - raw_resolved = OmegaConf.to_container(dict_config, resolve=True) + raw_resolved = OmegaConf.to_container(dict_config, resolve=True) - if not isinstance(raw_resolved, Dict): - raise RuntimeError("error yaml does not output a dict") + if not isinstance(raw_resolved, Dict): + raise RuntimeError("error yaml does not output a dict") - return ClientConfig.from_dict(raw_resolved) + return ClientConfig.from_dict(raw_resolved) async def stream_audio(): @@ -38,7 +53,19 @@ async def stream_audio(): async with websockets.connect(config.huri_url) as ws: print("Connected to server") - await ws.send(json.dumps(asdict(config))) + payload = asdict(config) + user_id = load_user_id() + if user_id: + payload["user_id"] = user_id + print(f"Reconnecting with user_id: {user_id}") + + await ws.send(json.dumps(payload)) + + init_msg = json.loads(await ws.recv()) + if init_msg.get("type") == "session_init": + user_id = init_msg["user_id"] + save_user_id(user_id) + print(f"Session started with user_id: {user_id}") async def receive(ws: websockets.ClientConnection): while True: diff --git a/src/core/huri.py b/src/core/huri.py index 6d6d747..a07c4fe 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -30,11 +30,17 @@ def __init__( @app.websocket("/session") async def run_session(self, ws: WebSocket): await ws.accept() - client_config_raw: Dict = await ws.receive_json() - client_config = ClientConfig.from_dict(client_config_raw) + user_id = client_config_raw.get("user_id") or str(uuid.uuid4()) + await ws.send_json({"type": "session_init", "user_id": user_id}) + + if "rag" in client_config.modules: + if client_config.modules["rag"].args is None: + client_config.modules["rag"].args = {} + client_config.modules["rag"].args["user_id"] = user_id + senders: List[Module] = [ Sender(ws, topic) for topic in client_config.topic_list ] @@ -43,9 +49,7 @@ async def run_session(self, ws: WebSocket): ) session_id = str(uuid.uuid4()) - self.clients[session_id] = Session(modules) - print("Client registered successfully with config:", client_config) async def receive_loop(session: Session, ws: WebSocket): diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py new file mode 100644 index 0000000..d517c67 --- /dev/null +++ b/src/modules/rag/ingestion.py @@ -0,0 +1,75 @@ +# ingestion.py +import argparse +import os +import uuid + +from qdrant_client import QdrantClient +from qdrant_client.models import VectorParams, Distance, PointStruct +from sentence_transformers import SentenceTransformer + +USER_ID_FILE = os.path.expanduser("~/.huri_user_id") + + +def get_user_id(provided_id: str = None) -> str: + """Use provided ID, or load from file, or generate new one.""" + if provided_id: + return provided_id + if os.path.exists(USER_ID_FILE): + with open(USER_ID_FILE) as f: + return f.read().strip() + new_id = str(uuid.uuid4()) + with open(USER_ID_FILE, "w") as f: + f.write(new_id) + return new_id + + +def main(): + parser = argparse.ArgumentParser(description="Ingest documents into Qdrant") + parser.add_argument("--user-id", type=str, default=None, help="User ID (reads from ~/.huri_user_id if not provided)") + parser.add_argument("--collection", type=str, default="documents") + parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333") + args = parser.parse_args() + + user_id = get_user_id(args.user_id) + print(f"Ingesting for user_id: {user_id}") + + client = QdrantClient(url=args.qdrant_url) + model = SentenceTransformer("BAAI/bge-large-en-v1.5") + + # Create collection if it doesn't exist + collections = [c.name for c in client.get_collections().collections] + if args.collection not in collections: + client.create_collection( + collection_name=args.collection, + vectors_config=VectorParams(size=1024, distance=Distance.COSINE), + ) + print(f"Created collection: {args.collection}") + + # Sample documents + docs = [ + {"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"}, + {"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"}, + {"text": "The team consists of 5 developers and 2 designers.", "source": "team.pdf"}, + {"text": "The main office is located in Paris, France.", "source": "info.pdf"}, + ] + + # Embed and insert with user_id + points = [] + for doc in docs: + vector = model.encode(doc["text"], normalize_embeddings=True).tolist() + points.append(PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "text": doc["text"], + "source": doc["source"], + "user_id": user_id, # ← scoped to this user + }, + )) + + client.upsert(collection_name=args.collection, points=points) + print(f"Ingested {len(points)} documents for user {user_id}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 07a3244..e37db92 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -79,7 +79,7 @@ def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: # Option B: shared collection with user_id filter (recommended) collection = self.default_collection - filters = None #{"user_id": user_id} + filters = {"user_id": user_id} return collection, filters @@ -131,7 +131,7 @@ def _build_prompt( ) -> tuple[str, str]: parts = [ - "You are a helpful assistant. Answer based on the provided context.", + "You are a robot speaking to a user. Answer based on the provided context.", "If the context is insufficient, say so clearly.", ] if preferences.get("language"): @@ -150,7 +150,7 @@ def _build_prompt( user_prompt = ( "No relevant context was found.\n\n" f"Question: {question}\n\n" - "Answer based on general knowledge and mention no documents were found." + "Answer based on general knowledge." ) else: context_parts = [] @@ -163,7 +163,7 @@ def _build_prompt( user_prompt = ( f"Context:\n{context_block}\n\n" f"Question: {question}\n\n" - "Answer based on the context above. Cite sources by number." + "Answer based on the context above. Don't speak about the sources, just use them to answer the question." ) return system_prompt, user_prompt From 52b2cc567f51ce3861bcfa8f6256c724a3f28a9f Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Mon, 11 May 2026 15:43:54 +0100 Subject: [PATCH 4/5] clean(id): clean code --- src/modules/rag/ingestion.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index d517c67..5c458d1 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -36,7 +36,6 @@ def main(): client = QdrantClient(url=args.qdrant_url) model = SentenceTransformer("BAAI/bge-large-en-v1.5") - # Create collection if it doesn't exist collections = [c.name for c in client.get_collections().collections] if args.collection not in collections: client.create_collection( @@ -45,7 +44,6 @@ def main(): ) print(f"Created collection: {args.collection}") - # Sample documents docs = [ {"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"}, {"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"}, @@ -53,7 +51,6 @@ def main(): {"text": "The main office is located in Paris, France.", "source": "info.pdf"}, ] - # Embed and insert with user_id points = [] for doc in docs: vector = model.encode(doc["text"], normalize_embeddings=True).tolist() @@ -63,7 +60,7 @@ def main(): payload={ "text": doc["text"], "source": doc["source"], - "user_id": user_id, # ← scoped to this user + "user_id": user_id, }, )) From 529297ae7a592f907d00ec3695206efd32cc21e3 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Fri, 15 May 2026 13:45:35 +0100 Subject: [PATCH 5/5] wip(pr): add a module with ids and generate a rag class with module with id and module with handle. Make some refacto : user_id -> _user_id and handle -> _handle --- src/client.py | 18 ++++---- src/core/huri.py | 23 ++++------ src/core/module.py | 15 +++++-- src/modules/factory.py | 31 ++++++++++---- src/modules/rag/rag.py | 67 +++++++++++++----------------- src/modules/reasoning/embedding.py | 6 +-- 6 files changed, 83 insertions(+), 77 deletions(-) diff --git a/src/client.py b/src/client.py index 2240709..63f490a 100644 --- a/src/client.py +++ b/src/client.py @@ -23,9 +23,9 @@ def load_user_id() -> str | None: return None -def save_user_id(user_id: str): +def save_user_id(_user_id: str): with open(USER_ID_FILE, "w") as f: - f.write(user_id) + f.write(_user_id) def load_client_config(path: str) -> ClientConfig: with open(path) as f: @@ -54,18 +54,18 @@ async def stream_audio(): print("Connected to server") payload = asdict(config) - user_id = load_user_id() - if user_id: - payload["user_id"] = user_id - print(f"Reconnecting with user_id: {user_id}") + _user_id = load_user_id() + if _user_id: + payload["_user_id"] = _user_id + print(f"Reconnecting with _user_id: {_user_id}") await ws.send(json.dumps(payload)) init_msg = json.loads(await ws.recv()) if init_msg.get("type") == "session_init": - user_id = init_msg["user_id"] - save_user_id(user_id) - print(f"Session started with user_id: {user_id}") + _user_id = init_msg["_user_id"] + save_user_id(_user_id) + print(f"Session started with _user_id: {_user_id}") async def receive(ws: websockets.ClientConnection): while True: diff --git a/src/core/huri.py b/src/core/huri.py index a07c4fe..4c4207a 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -7,7 +7,6 @@ from src.modules.factory import Module, ModuleFactory from src.modules.utils.sender import Sender - from .app import app from .dataclasses.config import ClientConfig from .session import Session @@ -24,7 +23,6 @@ def __init__( self.factory = ModuleFactory(handles) for name, module_cls in modules.items(): self.factory.register(name, module_cls) - self.clients: Dict[str, Session] = {} @app.websocket("/session") @@ -33,24 +31,20 @@ async def run_session(self, ws: WebSocket): client_config_raw: Dict = await ws.receive_json() client_config = ClientConfig.from_dict(client_config_raw) - user_id = client_config_raw.get("user_id") or str(uuid.uuid4()) - await ws.send_json({"type": "session_init", "user_id": user_id}) - - if "rag" in client_config.modules: - if client_config.modules["rag"].args is None: - client_config.modules["rag"].args = {} - client_config.modules["rag"].args["user_id"] = user_id + _user_id = client_config_raw.get("_user_id") or str(uuid.uuid4()) senders: List[Module] = [ Sender(ws, topic) for topic in client_config.topic_list ] modules: List[Module] = ( - self.factory.create_from_config(client_config.modules) + senders + self.factory.create_from_config(_user_id, client_config.modules) + senders ) + await ws.send_json({"type": "session_init", "_user_id": _user_id}) + session_id = str(uuid.uuid4()) self.clients[session_id] = Session(modules) - print("Client registered successfully with config:", client_config) + print(f"Client registered with _user_id={_user_id}, config: {client_config}") async def receive_loop(session: Session, ws: WebSocket): try: @@ -59,9 +53,8 @@ async def receive_loop(session: Session, ws: WebSocket): if "bytes" in msg: chunk = msg["bytes"] await session.publish("chunk", chunk) - # else: - # data = msg - # await session.publish(data["type"], data["data"]) except (WebSocketDisconnect, RuntimeError): - print(f"Client disconnected") + print(f"Client {_user_id} disconnected") + await receive_loop(self.clients[session_id], ws) + del self.clients[session_id] \ No newline at end of file diff --git a/src/core/module.py b/src/core/module.py index 12543cf..e57dba4 100644 --- a/src/core/module.py +++ b/src/core/module.py @@ -14,6 +14,15 @@ async def process(self, _) -> Optional[Any]: class ModuleWithHandle(Module): _handle_cls: Type[Any] - def __init__(self, handle: handle.DeploymentHandle): - super().__init__() - self.handle = handle + def __init__(self, _handle: handle.DeploymentHandle = None, **kwargs): + super().__init__(**kwargs) + self._handle = _handle + +class ModuleWithId(Module): + def __init__(self, _user_id: str, **kwargs): + super().__init__(**kwargs) + self._user_id = _user_id + + def get_user_context(self) -> dict: + """Override in subclasses to provide user-specific context.""" + return {"_user_id": self._user_id} diff --git a/src/modules/factory.py b/src/modules/factory.py index 39a1a83..d252ab1 100644 --- a/src/modules/factory.py +++ b/src/modules/factory.py @@ -1,7 +1,8 @@ +from os import name from typing import Any, Dict, List, Mapping, Type from src.core.dataclasses.config import ModuleConfig -from src.core.module import Module, ModuleWithHandle, handle +from src.core.module import Module, ModuleWithHandle, handle, ModuleWithId class ModuleFactory: @@ -9,6 +10,7 @@ def __init__(self, handles): self._registry: Dict[str, Type[Module]] = {} self._handles = handles + def register(self, name: str, module_cls: Type[Module]) -> None: if not issubclass(module_cls, Module): raise TypeError(f"{module_cls} must inherit from Module") @@ -19,29 +21,40 @@ def register(self, name: str, module_cls: Type[Module]) -> None: ) self._registry[name] = module_cls - def create(self, name: str, args: Mapping[str, Any] | None = None) -> Module: + + def create( + self, + _user_id: str, + name: str, + args: Mapping[str, Any] | None = None + ) -> Module: + if name not in self._registry: raise ValueError(f"Unknown module '{name}'") + module_cls = self._registry[name] - if args is None: - args = {} + kwargs = dict(args or {}) + if issubclass(module_cls, ModuleWithHandle): if name not in self._handles: raise RuntimeError( f"Handles not bound for '{name}'. Check your config first." ) - return module_cls(handle=self._handles[name], **args) - return module_cls(**args) + kwargs["_handle"] = self._handles[name] + + if issubclass(module_cls, ModuleWithId): + kwargs["_user_id"] = _user_id + + return module_cls(**kwargs) def create_from_config( - self, module_configs: Dict[str, ModuleConfig] + self, _user_id: str, module_configs: Dict[str, ModuleConfig] ) -> List[Module]: modules: List[Module] = [] for _, module_config in module_configs.items(): - modules.append(self.create(module_config.name, module_config.args)) - + modules.append(self.create(_user_id, module_config.name, module_config.args)) if modules == []: raise Exception diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index e37db92..ac594d8 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -1,9 +1,8 @@ from typing import Any, Optional from dataclasses import dataclass, field -from ray import data, serve -from ray.serve import handle -from src.core.module import ModuleWithHandle +from ray import serve +from src.core.module import ModuleWithHandle, ModuleWithId, handle from qdrant_client.models import Filter, FieldCondition, MatchValue from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient @@ -15,7 +14,7 @@ @dataclass class RAGQuery: """What flows from RAG module to RAGHandle.""" - user_id: str + _user_id: str question: str preferences: dict = field(default_factory=dict) # preferences can include: language, tone, response_format, max_length, system_prompt, extra_instructions, etc. @@ -35,7 +34,7 @@ class RAGResult: class RAGHandle: """ Stateless RAG processor. Knows nothing about sessions. - Receives a user_id + question, uses user_id to find the right + Receives a _user_id + question, uses _user_id to find the right collection/data in the vector DB, runs embed -> search -> LLM. """ @@ -62,24 +61,24 @@ def __init__( self.llm_model = llm_model self.llm_api_key = llm_api_key - def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: + def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: """ - Given a user_id, decide which collection to search + Given a _user_id, decide which collection to search and which filters to apply. Options (pick what fits your data model): - A) One collection per user: collection = f"user_{user_id}" - B) Shared collection, filter by user_id in payload + A) One collection per user: collection = f"user_{_user_id}" + B) Shared collection, filter by _user_id in payload C) Lookup in a DB to find the user's config """ # Option A: separate collection per user - # collection = f"user_{user_id}" + # collection = f"user_{_user_id}" # filters = None - # Option B: shared collection with user_id filter (recommended) + # Option B: shared collection with _user_id filter (recommended) collection = self.default_collection - filters = {"user_id": user_id} + filters = {"_user_id": _user_id} return collection, filters @@ -96,7 +95,6 @@ def _search( filters: dict | None = None, ) -> list[dict]: - # Build qdrant filter from user context qdrant_filter = None if filters: conditions = [ @@ -227,11 +225,11 @@ async def _call_ollama(self, messages: list, max_tokens: int) -> str: async def process(self, query: RAGQuery) -> RAGResult: """ Main entry point. Called by the RAG module. - Uses user_id to determine which collection / filters to use. + Uses _user_id to determine which collection / filters to use. """ print(f"[RAG] Question: {query.question}") - collection, filters = self._resolve_user_context(query.user_id) + collection, filters = self._resolve_user_context(query._user_id) query_vector = self._embed(query.question) chunks = self._search(query_vector, collection, filters) @@ -256,31 +254,23 @@ async def process(self, query: RAGQuery) -> RAGResult: ) -class RAG(ModuleWithHandle): - """ - Session-bound module. HuRI instantiates this when a client connects, - passing the user_id from the WebSocket config. - - Listens to "question" events. - Forwards question + user_id to the detached RAGHandle. - Emits "rag_response" event with the answer. - """ +class RAG(ModuleWithHandle, ModuleWithId): _handle_cls = RAGHandle input_type = "question" output_type = "rag_response" def __init__( self, - handle: handle.DeploymentHandle[RAGHandle], - user_id: str = "", - language: str = "en", - tone: str = "formal", - response_format: str = "paragraph", - max_length: int = 1024, - extra_instructions: str = "", + _handle=None, + _user_id="", + language="en", + tone="formal", + response_format="paragraph", + max_length=1024, + extra_instructions="", + **kwargs, ): - super().__init__(handle) - self.user_id = user_id + super().__init__(_handle=_handle, _user_id=_user_id, **kwargs) self.preferences = { "language": language, "tone": tone, @@ -288,23 +278,24 @@ def __init__( "max_length": max_length, "extra_instructions": extra_instructions, } - + async def process(self, data) -> Optional[Any]: """ Called when a "question" event arrives through the event bus. - Packages user_id + question, sends to the stateless RAGHandle. + Packages _user_id + question, sends to the stateless RAGHandle. """ question_text = data.text if hasattr(data, 'text') else str(data) query = RAGQuery( - user_id=self.user_id if self.user_id else "anonymous", + _user_id=self._user_id if self._user_id else "anonymous", question=question_text, preferences=self.preferences, ) - result: RAGResult = await self.handle.process.remote(query) + result: RAGResult = await self._handle.process.remote(query) return result - + + def update_preferences(self, new_preferences: dict): """Client can update preferences mid-session via the event bus.""" self.preferences.update(new_preferences) diff --git a/src/modules/reasoning/embedding.py b/src/modules/reasoning/embedding.py index ff14897..b6139d9 100644 --- a/src/modules/reasoning/embedding.py +++ b/src/modules/reasoning/embedding.py @@ -29,13 +29,13 @@ class EMB(ModuleWithHandle): input_type = "toembed" output_type = "embedded" - def __init__(self, handle: handle.DeploymentHandle[EMBHandle]): - super().__init__(handle) + def __init__(self, _handle: handle.DeploymentHandle[EMBHandle]): + super().__init__(_handle) self.database = "" async def process(self, data_to_embed: np.ndarray) -> Optional[Any]: - embedded = await self.handle.embbed.remote(data_to_embed) + embedded = await self._handle.embbed.remote(data_to_embed) # TODO write embedding return embedded