diff --git a/main.py b/main.py index 3aff528..7d26e05 100644 --- a/main.py +++ b/main.py @@ -581,27 +581,36 @@ async def check_and_execute_tools( # In a production system, you'd use the LLM to decide tool usage message_lower = user_message.lower() - # Check for website download intent - if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site"]): + # Check for website download intent - use RAG system for full integration + if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site", "ingest site"]): # Extract URL from message import re url_pattern = r'https?://[^\s]+' urls = re.findall(url_pattern, user_message) - if urls: - tool_result = state.tool_manager.execute_tool( - "website_downloader", - {"url": urls[0], "max_pages": 10} - ) - tool_calls.append({ - "id": f"call_{uuid.uuid4().hex[:24]}", - "type": "function", - "function": { - "name": "website_downloader", - "arguments": json.dumps({"url": urls[0]}), - } - }) - log.info(f"Executed website_downloader tool: {tool_result}") + if urls and state.rag_system: + # Use RAG system's integrated website downloader + try: + result = await state.rag_system.download_and_ingest_website( + url=urls[0], + max_pages=20, # Reasonable default + ) + + tool_calls.append({ + "id": f"call_{uuid.uuid4().hex[:24]}", + "type": "function", + "function": { + "name": "website_downloader", + "arguments": json.dumps({ + "url": urls[0], + "success": result.get("success"), + "chunks_ingested": result.get("total_chunks", 0), + }), + } + }) + log.info(f"Downloaded and ingested website: {urls[0]} -> {result.get('total_chunks', 0)} chunks") + except Exception as e: + log.error(f"Website download failed: {e}") return tool_calls @@ -638,9 +647,67 @@ async def upload_document(request: Request): raise HTTPException(status_code=500, detail=str(e)) +class WebsiteDownloadRequest(BaseModel): + """Request model for website download.""" + url: str + max_pages: int = 50 + threads: int = 6 + download_external_assets: bool = False + external_domains: Optional[list[str]] = None + + +@app.post("/v1/documents/website") +async def download_website(request: WebsiteDownloadRequest): + """ + Download a website and ingest it into the knowledge base. + + This is the PRIMARY way to add content to the RAG system. + Uses the website_downloader_tool to download and process websites. + """ + if not state.rag_system: + raise HTTPException(status_code=503, detail="RAG system not initialized") + + try: + log.info(f"Downloading website: {request.url}") + + result = await state.rag_system.download_and_ingest_website( + url=request.url, + max_pages=request.max_pages, + threads=request.threads, + download_external_assets=request.download_external_assets, + external_domains=request.external_domains, + ) + + if result.get("success"): + return { + "success": True, + "message": f"Website downloaded and ingested: {request.url}", + "url": request.url, + "local_path": result.get("local_path"), + "pages_processed": result.get("pages_processed", 0), + "total_chunks": result.get("total_chunks", 0), + } + else: + raise HTTPException( + status_code=500, + detail=result.get("message", "Website download failed") + ) + + except HTTPException: + raise + except Exception as e: + log.exception("Website download failed") + raise HTTPException(status_code=500, detail=str(e)) + + @app.post("/v1/documents/url") async def add_document_from_url(request: dict): - """Add a document from URL to the knowledge base.""" + """ + Add a document from URL to the knowledge base. + + NOTE: For websites, prefer using /v1/documents/website instead + as it downloads the entire site and provides better context. + """ if not state.rag_system: raise HTTPException(status_code=503, detail="RAG system not initialized") @@ -670,6 +737,47 @@ async def list_documents(): raise HTTPException(status_code=500, detail=str(e)) +@app.get("/v1/documents/sites") +async def list_downloaded_sites(): + """List all downloaded websites in the knowledge base.""" + if not state.rag_system: + raise HTTPException(status_code=503, detail="RAG system not initialized") + + try: + sites = await state.rag_system.list_downloaded_sites() + return {"sites": sites} + except Exception as e: + log.exception("Site listing failed") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/v1/documents/sites/{url:path}") +async def get_site_info(url: str): + """Get information about a specific downloaded site.""" + if not state.rag_system: + raise HTTPException(status_code=503, detail="RAG system not initialized") + + try: + # URL will be passed as path parameter, need to decode + from urllib.parse import unquote + decoded_url = unquote(url) + + # Add scheme if missing + if not decoded_url.startswith(("http://", "https://")): + decoded_url = "https://" + decoded_url + + site_info = state.rag_system.get_site_info(decoded_url) + if site_info: + return {"site": site_info} + else: + raise HTTPException(status_code=404, detail="Site not found") + except HTTPException: + raise + except Exception as e: + log.exception("Site info retrieval failed") + raise HTTPException(status_code=500, detail=str(e)) + + @app.delete("/v1/documents/{doc_id}") async def delete_document(doc_id: str): """Delete a document from the knowledge base.""" @@ -684,6 +792,40 @@ async def delete_document(doc_id: str): raise HTTPException(status_code=500, detail=str(e)) +@app.delete("/v1/documents/sites/{url:path}") +async def delete_site(url: str): + """Delete a downloaded website and all its content from the knowledge base.""" + if not state.rag_system: + raise HTTPException(status_code=503, detail="RAG system not initialized") + + try: + # URL will be passed as path parameter, need to decode + from urllib.parse import unquote + decoded_url = unquote(url) + + # Add scheme if missing + if not decoded_url.startswith(("http://", "https://")): + decoded_url = "https://" + decoded_url + + result = await state.rag_system.delete_site(decoded_url) + + if result.get("success"): + return { + "success": True, + "message": f"Site {decoded_url} deleted", + "deleted_chunks": result.get("deleted_chunks", 0), + "deleted_path": result.get("deleted_path"), + } + else: + raise HTTPException(status_code=404, detail=result.get("message", "Site not found")) + + except HTTPException: + raise + except Exception as e: + log.exception("Site deletion failed") + raise HTTPException(status_code=500, detail=str(e)) + + # ============================================================================= # Health and Status Endpoints # ============================================================================= @@ -711,6 +853,8 @@ async def root(): "chat": "/v1/chat/completions", "models": "/v1/models", "documents": "/v1/documents", + "download_website": "/v1/documents/website", + "list_sites": "/v1/documents/sites", "health": "/health", }, } diff --git a/rag/__init__.py b/rag/__init__.py index b6c4e91..acf1d70 100644 --- a/rag/__init__.py +++ b/rag/__init__.py @@ -2,6 +2,7 @@ RAG System - Retrieval Augmented Generation This module provides the core RAG functionality for DocRAG, including: +- Website downloading and ingestion via website_downloader_tool - Document processing and chunking - Vector storage and similarity search - Context retrieval for enhanced prompts @@ -14,20 +15,26 @@ import logging import os from pathlib import Path from typing import Any, Optional +from urllib.parse import urlparse from .document_processor import DocumentProcessor from .vector_store import VectorStore from .retriever import Retriever +# Import the website downloader tool +from website_downloader_tool import website_downloader + log = logging.getLogger(__name__) class RAGSystem: """ - Main RAG system that coordinates document processing, storage, and retrieval. + Main RAG system that coordinates website downloading, document processing, + storage, and retrieval. This class provides a unified interface for: - - Adding documents to the knowledge base + - Downloading websites using website_downloader_tool + - Processing downloaded content into the knowledge base - Querying for relevant context - Managing the document lifecycle """ @@ -37,12 +44,14 @@ class RAGSystem: embedding_model: str = "text-embedding-3-small", vector_store_path: str = "./data/vectors", documents_path: str = "./data/documents", + downloaded_sites_path: str = "./data/downloaded_sites", chunk_size: int = 1000, chunk_overlap: int = 200, ): self.embedding_model = embedding_model self.vector_store_path = Path(vector_store_path) self.documents_path = Path(documents_path) + self.downloaded_sites_path = Path(downloaded_sites_path) self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap @@ -50,6 +59,9 @@ class RAGSystem: self._document_processor: Optional[DocumentProcessor] = None self._vector_store: Optional[VectorStore] = None self._retriever: Optional[Retriever] = None + + # Track downloaded sites with their source URLs + self._site_registry: dict[str, dict[str, Any]] = {} async def initialize(self) -> None: """Initialize the RAG system components.""" @@ -61,6 +73,7 @@ class RAGSystem: # Create directories self.vector_store_path.mkdir(parents=True, exist_ok=True) self.documents_path.mkdir(parents=True, exist_ok=True) + self.downloaded_sites_path.mkdir(parents=True, exist_ok=True) # Initialize document processor self._document_processor = DocumentProcessor( @@ -73,12 +86,16 @@ class RAGSystem: persist_directory=str(self.vector_store_path), embedding_model=self.embedding_model, ) + await self._vector_store.initialize() # Initialize retriever self._retriever = Retriever( vector_store=self._vector_store, ) + # Load existing site registry + await self._load_site_registry() + self._initialized = True log.info("RAG system initialized successfully") @@ -86,6 +103,7 @@ class RAGSystem: """Close the RAG system and release resources.""" if self._vector_store: await self._vector_store.close() + await self._save_site_registry() self._initialized = False log.info("RAG system closed") @@ -94,6 +112,228 @@ class RAGSystem: if not self._initialized: raise RuntimeError("RAG system not initialized. Call initialize() first.") + async def _load_site_registry(self) -> None: + """Load the site registry from disk.""" + import json + registry_file = self.downloaded_sites_path / "site_registry.json" + if registry_file.exists(): + try: + with open(registry_file, "r") as f: + self._site_registry = json.load(f) + log.info(f"Loaded site registry with {len(self._site_registry)} sites") + except Exception as e: + log.warning(f"Failed to load site registry: {e}") + self._site_registry = {} + + async def _save_site_registry(self) -> None: + """Save the site registry to disk.""" + import json + registry_file = self.downloaded_sites_path / "site_registry.json" + try: + with open(registry_file, "w") as f: + json.dump(self._site_registry, f, indent=2, ensure_ascii=False) + log.info(f"Saved site registry with {len(self._site_registry)} sites") + except Exception as e: + log.error(f"Failed to save site registry: {e}") + + async def download_and_ingest_website( + self, + url: str, + max_pages: int = 50, + threads: int = 6, + download_external_assets: bool = False, + external_domains: Optional[list[str]] = None, + ) -> dict[str, Any]: + """ + Download a website using website_downloader_tool and ingest all content + into the knowledge base. + + This is the PRIMARY method for adding content to the RAG system. + + Args: + url: URL of the website to download + max_pages: Maximum number of pages to crawl + threads: Number of concurrent download threads + download_external_assets: Whether to download external assets + external_domains: List of external domains to allow + + Returns: + Dictionary with download and ingestion results + """ + self._ensure_initialized() + + log.info(f"Downloading website: {url}") + + # Use website_downloader_tool to download the site + download_result = website_downloader( + url=url, + destination=str(self.downloaded_sites_path / self._get_site_folder(url)), + max_pages=max_pages, + threads=threads, + download_external_assets=download_external_assets, + external_domains=external_domains, + ) + + if not download_result.get("success"): + log.error(f"Website download failed: {download_result.get('message')}") + return { + "success": False, + "message": download_result.get("message", "Download failed"), + "url": url, + } + + output_dir = download_result.get("output_directory", "") + stats = download_result.get("stats", {}) + + log.info(f"Website downloaded to: {output_dir}") + + # Process all HTML files from the downloaded site + ingestion_result = await self._ingest_downloaded_site( + site_path=Path(output_dir), + source_url=url, + ) + + # Register the site + site_id = self._generate_site_id(url) + self._site_registry[site_id] = { + "url": url, + "local_path": output_dir, + "pages_downloaded": stats.get("pages_crawled", 0), + "assets_downloaded": stats.get("assets_downloaded", 0), + "chunks_ingested": ingestion_result.get("total_chunks", 0), + "timestamp": self._get_timestamp(), + } + await self._save_site_registry() + + return { + "success": True, + "url": url, + "local_path": output_dir, + "pages_processed": ingestion_result.get("pages_processed", 0), + "total_chunks": ingestion_result.get("total_chunks", 0), + "stats": stats, + } + + async def _ingest_downloaded_site( + self, + site_path: Path, + source_url: str, + ) -> dict[str, Any]: + """ + Ingest all HTML files from a downloaded website into the knowledge base. + + Args: + site_path: Path to the downloaded website directory + source_url: Original URL of the website + + Returns: + Dictionary with ingestion statistics + """ + pages_processed = 0 + total_chunks = 0 + errors = [] + + # Find all HTML files + html_files = list(site_path.rglob("*.html")) + log.info(f"Found {len(html_files)} HTML files in {site_path}") + + for html_file in html_files: + try: + # Read the HTML file + content = html_file.read_bytes() + + # Calculate relative path for the page pointer + relative_path = html_file.relative_to(site_path) + page_url = self._reconstruct_page_url(source_url, relative_path) + + # Extract text from HTML + text_content = await self._document_processor.extract_text_from_html(content) + + if not text_content.strip(): + continue + + # Process into chunks + doc_info = await self._document_processor.process( + content=content, + filename=str(html_file), + metadata={ + "source_url": source_url, + "page_url": page_url, + "local_path": str(html_file), + "relative_path": str(relative_path), + "source_type": "downloaded_website", + }, + ) + + # Store chunks in vector store with pointers + if doc_info.get("chunks"): + # Add source pointer to each chunk's metadata + for metadata in doc_info.get("metadatas", []): + metadata["source_url"] = source_url + metadata["page_url"] = page_url + metadata["local_path"] = str(html_file) + metadata["pointer"] = { + "type": "downloaded_page", + "url": page_url, + "local_file": str(html_file), + } + + await self._vector_store.add_chunks( + chunks=doc_info["chunks"], + metadatas=doc_info.get("metadatas", []), + ids=doc_info.get("ids", []), + ) + + total_chunks += len(doc_info["chunks"]) + pages_processed += 1 + + log.debug(f"Ingested: {relative_path} -> {len(doc_info['chunks'])} chunks") + + except Exception as e: + errors.append(f"{html_file}: {str(e)}") + log.warning(f"Failed to process {html_file}: {e}") + + log.info(f"Ingestion complete: {pages_processed} pages, {total_chunks} chunks") + + return { + "pages_processed": pages_processed, + "total_chunks": total_chunks, + "errors": errors, + } + + def _get_site_folder(self, url: str) -> str: + """Generate a folder name for a site from its URL.""" + parsed = urlparse(url) + # Use domain name as folder, replace dots with underscores + folder = parsed.netloc.replace(".", "_").replace(":", "_") + return folder + + def _generate_site_id(self, url: str) -> str: + """Generate a unique ID for a site.""" + import hashlib + return hashlib.md5(url.encode()).hexdigest()[:16] + + def _reconstruct_page_url(self, base_url: str, relative_path: Path) -> str: + """Reconstruct the original URL for a downloaded page.""" + parsed = urlparse(base_url) + # Convert relative path back to URL path + path_parts = list(relative_path.parts) + + # Handle index.html as directory + if path_parts and path_parts[-1] == "index.html": + path_parts = path_parts[:-1] + # Remove .html extension from other files + elif path_parts and path_parts[-1].endswith(".html"): + path_parts[-1] = path_parts[-1][:-5] + + url_path = "/".join(path_parts) + return f"{parsed.scheme}://{parsed.netloc}/{url_path}" + + def _get_timestamp(self) -> str: + """Get current timestamp in ISO format.""" + from datetime import datetime + return datetime.utcnow().isoformat() + async def add_document( self, content: bytes, @@ -102,6 +342,8 @@ class RAGSystem: ) -> dict[str, Any]: """ Add a document to the knowledge base. + + Note: For websites, prefer using download_and_ingest_website() instead. Args: content: Raw document content @@ -120,8 +362,15 @@ class RAGSystem: metadata=metadata, ) - # Store chunks in vector store + # Store chunks in vector store with pointers if doc_info.get("chunks"): + # Add pointer metadata + for metadata in doc_info.get("metadatas", []): + metadata["pointer"] = { + "type": "uploaded_file", + "filename": filename, + } + await self._vector_store.add_chunks( chunks=doc_info["chunks"], metadatas=doc_info.get("metadatas", []), @@ -131,37 +380,12 @@ class RAGSystem: log.info(f"Added document '{filename}' with {len(doc_info.get('chunks', []))} chunks") return {"chunks": len(doc_info.get("chunks", [])), "document_id": doc_info.get("document_id")} - async def add_document_from_url(self, url: str) -> dict[str, Any]: - """ - Add a document from a URL to the knowledge base. - - Args: - url: URL to fetch and process - - Returns: - Dictionary with processing results - """ - self._ensure_initialized() - - # Fetch content from URL - import aiohttp - async with aiohttp.ClientSession() as session: - async with session.get(url, timeout=30) as response: - response.raise_for_status() - content = await response.read() - - # Extract filename from URL - from urllib.parse import urlparse - parsed = urlparse(url) - filename = os.path.basename(parsed.path) or "webpage.html" - - return await self.add_document(content=content, filename=filename, metadata={"source_url": url}) - async def query( self, query: str, top_k: int = 5, filter_metadata: Optional[dict] = None, + include_pointers: bool = True, ) -> dict[str, Any]: """ Query the knowledge base for relevant context. @@ -170,9 +394,10 @@ class RAGSystem: query: Query string top_k: Number of results to return filter_metadata: Optional metadata filters + include_pointers: Whether to include page pointers in results Returns: - Dictionary with context and sources + Dictionary with context, sources, and page pointers """ self._ensure_initialized() @@ -183,20 +408,37 @@ class RAGSystem: filter_metadata=filter_metadata, ) - # Build context string + # Build context string and collect pointers context_parts = [] sources = [] + pointers = [] for i, result in enumerate(results): context_parts.append(f"[{i+1}] {result['content']}") - if result.get("metadata", {}).get("source"): - sources.append(result["metadata"]["source"]) + + metadata = result.get("metadata", {}) + + # Collect source info + if metadata.get("page_url"): + sources.append(metadata["page_url"]) + elif metadata.get("source_url"): + sources.append(metadata["source_url"]) + elif metadata.get("source"): + sources.append(metadata["source"]) + + # Collect pointer info + if include_pointers and metadata.get("pointer"): + pointer = metadata["pointer"] + pointer["chunk_id"] = result.get("id") + pointer["score"] = result.get("score") + pointers.append(pointer) context = "\n\n".join(context_parts) return { "context": context, "sources": list(set(sources)), + "pointers": pointers, "num_results": len(results), "results": results, } @@ -204,7 +446,23 @@ class RAGSystem: async def list_documents(self) -> list[dict[str, Any]]: """List all documents in the knowledge base.""" self._ensure_initialized() - return await self._vector_store.list_documents() + + # Get documents from vector store + docs = await self._vector_store.list_documents() + + # Enrich with site registry info + for doc in docs: + source_url = doc.get("source_url") + if source_url: + site_id = self._generate_site_id(source_url) + if site_id in self._site_registry: + doc["site_info"] = self._site_registry[site_id] + + return docs + + async def list_downloaded_sites(self) -> list[dict[str, Any]]: + """List all downloaded websites.""" + return list(self._site_registry.values()) async def delete_document(self, document_id: str) -> None: """Delete a document from the knowledge base.""" @@ -212,6 +470,52 @@ class RAGSystem: await self._vector_store.delete_document(document_id) log.info(f"Deleted document {document_id}") + async def delete_site(self, url: str) -> dict[str, Any]: + """ + Delete a downloaded website and all its content from the knowledge base. + + Args: + url: URL of the site to delete + + Returns: + Dictionary with deletion results + """ + self._ensure_initialized() + + site_id = self._generate_site_id(url) + + if site_id not in self._site_registry: + return {"success": False, "message": f"Site not found: {url}"} + + site_info = self._site_registry[site_id] + local_path = site_info.get("local_path") + + # Delete from vector store + deleted_chunks = await self._vector_store.delete_by_source_url(url) + + # Delete local files + import shutil + if local_path and Path(local_path).exists(): + shutil.rmtree(local_path) + + # Remove from registry + del self._site_registry[site_id] + await self._save_site_registry() + + log.info(f"Deleted site: {url}") + + return { + "success": True, + "url": url, + "deleted_chunks": deleted_chunks, + "deleted_path": local_path, + } + + def get_site_info(self, url: str) -> Optional[dict[str, Any]]: + """Get information about a downloaded site.""" + site_id = self._generate_site_id(url) + return self._site_registry.get(site_id) + # Global RAG system instance _rag_system: Optional[RAGSystem] = None @@ -221,6 +525,7 @@ async def get_rag_system( embedding_model: str = "text-embedding-3-small", vector_store_path: str = "./data/vectors", documents_path: str = "./data/documents", + downloaded_sites_path: str = "./data/downloaded_sites", chunk_size: int = 1000, chunk_overlap: int = 200, ) -> RAGSystem: @@ -231,6 +536,7 @@ async def get_rag_system( embedding_model: Name of the embedding model vector_store_path: Path to vector store documents_path: Path to document storage + downloaded_sites_path: Path to downloaded websites chunk_size: Size of document chunks chunk_overlap: Overlap between chunks @@ -244,6 +550,7 @@ async def get_rag_system( embedding_model=embedding_model, vector_store_path=vector_store_path, documents_path=documents_path, + downloaded_sites_path=downloaded_sites_path, chunk_size=chunk_size, chunk_overlap=chunk_overlap, ) diff --git a/rag/document_processor.py b/rag/document_processor.py index 0cdb838..83430b4 100644 --- a/rag/document_processor.py +++ b/rag/document_processor.py @@ -167,6 +167,18 @@ class DocumentProcessor: log.error(f"HTML extraction failed: {e}") return "" + async def extract_text_from_html(self, content: bytes) -> str: + """ + Public method to extract text from HTML content. + + Args: + content: Raw HTML content + + Returns: + Extracted text content + """ + return await self._extract_html(content) + async def _extract_docx(self, content: bytes) -> str: """Extract text from DOCX.""" try: diff --git a/rag/vector_store.py b/rag/vector_store.py index ff8b181..52eafa3 100644 --- a/rag/vector_store.py +++ b/rag/vector_store.py @@ -283,3 +283,55 @@ class VectorStore: await self._save() log.info(f"Deleted document {document_id} ({len(indices_to_remove)} chunks)") + + async def delete_by_source_url(self, source_url: str) -> int: + """ + Delete all chunks from a specific source URL. + + Args: + source_url: The source URL to delete + + Returns: + Number of deleted chunks + """ + self._ensure_initialized() + + # Find indices to remove + indices_to_remove = [ + i + for i, metadata in enumerate(self._metadata) + if metadata.get("source_url") == source_url + ] + + # Remove in reverse order to maintain indices + for i in sorted(indices_to_remove, reverse=True): + self._chunks.pop(i) + self._embeddings.pop(i) + self._metadata.pop(i) + self._ids.pop(i) + + # Save changes + await self._save() + + log.info(f"Deleted {len(indices_to_remove)} chunks from source: {source_url}") + return len(indices_to_remove) + + async def get_stats(self) -> dict[str, Any]: + """Get statistics about the vector store.""" + self._ensure_initialized() + + # Count unique sources + sources = set() + source_urls = set() + for metadata in self._metadata: + if metadata.get("source"): + sources.add(metadata.get("source")) + if metadata.get("source_url"): + source_urls.add(metadata.get("source_url")) + + return { + "total_chunks": len(self._chunks), + "unique_sources": len(sources), + "unique_urls": len(source_urls), + "embedding_dimension": len(self._embeddings[0]) if self._embeddings else 0, + }