355 lines
11 KiB
Python
Executable File
355 lines
11 KiB
Python
Executable File
"""
|
|
RAG Store
|
|
SQLite-based vector store for document retrieval.
|
|
"""
|
|
import sqlite3
|
|
import json
|
|
import uuid
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
import numpy as np
|
|
from loguru import logger
|
|
|
|
from config import get_db_path, load_config_from_db, settings
|
|
|
|
|
|
class RAGStore:
|
|
"""
|
|
SQLite-based RAG store with vector similarity search.
|
|
|
|
Features:
|
|
- Document storage and chunking
|
|
- Vector embeddings via Ollama
|
|
- Cosine similarity search
|
|
- Document management (add, delete, list)
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.db_path = get_db_path()
|
|
self._init_db()
|
|
logger.info(f"RAG Store initialized at {self.db_path}")
|
|
|
|
def _init_db(self) -> None:
|
|
"""Initialize the database schema."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
# Documents table
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS documents (
|
|
id TEXT PRIMARY KEY,
|
|
filename TEXT NOT NULL,
|
|
file_type TEXT,
|
|
content_hash TEXT,
|
|
created_at TEXT,
|
|
metadata TEXT
|
|
)
|
|
""")
|
|
|
|
# Chunks table
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS chunks (
|
|
id TEXT PRIMARY KEY,
|
|
document_id TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
chunk_index INTEGER,
|
|
embedding BLOB,
|
|
created_at TEXT,
|
|
FOREIGN KEY (document_id) REFERENCES documents(id)
|
|
)
|
|
""")
|
|
|
|
# Create index for faster searches
|
|
cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_document_id
|
|
ON chunks(document_id)
|
|
""")
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
async def add_document(
|
|
self,
|
|
filename: str,
|
|
content: bytes,
|
|
file_type: str,
|
|
chunk_size: int = 500,
|
|
overlap: int = 50
|
|
) -> str:
|
|
"""
|
|
Add a document to the store.
|
|
|
|
Returns the document ID.
|
|
"""
|
|
# Generate document ID
|
|
doc_id = str(uuid.uuid4())
|
|
|
|
# Extract text based on file type
|
|
text = self._extract_text(content, file_type)
|
|
|
|
if not text.strip():
|
|
raise ValueError("No text content extracted from document")
|
|
|
|
# Chunk the text
|
|
chunks = self._chunk_text(text, chunk_size, overlap)
|
|
|
|
# Insert document
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
INSERT INTO documents (id, filename, file_type, created_at, metadata)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""", (
|
|
doc_id,
|
|
filename,
|
|
file_type,
|
|
datetime.now().isoformat(),
|
|
json.dumps({"chunk_size": chunk_size, "overlap": overlap})
|
|
))
|
|
|
|
# Insert chunks with embeddings
|
|
for i, chunk in enumerate(chunks):
|
|
chunk_id = str(uuid.uuid4())
|
|
|
|
# Generate embedding
|
|
embedding = await self.generate_embedding(chunk)
|
|
embedding_blob = np.array(embedding, dtype=np.float32).tobytes()
|
|
|
|
cursor.execute("""
|
|
INSERT INTO chunks (id, document_id, content, chunk_index, embedding, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
chunk_id,
|
|
doc_id,
|
|
chunk,
|
|
i,
|
|
embedding_blob,
|
|
datetime.now().isoformat()
|
|
))
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
logger.info(f"Added document: {filename} ({len(chunks)} chunks)")
|
|
return doc_id
|
|
|
|
def _extract_text(self, content: bytes, file_type: str) -> str:
|
|
"""Extract text from various file types."""
|
|
text = ""
|
|
|
|
try:
|
|
if file_type in [".txt", ".md", ".text"]:
|
|
text = content.decode("utf-8", errors="ignore")
|
|
|
|
elif file_type == ".pdf":
|
|
try:
|
|
import io
|
|
from pypdf import PdfReader
|
|
|
|
reader = PdfReader(io.BytesIO(content))
|
|
for page in reader.pages:
|
|
text += page.extract_text() + "\n"
|
|
except ImportError:
|
|
logger.warning("pypdf not installed, cannot extract PDF text")
|
|
text = "[PDF content - pypdf not installed]"
|
|
|
|
elif file_type == ".docx":
|
|
try:
|
|
import io
|
|
from docx import Document
|
|
|
|
doc = Document(io.BytesIO(content))
|
|
for para in doc.paragraphs:
|
|
text += para.text + "\n"
|
|
except ImportError:
|
|
logger.warning("python-docx not installed, cannot extract DOCX text")
|
|
text = "[DOCX content - python-docx not installed]"
|
|
|
|
elif file_type in [".html", ".htm"]:
|
|
from bs4 import BeautifulSoup
|
|
soup = BeautifulSoup(content, "html.parser")
|
|
text = soup.get_text(separator="\n")
|
|
|
|
else:
|
|
# Try as plain text
|
|
text = content.decode("utf-8", errors="ignore")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to extract text: {e}")
|
|
text = ""
|
|
|
|
return text
|
|
|
|
def _chunk_text(
|
|
self,
|
|
text: str,
|
|
chunk_size: int,
|
|
overlap: int
|
|
) -> List[str]:
|
|
"""Split text into overlapping chunks."""
|
|
words = text.split()
|
|
chunks = []
|
|
|
|
if len(words) <= chunk_size:
|
|
return [text]
|
|
|
|
start = 0
|
|
while start < len(words):
|
|
end = start + chunk_size
|
|
chunk = " ".join(words[start:end])
|
|
chunks.append(chunk)
|
|
start = end - overlap
|
|
|
|
return chunks
|
|
|
|
async def generate_embedding(self, text: str) -> List[float]:
|
|
"""Generate embedding using Ollama."""
|
|
import ollama
|
|
|
|
config = load_config_from_db()
|
|
ollama_host = config.get("ollama_host", settings.ollama_host)
|
|
embedding_model = config.get("embedding_model", settings.embedding_model)
|
|
|
|
client = ollama.Client(host=ollama_host)
|
|
|
|
try:
|
|
response = client.embeddings(
|
|
model=embedding_model,
|
|
prompt=text
|
|
)
|
|
return response.get("embedding", [])
|
|
except Exception as e:
|
|
logger.error(f"Failed to generate embedding: {e}")
|
|
# Return zero vector as fallback
|
|
return [0.0] * 768 # Common embedding size
|
|
|
|
async def search(
|
|
self,
|
|
query: str,
|
|
top_k: int = 5
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Search for relevant chunks.
|
|
|
|
Returns list of results with content, document name, and score.
|
|
"""
|
|
# Generate query embedding
|
|
query_embedding = await self.generate_embedding(query)
|
|
query_vector = np.array(query_embedding, dtype=np.float32)
|
|
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
# Get all chunks with embeddings
|
|
cursor.execute("""
|
|
SELECT c.id, c.content, c.document_id, c.embedding, d.filename
|
|
FROM chunks c
|
|
JOIN documents d ON c.document_id = d.id
|
|
""")
|
|
|
|
results = []
|
|
|
|
for row in cursor.fetchall():
|
|
chunk_id, content, doc_id, embedding_blob, filename = row
|
|
|
|
if embedding_blob:
|
|
# Convert blob to numpy array
|
|
chunk_vector = np.frombuffer(embedding_blob, dtype=np.float32)
|
|
|
|
# Calculate cosine similarity
|
|
similarity = self._cosine_similarity(query_vector, chunk_vector)
|
|
|
|
results.append({
|
|
"chunk_id": chunk_id,
|
|
"content": content,
|
|
"document_id": doc_id,
|
|
"document_name": filename,
|
|
"score": float(similarity)
|
|
})
|
|
|
|
conn.close()
|
|
|
|
# Sort by score and return top_k
|
|
results.sort(key=lambda x: x["score"], reverse=True)
|
|
return results[:top_k]
|
|
|
|
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
|
|
"""Calculate cosine similarity between two vectors."""
|
|
if len(a) != len(b):
|
|
return 0.0
|
|
|
|
norm_a = np.linalg.norm(a)
|
|
norm_b = np.linalg.norm(b)
|
|
|
|
if norm_a == 0 or norm_b == 0:
|
|
return 0.0
|
|
|
|
return float(np.dot(a, b) / (norm_a * norm_b))
|
|
|
|
def delete_document(self, doc_id: str) -> None:
|
|
"""Delete a document and all its chunks."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
# Delete chunks first
|
|
cursor.execute("DELETE FROM chunks WHERE document_id = ?", (doc_id,))
|
|
|
|
# Delete document
|
|
cursor.execute("DELETE FROM documents WHERE id = ?", (doc_id,))
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
logger.info(f"Deleted document: {doc_id}")
|
|
|
|
def list_documents(self) -> List[Dict[str, Any]]:
|
|
"""List all documents."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT d.id, d.filename, d.file_type, d.created_at,
|
|
COUNT(c.id) as chunk_count
|
|
FROM documents d
|
|
LEFT JOIN chunks c ON d.id = c.document_id
|
|
GROUP BY d.id
|
|
ORDER BY d.created_at DESC
|
|
""")
|
|
|
|
documents = []
|
|
for row in cursor.fetchall():
|
|
documents.append({
|
|
"id": row[0],
|
|
"filename": row[1],
|
|
"file_type": row[2],
|
|
"created_at": row[3],
|
|
"chunk_count": row[4]
|
|
})
|
|
|
|
conn.close()
|
|
return documents
|
|
|
|
def get_document_count(self) -> int:
|
|
"""Get total number of documents."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM documents")
|
|
count = cursor.fetchone()[0]
|
|
|
|
conn.close()
|
|
return count
|
|
|
|
def get_chunk_count(self) -> int:
|
|
"""Get total number of chunks."""
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT COUNT(*) FROM chunks")
|
|
count = cursor.fetchone()[0]
|
|
|
|
conn.close()
|
|
return count
|