- Pass all registered tools to LLM during chat completion - Handle tool_calls from LLM response - Execute tools and feed results back to LLM - Loop until LLM returns final response - Updated system prompt to encourage tool use - Updated streaming to handle tool calls - Increased MAX_TOOL_ITERATIONS to 5
560 lines
19 KiB
Python
Executable File
560 lines
19 KiB
Python
Executable File
"""
|
|
RAG System - Retrieval Augmented Generation
|
|
|
|
This module provides the core RAG functionality for DocRAG, including:
|
|
- Website downloading and ingestion via website_downloader_tool
|
|
- Document processing and chunking
|
|
- Vector storage and similarity search
|
|
- Context retrieval for enhanced prompts
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any, Optional
|
|
from urllib.parse import urlparse
|
|
|
|
from .document_processor import DocumentProcessor
|
|
from .vector_store import VectorStore
|
|
from .retriever import Retriever
|
|
|
|
# Import the website downloader tool
|
|
from website_downloader_tool import website_downloader
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class RAGSystem:
|
|
"""
|
|
Main RAG system that coordinates website downloading, document processing,
|
|
storage, and retrieval.
|
|
|
|
This class provides a unified interface for:
|
|
- Downloading websites using website_downloader_tool
|
|
- Processing downloaded content into the knowledge base
|
|
- Querying for relevant context
|
|
- Managing the document lifecycle
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_model: str = "text-embedding-3-small",
|
|
vector_store_path: str = "./data/vectors",
|
|
documents_path: str = "./data/documents",
|
|
downloaded_sites_path: str = "./data/downloaded_sites",
|
|
chunk_size: int = 1000,
|
|
chunk_overlap: int = 200,
|
|
):
|
|
self.embedding_model = embedding_model
|
|
self.vector_store_path = Path(vector_store_path)
|
|
self.documents_path = Path(documents_path)
|
|
self.downloaded_sites_path = Path(downloaded_sites_path)
|
|
self.chunk_size = chunk_size
|
|
self.chunk_overlap = chunk_overlap
|
|
|
|
self._initialized = False
|
|
self._document_processor: Optional[DocumentProcessor] = None
|
|
self._vector_store: Optional[VectorStore] = None
|
|
self._retriever: Optional[Retriever] = None
|
|
|
|
# Track downloaded sites with their source URLs
|
|
self._site_registry: dict[str, dict[str, Any]] = {}
|
|
|
|
async def initialize(self) -> None:
|
|
"""Initialize the RAG system components."""
|
|
if self._initialized:
|
|
return
|
|
|
|
log.info("Initializing RAG system...")
|
|
|
|
# Create directories
|
|
self.vector_store_path.mkdir(parents=True, exist_ok=True)
|
|
self.documents_path.mkdir(parents=True, exist_ok=True)
|
|
self.downloaded_sites_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Initialize document processor
|
|
self._document_processor = DocumentProcessor(
|
|
chunk_size=self.chunk_size,
|
|
chunk_overlap=self.chunk_overlap,
|
|
)
|
|
|
|
# Initialize vector store
|
|
self._vector_store = VectorStore(
|
|
persist_directory=str(self.vector_store_path),
|
|
embedding_model=self.embedding_model,
|
|
)
|
|
await self._vector_store.initialize()
|
|
|
|
# Initialize retriever
|
|
self._retriever = Retriever(
|
|
vector_store=self._vector_store,
|
|
)
|
|
|
|
# Load existing site registry
|
|
await self._load_site_registry()
|
|
|
|
self._initialized = True
|
|
log.info("RAG system initialized successfully")
|
|
|
|
async def close(self) -> None:
|
|
"""Close the RAG system and release resources."""
|
|
if self._vector_store:
|
|
await self._vector_store.close()
|
|
await self._save_site_registry()
|
|
self._initialized = False
|
|
log.info("RAG system closed")
|
|
|
|
def _ensure_initialized(self) -> None:
|
|
"""Ensure the RAG system is initialized."""
|
|
if not self._initialized:
|
|
raise RuntimeError("RAG system not initialized. Call initialize() first.")
|
|
|
|
async def _load_site_registry(self) -> None:
|
|
"""Load the site registry from disk."""
|
|
import json
|
|
registry_file = self.downloaded_sites_path / "site_registry.json"
|
|
if registry_file.exists():
|
|
try:
|
|
with open(registry_file, "r") as f:
|
|
self._site_registry = json.load(f)
|
|
log.info(f"Loaded site registry with {len(self._site_registry)} sites")
|
|
except Exception as e:
|
|
log.warning(f"Failed to load site registry: {e}")
|
|
self._site_registry = {}
|
|
|
|
async def _save_site_registry(self) -> None:
|
|
"""Save the site registry to disk."""
|
|
import json
|
|
registry_file = self.downloaded_sites_path / "site_registry.json"
|
|
try:
|
|
with open(registry_file, "w") as f:
|
|
json.dump(self._site_registry, f, indent=2, ensure_ascii=False)
|
|
log.info(f"Saved site registry with {len(self._site_registry)} sites")
|
|
except Exception as e:
|
|
log.error(f"Failed to save site registry: {e}")
|
|
|
|
async def download_and_ingest_website(
|
|
self,
|
|
url: str,
|
|
max_pages: int = 50,
|
|
threads: int = 6,
|
|
download_external_assets: bool = False,
|
|
external_domains: Optional[list[str]] = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Download a website using website_downloader_tool and ingest all content
|
|
into the knowledge base.
|
|
|
|
This is the PRIMARY method for adding content to the RAG system.
|
|
|
|
Args:
|
|
url: URL of the website to download
|
|
max_pages: Maximum number of pages to crawl
|
|
threads: Number of concurrent download threads
|
|
download_external_assets: Whether to download external assets
|
|
external_domains: List of external domains to allow
|
|
|
|
Returns:
|
|
Dictionary with download and ingestion results
|
|
"""
|
|
self._ensure_initialized()
|
|
|
|
log.info(f"Downloading website: {url}")
|
|
|
|
# Use website_downloader_tool to download the site
|
|
download_result = website_downloader(
|
|
url=url,
|
|
destination=str(self.downloaded_sites_path / self._get_site_folder(url)),
|
|
max_pages=max_pages,
|
|
threads=threads,
|
|
download_external_assets=download_external_assets,
|
|
external_domains=external_domains,
|
|
)
|
|
|
|
if not download_result.get("success"):
|
|
log.error(f"Website download failed: {download_result.get('message')}")
|
|
return {
|
|
"success": False,
|
|
"message": download_result.get("message", "Download failed"),
|
|
"url": url,
|
|
}
|
|
|
|
output_dir = download_result.get("output_directory", "")
|
|
stats = download_result.get("stats", {})
|
|
|
|
log.info(f"Website downloaded to: {output_dir}")
|
|
|
|
# Process all HTML files from the downloaded site
|
|
ingestion_result = await self._ingest_downloaded_site(
|
|
site_path=Path(output_dir),
|
|
source_url=url,
|
|
)
|
|
|
|
# Register the site
|
|
site_id = self._generate_site_id(url)
|
|
self._site_registry[site_id] = {
|
|
"url": url,
|
|
"local_path": output_dir,
|
|
"pages_downloaded": stats.get("pages_crawled", 0),
|
|
"assets_downloaded": stats.get("assets_downloaded", 0),
|
|
"chunks_ingested": ingestion_result.get("total_chunks", 0),
|
|
"timestamp": self._get_timestamp(),
|
|
}
|
|
await self._save_site_registry()
|
|
|
|
return {
|
|
"success": True,
|
|
"url": url,
|
|
"local_path": output_dir,
|
|
"pages_processed": ingestion_result.get("pages_processed", 0),
|
|
"total_chunks": ingestion_result.get("total_chunks", 0),
|
|
"stats": stats,
|
|
}
|
|
|
|
async def _ingest_downloaded_site(
|
|
self,
|
|
site_path: Path,
|
|
source_url: str,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Ingest all HTML files from a downloaded website into the knowledge base.
|
|
|
|
Args:
|
|
site_path: Path to the downloaded website directory
|
|
source_url: Original URL of the website
|
|
|
|
Returns:
|
|
Dictionary with ingestion statistics
|
|
"""
|
|
pages_processed = 0
|
|
total_chunks = 0
|
|
errors = []
|
|
|
|
# Find all HTML files
|
|
html_files = list(site_path.rglob("*.html"))
|
|
log.info(f"Found {len(html_files)} HTML files in {site_path}")
|
|
|
|
for html_file in html_files:
|
|
try:
|
|
# Read the HTML file
|
|
content = html_file.read_bytes()
|
|
|
|
# Calculate relative path for the page pointer
|
|
relative_path = html_file.relative_to(site_path)
|
|
page_url = self._reconstruct_page_url(source_url, relative_path)
|
|
|
|
# Extract text from HTML
|
|
text_content = await self._document_processor.extract_text_from_html(content)
|
|
|
|
if not text_content.strip():
|
|
continue
|
|
|
|
# Process into chunks
|
|
doc_info = await self._document_processor.process(
|
|
content=content,
|
|
filename=str(html_file),
|
|
metadata={
|
|
"source_url": source_url,
|
|
"page_url": page_url,
|
|
"local_path": str(html_file),
|
|
"relative_path": str(relative_path),
|
|
"source_type": "downloaded_website",
|
|
},
|
|
)
|
|
|
|
# Store chunks in vector store with pointers
|
|
if doc_info.get("chunks"):
|
|
# Add source pointer to each chunk's metadata
|
|
for metadata in doc_info.get("metadatas", []):
|
|
metadata["source_url"] = source_url
|
|
metadata["page_url"] = page_url
|
|
metadata["local_path"] = str(html_file)
|
|
metadata["pointer"] = {
|
|
"type": "downloaded_page",
|
|
"url": page_url,
|
|
"local_file": str(html_file),
|
|
}
|
|
|
|
await self._vector_store.add_chunks(
|
|
chunks=doc_info["chunks"],
|
|
metadatas=doc_info.get("metadatas", []),
|
|
ids=doc_info.get("ids", []),
|
|
)
|
|
|
|
total_chunks += len(doc_info["chunks"])
|
|
pages_processed += 1
|
|
|
|
log.debug(f"Ingested: {relative_path} -> {len(doc_info['chunks'])} chunks")
|
|
|
|
except Exception as e:
|
|
errors.append(f"{html_file}: {str(e)}")
|
|
log.warning(f"Failed to process {html_file}: {e}")
|
|
|
|
log.info(f"Ingestion complete: {pages_processed} pages, {total_chunks} chunks")
|
|
|
|
return {
|
|
"pages_processed": pages_processed,
|
|
"total_chunks": total_chunks,
|
|
"errors": errors,
|
|
}
|
|
|
|
def _get_site_folder(self, url: str) -> str:
|
|
"""Generate a folder name for a site from its URL."""
|
|
parsed = urlparse(url)
|
|
# Use domain name as folder, replace dots with underscores
|
|
folder = parsed.netloc.replace(".", "_").replace(":", "_")
|
|
return folder
|
|
|
|
def _generate_site_id(self, url: str) -> str:
|
|
"""Generate a unique ID for a site."""
|
|
import hashlib
|
|
return hashlib.md5(url.encode()).hexdigest()[:16]
|
|
|
|
def _reconstruct_page_url(self, base_url: str, relative_path: Path) -> str:
|
|
"""Reconstruct the original URL for a downloaded page."""
|
|
parsed = urlparse(base_url)
|
|
# Convert relative path back to URL path
|
|
path_parts = list(relative_path.parts)
|
|
|
|
# Handle index.html as directory
|
|
if path_parts and path_parts[-1] == "index.html":
|
|
path_parts = path_parts[:-1]
|
|
# Remove .html extension from other files
|
|
elif path_parts and path_parts[-1].endswith(".html"):
|
|
path_parts[-1] = path_parts[-1][:-5]
|
|
|
|
url_path = "/".join(path_parts)
|
|
return f"{parsed.scheme}://{parsed.netloc}/{url_path}"
|
|
|
|
def _get_timestamp(self) -> str:
|
|
"""Get current timestamp in ISO format."""
|
|
from datetime import datetime
|
|
return datetime.utcnow().isoformat()
|
|
|
|
async def add_document(
|
|
self,
|
|
content: bytes,
|
|
filename: str,
|
|
metadata: Optional[dict[str, Any]] = None,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Add a document to the knowledge base.
|
|
|
|
Note: For websites, prefer using download_and_ingest_website() instead.
|
|
|
|
Args:
|
|
content: Raw document content
|
|
filename: Original filename
|
|
metadata: Optional metadata
|
|
|
|
Returns:
|
|
Dictionary with processing results
|
|
"""
|
|
self._ensure_initialized()
|
|
|
|
# Process document
|
|
doc_info = await self._document_processor.process(
|
|
content=content,
|
|
filename=filename,
|
|
metadata=metadata,
|
|
)
|
|
|
|
# Store chunks in vector store with pointers
|
|
if doc_info.get("chunks"):
|
|
# Add pointer metadata
|
|
for metadata in doc_info.get("metadatas", []):
|
|
metadata["pointer"] = {
|
|
"type": "uploaded_file",
|
|
"filename": filename,
|
|
}
|
|
|
|
await self._vector_store.add_chunks(
|
|
chunks=doc_info["chunks"],
|
|
metadatas=doc_info.get("metadatas", []),
|
|
ids=doc_info.get("ids", []),
|
|
)
|
|
|
|
log.info(f"Added document '{filename}' with {len(doc_info.get('chunks', []))} chunks")
|
|
return {"chunks": len(doc_info.get("chunks", [])), "document_id": doc_info.get("document_id")}
|
|
|
|
async def query(
|
|
self,
|
|
query: str,
|
|
top_k: int = 5,
|
|
filter_metadata: Optional[dict] = None,
|
|
include_pointers: bool = True,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Query the knowledge base for relevant context.
|
|
|
|
Args:
|
|
query: Query string
|
|
top_k: Number of results to return
|
|
filter_metadata: Optional metadata filters
|
|
include_pointers: Whether to include page pointers in results
|
|
|
|
Returns:
|
|
Dictionary with context, sources, and page pointers
|
|
"""
|
|
self._ensure_initialized()
|
|
|
|
# Retrieve relevant chunks
|
|
results = await self._retriever.retrieve(
|
|
query=query,
|
|
top_k=top_k,
|
|
filter_metadata=filter_metadata,
|
|
)
|
|
|
|
# Build context string and collect pointers
|
|
context_parts = []
|
|
sources = []
|
|
pointers = []
|
|
|
|
for i, result in enumerate(results):
|
|
context_parts.append(f"[{i+1}] {result['content']}")
|
|
|
|
metadata = result.get("metadata", {})
|
|
|
|
# Collect source info
|
|
if metadata.get("page_url"):
|
|
sources.append(metadata["page_url"])
|
|
elif metadata.get("source_url"):
|
|
sources.append(metadata["source_url"])
|
|
elif metadata.get("source"):
|
|
sources.append(metadata["source"])
|
|
|
|
# Collect pointer info
|
|
if include_pointers and metadata.get("pointer"):
|
|
pointer = metadata["pointer"]
|
|
pointer["chunk_id"] = result.get("id")
|
|
pointer["score"] = result.get("score")
|
|
pointers.append(pointer)
|
|
|
|
context = "\n\n".join(context_parts)
|
|
|
|
return {
|
|
"context": context,
|
|
"sources": list(set(sources)),
|
|
"pointers": pointers,
|
|
"num_results": len(results),
|
|
"results": results,
|
|
}
|
|
|
|
async def list_documents(self) -> list[dict[str, Any]]:
|
|
"""List all documents in the knowledge base."""
|
|
self._ensure_initialized()
|
|
|
|
# Get documents from vector store
|
|
docs = await self._vector_store.list_documents()
|
|
|
|
# Enrich with site registry info
|
|
for doc in docs:
|
|
source_url = doc.get("source_url")
|
|
if source_url:
|
|
site_id = self._generate_site_id(source_url)
|
|
if site_id in self._site_registry:
|
|
doc["site_info"] = self._site_registry[site_id]
|
|
|
|
return docs
|
|
|
|
async def list_downloaded_sites(self) -> list[dict[str, Any]]:
|
|
"""List all downloaded websites."""
|
|
return list(self._site_registry.values())
|
|
|
|
async def delete_document(self, document_id: str) -> None:
|
|
"""Delete a document from the knowledge base."""
|
|
self._ensure_initialized()
|
|
await self._vector_store.delete_document(document_id)
|
|
log.info(f"Deleted document {document_id}")
|
|
|
|
async def delete_site(self, url: str) -> dict[str, Any]:
|
|
"""
|
|
Delete a downloaded website and all its content from the knowledge base.
|
|
|
|
Args:
|
|
url: URL of the site to delete
|
|
|
|
Returns:
|
|
Dictionary with deletion results
|
|
"""
|
|
self._ensure_initialized()
|
|
|
|
site_id = self._generate_site_id(url)
|
|
|
|
if site_id not in self._site_registry:
|
|
return {"success": False, "message": f"Site not found: {url}"}
|
|
|
|
site_info = self._site_registry[site_id]
|
|
local_path = site_info.get("local_path")
|
|
|
|
# Delete from vector store
|
|
deleted_chunks = await self._vector_store.delete_by_source_url(url)
|
|
|
|
# Delete local files
|
|
import shutil
|
|
if local_path and Path(local_path).exists():
|
|
shutil.rmtree(local_path)
|
|
|
|
# Remove from registry
|
|
del self._site_registry[site_id]
|
|
await self._save_site_registry()
|
|
|
|
log.info(f"Deleted site: {url}")
|
|
|
|
return {
|
|
"success": True,
|
|
"url": url,
|
|
"deleted_chunks": deleted_chunks,
|
|
"deleted_path": local_path,
|
|
}
|
|
|
|
def get_site_info(self, url: str) -> Optional[dict[str, Any]]:
|
|
"""Get information about a downloaded site."""
|
|
site_id = self._generate_site_id(url)
|
|
return self._site_registry.get(site_id)
|
|
|
|
|
|
# Global RAG system instance
|
|
_rag_system: Optional[RAGSystem] = None
|
|
|
|
|
|
async def get_rag_system(
|
|
embedding_model: str = "text-embedding-3-small",
|
|
vector_store_path: str = "./data/vectors",
|
|
documents_path: str = "./data/documents",
|
|
downloaded_sites_path: str = "./data/downloaded_sites",
|
|
chunk_size: int = 1000,
|
|
chunk_overlap: int = 200,
|
|
) -> RAGSystem:
|
|
"""
|
|
Get or create the global RAG system instance.
|
|
|
|
Args:
|
|
embedding_model: Name of the embedding model
|
|
vector_store_path: Path to vector store
|
|
documents_path: Path to document storage
|
|
downloaded_sites_path: Path to downloaded websites
|
|
chunk_size: Size of document chunks
|
|
chunk_overlap: Overlap between chunks
|
|
|
|
Returns:
|
|
Initialized RAGSystem instance
|
|
"""
|
|
global _rag_system
|
|
|
|
if _rag_system is None:
|
|
_rag_system = RAGSystem(
|
|
embedding_model=embedding_model,
|
|
vector_store_path=vector_store_path,
|
|
documents_path=documents_path,
|
|
downloaded_sites_path=downloaded_sites_path,
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
)
|
|
await _rag_system.initialize()
|
|
|
|
return _rag_system
|