Fix: Auto-download websites BEFORE RAG retrieval
Key changes: - Add URL extraction and detection functions - Download websites BEFORE RAG retrieval (not after) - Expand trigger keywords to include common phrases like 'go to', 'headlines', etc. - Update system prompt to tell LLM it CAN access websites - Improve streaming response handling Now when user asks 'go to orovillemr.com and give me the headlines': 1. System detects URL and access intent 2. Downloads and ingests website content 3. RAG retrieves relevant content 4. LLM generates response with actual website content
This commit is contained in:
parent
6aecc4b231
commit
10e61dd2f1
218
main.py
218
main.py
@ -4,10 +4,10 @@ DocRAG - OpenAI-Compatible RAG Server
|
||||
|
||||
This application presents itself as a standard OpenAI API server that can be used
|
||||
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
|
||||
1. Processes user queries through a RAG system
|
||||
2. Retrieves relevant context from a knowledge base
|
||||
3. Passes the enriched context to GLM-4.7-Flash for response generation
|
||||
4. Optionally uses tools like website_downloader for enhanced capabilities
|
||||
1. Detects URLs in user messages and auto-downloads websites
|
||||
2. Ingests website content into the RAG knowledge base
|
||||
3. Retrieves relevant context from the knowledge base
|
||||
4. Passes the enriched context to GLM-4.7-Flash for response generation
|
||||
|
||||
The user sees a normal chat experience, but the system is actually doing
|
||||
sophisticated RAG operations in the background.
|
||||
@ -19,6 +19,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
@ -294,6 +295,94 @@ async def chat_completions(request: ChatCompletionRequest):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# URL Detection and Website Download
|
||||
# =============================================================================
|
||||
|
||||
def extract_urls_from_message(message: str) -> list[str]:
|
||||
"""Extract URLs from a message, including domains without scheme."""
|
||||
# Match full URLs
|
||||
url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+'
|
||||
full_urls = re.findall(url_pattern, message)
|
||||
|
||||
# Match domain names (with or without www)
|
||||
domain_pattern = r'(?:^|[^\w/-])(?:www\.)?([a-zA-Z0-9][-a-zA-Z0-9]*\.[a-zA-Z]{2,}(?:\.[a-zA-Z]{2,})?)(?:/[^\s]*)?(?=[^\w/-]|$)'
|
||||
domains = re.findall(domain_pattern, message)
|
||||
|
||||
urls = list(full_urls)
|
||||
for domain in domains:
|
||||
# Check if it's a valid domain (not just a word)
|
||||
if '.' in domain and len(domain) > 4:
|
||||
normalized = f"https://{domain}" if not domain.startswith(('http://', 'https://')) else domain
|
||||
if normalized not in urls:
|
||||
urls.append(normalized)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def should_download_website(message: str, urls: list[str]) -> bool:
|
||||
"""Determine if the user wants to access content from a website."""
|
||||
if not urls:
|
||||
return False
|
||||
|
||||
message_lower = message.lower()
|
||||
|
||||
# Keywords indicating user wants website content
|
||||
access_keywords = [
|
||||
'go to', 'visit', 'check', 'look at', 'browse', 'open',
|
||||
'what is on', 'tell me about', 'give me', 'show me', 'get',
|
||||
'headlines', 'content', 'information from', 'from the', 'from',
|
||||
'on the website', 'on the site', 'the website', 'the site', 'website',
|
||||
'summarize', 'extract', 'read', 'analyze', 'find', 'search',
|
||||
'what does', 'what\'s on', 'what is', 'tell me', 'about',
|
||||
'news', 'articles', 'posts', 'page', 'pages',
|
||||
]
|
||||
|
||||
# Check if message contains access intent
|
||||
has_access_intent = any(kw in message_lower for kw in access_keywords)
|
||||
|
||||
# Also trigger if URL is directly mentioned with a question
|
||||
has_question = '?' in message or any(qw in message_lower for qw in ['what', 'how', 'who', 'where', 'when', 'why'])
|
||||
|
||||
return (has_access_intent or has_question) and len(urls) > 0
|
||||
|
||||
|
||||
async def download_website_if_needed(user_message: str) -> dict[str, Any]:
|
||||
"""
|
||||
Download website if user is asking about one.
|
||||
Returns download info if successful.
|
||||
"""
|
||||
urls = extract_urls_from_message(user_message)
|
||||
|
||||
if not should_download_website(user_message, urls):
|
||||
return {"downloaded": False, "reason": "No website access intent detected"}
|
||||
|
||||
if not state.rag_system:
|
||||
return {"downloaded": False, "reason": "RAG system not initialized"}
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
log.info(f"Auto-downloading website: {url}")
|
||||
result = await state.rag_system.download_and_ingest_website(
|
||||
url=url,
|
||||
max_pages=30,
|
||||
)
|
||||
|
||||
if result.get("success"):
|
||||
log.info(f"Successfully downloaded {url}: {result.get('total_chunks', 0)} chunks")
|
||||
return {
|
||||
"downloaded": True,
|
||||
"url": url,
|
||||
"chunks": result.get("total_chunks", 0),
|
||||
"pages": result.get("pages_processed", 0),
|
||||
"local_path": result.get("local_path"),
|
||||
}
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to download {url}: {e}")
|
||||
|
||||
return {"downloaded": False, "reason": "All download attempts failed"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Completion Logic
|
||||
# =============================================================================
|
||||
@ -312,7 +401,12 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat
|
||||
if not user_message:
|
||||
raise HTTPException(status_code=400, detail="No user message found")
|
||||
|
||||
# Step 1: RAG Retrieval
|
||||
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
|
||||
download_info = await download_website_if_needed(user_message)
|
||||
if download_info.get("downloaded"):
|
||||
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
||||
|
||||
# Step 2: RAG Retrieval (now includes newly downloaded content)
|
||||
context = ""
|
||||
sources = []
|
||||
if state.rag_system:
|
||||
@ -327,22 +421,14 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat
|
||||
except Exception as e:
|
||||
log.warning(f"RAG retrieval failed: {e}")
|
||||
|
||||
# Step 2: Build enhanced prompt with context
|
||||
enhanced_messages = build_enhanced_messages(messages, context, sources)
|
||||
|
||||
# Step 3: Check for tool usage
|
||||
tool_calls_made = []
|
||||
if config.ENABLE_TOOLS and state.tool_manager:
|
||||
tool_calls_made = await check_and_execute_tools(
|
||||
user_message, enhanced_messages, request.tools
|
||||
)
|
||||
# Step 3: Build enhanced prompt with context
|
||||
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
|
||||
|
||||
# Step 4: Generate response with upstream LLM
|
||||
response_content = await generate_response(
|
||||
enhanced_messages,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
tools=request.tools,
|
||||
)
|
||||
|
||||
# Step 5: Build and return response
|
||||
@ -354,7 +440,6 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
tool_calls=tool_calls_made if tool_calls_made else None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
@ -384,7 +469,12 @@ async def stream_chat_completion(
|
||||
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
|
||||
return
|
||||
|
||||
# Step 1: RAG Retrieval
|
||||
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
|
||||
download_info = await download_website_if_needed(user_message)
|
||||
if download_info.get("downloaded"):
|
||||
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
||||
|
||||
# Step 2: RAG Retrieval (now includes newly downloaded content)
|
||||
context = ""
|
||||
sources = []
|
||||
if state.rag_system:
|
||||
@ -399,10 +489,10 @@ async def stream_chat_completion(
|
||||
except Exception as e:
|
||||
log.warning(f"RAG retrieval failed: {e}")
|
||||
|
||||
# Step 2: Build enhanced prompt with context
|
||||
enhanced_messages = build_enhanced_messages(messages, context, sources)
|
||||
# Step 3: Build enhanced prompt with context
|
||||
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
|
||||
|
||||
# Step 3: Stream response from upstream LLM
|
||||
# Step 4: Stream response from upstream LLM
|
||||
created = int(time.time())
|
||||
|
||||
try:
|
||||
@ -456,9 +546,12 @@ async def stream_chat_completion(
|
||||
else:
|
||||
# Mock streaming response for testing
|
||||
mock_response = f"I understand you're asking about: {user_message}\n\n"
|
||||
if download_info.get("downloaded"):
|
||||
mock_response += f"I have downloaded and analyzed {download_info.get('url')}.\n"
|
||||
mock_response += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} chunks.\n\n"
|
||||
if context:
|
||||
mock_response += f"Based on my knowledge base, I found relevant information that I'm using to help answer your question.\n\n"
|
||||
mock_response += "However, I'm currently running in demo mode without an upstream LLM connection. Please configure ZAI_API_KEY for full functionality."
|
||||
mock_response += f"Based on my knowledge base, here's what I found:\n\n{context[:1000]}...\n\n"
|
||||
mock_response += "\n\n[Demo mode - configure ZAI_API_KEY for full LLM responses]"
|
||||
|
||||
for char in mock_response:
|
||||
yield f"data: {json.dumps({
|
||||
@ -496,22 +589,30 @@ def build_enhanced_messages(
|
||||
messages: list[ChatMessage],
|
||||
context: str,
|
||||
sources: list[str],
|
||||
download_info: dict = None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Build enhanced messages with RAG context."""
|
||||
enhanced = []
|
||||
|
||||
# Add system message with RAG context
|
||||
system_content = (
|
||||
"You are a helpful AI assistant with access to a knowledge base. "
|
||||
"Use the provided context to give accurate and helpful responses. "
|
||||
"If the context doesn't contain relevant information, use your general knowledge "
|
||||
"but indicate when you're doing so."
|
||||
"You are a helpful AI assistant with the ability to access and analyze websites on-demand. "
|
||||
"When a user asks about a website, you can download and analyze its content directly. "
|
||||
"Use the provided context from the knowledge base to give accurate and helpful responses. "
|
||||
"If context from a website is provided, use it to answer the user's question directly with specific information. "
|
||||
"Be helpful, detailed, and provide the specific information the user is asking for (headlines, summaries, etc.)."
|
||||
)
|
||||
|
||||
if download_info and download_info.get("downloaded"):
|
||||
system_content += f"\n\n--- Website Access ---\n"
|
||||
system_content += f"I have successfully downloaded and analyzed the website: {download_info.get('url')}\n"
|
||||
system_content += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} text chunks.\n"
|
||||
system_content += "The context below contains the actual content from this website. Use it to answer the user's question."
|
||||
|
||||
if context:
|
||||
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
|
||||
if sources:
|
||||
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources)
|
||||
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources[:10])
|
||||
|
||||
enhanced.append(ChatMessage(role="system", content=system_content))
|
||||
|
||||
@ -527,7 +628,6 @@ async def generate_response(
|
||||
messages: list[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
tools: Optional[list[dict]] = None,
|
||||
) -> str:
|
||||
"""Generate response using upstream LLM."""
|
||||
if state.zai_client:
|
||||
@ -566,55 +666,6 @@ async def generate_response(
|
||||
return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality."
|
||||
|
||||
|
||||
async def check_and_execute_tools(
|
||||
user_message: str,
|
||||
messages: list[ChatMessage],
|
||||
available_tools: Optional[list[dict]],
|
||||
) -> list[dict]:
|
||||
"""Check if tools should be used and execute them."""
|
||||
if not state.tool_manager or not available_tools:
|
||||
return []
|
||||
|
||||
tool_calls = []
|
||||
|
||||
# Simple keyword-based tool detection
|
||||
# In a production system, you'd use the LLM to decide tool usage
|
||||
message_lower = user_message.lower()
|
||||
|
||||
# Check for website download intent - use RAG system for full integration
|
||||
if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site", "ingest site"]):
|
||||
# Extract URL from message
|
||||
import re
|
||||
url_pattern = r'https?://[^\s]+'
|
||||
urls = re.findall(url_pattern, user_message)
|
||||
|
||||
if urls and state.rag_system:
|
||||
# Use RAG system's integrated website downloader
|
||||
try:
|
||||
result = await state.rag_system.download_and_ingest_website(
|
||||
url=urls[0],
|
||||
max_pages=20, # Reasonable default
|
||||
)
|
||||
|
||||
tool_calls.append({
|
||||
"id": f"call_{uuid.uuid4().hex[:24]}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "website_downloader",
|
||||
"arguments": json.dumps({
|
||||
"url": urls[0],
|
||||
"success": result.get("success"),
|
||||
"chunks_ingested": result.get("total_chunks", 0),
|
||||
}),
|
||||
}
|
||||
})
|
||||
log.info(f"Downloaded and ingested website: {urls[0]} -> {result.get('total_chunks', 0)} chunks")
|
||||
except Exception as e:
|
||||
log.error(f"Website download failed: {e}")
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Document Management Endpoints
|
||||
# =============================================================================
|
||||
@ -702,12 +753,7 @@ async def download_website(request: WebsiteDownloadRequest):
|
||||
|
||||
@app.post("/v1/documents/url")
|
||||
async def add_document_from_url(request: dict):
|
||||
"""
|
||||
Add a document from URL to the knowledge base.
|
||||
|
||||
NOTE: For websites, prefer using /v1/documents/website instead
|
||||
as it downloads the entire site and provides better context.
|
||||
"""
|
||||
"""Add a document from URL to the knowledge base."""
|
||||
if not state.rag_system:
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
@ -758,11 +804,9 @@ async def get_site_info(url: str):
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
try:
|
||||
# URL will be passed as path parameter, need to decode
|
||||
from urllib.parse import unquote
|
||||
decoded_url = unquote(url)
|
||||
|
||||
# Add scheme if missing
|
||||
if not decoded_url.startswith(("http://", "https://")):
|
||||
decoded_url = "https://" + decoded_url
|
||||
|
||||
@ -799,11 +843,9 @@ async def delete_site(url: str):
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
try:
|
||||
# URL will be passed as path parameter, need to decode
|
||||
from urllib.parse import unquote
|
||||
decoded_url = unquote(url)
|
||||
|
||||
# Add scheme if missing
|
||||
if not decoded_url.startswith(("http://", "https://")):
|
||||
decoded_url = "https://" + decoded_url
|
||||
|
||||
@ -848,7 +890,7 @@ async def root():
|
||||
return {
|
||||
"name": "DocRAG API",
|
||||
"version": "1.0.0",
|
||||
"description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash",
|
||||
"description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash. Auto-downloads and analyzes websites when users ask about them.",
|
||||
"endpoints": {
|
||||
"chat": "/v1/chat/completions",
|
||||
"models": "/v1/models",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user