From 81826b7b4701b1ba0f91fdc0700e369713191db0 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Thu, 7 May 2026 12:06:18 +0100 Subject: [PATCH 1/9] wip(rag): V1 of a rag working with an ollama llm working with the current pipeline with a RAG + Embedding + LLM (ollama). Should work with vLLM but not tested --- src/modules/modules.py | 3 +- src/modules/rag/rag.py | 310 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 src/modules/rag/rag.py diff --git a/src/modules/modules.py b/src/modules/modules.py index 69ebb45..7283983 100644 --- a/src/modules/modules.py +++ b/src/modules/modules.py @@ -3,9 +3,10 @@ from src.modules.speech_to_text.record_speech import MIC from src.modules.speech_to_text.speech_to_text import STT from src.modules.speech_to_text.text_aggregator import TAG +from src.modules.rag.rag import RAG from .factory import Module def get_modules() -> Dict[str, Type[Module]]: - return {"mic": MIC, "stt": STT, "tag": TAG} + return {"mic": MIC, "stt": STT, "tag": TAG, "rag": RAG} diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py new file mode 100644 index 0000000..0884c76 --- /dev/null +++ b/src/modules/rag/rag.py @@ -0,0 +1,310 @@ +from typing import Any, Optional +from dataclasses import dataclass, field + +from ray import data, serve +from ray.serve import handle +from src.core.module import ModuleWithHandle +from qdrant_client.models import Filter, FieldCondition, MatchValue +from sentence_transformers import SentenceTransformer +from qdrant_client import QdrantClient + + +import httpx + + +@dataclass +class RAGQuery: + """What flows from RAG module to RAGHandle.""" + user_id: str + question: str + preferences: dict = field(default_factory=dict) + # preferences can include: language, tone, response_format, max_length, system_prompt, extra_instructions, etc. + + +@dataclass +class RAGResult: + """What RAGHandle returns.""" + answer: str + sources: list[dict] = field(default_factory=list) + + +@serve.deployment( + num_replicas=2, + ray_actor_options={"num_cpus": 1}, +) +class RAGHandle: + """ + Stateless RAG processor. Knows nothing about sessions. + Receives a user_id + question, uses user_id to find the right + collection/data in the vector DB, runs embed -> search -> LLM. + """ + + def __init__( + self, + qdrant_url: str = "http://localhost:6333", + default_collection: str = "documents", + embedding_model: str = "BAAI/bge-large-en-v1.5", + llm_provider: str = "ollama", # "vllm", "ollama", "api" + llm_url: str = "http://localhost:11434", + llm_model: str = "mistral:7b", + llm_api_key: str = "", + top_k: int = 5, + score_threshold: float = 0.5, + ): + self.embed_model = SentenceTransformer(embedding_model) + self.qdrant = QdrantClient(url=qdrant_url) + self.default_collection = default_collection + self.top_k = top_k + self.score_threshold = score_threshold + + self.llm_provider = llm_provider + self.llm_url = llm_url + self.llm_model = llm_model + self.llm_api_key = llm_api_key + + def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: + """ + Given a user_id, decide which collection to search + and which filters to apply. + + Options (pick what fits your data model): + A) One collection per user: collection = f"user_{user_id}" + B) Shared collection, filter by user_id in payload + C) Lookup in a DB to find the user's config + """ + + # Option A: separate collection per user + # collection = f"user_{user_id}" + # filters = None + + # Option B: shared collection with user_id filter (recommended) + collection = self.default_collection + filters = {"user_id": user_id} + + return collection, filters + + + def _embed(self, text) -> list[float]: + return self.embed_model.encode(str(text), normalize_embeddings=True).tolist() + + + + def _search( + self, + query_vector: list[float], + collection: str, + filters: dict | None = None, + ) -> list[dict]: + + # Build qdrant filter from user context + qdrant_filter = None + if filters: + conditions = [ + FieldCondition(key=k, match=MatchValue(value=v)) + for k, v in filters.items() + ] + qdrant_filter = Filter(must=conditions) + + results = self.qdrant.query_points( + collection_name=collection, + query=query_vector, + query_filter=qdrant_filter, + limit=self.top_k, + score_threshold=self.score_threshold, + ).points + + return [ + { + "text": point.payload.get("text", ""), + "score": point.score, + "metadata": {k: v for k, v in point.payload.items() if k != "text"}, + } + for point in results + ] + + + def _build_prompt( + self, + question: str, + chunks: list[dict], + preferences: dict, + ) -> tuple[str, str]: + + parts = [ + "You are a helpful assistant. Answer based on the provided context.", + "If the context is insufficient, say so clearly.", + ] + if preferences.get("language"): + parts.append(f"Always respond in {preferences['language']}.") + if preferences.get("tone"): + parts.append(f"Use a {preferences['tone']} tone.") + if preferences.get("response_format") == "bullet_points": + parts.append("Format your answer as bullet points.") + elif preferences.get("response_format") == "short": + parts.append("Keep your answer to 2-3 sentences maximum.") + if preferences.get("extra_instructions"): + parts.append(preferences["extra_instructions"]) + system_prompt = " ".join(parts) + + if not chunks: + user_prompt = ( + "No relevant context was found.\n\n" + f"Question: {question}\n\n" + "Answer based on general knowledge and mention no documents were found." + ) + else: + context_parts = [] + for i, chunk in enumerate(chunks, 1): + source = chunk["metadata"].get("source", "unknown") + context_parts.append( + f"[{i}] (source: {source}, score: {chunk['score']:.2f})\n{chunk['text']}" + ) + context_block = "\n\n".join(context_parts) + user_prompt = ( + f"Context:\n{context_block}\n\n" + f"Question: {question}\n\n" + "Answer based on the context above. Cite sources by number." + ) + + return system_prompt, user_prompt + + + async def _llm_generate( + self, + system_prompt: str, + user_prompt: str, + preferences: dict, + ) -> str: + max_tokens = preferences.get("max_length", 1024) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + if self.llm_provider == "vllm": + return await self._call_openai_compatible( + f"{self.llm_url}/v1/chat/completions", messages, max_tokens + ) + elif self.llm_provider == "ollama": + return await self._call_ollama(messages, max_tokens) + elif self.llm_provider == "api": + return await self._call_openai_compatible( + f"{self.llm_url}/v1/chat/completions", messages, max_tokens, self.llm_api_key + ) + else: + raise ValueError(f"Unknown llm_provider: {self.llm_provider}") + + + async def _call_openai_compatible( + self, url: str, messages: list, max_tokens: int, api_key: str = "" + ) -> str: + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(url, headers=headers, json={ + "model": self.llm_model, + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0.1, + }) + resp.raise_for_status() + return resp.json()["choices"][0]["message"]["content"] + + + async def _call_ollama(self, messages: list, max_tokens: int) -> str: + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post(f"{self.llm_url}/api/chat", json={ + "model": self.llm_model, + "messages": messages, + "stream": False, + "options": {"num_predict": max_tokens, "temperature": 0.1}, + }) + resp.raise_for_status() + return resp.json()["message"]["content"] + + + async def process(self, query: RAGQuery) -> RAGResult: + """ + Main entry point. Called by the RAG module. + Uses user_id to determine which collection / filters to use. + """ + + print(f"[RAG] Question: {query.question}") + collection, filters = self._resolve_user_context(query.user_id) + query_vector = self._embed(query.question) + chunks = self._search(query_vector, collection, filters) + + + print(f"[RAG] Found {len(chunks)} chunks") + for c in chunks: + print(f" - score: {c['score']:.2f} | {c['text'][:100]}...") + + system_prompt, user_prompt = self._build_prompt( + query.question, chunks, query.preferences + ) + print(f"[RAG] System prompt: {system_prompt[:200]}...") + answer = await self._llm_generate(system_prompt, user_prompt, query.preferences) + print(f"[RAG] Answer: {answer}") + + return RAGResult( + answer=answer, + sources=[ + {"text": c["text"], "score": c["score"], "metadata": c["metadata"]} + for c in chunks + ], + ) + + +class RAG(ModuleWithHandle): + """ + Session-bound module. HuRI instantiates this when a client connects, + passing the user_id from the WebSocket config. + + Listens to "question" events. + Forwards question + user_id to the detached RAGHandle. + Emits "rag_response" event with the answer. + """ + _handle_cls = RAGHandle + input_type = "question" + output_type = "rag_response" + + def __init__( + self, + handle: handle.DeploymentHandle[RAGHandle], + user_id: str = "", + language: str = "en", + tone: str = "formal", + response_format: str = "paragraph", + max_length: int = 1024, + extra_instructions: str = "", + ): + super().__init__(handle) + self.user_id = user_id + self.preferences = { + "language": language, + "tone": tone, + "response_format": response_format, + "max_length": max_length, + "extra_instructions": extra_instructions, + } + + async def process(self, data) -> Optional[Any]: + """ + Called when a "question" event arrives through the event bus. + Packages user_id + question, sends to the stateless RAGHandle. + """ + question_text = data.text if hasattr(data, 'text') else str(data) + + query = RAGQuery( + user_id=self.user_id if self.user_id else "anonymous", + question=question_text, + preferences=self.preferences, + ) + + result: RAGResult = await self.handle.process.remote(query) + return result + + def update_preferences(self, new_preferences: dict): + """Client can update preferences mid-session via the event bus.""" + self.preferences.update(new_preferences) From 87588fbd4f3093afe2556b4abc26fee2fffb1ef2 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Thu, 7 May 2026 12:17:18 +0100 Subject: [PATCH 2/9] wip(rag): set the filter at None to be able to restrieve collections without a user_id --- src/modules/rag/rag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 0884c76..07a3244 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -79,7 +79,7 @@ def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: # Option B: shared collection with user_id filter (recommended) collection = self.default_collection - filters = {"user_id": user_id} + filters = None #{"user_id": user_id} return collection, filters From 9b7dadea6f372792d867071451a63807d9513bb6 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Mon, 11 May 2026 15:39:53 +0100 Subject: [PATCH 3/9] feat(id): add ids to make it work with the rag system + an ingestion system --- src/client.py | 37 +++++++++++++++--- src/core/huri.py | 12 ++++-- src/modules/rag/ingestion.py | 75 ++++++++++++++++++++++++++++++++++++ src/modules/rag/rag.py | 8 ++-- 4 files changed, 119 insertions(+), 13 deletions(-) create mode 100644 src/modules/rag/ingestion.py diff --git a/src/client.py b/src/client.py index ca29146..2240709 100644 --- a/src/client.py +++ b/src/client.py @@ -1,6 +1,7 @@ import argparse import asyncio import json +import os from dataclasses import asdict from typing import Dict @@ -12,15 +13,29 @@ from src.core.dataclasses.config import ClientConfig +USER_ID_FILE = os.path.expanduser("~/.huri_user_id") + + +def load_user_id() -> str | None: + if os.path.exists(USER_ID_FILE): + with open(USER_ID_FILE) as f: + return f.read().strip() + return None + + +def save_user_id(user_id: str): + with open(USER_ID_FILE, "w") as f: + f.write(user_id) + def load_client_config(path: str) -> ClientConfig: with open(path) as f: dict_config = OmegaConf.load(f) - raw_resolved = OmegaConf.to_container(dict_config, resolve=True) + raw_resolved = OmegaConf.to_container(dict_config, resolve=True) - if not isinstance(raw_resolved, Dict): - raise RuntimeError("error yaml does not output a dict") + if not isinstance(raw_resolved, Dict): + raise RuntimeError("error yaml does not output a dict") - return ClientConfig.from_dict(raw_resolved) + return ClientConfig.from_dict(raw_resolved) async def stream_audio(): @@ -38,7 +53,19 @@ async def stream_audio(): async with websockets.connect(config.huri_url) as ws: print("Connected to server") - await ws.send(json.dumps(asdict(config))) + payload = asdict(config) + user_id = load_user_id() + if user_id: + payload["user_id"] = user_id + print(f"Reconnecting with user_id: {user_id}") + + await ws.send(json.dumps(payload)) + + init_msg = json.loads(await ws.recv()) + if init_msg.get("type") == "session_init": + user_id = init_msg["user_id"] + save_user_id(user_id) + print(f"Session started with user_id: {user_id}") async def receive(ws: websockets.ClientConnection): while True: diff --git a/src/core/huri.py b/src/core/huri.py index 6d6d747..a07c4fe 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -30,11 +30,17 @@ def __init__( @app.websocket("/session") async def run_session(self, ws: WebSocket): await ws.accept() - client_config_raw: Dict = await ws.receive_json() - client_config = ClientConfig.from_dict(client_config_raw) + user_id = client_config_raw.get("user_id") or str(uuid.uuid4()) + await ws.send_json({"type": "session_init", "user_id": user_id}) + + if "rag" in client_config.modules: + if client_config.modules["rag"].args is None: + client_config.modules["rag"].args = {} + client_config.modules["rag"].args["user_id"] = user_id + senders: List[Module] = [ Sender(ws, topic) for topic in client_config.topic_list ] @@ -43,9 +49,7 @@ async def run_session(self, ws: WebSocket): ) session_id = str(uuid.uuid4()) - self.clients[session_id] = Session(modules) - print("Client registered successfully with config:", client_config) async def receive_loop(session: Session, ws: WebSocket): diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py new file mode 100644 index 0000000..d517c67 --- /dev/null +++ b/src/modules/rag/ingestion.py @@ -0,0 +1,75 @@ +# ingestion.py +import argparse +import os +import uuid + +from qdrant_client import QdrantClient +from qdrant_client.models import VectorParams, Distance, PointStruct +from sentence_transformers import SentenceTransformer + +USER_ID_FILE = os.path.expanduser("~/.huri_user_id") + + +def get_user_id(provided_id: str = None) -> str: + """Use provided ID, or load from file, or generate new one.""" + if provided_id: + return provided_id + if os.path.exists(USER_ID_FILE): + with open(USER_ID_FILE) as f: + return f.read().strip() + new_id = str(uuid.uuid4()) + with open(USER_ID_FILE, "w") as f: + f.write(new_id) + return new_id + + +def main(): + parser = argparse.ArgumentParser(description="Ingest documents into Qdrant") + parser.add_argument("--user-id", type=str, default=None, help="User ID (reads from ~/.huri_user_id if not provided)") + parser.add_argument("--collection", type=str, default="documents") + parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333") + args = parser.parse_args() + + user_id = get_user_id(args.user_id) + print(f"Ingesting for user_id: {user_id}") + + client = QdrantClient(url=args.qdrant_url) + model = SentenceTransformer("BAAI/bge-large-en-v1.5") + + # Create collection if it doesn't exist + collections = [c.name for c in client.get_collections().collections] + if args.collection not in collections: + client.create_collection( + collection_name=args.collection, + vectors_config=VectorParams(size=1024, distance=Distance.COSINE), + ) + print(f"Created collection: {args.collection}") + + # Sample documents + docs = [ + {"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"}, + {"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"}, + {"text": "The team consists of 5 developers and 2 designers.", "source": "team.pdf"}, + {"text": "The main office is located in Paris, France.", "source": "info.pdf"}, + ] + + # Embed and insert with user_id + points = [] + for doc in docs: + vector = model.encode(doc["text"], normalize_embeddings=True).tolist() + points.append(PointStruct( + id=str(uuid.uuid4()), + vector=vector, + payload={ + "text": doc["text"], + "source": doc["source"], + "user_id": user_id, # ← scoped to this user + }, + )) + + client.upsert(collection_name=args.collection, points=points) + print(f"Ingested {len(points)} documents for user {user_id}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index 07a3244..e37db92 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -79,7 +79,7 @@ def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: # Option B: shared collection with user_id filter (recommended) collection = self.default_collection - filters = None #{"user_id": user_id} + filters = {"user_id": user_id} return collection, filters @@ -131,7 +131,7 @@ def _build_prompt( ) -> tuple[str, str]: parts = [ - "You are a helpful assistant. Answer based on the provided context.", + "You are a robot speaking to a user. Answer based on the provided context.", "If the context is insufficient, say so clearly.", ] if preferences.get("language"): @@ -150,7 +150,7 @@ def _build_prompt( user_prompt = ( "No relevant context was found.\n\n" f"Question: {question}\n\n" - "Answer based on general knowledge and mention no documents were found." + "Answer based on general knowledge." ) else: context_parts = [] @@ -163,7 +163,7 @@ def _build_prompt( user_prompt = ( f"Context:\n{context_block}\n\n" f"Question: {question}\n\n" - "Answer based on the context above. Cite sources by number." + "Answer based on the context above. Don't speak about the sources, just use them to answer the question." ) return system_prompt, user_prompt From 52b2cc567f51ce3861bcfa8f6256c724a3f28a9f Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Mon, 11 May 2026 15:43:54 +0100 Subject: [PATCH 4/9] clean(id): clean code --- src/modules/rag/ingestion.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index d517c67..5c458d1 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -36,7 +36,6 @@ def main(): client = QdrantClient(url=args.qdrant_url) model = SentenceTransformer("BAAI/bge-large-en-v1.5") - # Create collection if it doesn't exist collections = [c.name for c in client.get_collections().collections] if args.collection not in collections: client.create_collection( @@ -45,7 +44,6 @@ def main(): ) print(f"Created collection: {args.collection}") - # Sample documents docs = [ {"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"}, {"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"}, @@ -53,7 +51,6 @@ def main(): {"text": "The main office is located in Paris, France.", "source": "info.pdf"}, ] - # Embed and insert with user_id points = [] for doc in docs: vector = model.encode(doc["text"], normalize_embeddings=True).tolist() @@ -63,7 +60,7 @@ def main(): payload={ "text": doc["text"], "source": doc["source"], - "user_id": user_id, # ← scoped to this user + "user_id": user_id, }, )) From 209969e682558e85482b8c19ce0eb274ada48f12 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Tue, 12 May 2026 17:50:00 +0100 Subject: [PATCH 5/9] feat(ingestion): ingestion done with the possibility of semantic and word base ingestion. --- src/modules/rag/ingestion.py | 406 +++++++++++++++++++++++++--- src/modules/rag/semantic_chunker.py | 283 +++++++++++++++++++ 2 files changed, 657 insertions(+), 32 deletions(-) create mode 100644 src/modules/rag/semantic_chunker.py diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index 5c458d1..c0eef46 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -1,72 +1,414 @@ -# ingestion.py +import re import argparse import os +import sys import uuid +from pathlib import Path +from datetime import datetime +from pypdf import PdfReader from qdrant_client import QdrantClient -from qdrant_client.models import VectorParams, Distance, PointStruct +from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue from sentence_transformers import SentenceTransformer +from semantic_chunker import SemanticChunker + USER_ID_FILE = os.path.expanduser("~/.huri_user_id") +def _split_sentences(text: str) -> list[str]: + """Simple sentence splitter.""" + sentences = re.split(r'(?<=[.!?])\s+', text) + + result = [] + for s in sentences: + parts = s.split("\n\n") + result.extend(parts) + return [s.strip() for s in result if s.strip()] + + +def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]: + """ + Fallback: fixed-size chunking by sentences. + Used when --chunking=fixed. + """ + sentences = _split_sentences(text) + chunks = [] + current_chunk = [] + current_length = 0 + + for sentence in sentences: + sentence = sentence.strip() + if not sentence: + continue + + sentence_length = len(sentence.split()) + + if current_length + sentence_length > chunk_size and current_chunk: + chunks.append(" ".join(current_chunk)) + + overlap_words = 0 + overlap_sentences = [] + for s in reversed(current_chunk): + overlap_words += len(s.split()) + overlap_sentences.insert(0, s) + if overlap_words >= overlap: + break + + current_chunk = overlap_sentences + current_length = overlap_words + + current_chunk.append(sentence) + current_length += sentence_length + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + return chunks + + +def extract_text_from_pdf(pdf_path: str) -> str: + """Extract text from a PDF file.""" + try: + reader = PdfReader(pdf_path) + text = "" + for page in reader.pages: + text += page.extract_text() + "\n" + return text.strip() + except ImportError: + pass + + print("ERROR: Install a PDF library: pip install pymupdf OR pip install pypdf") + sys.exit(1) + + +# --- User ID --- + def get_user_id(provided_id: str = None) -> str: - """Use provided ID, or load from file, or generate new one.""" if provided_id: return provided_id if os.path.exists(USER_ID_FILE): with open(USER_ID_FILE) as f: - return f.read().strip() + uid = f.read().strip() + if uid: + return uid new_id = str(uuid.uuid4()) with open(USER_ID_FILE, "w") as f: f.write(new_id) + print(f"Generated new user_id: {new_id}") return new_id -def main(): - parser = argparse.ArgumentParser(description="Ingest documents into Qdrant") - parser.add_argument("--user-id", type=str, default=None, help="User ID (reads from ~/.huri_user_id if not provided)") - parser.add_argument("--collection", type=str, default="documents") - parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333") - args = parser.parse_args() - - user_id = get_user_id(args.user_id) - print(f"Ingesting for user_id: {user_id}") - - client = QdrantClient(url=args.qdrant_url) - model = SentenceTransformer("BAAI/bge-large-en-v1.5") +# --- Qdrant helpers --- +def ensure_collection(client: QdrantClient, collection: str, vector_size: int): collections = [c.name for c in client.get_collections().collections] - if args.collection not in collections: + if collection not in collections: client.create_collection( - collection_name=args.collection, - vectors_config=VectorParams(size=1024, distance=Distance.COSINE), + collection_name=collection, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) - print(f"Created collection: {args.collection}") + print(f"Created collection: {collection}") - docs = [ - {"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"}, - {"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"}, - {"text": "The team consists of 5 developers and 2 designers.", "source": "team.pdf"}, - {"text": "The main office is located in Paris, France.", "source": "info.pdf"}, - ] +def ingest_chunks( + client: QdrantClient, + model: SentenceTransformer, + collection: str, + chunks: list[str], + user_id: str, + source: str, + doc_type: str = "document", +): + """Embed chunks and upsert into Qdrant.""" points = [] - for doc in docs: - vector = model.encode(doc["text"], normalize_embeddings=True).tolist() + timestamp = datetime.now().isoformat() + + for i, chunk in enumerate(chunks): + vector = model.encode(chunk, normalize_embeddings=True).tolist() points.append(PointStruct( id=str(uuid.uuid4()), vector=vector, payload={ - "text": doc["text"], - "source": doc["source"], + "text": chunk, "user_id": user_id, + "source": source, + "type": doc_type, + "chunk_index": i, + "timestamp": timestamp, }, )) - client.upsert(collection_name=args.collection, points=points) - print(f"Ingested {len(points)} documents for user {user_id}") + if points: + # Upsert in batches of 100 + batch_size = 100 + for i in range(0, len(points), batch_size): + batch = points[i:i + batch_size] + client.upsert(collection_name=collection, points=batch) + + return len(points) + + +def chunk_strat(text: str, args, model: SentenceTransformer) -> list[str]: + """Pick the right chunking strategy based on args.""" + if args.chunking == "semantic": + chunker = SemanticChunker( + model=model, + strategy=args.semantic_strategy, + ) + return chunker.chunk(text) + else: + return chunk_text(text, chunk_size=args.chunk_size, overlap=args.overlap) + + +def cmd_pdf(args, client, model, user_id): + """Ingest PDF files.""" + files = [] + for path in args.files: + p = Path(path) + if p.is_dir(): + files.extend(p.glob("**/*.pdf")) + elif p.suffix.lower() == ".pdf": + files.append(p) + else: + print(f"Skipping non-PDF: {path}") + + if not files: + print("No PDF files found.") + return + + sample = model.encode("test", normalize_embeddings=True) + ensure_collection(client, args.collection, len(sample)) + + total = 0 + for pdf_path in files: + print(f"\nProcessing: {pdf_path}") + text = extract_text_from_pdf(str(pdf_path)) + + if not text.strip(): + print(f" WARNING: No text extracted from {pdf_path}") + continue + + chunks = chunk_strat(text, args, model) + count = ingest_chunks( + client, model, args.collection, chunks, + user_id, source=pdf_path.name, doc_type="pdf", + ) + print(f" → {count} chunks ingested") + total += count + + print(f"\nDone. Total: {total} chunks from {len(files)} PDF(s)") + + +def cmd_text(args, client, model, user_id): + """Ingest text files.""" + sample = model.encode("test", normalize_embeddings=True) + ensure_collection(client, args.collection, len(sample)) + + total = 0 + for file_path in args.files: + p = Path(file_path) + if not p.exists(): + print(f"File not found: {file_path}") + continue + + print(f"\nProcessing: {file_path}") + text = p.read_text(encoding="utf-8") + + if not text.strip(): + print(f" WARNING: File is empty: {file_path}") + continue + + chunks = chunk_strat(text, args, model) + count = ingest_chunks( + client, model, args.collection, chunks, + user_id, source=p.name, doc_type="text", + ) + print(f" -> {count} chunks ingested") + total += count + + print(f"\nDone. Total: {total} chunks from {len(args.files)} file(s)") + + +def cmd_write(args, client, model, user_id): + """Write text interactively and ingest it.""" + title = args.title or f"note_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + print(f"Write your text below (title: '{title}')") + print("Press Ctrl+D (Linux/Mac) or Ctrl+Z then Enter (Windows) when done.") + print("-" * 40) + + lines = [] + try: + while True: + line = input() + lines.append(line) + except EOFError: + pass + + text = "\n".join(lines).strip() + + if not text: + print("Nothing to ingest.") + return + + print(f"\n{'-' * 40}") + print(f"Received {len(text)} characters") + + sample = model.encode("test", normalize_embeddings=True) + ensure_collection(client, args.collection, len(sample)) + + chunks = chunk_strat(text, args, model) + count = ingest_chunks( + client, model, args.collection, chunks, + user_id, source=title, doc_type="manual", + ) + + print(f"Done. Ingested {count} chunks as '{title}'") + + +def cmd_list(args, client, model, user_id): + """List what's in the database for this user.""" + + try: + info = client.get_collection(args.collection) + print(f"Collection: {args.collection}") + print(f"Total points: {info.points_count}") + except Exception: + print(f"Collection '{args.collection}' doesn't exist.") + return + + results = client.scroll( + collection_name=args.collection, + scroll_filter=Filter(must=[ + FieldCondition(key="user_id", match=MatchValue(value=user_id)), + ]), + limit=100, + with_payload=True, + with_vectors=False, + ) + + points = results[0] + if not points: + print(f"No documents found for user {user_id}") + return + + sources = {} + for p in points: + source = p.payload.get("source", "unknown") + doc_type = p.payload.get("type", "unknown") + if source not in sources: + sources[source] = {"count": 0, "type": doc_type} + sources[source]["count"] += 1 + + print(f"\nDocuments for user {user_id}:") + print(f"{'Source':<40} {'Type':<10} {'Chunks':<8}") + print("-" * 60) + for source, info in sorted(sources.items()): + print(f"{source:<40} {info['type']:<10} {info['count']:<8}") + print(f"\nTotal: {len(points)} chunks across {len(sources)} sources") + + +def cmd_delete(args, client, model, user_id): + """Delete documents by source name.""" + + if not args.source: + print("Specify --source to delete. Use 'list' command to see sources.") + return + + filter_conditions = [ + FieldCondition(key="user_id", match=MatchValue(value=user_id)), + FieldCondition(key="source", match=MatchValue(value=args.source)), + ] + + client.delete( + collection_name=args.collection, + points_selector=Filter(must=filter_conditions), + ) + print(f"Deleted all chunks from source '{args.source}' for user {user_id}") + + + +def main(): + parser = argparse.ArgumentParser(description="HuRI RAG Ingestion Tool") + parser.add_argument("--user-id", type=str, default=None) + parser.add_argument("--collection", type=str, default="documents") + parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333") + parser.add_argument("--embedding-model", type=str, default="BAAI/bge-large-en-v1.5") + parser.add_argument("--chunk-size", type=int, default=500, help="Target chunk size in words (fixed mode)") + parser.add_argument("--overlap", type=int, default=50, help="Overlap between chunks in words (fixed mode)") + parser.add_argument("--chunking", type=str, default="fixed", + choices=["semantic", "fixed"], + help="Chunking strategy: 'semantic' (default) or 'fixed'") + parser.add_argument("--semantic-strategy", type=str, default="percentile", + choices=["percentile", "threshold", "stddev"], + help="Semantic chunking strategy (default: percentile)") + + subparsers = parser.add_subparsers(dest="command", required=True) + + # pdf + p_pdf = subparsers.add_parser("pdf", help="Ingest PDF files") + p_pdf.add_argument("files", nargs="+", help="PDF files or directories") + + # text + p_text = subparsers.add_parser("text", help="Ingest text files (.txt, .md)") + p_text.add_argument("files", nargs="+", help="Text files") + + # write + p_write = subparsers.add_parser("write", help="Write text interactively") + p_write.add_argument("--title", type=str, default=None, help="Title/source name") + + # list + p_list = subparsers.add_parser("list", help="List ingested documents") + + # delete + p_delete = subparsers.add_parser("delete", help="Delete documents by source") + p_delete.add_argument("--source", type=str, required=True, help="Source name to delete") + + args = parser.parse_args() + + # Init + user_id = get_user_id(args.user_id) + print(f"User: {user_id}") + + client = QdrantClient(url=args.qdrant_url) + model = SentenceTransformer(args.embedding_model) + + # Dispatch + commands = { + "pdf": cmd_pdf, + "text": cmd_text, + "write": cmd_write, + "list": cmd_list, + "delete": cmd_delete, + } + commands[args.command](args, client, model, user_id) if __name__ == "__main__": + """ + Ingestion tool for HuRI RAG. + + Usage: + # Ingest a PDF + python ingestion.py pdf report.pdf + + # Ingest multiple PDFs + python ingestion.py pdf doc1.pdf doc2.pdf doc3.pdf + + # Ingest a whole folder of PDFs + python ingestion.py pdf ./my_documents/ + + # Write text interactively (type, then Ctrl+D to save) + python ingestion.py write --title "My meeting notes" + + # Ingest a text file + python ingestion.py text notes.txt story.md + + # Specify a user ID (otherwise reads from ~/.huri_user_id) + python ingestion.py pdf report.pdf --user-id "abc-123" + + # Use a different collection + python ingestion.py pdf report.pdf --collection "my_docs" + """ main() \ No newline at end of file diff --git a/src/modules/rag/semantic_chunker.py b/src/modules/rag/semantic_chunker.py new file mode 100644 index 0000000..d5d81cd --- /dev/null +++ b/src/modules/rag/semantic_chunker.py @@ -0,0 +1,283 @@ +""" +Semantic Chunking for RAG. + +Three strategies: + 1. percentile - cut where similarity is below the Nth percentile (default) + 2. threshold - cut where similarity drops below a fixed value + 3. stddev - cut where similarity is more than N std devs below the mean + +Usage: + from semantic_chunker import SemanticChunker + + chunker = SemanticChunker(embedding_model) + chunks = chunker.chunk(text) +""" + +import re +import numpy as np +from dataclasses import dataclass, field +from sentence_transformers import SentenceTransformer + + +@dataclass +class Chunk: + text: str + sentences: list[str] = field(default_factory=list) + start_idx: int = 0 + end_idx: int = 0 + + +class SemanticChunker: + def __init__( + self, + model: SentenceTransformer, + strategy: str = "percentile", # "percentile", "threshold", "stddev" + percentile_cutoff: float = 25, # for percentile strategy + threshold_cutoff: float = 0.5, # for threshold strategy + stddev_cutoff: float = 1.0, # for stddev strategy (N std devs below mean) + min_chunk_size: int = 2, # minimum sentences per chunk + max_chunk_size: int = 50, # maximum sentences per chunk + buffer_size: int = 1, # sentences to look around for context + ): + self.model = model + self.strategy = strategy + self.percentile_cutoff = percentile_cutoff + self.threshold_cutoff = threshold_cutoff + self.stddev_cutoff = stddev_cutoff + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + self.buffer_size = buffer_size + + def chunk(self, text: str) -> list[str]: + """Main entry point. Returns list of chunk texts.""" + sentences = self._split_sentences(text) + + if len(sentences) <= self.min_chunk_size: + return [text.strip()] if text.strip() else [] + + # 1. Combine sentences with buffer for better embeddings + combined = self._combine_with_buffer(sentences) + + # 2. Embed all combined sentences + embeddings = self.model.encode(combined, normalize_embeddings=True) + + # 3. Calculate similarity between consecutive sentences + similarities = self._calculate_similarities(embeddings) + + # 4. Find breakpoints based on strategy + breakpoints = self._find_breakpoints(similarities) + + # 5. Group sentences into chunks + chunks = self._create_chunks(sentences, breakpoints) + + return chunks + + def chunk_detailed(self, text: str) -> list[Chunk]: + """Returns detailed Chunk objects with metadata.""" + sentences = self._split_sentences(text) + + if len(sentences) <= self.min_chunk_size: + return [Chunk(text=text.strip(), sentences=sentences, start_idx=0, end_idx=len(sentences))] + + combined = self._combine_with_buffer(sentences) + embeddings = self.model.encode(combined, normalize_embeddings=True) + similarities = self._calculate_similarities(embeddings) + breakpoints = self._find_breakpoints(similarities) + + chunks = [] + start = 0 + for bp in breakpoints: + end = bp + 1 + chunk_sentences = sentences[start:end] + chunks.append(Chunk( + text=" ".join(chunk_sentences), + sentences=chunk_sentences, + start_idx=start, + end_idx=end, + )) + start = end + + # Last chunk + if start < len(sentences): + chunk_sentences = sentences[start:] + chunks.append(Chunk( + text=" ".join(chunk_sentences), + sentences=chunk_sentences, + start_idx=start, + end_idx=len(sentences), + )) + + return chunks + + # --- Internal methods --- + + def _split_sentences(self, text: str) -> list[str]: + """Split text into sentences, respecting paragraph boundaries.""" + paragraphs = text.split("\n\n") + sentences = [] + for para in paragraphs: + para = para.strip() + if not para: + continue + parts = re.split(r'(?<=[.!?])\s+', para) + for part in parts: + part = part.strip() + if part: + sentences.append(part) + return sentences + + def _combine_with_buffer(self, sentences: list[str]) -> list[str]: + """ + Combine each sentence with its neighbors for richer embeddings. + Sentence at index i gets combined with sentences [i-buffer, i+buffer]. + This gives the embedding model more context to understand each sentence. + """ + combined = [] + for i in range(len(sentences)): + start = max(0, i - self.buffer_size) + end = min(len(sentences), i + self.buffer_size + 1) + window = " ".join(sentences[start:end]) + combined.append(window) + return combined + + def _calculate_similarities(self, embeddings: np.ndarray) -> list[float]: + """Calculate cosine similarity between consecutive sentence embeddings.""" + similarities = [] + for i in range(len(embeddings) - 1): + sim = np.dot(embeddings[i], embeddings[i + 1]) + similarities.append(float(sim)) + return similarities + + def _find_breakpoints(self, similarities: list[float]) -> list[int]: + """Find where to split based on the chosen strategy.""" + if not similarities: + return [] + + sims = np.array(similarities) + + if self.strategy == "percentile": + cutoff = np.percentile(sims, self.percentile_cutoff) + candidate_indices = [i for i, s in enumerate(similarities) if s < cutoff] + + elif self.strategy == "threshold": + candidate_indices = [i for i, s in enumerate(similarities) if s < self.threshold_cutoff] + + elif self.strategy == "stddev": + mean = np.mean(sims) + std = np.std(sims) + cutoff = mean - (self.stddev_cutoff * std) + candidate_indices = [i for i, s in enumerate(similarities) if s < cutoff] + + else: + raise ValueError(f"Unknown strategy: {self.strategy}") + + breakpoints = self._enforce_chunk_sizes(candidate_indices, len(similarities) + 1) + + return breakpoints + + def _enforce_chunk_sizes(self, candidates: list[int], num_sentences: int) -> list[int]: + """Ensure chunks respect min and max size constraints.""" + if not candidates: + breakpoints = [] + pos = self.max_chunk_size - 1 + while pos < num_sentences - 1: + breakpoints.append(pos) + pos += self.max_chunk_size + return breakpoints + + breakpoints = [] + last_break = -1 + + for candidate in sorted(candidates): + chunk_size = candidate - last_break + + if chunk_size < self.min_chunk_size: + continue + + if chunk_size > self.max_chunk_size: + pos = last_break + self.max_chunk_size + while pos < candidate: + breakpoints.append(pos) + last_break = pos + pos += self.max_chunk_size + + breakpoints.append(candidate) + last_break = candidate + + remaining = num_sentences - 1 - last_break + if remaining > self.max_chunk_size: + pos = last_break + self.max_chunk_size + while pos < num_sentences - 1: + breakpoints.append(pos) + pos += self.max_chunk_size + + return breakpoints + + def _create_chunks(self, sentences: list[str], breakpoints: list[int]) -> list[str]: + """Group sentences into chunks based on breakpoints.""" + chunks = [] + start = 0 + + for bp in breakpoints: + end = bp + 1 + chunk_text = " ".join(sentences[start:end]).strip() + if chunk_text: + chunks.append(chunk_text) + start = end + + # Last chunk + if start < len(sentences): + chunk_text = " ".join(sentences[start:]).strip() + if chunk_text: + chunks.append(chunk_text) + + return chunks + + + +def create_chunker( + model: SentenceTransformer = None, + model_name: str = "BAAI/bge-large-en-v1.5", + strategy: str = "percentile", + **kwargs, +) -> SemanticChunker: + """Create a chunker with defaults.""" + if model is None: + model = SentenceTransformer(model_name) + return SemanticChunker(model=model, strategy=strategy, **kwargs) + + + +if __name__ == "__main__": + print("Loading model...") + model = SentenceTransformer("BAAI/bge-large-en-v1.5") + chunker = SemanticChunker(model, strategy="stddev") + + text = """ + The company budget for 2026 is set at 2 million euros. This represents a 10% increase + from the previous year. The finance department has approved the allocation after extensive review. + + The engineering team is growing rapidly. We hired 5 new developers last quarter. + The team now consists of 15 engineers and 3 designers. We plan to hire 2 more QA engineers + by the end of Q2. + + Our main office is relocating to Lyon in September. The new building has 3 floors + and modern facilities. The move will affect approximately 50 employees. We are organizing + transport for all office equipment. + + The product roadmap for Q3 includes a major redesign of the dashboard. User feedback + indicated that the current interface is too complex. We will conduct usability testing + in July before the final release. + """ + + print("\n--- Percentile strategy (default) ---") + chunks = chunker.chunk(text) + for i, chunk in enumerate(chunks, 1): + print(f"\nChunk {i} ({len(chunk.split())} words):") + print(f" {chunk[:150]}...") + + print("\n--- Detailed output ---") + detailed = chunker.chunk_detailed(text) + for i, chunk in enumerate(detailed, 1): + print(f"\nChunk {i}: sentences {chunk.start_idx}-{chunk.end_idx} ({len(chunk.sentences)} sentences)") + print(f" {chunk.text[:150]}...") \ No newline at end of file From b5f810fc531b3bb78709caefc9abd5dc5dd41124 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Tue, 12 May 2026 17:54:29 +0100 Subject: [PATCH 6/9] wip(todo): Add some todos to not forget the work I have to do --- src/modules/rag/ingestion.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index c0eef46..2b57469 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -396,7 +396,8 @@ def main(): # Ingest multiple PDFs python ingestion.py pdf doc1.pdf doc2.pdf doc3.pdf - # Ingest a whole folder of PDFs + # Ingest a whole folder of PDFs + # TODO: To verify and to add the support of hole paths python ingestion.py pdf ./my_documents/ # Write text interactively (type, then Ctrl+D to save) @@ -406,9 +407,13 @@ def main(): python ingestion.py text notes.txt story.md # Specify a user ID (otherwise reads from ~/.huri_user_id) - python ingestion.py pdf report.pdf --user-id "abc-123" + python ingestion.py --user-id "abc-123" pdf report.pdf # Use a different collection - python ingestion.py pdf report.pdf --collection "my_docs" + python ingestion.py --collection "my_docs" pdf report.pdf + + # Use a different ingestion strategy + python src/modules/rag/ingestion.py --chunking semantic --semantic-strategy threshold pdf "EN.pdf" + """ main() \ No newline at end of file From a0418a9e6f61b9f8bf734b79dc966da6aa949fff Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Fri, 15 May 2026 13:41:55 +0100 Subject: [PATCH 7/9] refacto(user_ids): user_id -> user_id --- src/modules/rag/ingestion.py | 38 ++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index 2b57469..d0434a9 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -116,7 +116,7 @@ def ingest_chunks( model: SentenceTransformer, collection: str, chunks: list[str], - user_id: str, + _user_id: str, source: str, doc_type: str = "document", ): @@ -131,7 +131,7 @@ def ingest_chunks( vector=vector, payload={ "text": chunk, - "user_id": user_id, + "_user_id": _user_id, "source": source, "type": doc_type, "chunk_index": i, @@ -161,7 +161,7 @@ def chunk_strat(text: str, args, model: SentenceTransformer) -> list[str]: return chunk_text(text, chunk_size=args.chunk_size, overlap=args.overlap) -def cmd_pdf(args, client, model, user_id): +def cmd_pdf(args, client, model, _user_id): """Ingest PDF files.""" files = [] for path in args.files: @@ -192,15 +192,15 @@ def cmd_pdf(args, client, model, user_id): chunks = chunk_strat(text, args, model) count = ingest_chunks( client, model, args.collection, chunks, - user_id, source=pdf_path.name, doc_type="pdf", + _user_id, source=pdf_path.name, doc_type="pdf", ) - print(f" → {count} chunks ingested") + print(f" -> {count} chunks ingested") total += count print(f"\nDone. Total: {total} chunks from {len(files)} PDF(s)") -def cmd_text(args, client, model, user_id): +def cmd_text(args, client, model, _user_id): """Ingest text files.""" sample = model.encode("test", normalize_embeddings=True) ensure_collection(client, args.collection, len(sample)) @@ -222,7 +222,7 @@ def cmd_text(args, client, model, user_id): chunks = chunk_strat(text, args, model) count = ingest_chunks( client, model, args.collection, chunks, - user_id, source=p.name, doc_type="text", + _user_id, source=p.name, doc_type="text", ) print(f" -> {count} chunks ingested") total += count @@ -230,7 +230,7 @@ def cmd_text(args, client, model, user_id): print(f"\nDone. Total: {total} chunks from {len(args.files)} file(s)") -def cmd_write(args, client, model, user_id): +def cmd_write(args, client, model, _user_id): """Write text interactively and ingest it.""" title = args.title or f"note_{datetime.now().strftime('%Y%m%d_%H%M%S')}" @@ -261,13 +261,13 @@ def cmd_write(args, client, model, user_id): chunks = chunk_strat(text, args, model) count = ingest_chunks( client, model, args.collection, chunks, - user_id, source=title, doc_type="manual", + _user_id, source=title, doc_type="manual", ) print(f"Done. Ingested {count} chunks as '{title}'") -def cmd_list(args, client, model, user_id): +def cmd_list(args, client, model, _user_id): """List what's in the database for this user.""" try: @@ -281,7 +281,7 @@ def cmd_list(args, client, model, user_id): results = client.scroll( collection_name=args.collection, scroll_filter=Filter(must=[ - FieldCondition(key="user_id", match=MatchValue(value=user_id)), + FieldCondition(key="_user_id", match=MatchValue(value=_user_id)), ]), limit=100, with_payload=True, @@ -290,7 +290,7 @@ def cmd_list(args, client, model, user_id): points = results[0] if not points: - print(f"No documents found for user {user_id}") + print(f"No documents found for user {_user_id}") return sources = {} @@ -301,7 +301,7 @@ def cmd_list(args, client, model, user_id): sources[source] = {"count": 0, "type": doc_type} sources[source]["count"] += 1 - print(f"\nDocuments for user {user_id}:") + print(f"\nDocuments for user {_user_id}:") print(f"{'Source':<40} {'Type':<10} {'Chunks':<8}") print("-" * 60) for source, info in sorted(sources.items()): @@ -309,7 +309,7 @@ def cmd_list(args, client, model, user_id): print(f"\nTotal: {len(points)} chunks across {len(sources)} sources") -def cmd_delete(args, client, model, user_id): +def cmd_delete(args, client, model, _user_id): """Delete documents by source name.""" if not args.source: @@ -317,7 +317,7 @@ def cmd_delete(args, client, model, user_id): return filter_conditions = [ - FieldCondition(key="user_id", match=MatchValue(value=user_id)), + FieldCondition(key="_user_id", match=MatchValue(value=_user_id)), FieldCondition(key="source", match=MatchValue(value=args.source)), ] @@ -325,7 +325,7 @@ def cmd_delete(args, client, model, user_id): collection_name=args.collection, points_selector=Filter(must=filter_conditions), ) - print(f"Deleted all chunks from source '{args.source}' for user {user_id}") + print(f"Deleted all chunks from source '{args.source}' for user {_user_id}") @@ -368,8 +368,8 @@ def main(): args = parser.parse_args() # Init - user_id = get_user_id(args.user_id) - print(f"User: {user_id}") + _user_id = get_user_id(args._user_id) + print(f"User: {_user_id}") client = QdrantClient(url=args.qdrant_url) model = SentenceTransformer(args.embedding_model) @@ -382,7 +382,7 @@ def main(): "list": cmd_list, "delete": cmd_delete, } - commands[args.command](args, client, model, user_id) + commands[args.command](args, client, model, _user_id) if __name__ == "__main__": From 529297ae7a592f907d00ec3695206efd32cc21e3 Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Fri, 15 May 2026 13:45:35 +0100 Subject: [PATCH 8/9] wip(pr): add a module with ids and generate a rag class with module with id and module with handle. Make some refacto : user_id -> _user_id and handle -> _handle --- src/client.py | 18 ++++---- src/core/huri.py | 23 ++++------ src/core/module.py | 15 +++++-- src/modules/factory.py | 31 ++++++++++---- src/modules/rag/rag.py | 67 +++++++++++++----------------- src/modules/reasoning/embedding.py | 6 +-- 6 files changed, 83 insertions(+), 77 deletions(-) diff --git a/src/client.py b/src/client.py index 2240709..63f490a 100644 --- a/src/client.py +++ b/src/client.py @@ -23,9 +23,9 @@ def load_user_id() -> str | None: return None -def save_user_id(user_id: str): +def save_user_id(_user_id: str): with open(USER_ID_FILE, "w") as f: - f.write(user_id) + f.write(_user_id) def load_client_config(path: str) -> ClientConfig: with open(path) as f: @@ -54,18 +54,18 @@ async def stream_audio(): print("Connected to server") payload = asdict(config) - user_id = load_user_id() - if user_id: - payload["user_id"] = user_id - print(f"Reconnecting with user_id: {user_id}") + _user_id = load_user_id() + if _user_id: + payload["_user_id"] = _user_id + print(f"Reconnecting with _user_id: {_user_id}") await ws.send(json.dumps(payload)) init_msg = json.loads(await ws.recv()) if init_msg.get("type") == "session_init": - user_id = init_msg["user_id"] - save_user_id(user_id) - print(f"Session started with user_id: {user_id}") + _user_id = init_msg["_user_id"] + save_user_id(_user_id) + print(f"Session started with _user_id: {_user_id}") async def receive(ws: websockets.ClientConnection): while True: diff --git a/src/core/huri.py b/src/core/huri.py index a07c4fe..4c4207a 100644 --- a/src/core/huri.py +++ b/src/core/huri.py @@ -7,7 +7,6 @@ from src.modules.factory import Module, ModuleFactory from src.modules.utils.sender import Sender - from .app import app from .dataclasses.config import ClientConfig from .session import Session @@ -24,7 +23,6 @@ def __init__( self.factory = ModuleFactory(handles) for name, module_cls in modules.items(): self.factory.register(name, module_cls) - self.clients: Dict[str, Session] = {} @app.websocket("/session") @@ -33,24 +31,20 @@ async def run_session(self, ws: WebSocket): client_config_raw: Dict = await ws.receive_json() client_config = ClientConfig.from_dict(client_config_raw) - user_id = client_config_raw.get("user_id") or str(uuid.uuid4()) - await ws.send_json({"type": "session_init", "user_id": user_id}) - - if "rag" in client_config.modules: - if client_config.modules["rag"].args is None: - client_config.modules["rag"].args = {} - client_config.modules["rag"].args["user_id"] = user_id + _user_id = client_config_raw.get("_user_id") or str(uuid.uuid4()) senders: List[Module] = [ Sender(ws, topic) for topic in client_config.topic_list ] modules: List[Module] = ( - self.factory.create_from_config(client_config.modules) + senders + self.factory.create_from_config(_user_id, client_config.modules) + senders ) + await ws.send_json({"type": "session_init", "_user_id": _user_id}) + session_id = str(uuid.uuid4()) self.clients[session_id] = Session(modules) - print("Client registered successfully with config:", client_config) + print(f"Client registered with _user_id={_user_id}, config: {client_config}") async def receive_loop(session: Session, ws: WebSocket): try: @@ -59,9 +53,8 @@ async def receive_loop(session: Session, ws: WebSocket): if "bytes" in msg: chunk = msg["bytes"] await session.publish("chunk", chunk) - # else: - # data = msg - # await session.publish(data["type"], data["data"]) except (WebSocketDisconnect, RuntimeError): - print(f"Client disconnected") + print(f"Client {_user_id} disconnected") + await receive_loop(self.clients[session_id], ws) + del self.clients[session_id] \ No newline at end of file diff --git a/src/core/module.py b/src/core/module.py index 12543cf..e57dba4 100644 --- a/src/core/module.py +++ b/src/core/module.py @@ -14,6 +14,15 @@ async def process(self, _) -> Optional[Any]: class ModuleWithHandle(Module): _handle_cls: Type[Any] - def __init__(self, handle: handle.DeploymentHandle): - super().__init__() - self.handle = handle + def __init__(self, _handle: handle.DeploymentHandle = None, **kwargs): + super().__init__(**kwargs) + self._handle = _handle + +class ModuleWithId(Module): + def __init__(self, _user_id: str, **kwargs): + super().__init__(**kwargs) + self._user_id = _user_id + + def get_user_context(self) -> dict: + """Override in subclasses to provide user-specific context.""" + return {"_user_id": self._user_id} diff --git a/src/modules/factory.py b/src/modules/factory.py index 39a1a83..d252ab1 100644 --- a/src/modules/factory.py +++ b/src/modules/factory.py @@ -1,7 +1,8 @@ +from os import name from typing import Any, Dict, List, Mapping, Type from src.core.dataclasses.config import ModuleConfig -from src.core.module import Module, ModuleWithHandle, handle +from src.core.module import Module, ModuleWithHandle, handle, ModuleWithId class ModuleFactory: @@ -9,6 +10,7 @@ def __init__(self, handles): self._registry: Dict[str, Type[Module]] = {} self._handles = handles + def register(self, name: str, module_cls: Type[Module]) -> None: if not issubclass(module_cls, Module): raise TypeError(f"{module_cls} must inherit from Module") @@ -19,29 +21,40 @@ def register(self, name: str, module_cls: Type[Module]) -> None: ) self._registry[name] = module_cls - def create(self, name: str, args: Mapping[str, Any] | None = None) -> Module: + + def create( + self, + _user_id: str, + name: str, + args: Mapping[str, Any] | None = None + ) -> Module: + if name not in self._registry: raise ValueError(f"Unknown module '{name}'") + module_cls = self._registry[name] - if args is None: - args = {} + kwargs = dict(args or {}) + if issubclass(module_cls, ModuleWithHandle): if name not in self._handles: raise RuntimeError( f"Handles not bound for '{name}'. Check your config first." ) - return module_cls(handle=self._handles[name], **args) - return module_cls(**args) + kwargs["_handle"] = self._handles[name] + + if issubclass(module_cls, ModuleWithId): + kwargs["_user_id"] = _user_id + + return module_cls(**kwargs) def create_from_config( - self, module_configs: Dict[str, ModuleConfig] + self, _user_id: str, module_configs: Dict[str, ModuleConfig] ) -> List[Module]: modules: List[Module] = [] for _, module_config in module_configs.items(): - modules.append(self.create(module_config.name, module_config.args)) - + modules.append(self.create(_user_id, module_config.name, module_config.args)) if modules == []: raise Exception diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index e37db92..ac594d8 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -1,9 +1,8 @@ from typing import Any, Optional from dataclasses import dataclass, field -from ray import data, serve -from ray.serve import handle -from src.core.module import ModuleWithHandle +from ray import serve +from src.core.module import ModuleWithHandle, ModuleWithId, handle from qdrant_client.models import Filter, FieldCondition, MatchValue from sentence_transformers import SentenceTransformer from qdrant_client import QdrantClient @@ -15,7 +14,7 @@ @dataclass class RAGQuery: """What flows from RAG module to RAGHandle.""" - user_id: str + _user_id: str question: str preferences: dict = field(default_factory=dict) # preferences can include: language, tone, response_format, max_length, system_prompt, extra_instructions, etc. @@ -35,7 +34,7 @@ class RAGResult: class RAGHandle: """ Stateless RAG processor. Knows nothing about sessions. - Receives a user_id + question, uses user_id to find the right + Receives a _user_id + question, uses _user_id to find the right collection/data in the vector DB, runs embed -> search -> LLM. """ @@ -62,24 +61,24 @@ def __init__( self.llm_model = llm_model self.llm_api_key = llm_api_key - def _resolve_user_context(self, user_id: str) -> tuple[str, dict | None]: + def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: """ - Given a user_id, decide which collection to search + Given a _user_id, decide which collection to search and which filters to apply. Options (pick what fits your data model): - A) One collection per user: collection = f"user_{user_id}" - B) Shared collection, filter by user_id in payload + A) One collection per user: collection = f"user_{_user_id}" + B) Shared collection, filter by _user_id in payload C) Lookup in a DB to find the user's config """ # Option A: separate collection per user - # collection = f"user_{user_id}" + # collection = f"user_{_user_id}" # filters = None - # Option B: shared collection with user_id filter (recommended) + # Option B: shared collection with _user_id filter (recommended) collection = self.default_collection - filters = {"user_id": user_id} + filters = {"_user_id": _user_id} return collection, filters @@ -96,7 +95,6 @@ def _search( filters: dict | None = None, ) -> list[dict]: - # Build qdrant filter from user context qdrant_filter = None if filters: conditions = [ @@ -227,11 +225,11 @@ async def _call_ollama(self, messages: list, max_tokens: int) -> str: async def process(self, query: RAGQuery) -> RAGResult: """ Main entry point. Called by the RAG module. - Uses user_id to determine which collection / filters to use. + Uses _user_id to determine which collection / filters to use. """ print(f"[RAG] Question: {query.question}") - collection, filters = self._resolve_user_context(query.user_id) + collection, filters = self._resolve_user_context(query._user_id) query_vector = self._embed(query.question) chunks = self._search(query_vector, collection, filters) @@ -256,31 +254,23 @@ async def process(self, query: RAGQuery) -> RAGResult: ) -class RAG(ModuleWithHandle): - """ - Session-bound module. HuRI instantiates this when a client connects, - passing the user_id from the WebSocket config. - - Listens to "question" events. - Forwards question + user_id to the detached RAGHandle. - Emits "rag_response" event with the answer. - """ +class RAG(ModuleWithHandle, ModuleWithId): _handle_cls = RAGHandle input_type = "question" output_type = "rag_response" def __init__( self, - handle: handle.DeploymentHandle[RAGHandle], - user_id: str = "", - language: str = "en", - tone: str = "formal", - response_format: str = "paragraph", - max_length: int = 1024, - extra_instructions: str = "", + _handle=None, + _user_id="", + language="en", + tone="formal", + response_format="paragraph", + max_length=1024, + extra_instructions="", + **kwargs, ): - super().__init__(handle) - self.user_id = user_id + super().__init__(_handle=_handle, _user_id=_user_id, **kwargs) self.preferences = { "language": language, "tone": tone, @@ -288,23 +278,24 @@ def __init__( "max_length": max_length, "extra_instructions": extra_instructions, } - + async def process(self, data) -> Optional[Any]: """ Called when a "question" event arrives through the event bus. - Packages user_id + question, sends to the stateless RAGHandle. + Packages _user_id + question, sends to the stateless RAGHandle. """ question_text = data.text if hasattr(data, 'text') else str(data) query = RAGQuery( - user_id=self.user_id if self.user_id else "anonymous", + _user_id=self._user_id if self._user_id else "anonymous", question=question_text, preferences=self.preferences, ) - result: RAGResult = await self.handle.process.remote(query) + result: RAGResult = await self._handle.process.remote(query) return result - + + def update_preferences(self, new_preferences: dict): """Client can update preferences mid-session via the event bus.""" self.preferences.update(new_preferences) diff --git a/src/modules/reasoning/embedding.py b/src/modules/reasoning/embedding.py index ff14897..b6139d9 100644 --- a/src/modules/reasoning/embedding.py +++ b/src/modules/reasoning/embedding.py @@ -29,13 +29,13 @@ class EMB(ModuleWithHandle): input_type = "toembed" output_type = "embedded" - def __init__(self, handle: handle.DeploymentHandle[EMBHandle]): - super().__init__(handle) + def __init__(self, _handle: handle.DeploymentHandle[EMBHandle]): + super().__init__(_handle) self.database = "" async def process(self, data_to_embed: np.ndarray) -> Optional[Any]: - embedded = await self.handle.embbed.remote(data_to_embed) + embedded = await self._handle.embbed.remote(data_to_embed) # TODO write embedding return embedded From 04073a0909d43d27f1e107bf0db154417ec5326b Mon Sep 17 00:00:00 2001 From: Kaiserbuffle Date: Fri, 15 May 2026 17:35:50 +0100 Subject: [PATCH 9/9] wip(docker): add a docker that launch with one commande. Only work with Ollama for other provider we need to change the code. --- src/app.py | 12 +- src/modules/factory.py | 13 +- src/modules/rag/docker_services.py | 323 +++++++++++++++++++++++++++++ src/modules/rag/rag.py | 55 +++-- 4 files changed, 382 insertions(+), 21 deletions(-) create mode 100644 src/modules/rag/docker_services.py diff --git a/src/app.py b/src/app.py index 1d19fa6..35ed16a 100644 --- a/src/app.py +++ b/src/app.py @@ -3,13 +3,21 @@ from src.core.huri import HuRI from src.modules.factory import bind_deployment_handles from src.modules.modules import get_modules +from src.modules.rag.docker_services import OllamaService, QdrantService def build_app() -> Application: modules = get_modules() - handles = bind_deployment_handles(modules) - app: Application = HuRI.bind(modules, handles) # type: ignore[attr-defined] + qdrant = QdrantService.bind(port=6333) + ollama = OllamaService.options(num_replicas=1).bind( + model="mistral:7b", + image="ollama/ollama:rocm", + gpu_devices=True, + ) + + handles = bind_deployment_handles(modules, ollama=ollama, qdrant=qdrant) + app: Application = HuRI.bind(modules, handles) return app diff --git a/src/modules/factory.py b/src/modules/factory.py index d252ab1..56b4d7f 100644 --- a/src/modules/factory.py +++ b/src/modules/factory.py @@ -63,15 +63,24 @@ def create_from_config( def bind_deployment_handles( modules: Dict[str, Type[Module]], + **service_handles, ) -> Dict[str, handle.DeploymentHandle]: handles: Dict[str, handle.DeploymentHandle] = {} for name, module_cls in modules.items(): if not issubclass(module_cls, ModuleWithHandle): continue - + if not hasattr(module_cls, "_handle_cls"): raise TypeError(f"{module_cls.__name__} must define _handle_cls") + handle_cls = module_cls._handle_cls - handles[name] = handle_cls.bind() + + if name == "rag" and service_handles: + handles[name] = handle_cls.bind( + ollama_handle=service_handles.get("ollama"), + qdrant_handle=service_handles.get("qdrant"), + ) + else: + handles[name] = handle_cls.bind() return handles diff --git a/src/modules/rag/docker_services.py b/src/modules/rag/docker_services.py new file mode 100644 index 0000000..6fb3de9 --- /dev/null +++ b/src/modules/rag/docker_services.py @@ -0,0 +1,323 @@ +""" +Docker Services for HuRI — Ray-managed Docker containers. + +WHAT THIS DOES: + Each service is a Ray Serve deployment that: + 1. Starts a Docker container when the deployment initializes + 2. Exposes methods to interact with that container + 3. Cleans up the container when the deployment is destroyed + +WHY: + Instead of manually running `docker run ...` before starting HuRI, + Ray does it for you. And if you need more instances (e.g. more LLMs), + you just increase num_replicas — Ray starts more containers. + +HOW IT WORKS: + OllamaService: + - Starts an Ollama Docker container on a random free port + - Pulls the requested model + - Exposes generate() to send prompts and get answers + - Scales horizontally: 2 replicas = 2 containers = 2x throughput + + QdrantService: + - Starts a Qdrant Docker container (or reuses existing one) + - Exposes get_url() so other services know where to connect + - Always 1 replica: it's a database, data must be in one place +""" + +import time +import socket +import subprocess + +import httpx +from ray import serve + + +def find_free_port() -> int: + """ + Ask the OS for a random free port. + We need this because if we run multiple Ollama containers, + they can't all use port 11434 — each needs its own. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def wait_for_service(url: str, timeout: int = 120) -> bool: + """ + Returns True if ready, False if timeout. + """ + start = time.time() + while time.time() - start < timeout: + try: + resp = httpx.get(url, timeout=5) + if resp.status_code == 200: + return True + except Exception: + pass + time.sleep(2) + return False + + +def is_container_running(name: str) -> bool: + """Check if a Docker container with this name is already running.""" + result = subprocess.run( + ["docker", "ps", "-q", "-f", f"name=^{name}$"], + capture_output=True, text=True, + ) + return bool(result.stdout.strip()) + + +def remove_container(name: str): + """Force remove a container by name (ignores errors if it doesn't exist).""" + subprocess.run(["docker", "rm", "-f", name], capture_output=True) + + +@serve.deployment +class OllamaService: + """ + Manages one Ollama Docker container. + + LIFECYCLE: + __init__: starts container -> waits for it -> pulls model + generate: sends a prompt to the container, returns the answer + __del__: stops and removes the container + """ + + def __init__( + self, + model: str = "mistral:7b", + image: str = "ollama/ollama:latest", + gpu_devices: bool = False, + ): + self.model = model + self.port = find_free_port() + self.container_name = f"ollama-ray-{self.port}" + self.base_url = f"http://localhost:{self.port}" + + remove_container(self.container_name) + + cmd = [ + "docker", "run", "-d", + "--name", self.container_name, + "-p", f"{self.port}:11434", + "-v", "ollama_shared:/root/.ollama", + ] + + if gpu_devices: + cmd.extend([ + "--device=/dev/kfd", + "--device=/dev/dri", + "--group-add=video", + ]) + + cmd.append(image) + + print(f"[OllamaService] Starting container '{self.container_name}' on port {self.port}...") + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Docker failed: {result.stderr}") + + print(f"[OllamaService] Waiting for Ollama to be ready...") + if not wait_for_service(f"{self.base_url}/api/tags"): + raise RuntimeError(f"Ollama didn't start within timeout on port {self.port}") + + print(f"[OllamaService] Pulling model '{model}'...") + pull_result = subprocess.run( + ["docker", "exec", self.container_name, "ollama", "pull", model], + capture_output=True, text=True, + ) + if pull_result.returncode != 0: + raise RuntimeError(f"Failed to pull model: {pull_result.stderr}") + + print(f"[OllamaService] Ready! container='{self.container_name}', port={self.port}, model='{model}'") + + + async def generate( + self, + messages: list, + max_tokens: int = 1024, + temperature: float = 0.1, + ) -> str: + """ + Send messages to Ollama and return the response. + This is what RAGHandle calls to get LLM answers. + """ + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post( + f"{self.base_url}/api/chat", + json={ + "model": self.model, + "messages": messages, + "stream": False, + "options": { + "num_predict": max_tokens, + "temperature": temperature, + }, + }, + ) + resp.raise_for_status() + return resp.json()["message"]["content"] + + async def health(self) -> dict: + """Check if this Ollama instance is alive.""" + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(f"{self.base_url}/api/tags") + return {"status": "ok", "port": self.port, "container": self.container_name} + except Exception as e: + return {"status": "error", "error": str(e)} + + def __del__(self): + """Cleanup when Ray destroys this replica.""" + print(f"[OllamaService] Removing container '{self.container_name}'") + remove_container(self.container_name) + + +@serve.deployment(num_replicas=1) +class QdrantService: + """ + Manages a Qdrant Docker container. + + LIFECYCLE: + __init__: starts container (or reuses if already running) + get_url: returns the URL other services should connect to + __del__: leaves the container running (it has data!) + """ + + def __init__( + self, + port: int = 6333, + image: str = "qdrant/qdrant:latest", + storage_volume: str = "qdrant_data", + ): + self.port = port + self.container_name = "qdrant-ray" + self.url = f"http://localhost:{self.port}" + + if self._is_healthy(): + print(f"[QdrantService] Qdrant already running on port {self.port}") + return + + remove_container(self.container_name) + + cmd = [ + "docker", "run", "-d", + "--name", self.container_name, + "-p", f"{self.port}:6333", + "-v", f"{storage_volume}:/qdrant/storage", + image, + ] + + print(f"[QdrantService] Starting Qdrant on port {self.port}...") + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Docker failed: {result.stderr}") + + if not wait_for_service(f"{self.url}/healthz"): + raise RuntimeError(f"Qdrant didn't start within timeout on port {self.port}") + + print(f"[QdrantService] Ready on port {self.port}") + + + def _is_healthy(self) -> bool: + try: + resp = httpx.get(f"{self.url}/healthz", timeout=3) + return resp.status_code == 200 + except Exception: + return False + + + async def get_url(self) -> str: + """Return the URL. Called by RAGHandle to know where Qdrant is.""" + return self.url + + + async def health(self) -> dict: + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(f"{self.url}/healthz") + return {"status": "ok", "port": self.port, "url": self.url} + except Exception as e: + return {"status": "error", "error": str(e)} + + + def __del__(self): + print(f"[QdrantService] Actor destroyed. Container '{self.container_name}' left running.") + + +if __name__ == "__main__": + """ + Test the services independently, without HuRI. + + Run: + python docker_services.py + + What it does: + 1. Starts Ray + 2. Deploys QdrantService -> starts Qdrant container + 3. Deploys OllamaService -> starts Ollama container + pulls model + 4. Sends a test prompt to Ollama + 5. Prints results + 6. Ctrl+C to stop and cleanup + """ + import ray + + print("=" * 60) + print("Docker Services — Standalone Test") + print("=" * 60) + + # Step 1: Start Ray + print("\n[1/5] Starting Ray...") + ray.init() + serve.start() + + # Step 2: Deploy Qdrant + print("\n[2/5] Deploying QdrantService...") + qdrant_app = serve.run( + QdrantService.bind(port=6333), + name="qdrant-test", + route_prefix="/qdrant-test", + ) + qdrant_health = qdrant_app.health.remote().result() + print(f" Qdrant health: {qdrant_health}") + qdrant_url = qdrant_app.get_url.remote().result() + print(f" Qdrant URL: {qdrant_url}") + + # Step 3: Deploy Ollama + print("\n[3/5] Deploying OllamaService (this may take a minute)...") + ollama_app = serve.run( + OllamaService.bind( + model="mistral:7b", + image="ollama/ollama:rocm", # change to ollama/ollama:latest if no AMD GPU + gpu_devices=True, # set False if no AMD GPU + ), + name="ollama-test", + route_prefix="/ollama-test", + ) + ollama_health = ollama_app.health.remote().result() + print(f" Ollama health: {ollama_health}") + + # Step 4: Test generation + print("\n[4/5] Sending test prompt to Ollama...") + answer = ollama_app.generate.remote( + messages=[{"role": "user", "content": "Say hello in exactly 5 words."}], + max_tokens=50, + ).result() + print(f" Ollama response: {answer}") + + # Step 5: Done + print("\n[5/5] All tests passed!") + print(f" Qdrant running at: {qdrant_url}") + print(f" Ollama running at: port from health check above") + print("\nPress Ctrl+C to stop and cleanup.") + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("\nShutting down...") + serve.shutdown() + ray.shutdown() + print("Done.") \ No newline at end of file diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index ac594d8..7d2a9eb 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -40,6 +40,8 @@ class RAGHandle: def __init__( self, + ollama_handle=None, + qdrant_handle=None, qdrant_url: str = "http://localhost:6333", default_collection: str = "documents", embedding_model: str = "BAAI/bge-large-en-v1.5", @@ -51,15 +53,31 @@ def __init__( score_threshold: float = 0.5, ): self.embed_model = SentenceTransformer(embedding_model) - self.qdrant = QdrantClient(url=qdrant_url) self.default_collection = default_collection self.top_k = top_k self.score_threshold = score_threshold - + self.llm_provider = llm_provider self.llm_url = llm_url self.llm_model = llm_model self.llm_api_key = llm_api_key + + self.ollama_handle = ollama_handle + self.qdrant_handle = qdrant_handle + + self._qdrant_url = qdrant_url + self._qdrant = None + + + async def _get_qdrant(self): + """Connect to Qdrant on first use. Solves the async-in-init problem.""" + if self._qdrant is None: + if self.qdrant_handle: + self._qdrant_url = await self.qdrant_handle.get_url.remote() + self._qdrant = QdrantClient(url=self._qdrant_url) + print(f"[RAGHandle] Connected to Qdrant at {self._qdrant_url}") + return self._qdrant + def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: """ @@ -72,11 +90,6 @@ def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: C) Lookup in a DB to find the user's config """ - # Option A: separate collection per user - # collection = f"user_{_user_id}" - # filters = None - - # Option B: shared collection with _user_id filter (recommended) collection = self.default_collection filters = {"_user_id": _user_id} @@ -87,9 +100,9 @@ def _embed(self, text) -> list[float]: return self.embed_model.encode(str(text), normalize_embeddings=True).tolist() - def _search( self, + qdrant, query_vector: list[float], collection: str, filters: dict | None = None, @@ -103,14 +116,16 @@ def _search( ] qdrant_filter = Filter(must=conditions) - results = self.qdrant.query_points( - collection_name=collection, - query=query_vector, - query_filter=qdrant_filter, - limit=self.top_k, - score_threshold=self.score_threshold, - ).points - + try: + results = qdrant.query_points( + collection_name=collection, + query=query_vector, + query_filter=qdrant_filter, + limit=self.top_k, + score_threshold=self.score_threshold, + ).points + except Exception: + results = [] return [ { "text": point.payload.get("text", ""), @@ -178,6 +193,9 @@ async def _llm_generate( {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] + + if self.ollama_handle: + return await self.ollama_handle.generate.remote(messages, max_tokens) if self.llm_provider == "vllm": return await self._call_openai_compatible( @@ -229,9 +247,12 @@ async def process(self, query: RAGQuery) -> RAGResult: """ print(f"[RAG] Question: {query.question}") + + qdrant = await self._get_qdrant() + collection, filters = self._resolve_user_context(query._user_id) query_vector = self._embed(query.question) - chunks = self._search(query_vector, collection, filters) + chunks = self._search(qdrant, query_vector, collection, filters) print(f"[RAG] Found {len(chunks)} chunks")