""" RAG Store SQLite-based vector store for document retrieval. """ import sqlite3 import json import uuid from typing import List, Dict, Any, Optional, Tuple from pathlib import Path from datetime import datetime import numpy as np from loguru import logger from config import get_db_path, load_config_from_db, settings class RAGStore: """ SQLite-based RAG store with vector similarity search. Features: - Document storage and chunking - Vector embeddings via Ollama - Cosine similarity search - Document management (add, delete, list) """ def __init__(self): self.db_path = get_db_path() self._init_db() logger.info(f"RAG Store initialized at {self.db_path}") def _init_db(self) -> None: """Initialize the database schema.""" conn = sqlite3.connect(str(self.db_path)) cursor = conn.cursor() # Documents table cursor.execute(""" CREATE TABLE IF NOT EXISTS documents ( id TEXT PRIMARY KEY, filename TEXT NOT NULL, file_type TEXT, content_hash TEXT, created_at TEXT, metadata TEXT ) """) # Chunks table cursor.execute(""" CREATE TABLE IF NOT EXISTS chunks ( id TEXT PRIMARY KEY, document_id TEXT NOT NULL, content TEXT NOT NULL, chunk_index INTEGER, embedding BLOB, created_at TEXT, FOREIGN KEY (document_id) REFERENCES documents(id) ) """) # Create index for faster searches cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_chunks_document_id ON chunks(document_id) """) conn.commit() conn.close() async def add_document( self, filename: str, content: bytes, file_type: str, chunk_size: int = 500, overlap: int = 50 ) -> str: """ Add a document to the store. Returns the document ID. """ # Generate document ID doc_id = str(uuid.uuid4()) # Extract text based on file type text = self._extract_text(content, file_type) if not text.strip(): raise ValueError("No text content extracted from document") # Chunk the text chunks = self._chunk_text(text, chunk_size, overlap) # Insert document conn = sqlite3.connect(str(self.db_path)) cursor = conn.cursor() cursor.execute(""" INSERT INTO documents (id, filename, file_type, created_at, metadata) VALUES (?, ?, ?, ?, ?) """, ( doc_id, filename, file_type, datetime.now().isoformat(), json.dumps({"chunk_size": chunk_size, "overlap": overlap}) )) # Insert chunks with embeddings for i, chunk in enumerate(chunks): chunk_id = str(uuid.uuid4()) # Generate embedding embedding = await self.generate_embedding(chunk) embedding_blob = np.array(embedding, dtype=np.float32).tobytes() cursor.execute(""" INSERT INTO chunks (id, document_id, content, chunk_index, embedding, created_at) VALUES (?, ?, ?, ?, ?, ?) """, ( chunk_id, doc_id, chunk, i, embedding_blob, datetime.now().isoformat() )) conn.commit() conn.close() logger.info(f"Added document: {filename} ({len(chunks)} chunks)") return doc_id def _extract_text(self, content: bytes, file_type: str) -> str: """Extract text from various file types.""" text = "" try: if file_type in [".txt", ".md", ".text"]: text = content.decode("utf-8", errors="ignore") elif file_type == ".pdf": try: import io from pypdf import PdfReader reader = PdfReader(io.BytesIO(content)) for page in reader.pages: text += page.extract_text() + "\n" except ImportError: logger.warning("pypdf not installed, cannot extract PDF text") text = "[PDF content - pypdf not installed]" elif file_type == ".docx": try: import io from docx import Document doc = Document(io.BytesIO(content)) for para in doc.paragraphs: text += para.text + "\n" except ImportError: logger.warning("python-docx not installed, cannot extract DOCX text") text = "[DOCX content - python-docx not installed]" elif file_type in [".html", ".htm"]: from bs4 import BeautifulSoup soup = BeautifulSoup(content, "html.parser") text = soup.get_text(separator="\n") else: # Try as plain text text = content.decode("utf-8", errors="ignore") except Exception as e: logger.error(f"Failed to extract text: {e}") text = "" return text def _chunk_text( self, text: str, chunk_size: int, overlap: int ) -> List[str]: """Split text into overlapping chunks.""" words = text.split() chunks = [] if len(words) <= chunk_size: return [text] start = 0 while start < len(words): end = start + chunk_size chunk = " ".join(words[start:end]) chunks.append(chunk) start = end - overlap return chunks async def generate_embedding(self, text: str) -> List[float]: """Generate embedding using Ollama.""" import ollama config = load_config_from_db() ollama_host = config.get("ollama_host", settings.ollama_host) embedding_model = config.get("embedding_model", settings.embedding_model) client = ollama.Client(host=ollama_host) try: response = client.embeddings( model=embedding_model, prompt=text ) return response.get("embedding", []) except Exception as e: logger.error(f"Failed to generate embedding: {e}") # Return zero vector as fallback return [0.0] * 768 # Common embedding size async def search( self, query: str, top_k: int = 5 ) -> List[Dict[str, Any]]: """ Search for relevant chunks. Returns list of results with content, document name, and score. """ # Generate query embedding query_embedding = await self.generate_embedding(query) query_vector = np.array(query_embedding, dtype=np.float32) conn = sqlite3.connect(str(self.db_path)) cursor = conn.cursor() # Get all chunks with embeddings cursor.execute(""" SELECT c.id, c.content, c.document_id, c.embedding, d.filename FROM chunks c JOIN documents d ON c.document_id = d.id """) results = [] for row in cursor.fetchall(): chunk_id, content, doc_id, embedding_blob, filename = row if embedding_blob: # Convert blob to numpy array chunk_vector = np.frombuffer(embedding_blob, dtype=np.float32) # Calculate cosine similarity similarity = self._cosine_similarity(query_vector, chunk_vector) results.append({ "chunk_id": chunk_id, "content": content, "document_id": doc_id, "document_name": filename, "score": float(similarity) }) conn.close() # Sort by score and return top_k results.sort(key=lambda x: x["score"], reverse=True) return results[:top_k] def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float: """Calculate cosine similarity between two vectors.""" if len(a) != len(b): return 0.0 norm_a = np.linalg.norm(a) norm_b = np.linalg.norm(b) if norm_a == 0 or norm_b == 0: return 0.0 return float(np.dot(a, b) / (norm_a * norm_b)) def delete_document(self, doc_id: str) -> None: """Delete a document and all its chunks.""" conn = sqlite3.connect(str(self.db_path)) cursor = conn.cursor() # Delete chunks first cursor.execute("DELETE FROM chunks WHERE document_id = ?", (doc_id,)) # Delete document cursor.execute("DELETE FROM documents WHERE id = ?", (doc_id,)) conn.commit() conn.close() logger.info(f"Deleted document: {doc_id}") def list_documents(self) -> List[Dict[str, Any]]: """List all documents.""" conn = sqlite3.connect(str(self.db_path)) cursor = conn.cursor() cursor.execute(""" SELECT d.id, d.filename, d.file_type, d.created_at, COUNT(c.id) as chunk_count FROM documents d LEFT JOIN chunks c ON d.id = c.document_id GROUP BY d.id ORDER BY d.created_at DESC """) documents = [] for row in cursor.fetchall(): documents.append({ "id": row[0], "filename": row[1], "file_type": row[2], "created_at": row[3], "chunk_count": row[4] }) conn.close() return documents def get_document_count(self) -> int: """Get total number of documents.""" conn = sqlite3.connect(str(self.db_path)) cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM documents") count = cursor.fetchone()[0] conn.close() return count def get_chunk_count(self) -> int: """Get total number of chunks.""" conn = sqlite3.connect(str(self.db_path)) cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM chunks") count = cursor.fetchone()[0] conn.close() return count