diff --git a/config/client_aux2.yaml b/config/client_aux2.yaml new file mode 100644 index 0000000..09595b4 --- /dev/null +++ b/config/client_aux2.yaml @@ -0,0 +1,26 @@ +huri_url: ws://localhost:8000/session + +topic_list: ["transcript", "question", "rag_response"] +sample_rate: 16000 +frame_duration: 0.030 +modules: + mic: + name: mic + args: + vad_agressiveness: 3 + silence_duration: 1.5 + block_duration: ${frame_duration} + stt: + name: stt + args: + language: "en" + block_duration: ${frame_duration} + logging: INFO + tag: + name: tag + logging: INFO + rag: + name: rag + args: + language: "en" + tone: "formal" diff --git a/config/huri.yaml b/config/huri.yaml index c3545a4..70d2cc7 100644 --- a/config/huri.yaml +++ b/config/huri.yaml @@ -11,6 +11,17 @@ logging_config: enable_access_log: true additional_log_standard_attrs: [] +services: + qdrant: + port: 6333 + image: "qdrant/qdrant:latest" + storage_volume: "qdrant_data" + ollama: + model: "mistral:7b" + image: "ollama/ollama:rocm" + gpu_devices: true + num_replicas: 1 + applications: - name: huri-app route_prefix: / @@ -18,3 +29,7 @@ applications: runtime_env: { RAY_COLOR_PREFIX=1 } deployments: - name: HuRI + - name: RAGHandle + num_replicas: 2 + - name: OllamaService + - name: QdrantService diff --git a/src/app.py b/src/app.py index 1d19fa6..c186d7e 100644 --- a/src/app.py +++ b/src/app.py @@ -1,15 +1,48 @@ +from pathlib import Path + +import yaml from ray.serve import Application from src.core.huri import HuRI from src.modules.factory import bind_deployment_handles from src.modules.modules import get_modules +from src.modules.rag.docker_services import OllamaService, QdrantService + + +def load_services_config() -> dict: + config_path = Path(__file__).resolve().parents[1] / "config" / "huri.yaml" + with open(config_path) as f: + config = yaml.safe_load(f) + return config.get("services", {}) + + +def build_qdrant(config: dict): + return QdrantService.bind( + port=config.get("port", 6333), + image=config.get("image", "qdrant/qdrant:latest"), + storage_volume=config.get("storage_volume", "qdrant_data"), + ) + + +def build_ollama(config: dict): + return OllamaService.options( + num_replicas=config.get("num_replicas", 1), + ).bind( + model=config.get("model", "mistral:7b"), + image=config.get("image", "ollama/ollama:latest"), + gpu_devices=config.get("gpu_devices", False), + ) def build_app() -> Application: modules = get_modules() - handles = bind_deployment_handles(modules) + services_config = load_services_config() + + qdrant = build_qdrant(services_config.get("qdrant", {})) + ollama = build_ollama(services_config.get("ollama", {})) - app: Application = HuRI.bind(modules, handles) # type: ignore[attr-defined] + handles = bind_deployment_handles(modules, ollama=ollama, qdrant=qdrant) + app: Application = HuRI.bind(modules, handles) return app diff --git a/src/modules/factory.py b/src/modules/factory.py index d252ab1..56b4d7f 100644 --- a/src/modules/factory.py +++ b/src/modules/factory.py @@ -63,15 +63,24 @@ def create_from_config( def bind_deployment_handles( modules: Dict[str, Type[Module]], + **service_handles, ) -> Dict[str, handle.DeploymentHandle]: handles: Dict[str, handle.DeploymentHandle] = {} for name, module_cls in modules.items(): if not issubclass(module_cls, ModuleWithHandle): continue - + if not hasattr(module_cls, "_handle_cls"): raise TypeError(f"{module_cls.__name__} must define _handle_cls") + handle_cls = module_cls._handle_cls - handles[name] = handle_cls.bind() + + if name == "rag" and service_handles: + handles[name] = handle_cls.bind( + ollama_handle=service_handles.get("ollama"), + qdrant_handle=service_handles.get("qdrant"), + ) + else: + handles[name] = handle_cls.bind() return handles diff --git a/src/modules/rag/docker_services.py b/src/modules/rag/docker_services.py new file mode 100644 index 0000000..54a6665 --- /dev/null +++ b/src/modules/rag/docker_services.py @@ -0,0 +1,220 @@ +import time +import socket +import subprocess + +import httpx +from ray import serve + + +def find_free_port() -> int: + """ + Ask the OS for a random free port. + We need this because if we run multiple Ollama containers, + they can't all use port 11434 — each needs its own. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def wait_for_service(url: str, timeout: int = 120) -> bool: + """ + Returns True if ready, False if timeout. + """ + start = time.time() + while time.time() - start < timeout: + try: + resp = httpx.get(url, timeout=5) + if resp.status_code == 200: + return True + except Exception: + pass + time.sleep(2) + return False + + +def is_container_running(name: str) -> bool: + """Check if a Docker container with this name is already running.""" + result = subprocess.run( + ["docker", "ps", "-q", "-f", f"name=^{name}$"], + capture_output=True, text=True, + ) + return bool(result.stdout.strip()) + + +def remove_container(name: str): + """Force remove a container by name (ignores errors if it doesn't exist).""" + subprocess.run(["docker", "rm", "-f", name], capture_output=True) + + +@serve.deployment +class OllamaService: + """ + Manages one Ollama Docker container. + + LIFECYCLE: + __init__: starts container -> waits for it -> pulls model + generate: sends a prompt to the container, returns the answer + __del__: stops and removes the container + """ + + def __init__( + self, + model: str = "mistral:7b", + image: str = "ollama/ollama:latest", + gpu_devices: bool = False, + ): + self.model = model + self.port = find_free_port() + self.container_name = f"ollama-ray-{self.port}" + self.base_url = f"http://localhost:{self.port}" + + remove_container(self.container_name) + + cmd = [ + "docker", "run", "-d", + "--name", self.container_name, + "-p", f"{self.port}:11434", + "-v", "ollama_shared:/root/.ollama", + ] + + if gpu_devices: + cmd.extend([ + "--device=/dev/kfd", + "--device=/dev/dri", + "--group-add=video", + ]) + + cmd.append(image) + + print(f"[OllamaService] Starting container '{self.container_name}' on port {self.port}...") + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Docker failed: {result.stderr}") + + print(f"[OllamaService] Waiting for Ollama to be ready...") + if not wait_for_service(f"{self.base_url}/api/tags"): + raise RuntimeError(f"Ollama didn't start within timeout on port {self.port}") + + print(f"[OllamaService] Pulling model '{model}'...") + pull_result = subprocess.run( + ["docker", "exec", self.container_name, "ollama", "pull", model], + capture_output=True, text=True, + ) + if pull_result.returncode != 0: + raise RuntimeError(f"Failed to pull model: {pull_result.stderr}") + + print(f"[OllamaService] Ready! container='{self.container_name}', port={self.port}, model='{model}'") + + + async def generate( + self, + messages: list, + max_tokens: int = 1024, + temperature: float = 0.1, + ) -> str: + """ + Send messages to Ollama and return the response. + This is what RAGHandle calls to get LLM answers. + """ + async with httpx.AsyncClient(timeout=60.0) as client: + resp = await client.post( + f"{self.base_url}/api/chat", + json={ + "model": self.model, + "messages": messages, + "stream": False, + "options": { + "num_predict": max_tokens, + "temperature": temperature, + }, + }, + ) + resp.raise_for_status() + return resp.json()["message"]["content"] + + async def health(self) -> dict: + """Check if this Ollama instance is alive.""" + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(f"{self.base_url}/api/tags") + return {"status": "ok", "port": self.port, "container": self.container_name} + except Exception as e: + return {"status": "error", "error": str(e)} + + def __del__(self): + """Cleanup when Ray destroys this replica.""" + print(f"[OllamaService] Removing container '{self.container_name}'") + remove_container(self.container_name) + + +@serve.deployment(num_replicas=1) +class QdrantService: + """ + Manages a Qdrant Docker container. + + LIFECYCLE: + __init__: starts container (or reuses if already running) + get_url: returns the URL other services should connect to + __del__: leaves the container running (it has data!) + """ + + def __init__( + self, + port: int = 6333, + image: str = "qdrant/qdrant:latest", + storage_volume: str = "qdrant_data", + ): + self.port = port + self.container_name = "qdrant-ray" + self.url = f"http://localhost:{self.port}" + + if self._is_healthy(): + print(f"[QdrantService] Qdrant already running on port {self.port}") + return + + remove_container(self.container_name) + + cmd = [ + "docker", "run", "-d", + "--name", self.container_name, + "-p", f"{self.port}:6333", + "-v", f"{storage_volume}:/qdrant/storage", + image, + ] + + print(f"[QdrantService] Starting Qdrant on port {self.port}...") + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"Docker failed: {result.stderr}") + + if not wait_for_service(f"{self.url}/healthz"): + raise RuntimeError(f"Qdrant didn't start within timeout on port {self.port}") + + print(f"[QdrantService] Ready on port {self.port}") + + + def _is_healthy(self) -> bool: + try: + resp = httpx.get(f"{self.url}/healthz", timeout=3) + return resp.status_code == 200 + except Exception: + return False + + + async def get_url(self) -> str: + """Return the URL. Called by RAGHandle to know where Qdrant is.""" + return self.url + + + async def health(self) -> dict: + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(f"{self.url}/healthz") + return {"status": "ok", "port": self.port, "url": self.url} + except Exception as e: + return {"status": "error", "error": str(e)} + + + def __del__(self): + print(f"[QdrantService] Actor destroyed. Container '{self.container_name}' left running.") diff --git a/src/modules/rag/ingestion.py b/src/modules/rag/ingestion.py index 5c458d1..9bdf285 100644 --- a/src/modules/rag/ingestion.py +++ b/src/modules/rag/ingestion.py @@ -1,72 +1,406 @@ -# ingestion.py +import re import argparse import os +import sys import uuid +from pathlib import Path +from datetime import datetime +from pypdf import PdfReader from qdrant_client import QdrantClient -from qdrant_client.models import VectorParams, Distance, PointStruct +from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchValue from sentence_transformers import SentenceTransformer +from semantic_chunker import SemanticChunker USER_ID_FILE = os.path.expanduser("~/.huri_user_id") +def _split_sentences(text: str) -> list[str]: + """Simple sentence splitter.""" + sentences = re.split(r'(?<=[.!?])\s+', text) + + result = [] + for s in sentences: + parts = s.split("\n\n") + result.extend(parts) + return [s.strip() for s in result if s.strip()] + + +def chunk_text(text: str, chunk_size: int = 500, overlap: int = 50) -> list[str]: + """ + Fallback: fixed-size chunking by sentences. + Used when --chunking=fixed. + """ + sentences = _split_sentences(text) + chunks = [] + current_chunk = [] + current_length = 0 + + for sentence in sentences: + sentence = sentence.strip() + if not sentence: + continue + + sentence_length = len(sentence.split()) + + if current_length + sentence_length > chunk_size and current_chunk: + chunks.append(" ".join(current_chunk)) + + overlap_words = 0 + overlap_sentences = [] + for s in reversed(current_chunk): + overlap_words += len(s.split()) + overlap_sentences.insert(0, s) + if overlap_words >= overlap: + break + + current_chunk = overlap_sentences + current_length = overlap_words + + current_chunk.append(sentence) + current_length += sentence_length + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + return chunks + + +def extract_text_from_pdf(pdf_path: str) -> str: + """Extract text from a PDF file.""" + try: + reader = PdfReader(pdf_path) + text = "" + for page in reader.pages: + text += page.extract_text() + "\n" + return text.strip() + except ImportError: + pass + + print("ERROR: Install a PDF library: pip install pymupdf OR pip install pypdf") + sys.exit(1) + def get_user_id(provided_id: str = None) -> str: - """Use provided ID, or load from file, or generate new one.""" if provided_id: return provided_id if os.path.exists(USER_ID_FILE): with open(USER_ID_FILE) as f: - return f.read().strip() + uid = f.read().strip() + if uid: + return uid new_id = str(uuid.uuid4()) with open(USER_ID_FILE, "w") as f: f.write(new_id) + print(f"Generated new user_id: {new_id}") return new_id -def main(): - parser = argparse.ArgumentParser(description="Ingest documents into Qdrant") - parser.add_argument("--user-id", type=str, default=None, help="User ID (reads from ~/.huri_user_id if not provided)") - parser.add_argument("--collection", type=str, default="documents") - parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333") - args = parser.parse_args() - - user_id = get_user_id(args.user_id) - print(f"Ingesting for user_id: {user_id}") - - client = QdrantClient(url=args.qdrant_url) - model = SentenceTransformer("BAAI/bge-large-en-v1.5") - +def ensure_collection(client: QdrantClient, collection: str, vector_size: int): collections = [c.name for c in client.get_collections().collections] - if args.collection not in collections: + if collection not in collections: client.create_collection( - collection_name=args.collection, - vectors_config=VectorParams(size=1024, distance=Distance.COSINE), + collection_name=collection, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) - print(f"Created collection: {args.collection}") + print(f"Created collection: {collection}") - docs = [ - {"text": "The company budget for 2026 is 2 million euros.", "source": "budget.pdf"}, - {"text": "The project deadline is June 15th 2026.", "source": "planning.pdf"}, - {"text": "The team consists of 5 developers and 2 designers.", "source": "team.pdf"}, - {"text": "The main office is located in Paris, France.", "source": "info.pdf"}, - ] +def ingest_chunks( + client: QdrantClient, + model: SentenceTransformer, + collection: str, + chunks: list[str], + _user_id: str, + source: str, + doc_type: str = "document", +): + """Embed chunks and upsert into Qdrant.""" points = [] - for doc in docs: - vector = model.encode(doc["text"], normalize_embeddings=True).tolist() + timestamp = datetime.now().isoformat() + + for i, chunk in enumerate(chunks): + vector = model.encode(chunk, normalize_embeddings=True).tolist() points.append(PointStruct( id=str(uuid.uuid4()), vector=vector, payload={ - "text": doc["text"], - "source": doc["source"], - "user_id": user_id, + "text": chunk, + "_user_id": _user_id, + "source": source, + "type": doc_type, + "chunk_index": i, + "timestamp": timestamp, }, )) - client.upsert(collection_name=args.collection, points=points) - print(f"Ingested {len(points)} documents for user {user_id}") + if points: + # Upsert in batches of 100 + batch_size = 100 + for i in range(0, len(points), batch_size): + batch = points[i:i + batch_size] + client.upsert(collection_name=collection, points=batch) + + return len(points) + + +def chunk_strat(text: str, args, model: SentenceTransformer) -> list[str]: + """Pick the right chunking strategy based on args.""" + if args.chunking == "semantic": + chunker = SemanticChunker( + model=model, + strategy=args.semantic_strategy, + ) + return chunker.chunk(text) + else: + return chunk_text(text, chunk_size=args.chunk_size, overlap=args.overlap) + + +def cmd_pdf(args, client, model, _user_id): + """Ingest PDF files.""" + files = [] + for path in args.files: + p = Path(path) + if p.is_dir(): + files.extend(p.glob("**/*.pdf")) + elif p.suffix.lower() == ".pdf": + files.append(p) + else: + print(f"Skipping non-PDF: {path}") + + if not files: + print("No PDF files found.") + return + + sample = model.encode("test", normalize_embeddings=True) + ensure_collection(client, args.collection, len(sample)) + + total = 0 + for pdf_path in files: + print(f"\nProcessing: {pdf_path}") + text = extract_text_from_pdf(str(pdf_path)) + + if not text.strip(): + print(f" WARNING: No text extracted from {pdf_path}") + continue + + chunks = chunk_strat(text, args, model) + count = ingest_chunks( + client, model, args.collection, chunks, + _user_id, source=pdf_path.name, doc_type="pdf", + ) + print(f" -> {count} chunks ingested") + total += count + + print(f"\nDone. Total: {total} chunks from {len(files)} PDF(s)") + + +def cmd_text(args, client, model, _user_id): + """Ingest text files.""" + sample = model.encode("test", normalize_embeddings=True) + ensure_collection(client, args.collection, len(sample)) + + total = 0 + for file_path in args.files: + p = Path(file_path) + if not p.exists(): + print(f"File not found: {file_path}") + continue + + print(f"\nProcessing: {file_path}") + text = p.read_text(encoding="utf-8") + + if not text.strip(): + print(f" WARNING: File is empty: {file_path}") + continue + + chunks = chunk_strat(text, args, model) + count = ingest_chunks( + client, model, args.collection, chunks, + _user_id, source=p.name, doc_type="text", + ) + print(f" -> {count} chunks ingested") + total += count + + print(f"\nDone. Total: {total} chunks from {len(args.files)} file(s)") + + +def cmd_write(args, client, model, _user_id): + """Write text interactively and ingest it.""" + title = args.title or f"note_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + + print(f"Write your text below (title: '{title}')") + print("Press Ctrl+D (Linux/Mac) or Ctrl+Z then Enter (Windows) when done.") + print("-" * 40) + + lines = [] + try: + while True: + line = input() + lines.append(line) + except EOFError: + pass + + text = "\n".join(lines).strip() + + if not text: + print("Nothing to ingest.") + return + + print(f"\n{'-' * 40}") + print(f"Received {len(text)} characters") + + sample = model.encode("test", normalize_embeddings=True) + ensure_collection(client, args.collection, len(sample)) + + chunks = chunk_strat(text, args, model) + count = ingest_chunks( + client, model, args.collection, chunks, + _user_id, source=title, doc_type="manual", + ) + + print(f"Done. Ingested {count} chunks as '{title}'") + + +def cmd_list(args, client, model, _user_id): + """List what's in the database for this user.""" + + try: + info = client.get_collection(args.collection) + print(f"Collection: {args.collection}") + print(f"Total points: {info.points_count}") + except Exception: + print(f"Collection '{args.collection}' doesn't exist.") + return + + results = client.scroll( + collection_name=args.collection, + scroll_filter=Filter(must=[ + FieldCondition(key="_user_id", match=MatchValue(value=_user_id)), + ]), + limit=100, + with_payload=True, + with_vectors=False, + ) + + points = results[0] + if not points: + print(f"No documents found for user {_user_id}") + return + + sources = {} + for p in points: + source = p.payload.get("source", "unknown") + doc_type = p.payload.get("type", "unknown") + if source not in sources: + sources[source] = {"count": 0, "type": doc_type} + sources[source]["count"] += 1 + + print(f"\nDocuments for user {_user_id}:") + print(f"{'Source':<40} {'Type':<10} {'Chunks':<8}") + print("-" * 60) + for source, info in sorted(sources.items()): + print(f"{source:<40} {info['type']:<10} {info['count']:<8}") + print(f"\nTotal: {len(points)} chunks across {len(sources)} sources") + + +def cmd_delete(args, client, model, _user_id): + """Delete documents by source name.""" + + if not args.source: + print("Specify --source to delete. Use 'list' command to see sources.") + return + + filter_conditions = [ + FieldCondition(key="_user_id", match=MatchValue(value=_user_id)), + FieldCondition(key="source", match=MatchValue(value=args.source)), + ] + + client.delete( + collection_name=args.collection, + points_selector=Filter(must=filter_conditions), + ) + print(f"Deleted all chunks from source '{args.source}' for user {_user_id}") + + + +def main(): + parser = argparse.ArgumentParser(description="HuRI RAG Ingestion Tool") + parser.add_argument("--user-id", type=str, default=None) + parser.add_argument("--collection", type=str, default="documents") + parser.add_argument("--qdrant-url", type=str, default="http://localhost:6333") + parser.add_argument("--embedding-model", type=str, default="BAAI/bge-large-en-v1.5") + parser.add_argument("--chunk-size", type=int, default=500, help="Target chunk size in words (fixed mode)") + parser.add_argument("--overlap", type=int, default=50, help="Overlap between chunks in words (fixed mode)") + parser.add_argument("--chunking", type=str, default="fixed", + choices=["semantic", "fixed"], + help="Chunking strategy: 'semantic' (default) or 'fixed'") + parser.add_argument("--semantic-strategy", type=str, default="percentile", + choices=["percentile", "threshold", "stddev"], + help="Semantic chunking strategy (default: percentile)") + + subparsers = parser.add_subparsers(dest="command", required=True) + + p_pdf = subparsers.add_parser("pdf", help="Ingest PDF files") + p_pdf.add_argument("files", nargs="+", help="PDF files or directories") + + p_text = subparsers.add_parser("text", help="Ingest text files (.txt, .md)") + p_text.add_argument("files", nargs="+", help="Text files") + + p_write = subparsers.add_parser("write", help="Write text interactively") + p_write.add_argument("--title", type=str, default=None, help="Title/source name") + + p_list = subparsers.add_parser("list", help="List ingested documents") + + p_delete = subparsers.add_parser("delete", help="Delete documents by source") + p_delete.add_argument("--source", type=str, required=True, help="Source name to delete") + + args = parser.parse_args() + + _user_id = get_user_id(args._user_id) + print(f"User: {_user_id}") + + client = QdrantClient(url=args.qdrant_url) + model = SentenceTransformer(args.embedding_model) + + commands = { + "pdf": cmd_pdf, + "text": cmd_text, + "write": cmd_write, + "list": cmd_list, + "delete": cmd_delete, + } + commands[args.command](args, client, model, _user_id) if __name__ == "__main__": - main() \ No newline at end of file + """ + Ingestion tool for HuRI RAG. + + Usage: + # Ingest a PDF + python ingestion.py pdf report.pdf + + # Ingest multiple PDFs + python ingestion.py pdf doc1.pdf doc2.pdf doc3.pdf + + # Ingest a whole folder of PDFs + # TODO: To verify and to add the support of hole paths + python ingestion.py pdf ./my_documents/ + + # Write text interactively (type, then Ctrl+D to save) + python ingestion.py write --title "My meeting notes" + + # Ingest a text file + python ingestion.py text notes.txt story.md + + # Specify a user ID (otherwise reads from ~/.huri_user_id) + python ingestion.py --user-id "abc-123" pdf report.pdf + + # Use a different collection + python ingestion.py --collection "my_docs" pdf report.pdf + + # Use a different ingestion strategy + python src/modules/rag/ingestion.py --chunking semantic --semantic-strategy threshold pdf "EN.pdf" + + """ + main() diff --git a/src/modules/rag/rag.py b/src/modules/rag/rag.py index ac594d8..3bfc3fe 100644 --- a/src/modules/rag/rag.py +++ b/src/modules/rag/rag.py @@ -40,6 +40,8 @@ class RAGHandle: def __init__( self, + ollama_handle=None, + qdrant_handle=None, qdrant_url: str = "http://localhost:6333", default_collection: str = "documents", embedding_model: str = "BAAI/bge-large-en-v1.5", @@ -51,15 +53,31 @@ def __init__( score_threshold: float = 0.5, ): self.embed_model = SentenceTransformer(embedding_model) - self.qdrant = QdrantClient(url=qdrant_url) self.default_collection = default_collection self.top_k = top_k self.score_threshold = score_threshold - + self.llm_provider = llm_provider self.llm_url = llm_url self.llm_model = llm_model self.llm_api_key = llm_api_key + + self.ollama_handle = ollama_handle + self.qdrant_handle = qdrant_handle + + self._qdrant_url = qdrant_url + self._qdrant = None + + + async def _get_qdrant(self): + """Connect to Qdrant on first use. Solves the async-in-init problem.""" + if self._qdrant is None: + if self.qdrant_handle: + self._qdrant_url = await self.qdrant_handle.get_url.remote() + self._qdrant = QdrantClient(url=self._qdrant_url) + print(f"[RAGHandle] Connected to Qdrant at {self._qdrant_url}") + return self._qdrant + def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: """ @@ -72,11 +90,6 @@ def _resolve_user_context(self, _user_id: str) -> tuple[str, dict | None]: C) Lookup in a DB to find the user's config """ - # Option A: separate collection per user - # collection = f"user_{_user_id}" - # filters = None - - # Option B: shared collection with _user_id filter (recommended) collection = self.default_collection filters = {"_user_id": _user_id} @@ -87,9 +100,9 @@ def _embed(self, text) -> list[float]: return self.embed_model.encode(str(text), normalize_embeddings=True).tolist() - def _search( self, + qdrant, query_vector: list[float], collection: str, filters: dict | None = None, @@ -103,14 +116,16 @@ def _search( ] qdrant_filter = Filter(must=conditions) - results = self.qdrant.query_points( - collection_name=collection, - query=query_vector, - query_filter=qdrant_filter, - limit=self.top_k, - score_threshold=self.score_threshold, - ).points - + try: + results = qdrant.query_points( + collection_name=collection, + query=query_vector, + query_filter=qdrant_filter, + limit=self.top_k, + score_threshold=self.score_threshold, + ).points + except Exception: + results = [] return [ { "text": point.payload.get("text", ""), @@ -178,6 +193,9 @@ async def _llm_generate( {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] + + if self.ollama_handle: + return await self.ollama_handle.generate.remote(messages, max_tokens) if self.llm_provider == "vllm": return await self._call_openai_compatible( @@ -185,6 +203,7 @@ async def _llm_generate( ) elif self.llm_provider == "ollama": return await self._call_ollama(messages, max_tokens) + elif self.llm_provider == "api": return await self._call_openai_compatible( f"{self.llm_url}/v1/chat/completions", messages, max_tokens, self.llm_api_key @@ -229,9 +248,12 @@ async def process(self, query: RAGQuery) -> RAGResult: """ print(f"[RAG] Question: {query.question}") + + qdrant = await self._get_qdrant() + collection, filters = self._resolve_user_context(query._user_id) query_vector = self._embed(query.question) - chunks = self._search(query_vector, collection, filters) + chunks = self._search(qdrant, query_vector, collection, filters) print(f"[RAG] Found {len(chunks)} chunks") diff --git a/src/modules/rag/semantic_chunker.py b/src/modules/rag/semantic_chunker.py new file mode 100644 index 0000000..1ddbead --- /dev/null +++ b/src/modules/rag/semantic_chunker.py @@ -0,0 +1,236 @@ +""" +Semantic Chunking for RAG. + +Three strategies: + 1. percentile - cut where similarity is below the Nth percentile (default) + 2. threshold - cut where similarity drops below a fixed value + 3. stddev - cut where similarity is more than N std devs below the mean + +Usage: + from semantic_chunker import SemanticChunker + + chunker = SemanticChunker(embedding_model) + chunks = chunker.chunk(text) +""" + +import re +import numpy as np +from dataclasses import dataclass, field +from sentence_transformers import SentenceTransformer + + +@dataclass +class Chunk: + text: str + sentences: list[str] = field(default_factory=list) + start_idx: int = 0 + end_idx: int = 0 + + +class SemanticChunker: + def __init__( + self, + model: SentenceTransformer, + strategy: str = "percentile", # "percentile", "threshold", "stddev" + percentile_cutoff: float = 25, # for percentile strategy + threshold_cutoff: float = 0.5, # for threshold strategy + stddev_cutoff: float = 1.0, # for stddev strategy (N std devs below mean) + min_chunk_size: int = 2, # minimum sentences per chunk + max_chunk_size: int = 50, # maximum sentences per chunk + buffer_size: int = 1, # sentences to look around for context + ): + self.model = model + self.strategy = strategy + self.percentile_cutoff = percentile_cutoff + self.threshold_cutoff = threshold_cutoff + self.stddev_cutoff = stddev_cutoff + self.min_chunk_size = min_chunk_size + self.max_chunk_size = max_chunk_size + self.buffer_size = buffer_size + + def chunk(self, text: str) -> list[str]: + """Main entry point. Returns list of chunk texts.""" + sentences = self._split_sentences(text) + + if len(sentences) <= self.min_chunk_size: + return [text.strip()] if text.strip() else [] + + combined = self._combine_with_buffer(sentences) + embeddings = self.model.encode(combined, normalize_embeddings=True) + similarities = self._calculate_similarities(embeddings) + breakpoints = self._find_breakpoints(similarities) + chunks = self._create_chunks(sentences, breakpoints) + + return chunks + + def chunk_detailed(self, text: str) -> list[Chunk]: + """Returns detailed Chunk objects with metadata.""" + sentences = self._split_sentences(text) + + if len(sentences) <= self.min_chunk_size: + return [Chunk(text=text.strip(), sentences=sentences, start_idx=0, end_idx=len(sentences))] + + combined = self._combine_with_buffer(sentences) + embeddings = self.model.encode(combined, normalize_embeddings=True) + similarities = self._calculate_similarities(embeddings) + breakpoints = self._find_breakpoints(similarities) + + chunks = [] + start = 0 + for bp in breakpoints: + end = bp + 1 + chunk_sentences = sentences[start:end] + chunks.append(Chunk( + text=" ".join(chunk_sentences), + sentences=chunk_sentences, + start_idx=start, + end_idx=end, + )) + start = end + + if start < len(sentences): + chunk_sentences = sentences[start:] + chunks.append(Chunk( + text=" ".join(chunk_sentences), + sentences=chunk_sentences, + start_idx=start, + end_idx=len(sentences), + )) + + return chunks + + + def _split_sentences(self, text: str) -> list[str]: + """Split text into sentences, respecting paragraph boundaries.""" + paragraphs = text.split("\n\n") + sentences = [] + for para in paragraphs: + para = para.strip() + if not para: + continue + parts = re.split(r'(?<=[.!?])\s+', para) + for part in parts: + part = part.strip() + if part: + sentences.append(part) + return sentences + + def _combine_with_buffer(self, sentences: list[str]) -> list[str]: + """ + Combine each sentence with its neighbors for richer embeddings. + Sentence at index i gets combined with sentences [i-buffer, i+buffer]. + This gives the embedding model more context to understand each sentence. + """ + combined = [] + for i in range(len(sentences)): + start = max(0, i - self.buffer_size) + end = min(len(sentences), i + self.buffer_size + 1) + window = " ".join(sentences[start:end]) + combined.append(window) + return combined + + def _calculate_similarities(self, embeddings: np.ndarray) -> list[float]: + """Calculate cosine similarity between consecutive sentence embeddings.""" + similarities = [] + for i in range(len(embeddings) - 1): + sim = np.dot(embeddings[i], embeddings[i + 1]) + similarities.append(float(sim)) + return similarities + + def _find_breakpoints(self, similarities: list[float]) -> list[int]: + """Find where to split based on the chosen strategy.""" + if not similarities: + return [] + + sims = np.array(similarities) + + if self.strategy == "percentile": + cutoff = np.percentile(sims, self.percentile_cutoff) + candidate_indices = [i for i, s in enumerate(similarities) if s < cutoff] + + elif self.strategy == "threshold": + candidate_indices = [i for i, s in enumerate(similarities) if s < self.threshold_cutoff] + + elif self.strategy == "stddev": + mean = np.mean(sims) + std = np.std(sims) + cutoff = mean - (self.stddev_cutoff * std) + candidate_indices = [i for i, s in enumerate(similarities) if s < cutoff] + + else: + raise ValueError(f"Unknown strategy: {self.strategy}") + + breakpoints = self._enforce_chunk_sizes(candidate_indices, len(similarities) + 1) + + return breakpoints + + def _enforce_chunk_sizes(self, candidates: list[int], num_sentences: int) -> list[int]: + """Ensure chunks respect min and max size constraints.""" + if not candidates: + breakpoints = [] + pos = self.max_chunk_size - 1 + while pos < num_sentences - 1: + breakpoints.append(pos) + pos += self.max_chunk_size + return breakpoints + + breakpoints = [] + last_break = -1 + + for candidate in sorted(candidates): + chunk_size = candidate - last_break + + if chunk_size < self.min_chunk_size: + continue + + if chunk_size > self.max_chunk_size: + pos = last_break + self.max_chunk_size + while pos < candidate: + breakpoints.append(pos) + last_break = pos + pos += self.max_chunk_size + + breakpoints.append(candidate) + last_break = candidate + + remaining = num_sentences - 1 - last_break + if remaining > self.max_chunk_size: + pos = last_break + self.max_chunk_size + while pos < num_sentences - 1: + breakpoints.append(pos) + pos += self.max_chunk_size + + return breakpoints + + def _create_chunks(self, sentences: list[str], breakpoints: list[int]) -> list[str]: + """Group sentences into chunks based on breakpoints.""" + chunks = [] + start = 0 + + for bp in breakpoints: + end = bp + 1 + chunk_text = " ".join(sentences[start:end]).strip() + if chunk_text: + chunks.append(chunk_text) + start = end + + if start < len(sentences): + chunk_text = " ".join(sentences[start:]).strip() + if chunk_text: + chunks.append(chunk_text) + + return chunks + + + +def create_chunker( + model: SentenceTransformer = None, + model_name: str = "BAAI/bge-large-en-v1.5", + strategy: str = "percentile", + **kwargs, +) -> SemanticChunker: + """Create a chunker with defaults.""" + if model is None: + model = SentenceTransformer(model_name) + return SemanticChunker(model=model, strategy=strategy, **kwargs) +