Integrate website_downloader_tool into RAG system
Features:
- RAG system now uses website_downloader_tool as primary content ingestion method
- download_and_ingest_website() method for complete website processing
- Stores page pointers (source_url, page_url, local_path) in vector store
- Site registry tracks all downloaded websites with metadata
- New API endpoints for website management:
- POST /v1/documents/website - Download and ingest a website
- GET /v1/documents/sites - List all downloaded sites
- GET /v1/documents/sites/{url} - Get site info
- DELETE /v1/documents/sites/{url} - Delete a site and its content
Changes:
- rag/__init__.py: Added download_and_ingest_website(), site registry
- rag/document_processor.py: Added extract_text_from_html() public method
- rag/vector_store.py: Added delete_by_source_url(), get_stats()
- main.py: New website endpoints, integrated tool with RAG system
This commit is contained in:
parent
eabdadfb62
commit
6aecc4b231
178
main.py
178
main.py
@ -581,27 +581,36 @@ async def check_and_execute_tools(
|
||||
# In a production system, you'd use the LLM to decide tool usage
|
||||
message_lower = user_message.lower()
|
||||
|
||||
# Check for website download intent
|
||||
if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site"]):
|
||||
# Check for website download intent - use RAG system for full integration
|
||||
if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site", "ingest site"]):
|
||||
# Extract URL from message
|
||||
import re
|
||||
url_pattern = r'https?://[^\s]+'
|
||||
urls = re.findall(url_pattern, user_message)
|
||||
|
||||
if urls:
|
||||
tool_result = state.tool_manager.execute_tool(
|
||||
"website_downloader",
|
||||
{"url": urls[0], "max_pages": 10}
|
||||
)
|
||||
tool_calls.append({
|
||||
"id": f"call_{uuid.uuid4().hex[:24]}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "website_downloader",
|
||||
"arguments": json.dumps({"url": urls[0]}),
|
||||
}
|
||||
})
|
||||
log.info(f"Executed website_downloader tool: {tool_result}")
|
||||
if urls and state.rag_system:
|
||||
# Use RAG system's integrated website downloader
|
||||
try:
|
||||
result = await state.rag_system.download_and_ingest_website(
|
||||
url=urls[0],
|
||||
max_pages=20, # Reasonable default
|
||||
)
|
||||
|
||||
tool_calls.append({
|
||||
"id": f"call_{uuid.uuid4().hex[:24]}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "website_downloader",
|
||||
"arguments": json.dumps({
|
||||
"url": urls[0],
|
||||
"success": result.get("success"),
|
||||
"chunks_ingested": result.get("total_chunks", 0),
|
||||
}),
|
||||
}
|
||||
})
|
||||
log.info(f"Downloaded and ingested website: {urls[0]} -> {result.get('total_chunks', 0)} chunks")
|
||||
except Exception as e:
|
||||
log.error(f"Website download failed: {e}")
|
||||
|
||||
return tool_calls
|
||||
|
||||
@ -638,9 +647,67 @@ async def upload_document(request: Request):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class WebsiteDownloadRequest(BaseModel):
|
||||
"""Request model for website download."""
|
||||
url: str
|
||||
max_pages: int = 50
|
||||
threads: int = 6
|
||||
download_external_assets: bool = False
|
||||
external_domains: Optional[list[str]] = None
|
||||
|
||||
|
||||
@app.post("/v1/documents/website")
|
||||
async def download_website(request: WebsiteDownloadRequest):
|
||||
"""
|
||||
Download a website and ingest it into the knowledge base.
|
||||
|
||||
This is the PRIMARY way to add content to the RAG system.
|
||||
Uses the website_downloader_tool to download and process websites.
|
||||
"""
|
||||
if not state.rag_system:
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
try:
|
||||
log.info(f"Downloading website: {request.url}")
|
||||
|
||||
result = await state.rag_system.download_and_ingest_website(
|
||||
url=request.url,
|
||||
max_pages=request.max_pages,
|
||||
threads=request.threads,
|
||||
download_external_assets=request.download_external_assets,
|
||||
external_domains=request.external_domains,
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Website downloaded and ingested: {request.url}",
|
||||
"url": request.url,
|
||||
"local_path": result.get("local_path"),
|
||||
"pages_processed": result.get("pages_processed", 0),
|
||||
"total_chunks": result.get("total_chunks", 0),
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=result.get("message", "Website download failed")
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.exception("Website download failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/v1/documents/url")
|
||||
async def add_document_from_url(request: dict):
|
||||
"""Add a document from URL to the knowledge base."""
|
||||
"""
|
||||
Add a document from URL to the knowledge base.
|
||||
|
||||
NOTE: For websites, prefer using /v1/documents/website instead
|
||||
as it downloads the entire site and provides better context.
|
||||
"""
|
||||
if not state.rag_system:
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
@ -670,6 +737,47 @@ async def list_documents():
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/v1/documents/sites")
|
||||
async def list_downloaded_sites():
|
||||
"""List all downloaded websites in the knowledge base."""
|
||||
if not state.rag_system:
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
try:
|
||||
sites = await state.rag_system.list_downloaded_sites()
|
||||
return {"sites": sites}
|
||||
except Exception as e:
|
||||
log.exception("Site listing failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/v1/documents/sites/{url:path}")
|
||||
async def get_site_info(url: str):
|
||||
"""Get information about a specific downloaded site."""
|
||||
if not state.rag_system:
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
try:
|
||||
# URL will be passed as path parameter, need to decode
|
||||
from urllib.parse import unquote
|
||||
decoded_url = unquote(url)
|
||||
|
||||
# Add scheme if missing
|
||||
if not decoded_url.startswith(("http://", "https://")):
|
||||
decoded_url = "https://" + decoded_url
|
||||
|
||||
site_info = state.rag_system.get_site_info(decoded_url)
|
||||
if site_info:
|
||||
return {"site": site_info}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="Site not found")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.exception("Site info retrieval failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.delete("/v1/documents/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""Delete a document from the knowledge base."""
|
||||
@ -684,6 +792,40 @@ async def delete_document(doc_id: str):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.delete("/v1/documents/sites/{url:path}")
|
||||
async def delete_site(url: str):
|
||||
"""Delete a downloaded website and all its content from the knowledge base."""
|
||||
if not state.rag_system:
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
try:
|
||||
# URL will be passed as path parameter, need to decode
|
||||
from urllib.parse import unquote
|
||||
decoded_url = unquote(url)
|
||||
|
||||
# Add scheme if missing
|
||||
if not decoded_url.startswith(("http://", "https://")):
|
||||
decoded_url = "https://" + decoded_url
|
||||
|
||||
result = await state.rag_system.delete_site(decoded_url)
|
||||
|
||||
if result.get("success"):
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Site {decoded_url} deleted",
|
||||
"deleted_chunks": result.get("deleted_chunks", 0),
|
||||
"deleted_path": result.get("deleted_path"),
|
||||
}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=result.get("message", "Site not found"))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
log.exception("Site deletion failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Health and Status Endpoints
|
||||
# =============================================================================
|
||||
@ -711,6 +853,8 @@ async def root():
|
||||
"chat": "/v1/chat/completions",
|
||||
"models": "/v1/models",
|
||||
"documents": "/v1/documents",
|
||||
"download_website": "/v1/documents/website",
|
||||
"list_sites": "/v1/documents/sites",
|
||||
"health": "/health",
|
||||
},
|
||||
}
|
||||
|
||||
375
rag/__init__.py
375
rag/__init__.py
@ -2,6 +2,7 @@
|
||||
RAG System - Retrieval Augmented Generation
|
||||
|
||||
This module provides the core RAG functionality for DocRAG, including:
|
||||
- Website downloading and ingestion via website_downloader_tool
|
||||
- Document processing and chunking
|
||||
- Vector storage and similarity search
|
||||
- Context retrieval for enhanced prompts
|
||||
@ -14,20 +15,26 @@ import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .document_processor import DocumentProcessor
|
||||
from .vector_store import VectorStore
|
||||
from .retriever import Retriever
|
||||
|
||||
# Import the website downloader tool
|
||||
from website_downloader_tool import website_downloader
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RAGSystem:
|
||||
"""
|
||||
Main RAG system that coordinates document processing, storage, and retrieval.
|
||||
Main RAG system that coordinates website downloading, document processing,
|
||||
storage, and retrieval.
|
||||
|
||||
This class provides a unified interface for:
|
||||
- Adding documents to the knowledge base
|
||||
- Downloading websites using website_downloader_tool
|
||||
- Processing downloaded content into the knowledge base
|
||||
- Querying for relevant context
|
||||
- Managing the document lifecycle
|
||||
"""
|
||||
@ -37,12 +44,14 @@ class RAGSystem:
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
vector_store_path: str = "./data/vectors",
|
||||
documents_path: str = "./data/documents",
|
||||
downloaded_sites_path: str = "./data/downloaded_sites",
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
):
|
||||
self.embedding_model = embedding_model
|
||||
self.vector_store_path = Path(vector_store_path)
|
||||
self.documents_path = Path(documents_path)
|
||||
self.downloaded_sites_path = Path(downloaded_sites_path)
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
@ -50,6 +59,9 @@ class RAGSystem:
|
||||
self._document_processor: Optional[DocumentProcessor] = None
|
||||
self._vector_store: Optional[VectorStore] = None
|
||||
self._retriever: Optional[Retriever] = None
|
||||
|
||||
# Track downloaded sites with their source URLs
|
||||
self._site_registry: dict[str, dict[str, Any]] = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the RAG system components."""
|
||||
@ -61,6 +73,7 @@ class RAGSystem:
|
||||
# Create directories
|
||||
self.vector_store_path.mkdir(parents=True, exist_ok=True)
|
||||
self.documents_path.mkdir(parents=True, exist_ok=True)
|
||||
self.downloaded_sites_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize document processor
|
||||
self._document_processor = DocumentProcessor(
|
||||
@ -73,12 +86,16 @@ class RAGSystem:
|
||||
persist_directory=str(self.vector_store_path),
|
||||
embedding_model=self.embedding_model,
|
||||
)
|
||||
await self._vector_store.initialize()
|
||||
|
||||
# Initialize retriever
|
||||
self._retriever = Retriever(
|
||||
vector_store=self._vector_store,
|
||||
)
|
||||
|
||||
# Load existing site registry
|
||||
await self._load_site_registry()
|
||||
|
||||
self._initialized = True
|
||||
log.info("RAG system initialized successfully")
|
||||
|
||||
@ -86,6 +103,7 @@ class RAGSystem:
|
||||
"""Close the RAG system and release resources."""
|
||||
if self._vector_store:
|
||||
await self._vector_store.close()
|
||||
await self._save_site_registry()
|
||||
self._initialized = False
|
||||
log.info("RAG system closed")
|
||||
|
||||
@ -94,6 +112,228 @@ class RAGSystem:
|
||||
if not self._initialized:
|
||||
raise RuntimeError("RAG system not initialized. Call initialize() first.")
|
||||
|
||||
async def _load_site_registry(self) -> None:
|
||||
"""Load the site registry from disk."""
|
||||
import json
|
||||
registry_file = self.downloaded_sites_path / "site_registry.json"
|
||||
if registry_file.exists():
|
||||
try:
|
||||
with open(registry_file, "r") as f:
|
||||
self._site_registry = json.load(f)
|
||||
log.info(f"Loaded site registry with {len(self._site_registry)} sites")
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to load site registry: {e}")
|
||||
self._site_registry = {}
|
||||
|
||||
async def _save_site_registry(self) -> None:
|
||||
"""Save the site registry to disk."""
|
||||
import json
|
||||
registry_file = self.downloaded_sites_path / "site_registry.json"
|
||||
try:
|
||||
with open(registry_file, "w") as f:
|
||||
json.dump(self._site_registry, f, indent=2, ensure_ascii=False)
|
||||
log.info(f"Saved site registry with {len(self._site_registry)} sites")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to save site registry: {e}")
|
||||
|
||||
async def download_and_ingest_website(
|
||||
self,
|
||||
url: str,
|
||||
max_pages: int = 50,
|
||||
threads: int = 6,
|
||||
download_external_assets: bool = False,
|
||||
external_domains: Optional[list[str]] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Download a website using website_downloader_tool and ingest all content
|
||||
into the knowledge base.
|
||||
|
||||
This is the PRIMARY method for adding content to the RAG system.
|
||||
|
||||
Args:
|
||||
url: URL of the website to download
|
||||
max_pages: Maximum number of pages to crawl
|
||||
threads: Number of concurrent download threads
|
||||
download_external_assets: Whether to download external assets
|
||||
external_domains: List of external domains to allow
|
||||
|
||||
Returns:
|
||||
Dictionary with download and ingestion results
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
log.info(f"Downloading website: {url}")
|
||||
|
||||
# Use website_downloader_tool to download the site
|
||||
download_result = website_downloader(
|
||||
url=url,
|
||||
destination=str(self.downloaded_sites_path / self._get_site_folder(url)),
|
||||
max_pages=max_pages,
|
||||
threads=threads,
|
||||
download_external_assets=download_external_assets,
|
||||
external_domains=external_domains,
|
||||
)
|
||||
|
||||
if not download_result.get("success"):
|
||||
log.error(f"Website download failed: {download_result.get('message')}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": download_result.get("message", "Download failed"),
|
||||
"url": url,
|
||||
}
|
||||
|
||||
output_dir = download_result.get("output_directory", "")
|
||||
stats = download_result.get("stats", {})
|
||||
|
||||
log.info(f"Website downloaded to: {output_dir}")
|
||||
|
||||
# Process all HTML files from the downloaded site
|
||||
ingestion_result = await self._ingest_downloaded_site(
|
||||
site_path=Path(output_dir),
|
||||
source_url=url,
|
||||
)
|
||||
|
||||
# Register the site
|
||||
site_id = self._generate_site_id(url)
|
||||
self._site_registry[site_id] = {
|
||||
"url": url,
|
||||
"local_path": output_dir,
|
||||
"pages_downloaded": stats.get("pages_crawled", 0),
|
||||
"assets_downloaded": stats.get("assets_downloaded", 0),
|
||||
"chunks_ingested": ingestion_result.get("total_chunks", 0),
|
||||
"timestamp": self._get_timestamp(),
|
||||
}
|
||||
await self._save_site_registry()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"url": url,
|
||||
"local_path": output_dir,
|
||||
"pages_processed": ingestion_result.get("pages_processed", 0),
|
||||
"total_chunks": ingestion_result.get("total_chunks", 0),
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
async def _ingest_downloaded_site(
|
||||
self,
|
||||
site_path: Path,
|
||||
source_url: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Ingest all HTML files from a downloaded website into the knowledge base.
|
||||
|
||||
Args:
|
||||
site_path: Path to the downloaded website directory
|
||||
source_url: Original URL of the website
|
||||
|
||||
Returns:
|
||||
Dictionary with ingestion statistics
|
||||
"""
|
||||
pages_processed = 0
|
||||
total_chunks = 0
|
||||
errors = []
|
||||
|
||||
# Find all HTML files
|
||||
html_files = list(site_path.rglob("*.html"))
|
||||
log.info(f"Found {len(html_files)} HTML files in {site_path}")
|
||||
|
||||
for html_file in html_files:
|
||||
try:
|
||||
# Read the HTML file
|
||||
content = html_file.read_bytes()
|
||||
|
||||
# Calculate relative path for the page pointer
|
||||
relative_path = html_file.relative_to(site_path)
|
||||
page_url = self._reconstruct_page_url(source_url, relative_path)
|
||||
|
||||
# Extract text from HTML
|
||||
text_content = await self._document_processor.extract_text_from_html(content)
|
||||
|
||||
if not text_content.strip():
|
||||
continue
|
||||
|
||||
# Process into chunks
|
||||
doc_info = await self._document_processor.process(
|
||||
content=content,
|
||||
filename=str(html_file),
|
||||
metadata={
|
||||
"source_url": source_url,
|
||||
"page_url": page_url,
|
||||
"local_path": str(html_file),
|
||||
"relative_path": str(relative_path),
|
||||
"source_type": "downloaded_website",
|
||||
},
|
||||
)
|
||||
|
||||
# Store chunks in vector store with pointers
|
||||
if doc_info.get("chunks"):
|
||||
# Add source pointer to each chunk's metadata
|
||||
for metadata in doc_info.get("metadatas", []):
|
||||
metadata["source_url"] = source_url
|
||||
metadata["page_url"] = page_url
|
||||
metadata["local_path"] = str(html_file)
|
||||
metadata["pointer"] = {
|
||||
"type": "downloaded_page",
|
||||
"url": page_url,
|
||||
"local_file": str(html_file),
|
||||
}
|
||||
|
||||
await self._vector_store.add_chunks(
|
||||
chunks=doc_info["chunks"],
|
||||
metadatas=doc_info.get("metadatas", []),
|
||||
ids=doc_info.get("ids", []),
|
||||
)
|
||||
|
||||
total_chunks += len(doc_info["chunks"])
|
||||
pages_processed += 1
|
||||
|
||||
log.debug(f"Ingested: {relative_path} -> {len(doc_info['chunks'])} chunks")
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"{html_file}: {str(e)}")
|
||||
log.warning(f"Failed to process {html_file}: {e}")
|
||||
|
||||
log.info(f"Ingestion complete: {pages_processed} pages, {total_chunks} chunks")
|
||||
|
||||
return {
|
||||
"pages_processed": pages_processed,
|
||||
"total_chunks": total_chunks,
|
||||
"errors": errors,
|
||||
}
|
||||
|
||||
def _get_site_folder(self, url: str) -> str:
|
||||
"""Generate a folder name for a site from its URL."""
|
||||
parsed = urlparse(url)
|
||||
# Use domain name as folder, replace dots with underscores
|
||||
folder = parsed.netloc.replace(".", "_").replace(":", "_")
|
||||
return folder
|
||||
|
||||
def _generate_site_id(self, url: str) -> str:
|
||||
"""Generate a unique ID for a site."""
|
||||
import hashlib
|
||||
return hashlib.md5(url.encode()).hexdigest()[:16]
|
||||
|
||||
def _reconstruct_page_url(self, base_url: str, relative_path: Path) -> str:
|
||||
"""Reconstruct the original URL for a downloaded page."""
|
||||
parsed = urlparse(base_url)
|
||||
# Convert relative path back to URL path
|
||||
path_parts = list(relative_path.parts)
|
||||
|
||||
# Handle index.html as directory
|
||||
if path_parts and path_parts[-1] == "index.html":
|
||||
path_parts = path_parts[:-1]
|
||||
# Remove .html extension from other files
|
||||
elif path_parts and path_parts[-1].endswith(".html"):
|
||||
path_parts[-1] = path_parts[-1][:-5]
|
||||
|
||||
url_path = "/".join(path_parts)
|
||||
return f"{parsed.scheme}://{parsed.netloc}/{url_path}"
|
||||
|
||||
def _get_timestamp(self) -> str:
|
||||
"""Get current timestamp in ISO format."""
|
||||
from datetime import datetime
|
||||
return datetime.utcnow().isoformat()
|
||||
|
||||
async def add_document(
|
||||
self,
|
||||
content: bytes,
|
||||
@ -102,6 +342,8 @@ class RAGSystem:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Add a document to the knowledge base.
|
||||
|
||||
Note: For websites, prefer using download_and_ingest_website() instead.
|
||||
|
||||
Args:
|
||||
content: Raw document content
|
||||
@ -120,8 +362,15 @@ class RAGSystem:
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Store chunks in vector store
|
||||
# Store chunks in vector store with pointers
|
||||
if doc_info.get("chunks"):
|
||||
# Add pointer metadata
|
||||
for metadata in doc_info.get("metadatas", []):
|
||||
metadata["pointer"] = {
|
||||
"type": "uploaded_file",
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
await self._vector_store.add_chunks(
|
||||
chunks=doc_info["chunks"],
|
||||
metadatas=doc_info.get("metadatas", []),
|
||||
@ -131,37 +380,12 @@ class RAGSystem:
|
||||
log.info(f"Added document '{filename}' with {len(doc_info.get('chunks', []))} chunks")
|
||||
return {"chunks": len(doc_info.get("chunks", [])), "document_id": doc_info.get("document_id")}
|
||||
|
||||
async def add_document_from_url(self, url: str) -> dict[str, Any]:
|
||||
"""
|
||||
Add a document from a URL to the knowledge base.
|
||||
|
||||
Args:
|
||||
url: URL to fetch and process
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
# Fetch content from URL
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, timeout=30) as response:
|
||||
response.raise_for_status()
|
||||
content = await response.read()
|
||||
|
||||
# Extract filename from URL
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(url)
|
||||
filename = os.path.basename(parsed.path) or "webpage.html"
|
||||
|
||||
return await self.add_document(content=content, filename=filename, metadata={"source_url": url})
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
filter_metadata: Optional[dict] = None,
|
||||
include_pointers: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Query the knowledge base for relevant context.
|
||||
@ -170,9 +394,10 @@ class RAGSystem:
|
||||
query: Query string
|
||||
top_k: Number of results to return
|
||||
filter_metadata: Optional metadata filters
|
||||
include_pointers: Whether to include page pointers in results
|
||||
|
||||
Returns:
|
||||
Dictionary with context and sources
|
||||
Dictionary with context, sources, and page pointers
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
@ -183,20 +408,37 @@ class RAGSystem:
|
||||
filter_metadata=filter_metadata,
|
||||
)
|
||||
|
||||
# Build context string
|
||||
# Build context string and collect pointers
|
||||
context_parts = []
|
||||
sources = []
|
||||
pointers = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
context_parts.append(f"[{i+1}] {result['content']}")
|
||||
if result.get("metadata", {}).get("source"):
|
||||
sources.append(result["metadata"]["source"])
|
||||
|
||||
metadata = result.get("metadata", {})
|
||||
|
||||
# Collect source info
|
||||
if metadata.get("page_url"):
|
||||
sources.append(metadata["page_url"])
|
||||
elif metadata.get("source_url"):
|
||||
sources.append(metadata["source_url"])
|
||||
elif metadata.get("source"):
|
||||
sources.append(metadata["source"])
|
||||
|
||||
# Collect pointer info
|
||||
if include_pointers and metadata.get("pointer"):
|
||||
pointer = metadata["pointer"]
|
||||
pointer["chunk_id"] = result.get("id")
|
||||
pointer["score"] = result.get("score")
|
||||
pointers.append(pointer)
|
||||
|
||||
context = "\n\n".join(context_parts)
|
||||
|
||||
return {
|
||||
"context": context,
|
||||
"sources": list(set(sources)),
|
||||
"pointers": pointers,
|
||||
"num_results": len(results),
|
||||
"results": results,
|
||||
}
|
||||
@ -204,7 +446,23 @@ class RAGSystem:
|
||||
async def list_documents(self) -> list[dict[str, Any]]:
|
||||
"""List all documents in the knowledge base."""
|
||||
self._ensure_initialized()
|
||||
return await self._vector_store.list_documents()
|
||||
|
||||
# Get documents from vector store
|
||||
docs = await self._vector_store.list_documents()
|
||||
|
||||
# Enrich with site registry info
|
||||
for doc in docs:
|
||||
source_url = doc.get("source_url")
|
||||
if source_url:
|
||||
site_id = self._generate_site_id(source_url)
|
||||
if site_id in self._site_registry:
|
||||
doc["site_info"] = self._site_registry[site_id]
|
||||
|
||||
return docs
|
||||
|
||||
async def list_downloaded_sites(self) -> list[dict[str, Any]]:
|
||||
"""List all downloaded websites."""
|
||||
return list(self._site_registry.values())
|
||||
|
||||
async def delete_document(self, document_id: str) -> None:
|
||||
"""Delete a document from the knowledge base."""
|
||||
@ -212,6 +470,52 @@ class RAGSystem:
|
||||
await self._vector_store.delete_document(document_id)
|
||||
log.info(f"Deleted document {document_id}")
|
||||
|
||||
async def delete_site(self, url: str) -> dict[str, Any]:
|
||||
"""
|
||||
Delete a downloaded website and all its content from the knowledge base.
|
||||
|
||||
Args:
|
||||
url: URL of the site to delete
|
||||
|
||||
Returns:
|
||||
Dictionary with deletion results
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
site_id = self._generate_site_id(url)
|
||||
|
||||
if site_id not in self._site_registry:
|
||||
return {"success": False, "message": f"Site not found: {url}"}
|
||||
|
||||
site_info = self._site_registry[site_id]
|
||||
local_path = site_info.get("local_path")
|
||||
|
||||
# Delete from vector store
|
||||
deleted_chunks = await self._vector_store.delete_by_source_url(url)
|
||||
|
||||
# Delete local files
|
||||
import shutil
|
||||
if local_path and Path(local_path).exists():
|
||||
shutil.rmtree(local_path)
|
||||
|
||||
# Remove from registry
|
||||
del self._site_registry[site_id]
|
||||
await self._save_site_registry()
|
||||
|
||||
log.info(f"Deleted site: {url}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"url": url,
|
||||
"deleted_chunks": deleted_chunks,
|
||||
"deleted_path": local_path,
|
||||
}
|
||||
|
||||
def get_site_info(self, url: str) -> Optional[dict[str, Any]]:
|
||||
"""Get information about a downloaded site."""
|
||||
site_id = self._generate_site_id(url)
|
||||
return self._site_registry.get(site_id)
|
||||
|
||||
|
||||
# Global RAG system instance
|
||||
_rag_system: Optional[RAGSystem] = None
|
||||
@ -221,6 +525,7 @@ async def get_rag_system(
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
vector_store_path: str = "./data/vectors",
|
||||
documents_path: str = "./data/documents",
|
||||
downloaded_sites_path: str = "./data/downloaded_sites",
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
) -> RAGSystem:
|
||||
@ -231,6 +536,7 @@ async def get_rag_system(
|
||||
embedding_model: Name of the embedding model
|
||||
vector_store_path: Path to vector store
|
||||
documents_path: Path to document storage
|
||||
downloaded_sites_path: Path to downloaded websites
|
||||
chunk_size: Size of document chunks
|
||||
chunk_overlap: Overlap between chunks
|
||||
|
||||
@ -244,6 +550,7 @@ async def get_rag_system(
|
||||
embedding_model=embedding_model,
|
||||
vector_store_path=vector_store_path,
|
||||
documents_path=documents_path,
|
||||
downloaded_sites_path=downloaded_sites_path,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
|
||||
@ -167,6 +167,18 @@ class DocumentProcessor:
|
||||
log.error(f"HTML extraction failed: {e}")
|
||||
return ""
|
||||
|
||||
async def extract_text_from_html(self, content: bytes) -> str:
|
||||
"""
|
||||
Public method to extract text from HTML content.
|
||||
|
||||
Args:
|
||||
content: Raw HTML content
|
||||
|
||||
Returns:
|
||||
Extracted text content
|
||||
"""
|
||||
return await self._extract_html(content)
|
||||
|
||||
async def _extract_docx(self, content: bytes) -> str:
|
||||
"""Extract text from DOCX."""
|
||||
try:
|
||||
|
||||
@ -283,3 +283,55 @@ class VectorStore:
|
||||
await self._save()
|
||||
|
||||
log.info(f"Deleted document {document_id} ({len(indices_to_remove)} chunks)")
|
||||
|
||||
async def delete_by_source_url(self, source_url: str) -> int:
|
||||
"""
|
||||
Delete all chunks from a specific source URL.
|
||||
|
||||
Args:
|
||||
source_url: The source URL to delete
|
||||
|
||||
Returns:
|
||||
Number of deleted chunks
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
# Find indices to remove
|
||||
indices_to_remove = [
|
||||
i
|
||||
for i, metadata in enumerate(self._metadata)
|
||||
if metadata.get("source_url") == source_url
|
||||
]
|
||||
|
||||
# Remove in reverse order to maintain indices
|
||||
for i in sorted(indices_to_remove, reverse=True):
|
||||
self._chunks.pop(i)
|
||||
self._embeddings.pop(i)
|
||||
self._metadata.pop(i)
|
||||
self._ids.pop(i)
|
||||
|
||||
# Save changes
|
||||
await self._save()
|
||||
|
||||
log.info(f"Deleted {len(indices_to_remove)} chunks from source: {source_url}")
|
||||
return len(indices_to_remove)
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about the vector store."""
|
||||
self._ensure_initialized()
|
||||
|
||||
# Count unique sources
|
||||
sources = set()
|
||||
source_urls = set()
|
||||
for metadata in self._metadata:
|
||||
if metadata.get("source"):
|
||||
sources.add(metadata.get("source"))
|
||||
if metadata.get("source_url"):
|
||||
source_urls.add(metadata.get("source_url"))
|
||||
|
||||
return {
|
||||
"total_chunks": len(self._chunks),
|
||||
"unique_sources": len(sources),
|
||||
"unique_urls": len(source_urls),
|
||||
"embedding_dimension": len(self._embeddings[0]) if self._embeddings else 0,
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user