test/moxie/rag/store.py
2026-03-24 04:07:54 +00:00

355 lines
11 KiB
Python
Executable File

"""
RAG Store
SQLite-based vector store for document retrieval.
"""
import sqlite3
import json
import uuid
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
from datetime import datetime
import numpy as np
from loguru import logger
from config import get_db_path, load_config_from_db, settings
class RAGStore:
"""
SQLite-based RAG store with vector similarity search.
Features:
- Document storage and chunking
- Vector embeddings via Ollama
- Cosine similarity search
- Document management (add, delete, list)
"""
def __init__(self):
self.db_path = get_db_path()
self._init_db()
logger.info(f"RAG Store initialized at {self.db_path}")
def _init_db(self) -> None:
"""Initialize the database schema."""
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
# Documents table
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
id TEXT PRIMARY KEY,
filename TEXT NOT NULL,
file_type TEXT,
content_hash TEXT,
created_at TEXT,
metadata TEXT
)
""")
# Chunks table
cursor.execute("""
CREATE TABLE IF NOT EXISTS chunks (
id TEXT PRIMARY KEY,
document_id TEXT NOT NULL,
content TEXT NOT NULL,
chunk_index INTEGER,
embedding BLOB,
created_at TEXT,
FOREIGN KEY (document_id) REFERENCES documents(id)
)
""")
# Create index for faster searches
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_chunks_document_id
ON chunks(document_id)
""")
conn.commit()
conn.close()
async def add_document(
self,
filename: str,
content: bytes,
file_type: str,
chunk_size: int = 500,
overlap: int = 50
) -> str:
"""
Add a document to the store.
Returns the document ID.
"""
# Generate document ID
doc_id = str(uuid.uuid4())
# Extract text based on file type
text = self._extract_text(content, file_type)
if not text.strip():
raise ValueError("No text content extracted from document")
# Chunk the text
chunks = self._chunk_text(text, chunk_size, overlap)
# Insert document
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute("""
INSERT INTO documents (id, filename, file_type, created_at, metadata)
VALUES (?, ?, ?, ?, ?)
""", (
doc_id,
filename,
file_type,
datetime.now().isoformat(),
json.dumps({"chunk_size": chunk_size, "overlap": overlap})
))
# Insert chunks with embeddings
for i, chunk in enumerate(chunks):
chunk_id = str(uuid.uuid4())
# Generate embedding
embedding = await self.generate_embedding(chunk)
embedding_blob = np.array(embedding, dtype=np.float32).tobytes()
cursor.execute("""
INSERT INTO chunks (id, document_id, content, chunk_index, embedding, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""", (
chunk_id,
doc_id,
chunk,
i,
embedding_blob,
datetime.now().isoformat()
))
conn.commit()
conn.close()
logger.info(f"Added document: {filename} ({len(chunks)} chunks)")
return doc_id
def _extract_text(self, content: bytes, file_type: str) -> str:
"""Extract text from various file types."""
text = ""
try:
if file_type in [".txt", ".md", ".text"]:
text = content.decode("utf-8", errors="ignore")
elif file_type == ".pdf":
try:
import io
from pypdf import PdfReader
reader = PdfReader(io.BytesIO(content))
for page in reader.pages:
text += page.extract_text() + "\n"
except ImportError:
logger.warning("pypdf not installed, cannot extract PDF text")
text = "[PDF content - pypdf not installed]"
elif file_type == ".docx":
try:
import io
from docx import Document
doc = Document(io.BytesIO(content))
for para in doc.paragraphs:
text += para.text + "\n"
except ImportError:
logger.warning("python-docx not installed, cannot extract DOCX text")
text = "[DOCX content - python-docx not installed]"
elif file_type in [".html", ".htm"]:
from bs4 import BeautifulSoup
soup = BeautifulSoup(content, "html.parser")
text = soup.get_text(separator="\n")
else:
# Try as plain text
text = content.decode("utf-8", errors="ignore")
except Exception as e:
logger.error(f"Failed to extract text: {e}")
text = ""
return text
def _chunk_text(
self,
text: str,
chunk_size: int,
overlap: int
) -> List[str]:
"""Split text into overlapping chunks."""
words = text.split()
chunks = []
if len(words) <= chunk_size:
return [text]
start = 0
while start < len(words):
end = start + chunk_size
chunk = " ".join(words[start:end])
chunks.append(chunk)
start = end - overlap
return chunks
async def generate_embedding(self, text: str) -> List[float]:
"""Generate embedding using Ollama."""
import ollama
config = load_config_from_db()
ollama_host = config.get("ollama_host", settings.ollama_host)
embedding_model = config.get("embedding_model", settings.embedding_model)
client = ollama.Client(host=ollama_host)
try:
response = client.embeddings(
model=embedding_model,
prompt=text
)
return response.get("embedding", [])
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
# Return zero vector as fallback
return [0.0] * 768 # Common embedding size
async def search(
self,
query: str,
top_k: int = 5
) -> List[Dict[str, Any]]:
"""
Search for relevant chunks.
Returns list of results with content, document name, and score.
"""
# Generate query embedding
query_embedding = await self.generate_embedding(query)
query_vector = np.array(query_embedding, dtype=np.float32)
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
# Get all chunks with embeddings
cursor.execute("""
SELECT c.id, c.content, c.document_id, c.embedding, d.filename
FROM chunks c
JOIN documents d ON c.document_id = d.id
""")
results = []
for row in cursor.fetchall():
chunk_id, content, doc_id, embedding_blob, filename = row
if embedding_blob:
# Convert blob to numpy array
chunk_vector = np.frombuffer(embedding_blob, dtype=np.float32)
# Calculate cosine similarity
similarity = self._cosine_similarity(query_vector, chunk_vector)
results.append({
"chunk_id": chunk_id,
"content": content,
"document_id": doc_id,
"document_name": filename,
"score": float(similarity)
})
conn.close()
# Sort by score and return top_k
results.sort(key=lambda x: x["score"], reverse=True)
return results[:top_k]
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
"""Calculate cosine similarity between two vectors."""
if len(a) != len(b):
return 0.0
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
def delete_document(self, doc_id: str) -> None:
"""Delete a document and all its chunks."""
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
# Delete chunks first
cursor.execute("DELETE FROM chunks WHERE document_id = ?", (doc_id,))
# Delete document
cursor.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
conn.commit()
conn.close()
logger.info(f"Deleted document: {doc_id}")
def list_documents(self) -> List[Dict[str, Any]]:
"""List all documents."""
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute("""
SELECT d.id, d.filename, d.file_type, d.created_at,
COUNT(c.id) as chunk_count
FROM documents d
LEFT JOIN chunks c ON d.id = c.document_id
GROUP BY d.id
ORDER BY d.created_at DESC
""")
documents = []
for row in cursor.fetchall():
documents.append({
"id": row[0],
"filename": row[1],
"file_type": row[2],
"created_at": row[3],
"chunk_count": row[4]
})
conn.close()
return documents
def get_document_count(self) -> int:
"""Get total number of documents."""
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM documents")
count = cursor.fetchone()[0]
conn.close()
return count
def get_chunk_count(self) -> int:
"""Get total number of chunks."""
conn = sqlite3.connect(str(self.db_path))
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM chunks")
count = cursor.fetchone()[0]
conn.close()
return count