docrag/rag/__init__.py
Z User 6aecc4b231 Integrate website_downloader_tool into RAG system
Features:
- RAG system now uses website_downloader_tool as primary content ingestion method
- download_and_ingest_website() method for complete website processing
- Stores page pointers (source_url, page_url, local_path) in vector store
- Site registry tracks all downloaded websites with metadata
- New API endpoints for website management:
  - POST /v1/documents/website - Download and ingest a website
  - GET /v1/documents/sites - List all downloaded sites
  - GET /v1/documents/sites/{url} - Get site info
  - DELETE /v1/documents/sites/{url} - Delete a site and its content

Changes:
- rag/__init__.py: Added download_and_ingest_website(), site registry
- rag/document_processor.py: Added extract_text_from_html() public method
- rag/vector_store.py: Added delete_by_source_url(), get_stats()
- main.py: New website endpoints, integrated tool with RAG system
2026-03-29 02:36:59 +00:00

560 lines
19 KiB
Python

"""
RAG System - Retrieval Augmented Generation
This module provides the core RAG functionality for DocRAG, including:
- Website downloading and ingestion via website_downloader_tool
- Document processing and chunking
- Vector storage and similarity search
- Context retrieval for enhanced prompts
"""
from __future__ import annotations
import asyncio
import logging
import os
from pathlib import Path
from typing import Any, Optional
from urllib.parse import urlparse
from .document_processor import DocumentProcessor
from .vector_store import VectorStore
from .retriever import Retriever
# Import the website downloader tool
from website_downloader_tool import website_downloader
log = logging.getLogger(__name__)
class RAGSystem:
"""
Main RAG system that coordinates website downloading, document processing,
storage, and retrieval.
This class provides a unified interface for:
- Downloading websites using website_downloader_tool
- Processing downloaded content into the knowledge base
- Querying for relevant context
- Managing the document lifecycle
"""
def __init__(
self,
embedding_model: str = "text-embedding-3-small",
vector_store_path: str = "./data/vectors",
documents_path: str = "./data/documents",
downloaded_sites_path: str = "./data/downloaded_sites",
chunk_size: int = 1000,
chunk_overlap: int = 200,
):
self.embedding_model = embedding_model
self.vector_store_path = Path(vector_store_path)
self.documents_path = Path(documents_path)
self.downloaded_sites_path = Path(downloaded_sites_path)
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self._initialized = False
self._document_processor: Optional[DocumentProcessor] = None
self._vector_store: Optional[VectorStore] = None
self._retriever: Optional[Retriever] = None
# Track downloaded sites with their source URLs
self._site_registry: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None:
"""Initialize the RAG system components."""
if self._initialized:
return
log.info("Initializing RAG system...")
# Create directories
self.vector_store_path.mkdir(parents=True, exist_ok=True)
self.documents_path.mkdir(parents=True, exist_ok=True)
self.downloaded_sites_path.mkdir(parents=True, exist_ok=True)
# Initialize document processor
self._document_processor = DocumentProcessor(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
# Initialize vector store
self._vector_store = VectorStore(
persist_directory=str(self.vector_store_path),
embedding_model=self.embedding_model,
)
await self._vector_store.initialize()
# Initialize retriever
self._retriever = Retriever(
vector_store=self._vector_store,
)
# Load existing site registry
await self._load_site_registry()
self._initialized = True
log.info("RAG system initialized successfully")
async def close(self) -> None:
"""Close the RAG system and release resources."""
if self._vector_store:
await self._vector_store.close()
await self._save_site_registry()
self._initialized = False
log.info("RAG system closed")
def _ensure_initialized(self) -> None:
"""Ensure the RAG system is initialized."""
if not self._initialized:
raise RuntimeError("RAG system not initialized. Call initialize() first.")
async def _load_site_registry(self) -> None:
"""Load the site registry from disk."""
import json
registry_file = self.downloaded_sites_path / "site_registry.json"
if registry_file.exists():
try:
with open(registry_file, "r") as f:
self._site_registry = json.load(f)
log.info(f"Loaded site registry with {len(self._site_registry)} sites")
except Exception as e:
log.warning(f"Failed to load site registry: {e}")
self._site_registry = {}
async def _save_site_registry(self) -> None:
"""Save the site registry to disk."""
import json
registry_file = self.downloaded_sites_path / "site_registry.json"
try:
with open(registry_file, "w") as f:
json.dump(self._site_registry, f, indent=2, ensure_ascii=False)
log.info(f"Saved site registry with {len(self._site_registry)} sites")
except Exception as e:
log.error(f"Failed to save site registry: {e}")
async def download_and_ingest_website(
self,
url: str,
max_pages: int = 50,
threads: int = 6,
download_external_assets: bool = False,
external_domains: Optional[list[str]] = None,
) -> dict[str, Any]:
"""
Download a website using website_downloader_tool and ingest all content
into the knowledge base.
This is the PRIMARY method for adding content to the RAG system.
Args:
url: URL of the website to download
max_pages: Maximum number of pages to crawl
threads: Number of concurrent download threads
download_external_assets: Whether to download external assets
external_domains: List of external domains to allow
Returns:
Dictionary with download and ingestion results
"""
self._ensure_initialized()
log.info(f"Downloading website: {url}")
# Use website_downloader_tool to download the site
download_result = website_downloader(
url=url,
destination=str(self.downloaded_sites_path / self._get_site_folder(url)),
max_pages=max_pages,
threads=threads,
download_external_assets=download_external_assets,
external_domains=external_domains,
)
if not download_result.get("success"):
log.error(f"Website download failed: {download_result.get('message')}")
return {
"success": False,
"message": download_result.get("message", "Download failed"),
"url": url,
}
output_dir = download_result.get("output_directory", "")
stats = download_result.get("stats", {})
log.info(f"Website downloaded to: {output_dir}")
# Process all HTML files from the downloaded site
ingestion_result = await self._ingest_downloaded_site(
site_path=Path(output_dir),
source_url=url,
)
# Register the site
site_id = self._generate_site_id(url)
self._site_registry[site_id] = {
"url": url,
"local_path": output_dir,
"pages_downloaded": stats.get("pages_crawled", 0),
"assets_downloaded": stats.get("assets_downloaded", 0),
"chunks_ingested": ingestion_result.get("total_chunks", 0),
"timestamp": self._get_timestamp(),
}
await self._save_site_registry()
return {
"success": True,
"url": url,
"local_path": output_dir,
"pages_processed": ingestion_result.get("pages_processed", 0),
"total_chunks": ingestion_result.get("total_chunks", 0),
"stats": stats,
}
async def _ingest_downloaded_site(
self,
site_path: Path,
source_url: str,
) -> dict[str, Any]:
"""
Ingest all HTML files from a downloaded website into the knowledge base.
Args:
site_path: Path to the downloaded website directory
source_url: Original URL of the website
Returns:
Dictionary with ingestion statistics
"""
pages_processed = 0
total_chunks = 0
errors = []
# Find all HTML files
html_files = list(site_path.rglob("*.html"))
log.info(f"Found {len(html_files)} HTML files in {site_path}")
for html_file in html_files:
try:
# Read the HTML file
content = html_file.read_bytes()
# Calculate relative path for the page pointer
relative_path = html_file.relative_to(site_path)
page_url = self._reconstruct_page_url(source_url, relative_path)
# Extract text from HTML
text_content = await self._document_processor.extract_text_from_html(content)
if not text_content.strip():
continue
# Process into chunks
doc_info = await self._document_processor.process(
content=content,
filename=str(html_file),
metadata={
"source_url": source_url,
"page_url": page_url,
"local_path": str(html_file),
"relative_path": str(relative_path),
"source_type": "downloaded_website",
},
)
# Store chunks in vector store with pointers
if doc_info.get("chunks"):
# Add source pointer to each chunk's metadata
for metadata in doc_info.get("metadatas", []):
metadata["source_url"] = source_url
metadata["page_url"] = page_url
metadata["local_path"] = str(html_file)
metadata["pointer"] = {
"type": "downloaded_page",
"url": page_url,
"local_file": str(html_file),
}
await self._vector_store.add_chunks(
chunks=doc_info["chunks"],
metadatas=doc_info.get("metadatas", []),
ids=doc_info.get("ids", []),
)
total_chunks += len(doc_info["chunks"])
pages_processed += 1
log.debug(f"Ingested: {relative_path} -> {len(doc_info['chunks'])} chunks")
except Exception as e:
errors.append(f"{html_file}: {str(e)}")
log.warning(f"Failed to process {html_file}: {e}")
log.info(f"Ingestion complete: {pages_processed} pages, {total_chunks} chunks")
return {
"pages_processed": pages_processed,
"total_chunks": total_chunks,
"errors": errors,
}
def _get_site_folder(self, url: str) -> str:
"""Generate a folder name for a site from its URL."""
parsed = urlparse(url)
# Use domain name as folder, replace dots with underscores
folder = parsed.netloc.replace(".", "_").replace(":", "_")
return folder
def _generate_site_id(self, url: str) -> str:
"""Generate a unique ID for a site."""
import hashlib
return hashlib.md5(url.encode()).hexdigest()[:16]
def _reconstruct_page_url(self, base_url: str, relative_path: Path) -> str:
"""Reconstruct the original URL for a downloaded page."""
parsed = urlparse(base_url)
# Convert relative path back to URL path
path_parts = list(relative_path.parts)
# Handle index.html as directory
if path_parts and path_parts[-1] == "index.html":
path_parts = path_parts[:-1]
# Remove .html extension from other files
elif path_parts and path_parts[-1].endswith(".html"):
path_parts[-1] = path_parts[-1][:-5]
url_path = "/".join(path_parts)
return f"{parsed.scheme}://{parsed.netloc}/{url_path}"
def _get_timestamp(self) -> str:
"""Get current timestamp in ISO format."""
from datetime import datetime
return datetime.utcnow().isoformat()
async def add_document(
self,
content: bytes,
filename: str,
metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""
Add a document to the knowledge base.
Note: For websites, prefer using download_and_ingest_website() instead.
Args:
content: Raw document content
filename: Original filename
metadata: Optional metadata
Returns:
Dictionary with processing results
"""
self._ensure_initialized()
# Process document
doc_info = await self._document_processor.process(
content=content,
filename=filename,
metadata=metadata,
)
# Store chunks in vector store with pointers
if doc_info.get("chunks"):
# Add pointer metadata
for metadata in doc_info.get("metadatas", []):
metadata["pointer"] = {
"type": "uploaded_file",
"filename": filename,
}
await self._vector_store.add_chunks(
chunks=doc_info["chunks"],
metadatas=doc_info.get("metadatas", []),
ids=doc_info.get("ids", []),
)
log.info(f"Added document '{filename}' with {len(doc_info.get('chunks', []))} chunks")
return {"chunks": len(doc_info.get("chunks", [])), "document_id": doc_info.get("document_id")}
async def query(
self,
query: str,
top_k: int = 5,
filter_metadata: Optional[dict] = None,
include_pointers: bool = True,
) -> dict[str, Any]:
"""
Query the knowledge base for relevant context.
Args:
query: Query string
top_k: Number of results to return
filter_metadata: Optional metadata filters
include_pointers: Whether to include page pointers in results
Returns:
Dictionary with context, sources, and page pointers
"""
self._ensure_initialized()
# Retrieve relevant chunks
results = await self._retriever.retrieve(
query=query,
top_k=top_k,
filter_metadata=filter_metadata,
)
# Build context string and collect pointers
context_parts = []
sources = []
pointers = []
for i, result in enumerate(results):
context_parts.append(f"[{i+1}] {result['content']}")
metadata = result.get("metadata", {})
# Collect source info
if metadata.get("page_url"):
sources.append(metadata["page_url"])
elif metadata.get("source_url"):
sources.append(metadata["source_url"])
elif metadata.get("source"):
sources.append(metadata["source"])
# Collect pointer info
if include_pointers and metadata.get("pointer"):
pointer = metadata["pointer"]
pointer["chunk_id"] = result.get("id")
pointer["score"] = result.get("score")
pointers.append(pointer)
context = "\n\n".join(context_parts)
return {
"context": context,
"sources": list(set(sources)),
"pointers": pointers,
"num_results": len(results),
"results": results,
}
async def list_documents(self) -> list[dict[str, Any]]:
"""List all documents in the knowledge base."""
self._ensure_initialized()
# Get documents from vector store
docs = await self._vector_store.list_documents()
# Enrich with site registry info
for doc in docs:
source_url = doc.get("source_url")
if source_url:
site_id = self._generate_site_id(source_url)
if site_id in self._site_registry:
doc["site_info"] = self._site_registry[site_id]
return docs
async def list_downloaded_sites(self) -> list[dict[str, Any]]:
"""List all downloaded websites."""
return list(self._site_registry.values())
async def delete_document(self, document_id: str) -> None:
"""Delete a document from the knowledge base."""
self._ensure_initialized()
await self._vector_store.delete_document(document_id)
log.info(f"Deleted document {document_id}")
async def delete_site(self, url: str) -> dict[str, Any]:
"""
Delete a downloaded website and all its content from the knowledge base.
Args:
url: URL of the site to delete
Returns:
Dictionary with deletion results
"""
self._ensure_initialized()
site_id = self._generate_site_id(url)
if site_id not in self._site_registry:
return {"success": False, "message": f"Site not found: {url}"}
site_info = self._site_registry[site_id]
local_path = site_info.get("local_path")
# Delete from vector store
deleted_chunks = await self._vector_store.delete_by_source_url(url)
# Delete local files
import shutil
if local_path and Path(local_path).exists():
shutil.rmtree(local_path)
# Remove from registry
del self._site_registry[site_id]
await self._save_site_registry()
log.info(f"Deleted site: {url}")
return {
"success": True,
"url": url,
"deleted_chunks": deleted_chunks,
"deleted_path": local_path,
}
def get_site_info(self, url: str) -> Optional[dict[str, Any]]:
"""Get information about a downloaded site."""
site_id = self._generate_site_id(url)
return self._site_registry.get(site_id)
# Global RAG system instance
_rag_system: Optional[RAGSystem] = None
async def get_rag_system(
embedding_model: str = "text-embedding-3-small",
vector_store_path: str = "./data/vectors",
documents_path: str = "./data/documents",
downloaded_sites_path: str = "./data/downloaded_sites",
chunk_size: int = 1000,
chunk_overlap: int = 200,
) -> RAGSystem:
"""
Get or create the global RAG system instance.
Args:
embedding_model: Name of the embedding model
vector_store_path: Path to vector store
documents_path: Path to document storage
downloaded_sites_path: Path to downloaded websites
chunk_size: Size of document chunks
chunk_overlap: Overlap between chunks
Returns:
Initialized RAGSystem instance
"""
global _rag_system
if _rag_system is None:
_rag_system = RAGSystem(
embedding_model=embedding_model,
vector_store_path=vector_store_path,
documents_path=documents_path,
downloaded_sites_path=downloaded_sites_path,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
await _rag_system.initialize()
return _rag_system