diff --git a/src/client.py b/src/client.py index ca29146..63f490a 100644 --- a/src/client.py +++ b/src/client.py @@ -1,6 +1,7 @@ import argparse import asyncio import json +import os from dataclasses import asdict from typing import Dict @@ -12,15 +13,29 @@ from src.core.dataclasses.config import ClientConfig +USER_ID_FILE = os.path.expanduser("~/.huri_user_id") + + +def load_user_id() -> str | None: + if os.path.exists(USER_ID_FILE): + with open(USER_ID_FILE) as f: + return f.read().strip() + return None + + +def save_user_id(_user_id: str): + with open(USER_ID_FILE, "w") as f: + f.write(_user_id) + def load_client_config(path: str) -> ClientConfig: with open(path) as f: dict_config = OmegaConf.load(f) - raw_resolved = OmegaConf.to_container(dict_config, resolve=True) + raw_resolved = OmegaConf.to_container(dict_config, resolve=True) - if not isinstance(raw_resolved, Dict): - raise RuntimeError("error yaml does not output a dict") + if not isinstance(raw_resolved, Dict): + raise RuntimeError("error yaml does not output a dict") - return ClientConfig.from_dict(raw_resolved) + return ClientConfig.from_dict(raw_resolved) async def stream_audio(): @@ -38,7 +53,19 @@ async def stream_audio(): async with websockets.connect(config.huri_url) as ws: print("Connected to server") - await ws.send(json.dumps(asdict(config))) + payload = asdict(config) + _user_id = load_user_id() + if _user_id: + payload["_user_id"] = _user_id + print(f"Reconnecting with _user_id: {_user_id}") + + await ws.send(json.dumps(payload)) + + init_msg = json.loads(await ws.recv()) + if init_msg.get("type") == "session_init": + _user_id = init_msg["_user_id"] + save_user_id(_user_id) + print(f"Session started with _user_id: {_user_id}") async def receive(ws: websockets.ClientConnection): while True: diff --git a/src/core/huri.py b/src/core/huri.py index 6d6d747..4c4207a 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -7,7 +7,6 @@ from src.modules.factory import Module, ModuleFactory from src.modules.utils.sender import Sender - from .app import app from .dataclasses.config import ClientConfig from .session import Session @@ -24,29 +23,28 @@ def __init__( self.factory = ModuleFactory(handles) for name, module_cls in modules.items(): self.factory.register(name, module_cls) - self.clients: Dict[str, Session] = {} @app.websocket("/session") async def run_session(self, ws: WebSocket): await ws.accept() - client_config_raw: Dict = await ws.receive_json() - client_config = ClientConfig.from_dict(client_config_raw) + _user_id = client_config_raw.get("_user_id") or str(uuid.uuid4()) + senders: List[Module] = [ Sender(ws, topic) for topic in client_config.topic_list ] modules: List[Module] = ( - self.factory.create_from_config(client_config.modules) + senders + self.factory.create_from_config(_user_id, client_config.modules) + senders ) - session_id = str(uuid.uuid4()) + await ws.send_json({"type": "session_init", "_user_id": _user_id}) + session_id = str(uuid.uuid4()) self.clients[session_id] = Session(modules) - - print("Client registered successfully with config:", client_config) + print(f"Client registered with _user_id={_user_id}, config: {client_config}") async def receive_loop(session: Session, ws: WebSocket): try: @@ -55,9 +53,8 @@ async def receive_loop(session: Session, ws: WebSocket): if "bytes" in msg: chunk = msg["bytes"] await session.publish("chunk", chunk) - # else: - # data = msg - # await session.publish(data["type"], data["data"]) except (WebSocketDisconnect, RuntimeError): - print(f"Client disconnected") + print(f"Client {_user_id} disconnected") + await receive_loop(self.clients[session_id], ws) + del self.clients[session_id] \ No newline at end of file diff --git a/src/core/module.py b/src/core/module.py index 12543cf..e57dba4 100644 --- a/src/core/module.py +++ b/src/core/module.py @@ -14,6 +14,15 @@ async def process(self, _) -> Optional[Any]: class ModuleWithHandle(Module): _handle_cls: Type[Any] - def __init__(self, handle: handle.DeploymentHandle): - super().__init__() - self.handle = handle + def __init__(self, _handle: handle.DeploymentHandle = None, **kwargs): + super().__init__(**kwargs) + self._handle = _handle + +class ModuleWithId(Module): + def __init__(self, _user_id: str, **kwargs): + super().__init__(**kwargs) + self._user_id = _user_id + + def get_user_context(self) -> dict: + """Override in subclasses to provide user-specific context.""" + return {"_user_id": self._user_id} diff --git a/src/modules/factory.py b/src/modules/factory.py index 39a1a83..d252ab1 100644 --- a/src/modules/factory.py +++ b/src/modules/factory.py @@ -1,7 +1,8 @@ +from os import name from typing import Any, Dict, List, Mapping, Type from src.core.dataclasses.config import ModuleConfig -from src.core.module import Module, ModuleWithHandle, handle +from src.core.module import Module, ModuleWithHandle, handle, ModuleWithId class ModuleFactory: @@ -9,6 +10,7 @@ def __init__(self, handles): self._registry: Dict[str, Type[Module]] = {} self._handles = handles + def register(self, name: str, module_cls: Type[Module]) -> None: if not issubclass(module_cls, Module): raise TypeError(f"{module_cls} must inherit from Module") @@ -19,29 +21,40 @@ def register(self, name: str, module_cls: Type[Module]) -> None: ) self._registry[name] = module_cls - def create(self, name: str, args: Mapping[str, Any] | None = None) -> Module: + + def create( + self, + _user_id: str, + name: str, + args: Mapping[str, Any] | None = None + ) -> Module: + if name not in self._registry: raise ValueError(f"Unknown module '{name}'") + module_cls = self._registry[name] - if args is None: - args = {} + kwargs = dict(args or {}) + if issubclass(module_cls, ModuleWithHandle): if name not in self._handles: raise RuntimeError( f"Handles not bound for '{name}'. Check your config first." ) - return module_cls(handle=self._handles[name], **args) - return module_cls(**args) + kwargs["_handle"] = self._handles[name] + + if issubclass(module_cls, ModuleWithId): + kwargs["_user_id"] = _user_id + + return module_cls(**kwargs) def create_from_config( - self, module_configs: Dict[str, ModuleConfig] + self, _user_id: str, module_configs: Dict[str, ModuleConfig] ) -> List[Module]: modules: List[Module] = [] for _, module_config in module_configs.items(): - modules.append(self.create(module_config.name, module_config.args)) - + modules.append(self.create(_user_id, module_config.name, module_config.args)) if modules == []: raise Exception diff --git a/src/modules/modules.py b/src/modules/modules.py index 69ebb45..7283983 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -3,9 +3,10 @@ from src.modules.speech_to_text.record_speech import MIC from src.modules.speech_to_text.speech_to_text import STT from src.modules.speech_to_text.text_aggregator import TAG +from src.modules.rag.rag import RAG from .factory import Module def get_modules() -> Dict[str, Type[Module]]: - return {"mic": MIC, "stt": STT, "tag": TAG} + return {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG} diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py new file mode 100644 index 0000000..5c458d1 --- /dev/null +++ b/src/modules/rag/ingestion.py @@ -0,0 +1,72 @@ +# ingestion.py +import argparse +import os +import uuid + +from qdrant_client import QdrantClient +from qdrant_client.models import VectorParams, Distance, PointStruct +from sentence_transformers import SentenceTransformer + +USER_ID_FILE = os.path.expanduser("~/.huri_user_id") + + +def get_user_id(provided_id: str = None) -> str: + """Use provided ID, or load from file, or generate new one.""" + if provided_id: + return provided_id + if os.path.exists(USER_ID_FILE): + with open(USER_ID_FILE) as f: + return f.read().strip() + new_id = str(uuid.uuid4()) + with open(USER_ID_FILE, "w") as f: + f.write(new_id) + return new_id + + +def main(): + parser = argparse.ArgumentParser(description="Ingest documents into Qdrant") + parser.add_argument("--user-id", type=str, default=None, help="User ID (reads from ~/.huri_user_id if not provided)") + parser.add_argument("--collection", type=str, default="documents") + parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333") + args = parser.parse_args() + + user_id = get_user_id(args.user_id) + print(f"Ingesting for user_id: {user_id}") + + client = QdrantClient(url=args.qdrant_url) + model = SentenceTransformer("BAAI/bge-large-en-v1.5") + + collections = [c.name for c in client.get_collections().collections] + if args.collection not in collections: + client.create_collection( + collection_name=args.collection, + vectors_config=VectorParams(size=1024, distance=Distance.COSINE), + ) + print(f"Created collection: {args.collection}") + + docs = [ + {"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"}, + {"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"}, + {"text": "The team consists of 5 developers and 2 designers.", "source": "team.pdf"}, + {"text": "The main office is located in Paris, France.", "source": "info.pdf"}, + ] + + points = [] + for doc in docs: + vector = model.encode(doc["text"], normalize_embeddings=True).tolist() + points.append(PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "text": doc["text"], + "source": doc["source"], + "user_id": user_id, + }, + )) + + client.upsert(collection_name=args.collection, points=points) + print(f"Ingested {len(points)} documents for user {user_id}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py new file mode 100644 index 0000000..ac594d8 --- /dev/null +++ b/src/modules/rag/rag.py @@ -0,0 +1,301 @@ +from typing import Any, Optional +from dataclasses import dataclass, field + +from ray import serve +from src.core.module import ModuleWithHandle, ModuleWithId, handle +from qdrant_client.models import Filter, FieldCondition, MatchValue +from sentence_transformers import SentenceTransformer +from qdrant_client import QdrantClient + + +import httpx + + +@dataclass +class RAGQuery: + """What flows from RAG module to RAGHandle.""" + _user_id: str + question: str + preferences: dict = field(default_factory=dict) + # preferences can include: language, tone, response_format, max_length, system_prompt, extra_instructions, etc. + + +@dataclass +class RAGResult: + """What RAGHandle returns.""" + answer: str + sources: list[dict] = field(default_factory=list) + + +@serve.deployment( + num_replicas=2, + ray_actor_options={"num_cpus": 1}, +) +class RAGHandle: + """ + Stateless RAG processor. Knows nothing about sessions. + Receives a _user_id + question, uses _user_id to find the right + collection/data in the vector DB, runs embed -> search -> LLM. + """ + + def __init__( + self, + qdrant_url: str = "http://localhost:6333", + default_collection: str = "documents", + embedding_model: str = "BAAI/bge-large-en-v1.5", + llm_provider: str = "ollama", # "vllm", "ollama", "api" + llm_url: str = "http://localhost:11434", + llm_model: str = "mistral:7b", + llm_api_key: str = "", + top_k: int = 5, + score_threshold: float = 0.5, + ): + self.embed_model = SentenceTransformer(embedding_model) + self.qdrant = QdrantClient(url=qdrant_url) + self.default_collection = default_collection + self.top_k = top_k + self.score_threshold = score_threshold + + self.llm_provider = llm_provider + self.llm_url = llm_url + self.llm_model = llm_model + self.llm_api_key = llm_api_key + + def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: + """ + Given a _user_id, decide which collection to search + and which filters to apply. + + Options (pick what fits your data model): + A) One collection per user: collection = f"user_{_user_id}" + B) Shared collection, filter by _user_id in payload + C) Lookup in a DB to find the user's config + """ + + # Option A: separate collection per user + # collection = f"user_{_user_id}" + # filters = None + + # Option B: shared collection with _user_id filter (recommended) + collection = self.default_collection + filters = {"_user_id": _user_id} + + return collection, filters + + + def _embed(self, text) -> list[float]: + return self.embed_model.encode(str(text), normalize_embeddings=True).tolist() + + + + def _search( + self, + query_vector: list[float], + collection: str, + filters: dict | None = None, + ) -> list[dict]: + + qdrant_filter = None + if filters: + conditions = [ + FieldCondition(key=k, match=MatchValue(value=v)) + for k, v in filters.items() + ] + qdrant_filter = Filter(must=conditions) + + results = self.qdrant.query_points( + collection_name=collection, + query=query_vector, + query_filter=qdrant_filter, + limit=self.top_k, + score_threshold=self.score_threshold, + ).points + + return [ + { + "text": point.payload.get("text", ""), + "score": point.score, + "metadata": {k: v for k, v in point.payload.items() if k != "text"}, + } + for point in results + ] + + + def _build_prompt( + self, + question: str, + chunks: list[dict], + preferences: dict, + ) -> tuple[str, str]: + + parts = [ + "You are a robot speaking to a user. Answer based on the provided context.", + "If the context is insufficient, say so clearly.", + ] + if preferences.get("language"): + parts.append(f"Always respond in {preferences['language']}.") + if preferences.get("tone"): + parts.append(f"Use a {preferences['tone']} tone.") + if preferences.get("response_format") == "bullet_points": + parts.append("Format your answer as bullet points.") + elif preferences.get("response_format") == "short": + parts.append("Keep your answer to 2-3 sentences maximum.") + if preferences.get("extra_instructions"): + parts.append(preferences["extra_instructions"]) + system_prompt = " ".join(parts) + + if not chunks: + user_prompt = ( + "No relevant context was found.\n\n" + f"Question: {question}\n\n" + "Answer based on general knowledge." + ) + else: + context_parts = [] + for i, chunk in enumerate(chunks, 1): + source = chunk["metadata"].get("source", "unknown") + context_parts.append( + f"[{i}] (source: {source}, score: {chunk['score']:.2f})\n{chunk['text']}" + ) + context_block = "\n\n".join(context_parts) + user_prompt = ( + f"Context:\n{context_block}\n\n" + f"Question: {question}\n\n" + "Answer based on the context above. Don't speak about the sources, just use them to answer the question." + ) + + return system_prompt, user_prompt + + + async def _llm_generate( + self, + system_prompt: str, + user_prompt: str, + preferences: dict, + ) -> str: + max_tokens = preferences.get("max_length", 1024) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + if self.llm_provider == "vllm": + return await self._call_openai_compatible( + f"{self.llm_url}/v1/chat/completions", messages, max_tokens + ) + elif self.llm_provider == "ollama": + return await self._call_ollama(messages, max_tokens) + elif self.llm_provider == "api": + return await self._call_openai_compatible( + f"{self.llm_url}/v1/chat/completions", messages, max_tokens, self.llm_api_key + ) + else: + raise ValueError(f"Unknown llm_provider: {self.llm_provider}") + + + async def _call_openai_compatible( + self, url: str, messages: list, max_tokens: int, api_key: str = "" + ) -> str: + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(url, headers=headers, json={ + "model": self.llm_model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0.1, + }) + resp.raise_for_status() + return resp.json()["choices"][0]["message"]["content"] + + + async def _call_ollama(self, messages: list, max_tokens: int) -> str: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{self.llm_url}/api/chat", json={ + "model": self.llm_model, + "messages": messages, + "stream": False, + "options": {"num_predict": max_tokens, "temperature": 0.1}, + }) + resp.raise_for_status() + return resp.json()["message"]["content"] + + + async def process(self, query: RAGQuery) -> RAGResult: + """ + Main entry point. Called by the RAG module. + Uses _user_id to determine which collection / filters to use. + """ + + print(f"[RAG] Question: {query.question}") + collection, filters = self._resolve_user_context(query._user_id) + query_vector = self._embed(query.question) + chunks = self._search(query_vector, collection, filters) + + + print(f"[RAG] Found {len(chunks)} chunks") + for c in chunks: + print(f" - score: {c['score']:.2f} | {c['text'][:100]}...") + + system_prompt, user_prompt = self._build_prompt( + query.question, chunks, query.preferences + ) + print(f"[RAG] System prompt: {system_prompt[:200]}...") + answer = await self._llm_generate(system_prompt, user_prompt, query.preferences) + print(f"[RAG] Answer: {answer}") + + return RAGResult( + answer=answer, + sources=[ + {"text": c["text"], "score": c["score"], "metadata": c["metadata"]} + for c in chunks + ], + ) + + +class RAG(ModuleWithHandle, ModuleWithId): + _handle_cls = RAGHandle + input_type = "question" + output_type = "rag_response" + + def __init__( + self, + _handle=None, + _user_id="", + language="en", + tone="formal", + response_format="paragraph", + max_length=1024, + extra_instructions="", + **kwargs, + ): + super().__init__(_handle=_handle, _user_id=_user_id, **kwargs) + self.preferences = { + "language": language, + "tone": tone, + "response_format": response_format, + "max_length": max_length, + "extra_instructions": extra_instructions, + } + + async def process(self, data) -> Optional[Any]: + """ + Called when a "question" event arrives through the event bus. + Packages _user_id + question, sends to the stateless RAGHandle. + """ + question_text = data.text if hasattr(data, 'text') else str(data) + + query = RAGQuery( + _user_id=self._user_id if self._user_id else "anonymous", + question=question_text, + preferences=self.preferences, + ) + + result: RAGResult = await self._handle.process.remote(query) + return result + + + def update_preferences(self, new_preferences: dict): + """Client can update preferences mid-session via the event bus.""" + self.preferences.update(new_preferences) diff --git a/src/modules/reasoning/embedding.py b/src/modules/reasoning/embedding.py index ff14897..b6139d9 100644 --- a/src/modules/reasoning/embedding.py +++ b/src/modules/reasoning/embedding.py @@ -29,13 +29,13 @@ class EMB(ModuleWithHandle): input_type = "toembed" output_type = "embedded" - def __init__(self, handle: handle.DeploymentHandle[EMBHandle]): - super().__init__(handle) + def __init__(self, _handle: handle.DeploymentHandle[EMBHandle]): + super().__init__(_handle) self.database = "" async def process(self, data_to_embed: np.ndarray) -> Optional[Any]: - embedded = await self.handle.embbed.remote(data_to_embed) + embedded = await self._handle.embbed.remote(data_to_embed) # TODO write embedding return embedded