diff --git a/main.py b/main.py index 7d26e05..01e64d1 100644 --- a/main.py +++ b/main.py @@ -4,10 +4,10 @@ DocRAG - OpenAI-Compatible RAG Server This application presents itself as a standard OpenAI API server that can be used with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it: -1. Processes user queries through a RAG system -2. Retrieves relevant context from a knowledge base -3. Passes the enriched context to GLM-4.7-Flash for response generation -4. Optionally uses tools like website_downloader for enhanced capabilities +1. Detects URLs in user messages and auto-downloads websites +2. Ingests website content into the RAG knowledge base +3. Retrieves relevant context from the knowledge base +4. Passes the enriched context to GLM-4.7-Flash for response generation The user sees a normal chat experience, but the system is actually doing sophisticated RAG operations in the background. @@ -19,6 +19,7 @@ import asyncio import json import logging import os +import re import sys import time import uuid @@ -294,6 +295,94 @@ async def chat_completions(request: ChatCompletionRequest): raise HTTPException(status_code=500, detail=str(e)) +# ============================================================================= +# URL Detection and Website Download +# ============================================================================= + +def extract_urls_from_message(message: str) -> list[str]: + """Extract URLs from a message, including domains without scheme.""" + # Match full URLs + url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+' + full_urls = re.findall(url_pattern, message) + + # Match domain names (with or without www) + domain_pattern = r'(?:^|[^\w/-])(?:www\.)?([a-zA-Z0-9][-a-zA-Z0-9]*\.[a-zA-Z]{2,}(?:\.[a-zA-Z]{2,})?)(?:/[^\s]*)?(?=[^\w/-]|$)' + domains = re.findall(domain_pattern, message) + + urls = list(full_urls) + for domain in domains: + # Check if it's a valid domain (not just a word) + if '.' in domain and len(domain) > 4: + normalized = f"https://{domain}" if not domain.startswith(('http://', 'https://')) else domain + if normalized not in urls: + urls.append(normalized) + + return urls + + +def should_download_website(message: str, urls: list[str]) -> bool: + """Determine if the user wants to access content from a website.""" + if not urls: + return False + + message_lower = message.lower() + + # Keywords indicating user wants website content + access_keywords = [ + 'go to', 'visit', 'check', 'look at', 'browse', 'open', + 'what is on', 'tell me about', 'give me', 'show me', 'get', + 'headlines', 'content', 'information from', 'from the', 'from', + 'on the website', 'on the site', 'the website', 'the site', 'website', + 'summarize', 'extract', 'read', 'analyze', 'find', 'search', + 'what does', 'what\'s on', 'what is', 'tell me', 'about', + 'news', 'articles', 'posts', 'page', 'pages', + ] + + # Check if message contains access intent + has_access_intent = any(kw in message_lower for kw in access_keywords) + + # Also trigger if URL is directly mentioned with a question + has_question = '?' in message or any(qw in message_lower for qw in ['what', 'how', 'who', 'where', 'when', 'why']) + + return (has_access_intent or has_question) and len(urls) > 0 + + +async def download_website_if_needed(user_message: str) -> dict[str, Any]: + """ + Download website if user is asking about one. + Returns download info if successful. + """ + urls = extract_urls_from_message(user_message) + + if not should_download_website(user_message, urls): + return {"downloaded": False, "reason": "No website access intent detected"} + + if not state.rag_system: + return {"downloaded": False, "reason": "RAG system not initialized"} + + for url in urls: + try: + log.info(f"Auto-downloading website: {url}") + result = await state.rag_system.download_and_ingest_website( + url=url, + max_pages=30, + ) + + if result.get("success"): + log.info(f"Successfully downloaded {url}: {result.get('total_chunks', 0)} chunks") + return { + "downloaded": True, + "url": url, + "chunks": result.get("total_chunks", 0), + "pages": result.get("pages_processed", 0), + "local_path": result.get("local_path"), + } + except Exception as e: + log.warning(f"Failed to download {url}: {e}") + + return {"downloaded": False, "reason": "All download attempts failed"} + + # ============================================================================= # Chat Completion Logic # ============================================================================= @@ -312,7 +401,12 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat if not user_message: raise HTTPException(status_code=400, detail="No user message found") - # Step 1: RAG Retrieval + # Step 1: Download website if user is asking about one (BEFORE RAG retrieval) + download_info = await download_website_if_needed(user_message) + if download_info.get("downloaded"): + log.info(f"Website auto-downloaded: {download_info.get('url')}") + + # Step 2: RAG Retrieval (now includes newly downloaded content) context = "" sources = [] if state.rag_system: @@ -327,22 +421,14 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat except Exception as e: log.warning(f"RAG retrieval failed: {e}") - # Step 2: Build enhanced prompt with context - enhanced_messages = build_enhanced_messages(messages, context, sources) - - # Step 3: Check for tool usage - tool_calls_made = [] - if config.ENABLE_TOOLS and state.tool_manager: - tool_calls_made = await check_and_execute_tools( - user_message, enhanced_messages, request.tools - ) + # Step 3: Build enhanced prompt with context + enhanced_messages = build_enhanced_messages(messages, context, sources, download_info) # Step 4: Generate response with upstream LLM response_content = await generate_response( enhanced_messages, temperature=request.temperature, max_tokens=request.max_tokens, - tools=request.tools, ) # Step 5: Build and return response @@ -354,7 +440,6 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat message=ChatMessage( role="assistant", content=response_content, - tool_calls=tool_calls_made if tool_calls_made else None, ), finish_reason="stop", ) @@ -384,7 +469,12 @@ async def stream_chat_completion( yield f"data: {json.dumps({'error': 'No user message found'})}\n\n" return - # Step 1: RAG Retrieval + # Step 1: Download website if user is asking about one (BEFORE RAG retrieval) + download_info = await download_website_if_needed(user_message) + if download_info.get("downloaded"): + log.info(f"Website auto-downloaded: {download_info.get('url')}") + + # Step 2: RAG Retrieval (now includes newly downloaded content) context = "" sources = [] if state.rag_system: @@ -399,10 +489,10 @@ async def stream_chat_completion( except Exception as e: log.warning(f"RAG retrieval failed: {e}") - # Step 2: Build enhanced prompt with context - enhanced_messages = build_enhanced_messages(messages, context, sources) + # Step 3: Build enhanced prompt with context + enhanced_messages = build_enhanced_messages(messages, context, sources, download_info) - # Step 3: Stream response from upstream LLM + # Step 4: Stream response from upstream LLM created = int(time.time()) try: @@ -456,9 +546,12 @@ async def stream_chat_completion( else: # Mock streaming response for testing mock_response = f"I understand you're asking about: {user_message}\n\n" + if download_info.get("downloaded"): + mock_response += f"I have downloaded and analyzed {download_info.get('url')}.\n" + mock_response += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} chunks.\n\n" if context: - mock_response += f"Based on my knowledge base, I found relevant information that I'm using to help answer your question.\n\n" - mock_response += "However, I'm currently running in demo mode without an upstream LLM connection. Please configure ZAI_API_KEY for full functionality." + mock_response += f"Based on my knowledge base, here's what I found:\n\n{context[:1000]}...\n\n" + mock_response += "\n\n[Demo mode - configure ZAI_API_KEY for full LLM responses]" for char in mock_response: yield f"data: {json.dumps({ @@ -496,22 +589,30 @@ def build_enhanced_messages( messages: list[ChatMessage], context: str, sources: list[str], + download_info: dict = None, ) -> list[ChatMessage]: """Build enhanced messages with RAG context.""" enhanced = [] # Add system message with RAG context system_content = ( - "You are a helpful AI assistant with access to a knowledge base. " - "Use the provided context to give accurate and helpful responses. " - "If the context doesn't contain relevant information, use your general knowledge " - "but indicate when you're doing so." + "You are a helpful AI assistant with the ability to access and analyze websites on-demand. " + "When a user asks about a website, you can download and analyze its content directly. " + "Use the provided context from the knowledge base to give accurate and helpful responses. " + "If context from a website is provided, use it to answer the user's question directly with specific information. " + "Be helpful, detailed, and provide the specific information the user is asking for (headlines, summaries, etc.)." ) + if download_info and download_info.get("downloaded"): + system_content += f"\n\n--- Website Access ---\n" + system_content += f"I have successfully downloaded and analyzed the website: {download_info.get('url')}\n" + system_content += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} text chunks.\n" + system_content += "The context below contains the actual content from this website. Use it to answer the user's question." + if context: system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n" if sources: - system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources) + system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources[:10]) enhanced.append(ChatMessage(role="system", content=system_content)) @@ -527,7 +628,6 @@ async def generate_response( messages: list[ChatMessage], temperature: float = 0.7, max_tokens: int = 4096, - tools: Optional[list[dict]] = None, ) -> str: """Generate response using upstream LLM.""" if state.zai_client: @@ -566,55 +666,6 @@ async def generate_response( return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality." -async def check_and_execute_tools( - user_message: str, - messages: list[ChatMessage], - available_tools: Optional[list[dict]], -) -> list[dict]: - """Check if tools should be used and execute them.""" - if not state.tool_manager or not available_tools: - return [] - - tool_calls = [] - - # Simple keyword-based tool detection - # In a production system, you'd use the LLM to decide tool usage - message_lower = user_message.lower() - - # Check for website download intent - use RAG system for full integration - if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site", "ingest site"]): - # Extract URL from message - import re - url_pattern = r'https?://[^\s]+' - urls = re.findall(url_pattern, user_message) - - if urls and state.rag_system: - # Use RAG system's integrated website downloader - try: - result = await state.rag_system.download_and_ingest_website( - url=urls[0], - max_pages=20, # Reasonable default - ) - - tool_calls.append({ - "id": f"call_{uuid.uuid4().hex[:24]}", - "type": "function", - "function": { - "name": "website_downloader", - "arguments": json.dumps({ - "url": urls[0], - "success": result.get("success"), - "chunks_ingested": result.get("total_chunks", 0), - }), - } - }) - log.info(f"Downloaded and ingested website: {urls[0]} -> {result.get('total_chunks', 0)} chunks") - except Exception as e: - log.error(f"Website download failed: {e}") - - return tool_calls - - # ============================================================================= # Document Management Endpoints # ============================================================================= @@ -702,12 +753,7 @@ async def download_website(request: WebsiteDownloadRequest): @app.post("/v1/documents/url") async def add_document_from_url(request: dict): - """ - Add a document from URL to the knowledge base. - - NOTE: For websites, prefer using /v1/documents/website instead - as it downloads the entire site and provides better context. - """ + """Add a document from URL to the knowledge base.""" if not state.rag_system: raise HTTPException(status_code=503, detail="RAG system not initialized") @@ -758,11 +804,9 @@ async def get_site_info(url: str): raise HTTPException(status_code=503, detail="RAG system not initialized") try: - # URL will be passed as path parameter, need to decode from urllib.parse import unquote decoded_url = unquote(url) - # Add scheme if missing if not decoded_url.startswith(("http://", "https://")): decoded_url = "https://" + decoded_url @@ -799,11 +843,9 @@ async def delete_site(url: str): raise HTTPException(status_code=503, detail="RAG system not initialized") try: - # URL will be passed as path parameter, need to decode from urllib.parse import unquote decoded_url = unquote(url) - # Add scheme if missing if not decoded_url.startswith(("http://", "https://")): decoded_url = "https://" + decoded_url @@ -848,7 +890,7 @@ async def root(): return { "name": "DocRAG API", "version": "1.0.0", - "description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash", + "description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash. Auto-downloads and analyzes websites when users ask about them.", "endpoints": { "chat": "/v1/chat/completions", "models": "/v1/models",