Fix: Auto-download websites BEFORE RAG retrieval
Key changes: - Add URL extraction and detection functions - Download websites BEFORE RAG retrieval (not after) - Expand trigger keywords to include common phrases like 'go to', 'headlines', etc. - Update system prompt to tell LLM it CAN access websites - Improve streaming response handling Now when user asks 'go to orovillemr.com and give me the headlines': 1. System detects URL and access intent 2. Downloads and ingests website content 3. RAG retrieves relevant content 4. LLM generates response with actual website content
This commit is contained in:
parent
6aecc4b231
commit
10e61dd2f1
218
main.py
218
main.py
@ -4,10 +4,10 @@ DocRAG - OpenAI-Compatible RAG Server
|
|||||||
|
|
||||||
This application presents itself as a standard OpenAI API server that can be used
|
This application presents itself as a standard OpenAI API server that can be used
|
||||||
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
|
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
|
||||||
1. Processes user queries through a RAG system
|
1. Detects URLs in user messages and auto-downloads websites
|
||||||
2. Retrieves relevant context from a knowledge base
|
2. Ingests website content into the RAG knowledge base
|
||||||
3. Passes the enriched context to GLM-4.7-Flash for response generation
|
3. Retrieves relevant context from the knowledge base
|
||||||
4. Optionally uses tools like website_downloader for enhanced capabilities
|
4. Passes the enriched context to GLM-4.7-Flash for response generation
|
||||||
|
|
||||||
The user sees a normal chat experience, but the system is actually doing
|
The user sees a normal chat experience, but the system is actually doing
|
||||||
sophisticated RAG operations in the background.
|
sophisticated RAG operations in the background.
|
||||||
@ -19,6 +19,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@ -294,6 +295,94 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# URL Detection and Website Download
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def extract_urls_from_message(message: str) -> list[str]:
|
||||||
|
"""Extract URLs from a message, including domains without scheme."""
|
||||||
|
# Match full URLs
|
||||||
|
url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+'
|
||||||
|
full_urls = re.findall(url_pattern, message)
|
||||||
|
|
||||||
|
# Match domain names (with or without www)
|
||||||
|
domain_pattern = r'(?:^|[^\w/-])(?:www\.)?([a-zA-Z0-9][-a-zA-Z0-9]*\.[a-zA-Z]{2,}(?:\.[a-zA-Z]{2,})?)(?:/[^\s]*)?(?=[^\w/-]|$)'
|
||||||
|
domains = re.findall(domain_pattern, message)
|
||||||
|
|
||||||
|
urls = list(full_urls)
|
||||||
|
for domain in domains:
|
||||||
|
# Check if it's a valid domain (not just a word)
|
||||||
|
if '.' in domain and len(domain) > 4:
|
||||||
|
normalized = f"https://{domain}" if not domain.startswith(('http://', 'https://')) else domain
|
||||||
|
if normalized not in urls:
|
||||||
|
urls.append(normalized)
|
||||||
|
|
||||||
|
return urls
|
||||||
|
|
||||||
|
|
||||||
|
def should_download_website(message: str, urls: list[str]) -> bool:
|
||||||
|
"""Determine if the user wants to access content from a website."""
|
||||||
|
if not urls:
|
||||||
|
return False
|
||||||
|
|
||||||
|
message_lower = message.lower()
|
||||||
|
|
||||||
|
# Keywords indicating user wants website content
|
||||||
|
access_keywords = [
|
||||||
|
'go to', 'visit', 'check', 'look at', 'browse', 'open',
|
||||||
|
'what is on', 'tell me about', 'give me', 'show me', 'get',
|
||||||
|
'headlines', 'content', 'information from', 'from the', 'from',
|
||||||
|
'on the website', 'on the site', 'the website', 'the site', 'website',
|
||||||
|
'summarize', 'extract', 'read', 'analyze', 'find', 'search',
|
||||||
|
'what does', 'what\'s on', 'what is', 'tell me', 'about',
|
||||||
|
'news', 'articles', 'posts', 'page', 'pages',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check if message contains access intent
|
||||||
|
has_access_intent = any(kw in message_lower for kw in access_keywords)
|
||||||
|
|
||||||
|
# Also trigger if URL is directly mentioned with a question
|
||||||
|
has_question = '?' in message or any(qw in message_lower for qw in ['what', 'how', 'who', 'where', 'when', 'why'])
|
||||||
|
|
||||||
|
return (has_access_intent or has_question) and len(urls) > 0
|
||||||
|
|
||||||
|
|
||||||
|
async def download_website_if_needed(user_message: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Download website if user is asking about one.
|
||||||
|
Returns download info if successful.
|
||||||
|
"""
|
||||||
|
urls = extract_urls_from_message(user_message)
|
||||||
|
|
||||||
|
if not should_download_website(user_message, urls):
|
||||||
|
return {"downloaded": False, "reason": "No website access intent detected"}
|
||||||
|
|
||||||
|
if not state.rag_system:
|
||||||
|
return {"downloaded": False, "reason": "RAG system not initialized"}
|
||||||
|
|
||||||
|
for url in urls:
|
||||||
|
try:
|
||||||
|
log.info(f"Auto-downloading website: {url}")
|
||||||
|
result = await state.rag_system.download_and_ingest_website(
|
||||||
|
url=url,
|
||||||
|
max_pages=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("success"):
|
||||||
|
log.info(f"Successfully downloaded {url}: {result.get('total_chunks', 0)} chunks")
|
||||||
|
return {
|
||||||
|
"downloaded": True,
|
||||||
|
"url": url,
|
||||||
|
"chunks": result.get("total_chunks", 0),
|
||||||
|
"pages": result.get("pages_processed", 0),
|
||||||
|
"local_path": result.get("local_path"),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Failed to download {url}: {e}")
|
||||||
|
|
||||||
|
return {"downloaded": False, "reason": "All download attempts failed"}
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Chat Completion Logic
|
# Chat Completion Logic
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -312,7 +401,12 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat
|
|||||||
if not user_message:
|
if not user_message:
|
||||||
raise HTTPException(status_code=400, detail="No user message found")
|
raise HTTPException(status_code=400, detail="No user message found")
|
||||||
|
|
||||||
# Step 1: RAG Retrieval
|
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
|
||||||
|
download_info = await download_website_if_needed(user_message)
|
||||||
|
if download_info.get("downloaded"):
|
||||||
|
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
||||||
|
|
||||||
|
# Step 2: RAG Retrieval (now includes newly downloaded content)
|
||||||
context = ""
|
context = ""
|
||||||
sources = []
|
sources = []
|
||||||
if state.rag_system:
|
if state.rag_system:
|
||||||
@ -327,22 +421,14 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"RAG retrieval failed: {e}")
|
log.warning(f"RAG retrieval failed: {e}")
|
||||||
|
|
||||||
# Step 2: Build enhanced prompt with context
|
# Step 3: Build enhanced prompt with context
|
||||||
enhanced_messages = build_enhanced_messages(messages, context, sources)
|
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
|
||||||
|
|
||||||
# Step 3: Check for tool usage
|
|
||||||
tool_calls_made = []
|
|
||||||
if config.ENABLE_TOOLS and state.tool_manager:
|
|
||||||
tool_calls_made = await check_and_execute_tools(
|
|
||||||
user_message, enhanced_messages, request.tools
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 4: Generate response with upstream LLM
|
# Step 4: Generate response with upstream LLM
|
||||||
response_content = await generate_response(
|
response_content = await generate_response(
|
||||||
enhanced_messages,
|
enhanced_messages,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
tools=request.tools,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 5: Build and return response
|
# Step 5: Build and return response
|
||||||
@ -354,7 +440,6 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat
|
|||||||
message=ChatMessage(
|
message=ChatMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=response_content,
|
content=response_content,
|
||||||
tool_calls=tool_calls_made if tool_calls_made else None,
|
|
||||||
),
|
),
|
||||||
finish_reason="stop",
|
finish_reason="stop",
|
||||||
)
|
)
|
||||||
@ -384,7 +469,12 @@ async def stream_chat_completion(
|
|||||||
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
|
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
|
||||||
return
|
return
|
||||||
|
|
||||||
# Step 1: RAG Retrieval
|
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
|
||||||
|
download_info = await download_website_if_needed(user_message)
|
||||||
|
if download_info.get("downloaded"):
|
||||||
|
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
||||||
|
|
||||||
|
# Step 2: RAG Retrieval (now includes newly downloaded content)
|
||||||
context = ""
|
context = ""
|
||||||
sources = []
|
sources = []
|
||||||
if state.rag_system:
|
if state.rag_system:
|
||||||
@ -399,10 +489,10 @@ async def stream_chat_completion(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"RAG retrieval failed: {e}")
|
log.warning(f"RAG retrieval failed: {e}")
|
||||||
|
|
||||||
# Step 2: Build enhanced prompt with context
|
# Step 3: Build enhanced prompt with context
|
||||||
enhanced_messages = build_enhanced_messages(messages, context, sources)
|
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
|
||||||
|
|
||||||
# Step 3: Stream response from upstream LLM
|
# Step 4: Stream response from upstream LLM
|
||||||
created = int(time.time())
|
created = int(time.time())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -456,9 +546,12 @@ async def stream_chat_completion(
|
|||||||
else:
|
else:
|
||||||
# Mock streaming response for testing
|
# Mock streaming response for testing
|
||||||
mock_response = f"I understand you're asking about: {user_message}\n\n"
|
mock_response = f"I understand you're asking about: {user_message}\n\n"
|
||||||
|
if download_info.get("downloaded"):
|
||||||
|
mock_response += f"I have downloaded and analyzed {download_info.get('url')}.\n"
|
||||||
|
mock_response += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} chunks.\n\n"
|
||||||
if context:
|
if context:
|
||||||
mock_response += f"Based on my knowledge base, I found relevant information that I'm using to help answer your question.\n\n"
|
mock_response += f"Based on my knowledge base, here's what I found:\n\n{context[:1000]}...\n\n"
|
||||||
mock_response += "However, I'm currently running in demo mode without an upstream LLM connection. Please configure ZAI_API_KEY for full functionality."
|
mock_response += "\n\n[Demo mode - configure ZAI_API_KEY for full LLM responses]"
|
||||||
|
|
||||||
for char in mock_response:
|
for char in mock_response:
|
||||||
yield f"data: {json.dumps({
|
yield f"data: {json.dumps({
|
||||||
@ -496,22 +589,30 @@ def build_enhanced_messages(
|
|||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
context: str,
|
context: str,
|
||||||
sources: list[str],
|
sources: list[str],
|
||||||
|
download_info: dict = None,
|
||||||
) -> list[ChatMessage]:
|
) -> list[ChatMessage]:
|
||||||
"""Build enhanced messages with RAG context."""
|
"""Build enhanced messages with RAG context."""
|
||||||
enhanced = []
|
enhanced = []
|
||||||
|
|
||||||
# Add system message with RAG context
|
# Add system message with RAG context
|
||||||
system_content = (
|
system_content = (
|
||||||
"You are a helpful AI assistant with access to a knowledge base. "
|
"You are a helpful AI assistant with the ability to access and analyze websites on-demand. "
|
||||||
"Use the provided context to give accurate and helpful responses. "
|
"When a user asks about a website, you can download and analyze its content directly. "
|
||||||
"If the context doesn't contain relevant information, use your general knowledge "
|
"Use the provided context from the knowledge base to give accurate and helpful responses. "
|
||||||
"but indicate when you're doing so."
|
"If context from a website is provided, use it to answer the user's question directly with specific information. "
|
||||||
|
"Be helpful, detailed, and provide the specific information the user is asking for (headlines, summaries, etc.)."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if download_info and download_info.get("downloaded"):
|
||||||
|
system_content += f"\n\n--- Website Access ---\n"
|
||||||
|
system_content += f"I have successfully downloaded and analyzed the website: {download_info.get('url')}\n"
|
||||||
|
system_content += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} text chunks.\n"
|
||||||
|
system_content += "The context below contains the actual content from this website. Use it to answer the user's question."
|
||||||
|
|
||||||
if context:
|
if context:
|
||||||
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
|
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
|
||||||
if sources:
|
if sources:
|
||||||
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources)
|
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources[:10])
|
||||||
|
|
||||||
enhanced.append(ChatMessage(role="system", content=system_content))
|
enhanced.append(ChatMessage(role="system", content=system_content))
|
||||||
|
|
||||||
@ -527,7 +628,6 @@ async def generate_response(
|
|||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
tools: Optional[list[dict]] = None,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate response using upstream LLM."""
|
"""Generate response using upstream LLM."""
|
||||||
if state.zai_client:
|
if state.zai_client:
|
||||||
@ -566,55 +666,6 @@ async def generate_response(
|
|||||||
return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality."
|
return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality."
|
||||||
|
|
||||||
|
|
||||||
async def check_and_execute_tools(
|
|
||||||
user_message: str,
|
|
||||||
messages: list[ChatMessage],
|
|
||||||
available_tools: Optional[list[dict]],
|
|
||||||
) -> list[dict]:
|
|
||||||
"""Check if tools should be used and execute them."""
|
|
||||||
if not state.tool_manager or not available_tools:
|
|
||||||
return []
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
|
|
||||||
# Simple keyword-based tool detection
|
|
||||||
# In a production system, you'd use the LLM to decide tool usage
|
|
||||||
message_lower = user_message.lower()
|
|
||||||
|
|
||||||
# Check for website download intent - use RAG system for full integration
|
|
||||||
if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site", "ingest site"]):
|
|
||||||
# Extract URL from message
|
|
||||||
import re
|
|
||||||
url_pattern = r'https?://[^\s]+'
|
|
||||||
urls = re.findall(url_pattern, user_message)
|
|
||||||
|
|
||||||
if urls and state.rag_system:
|
|
||||||
# Use RAG system's integrated website downloader
|
|
||||||
try:
|
|
||||||
result = await state.rag_system.download_and_ingest_website(
|
|
||||||
url=urls[0],
|
|
||||||
max_pages=20, # Reasonable default
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_calls.append({
|
|
||||||
"id": f"call_{uuid.uuid4().hex[:24]}",
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "website_downloader",
|
|
||||||
"arguments": json.dumps({
|
|
||||||
"url": urls[0],
|
|
||||||
"success": result.get("success"),
|
|
||||||
"chunks_ingested": result.get("total_chunks", 0),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
log.info(f"Downloaded and ingested website: {urls[0]} -> {result.get('total_chunks', 0)} chunks")
|
|
||||||
except Exception as e:
|
|
||||||
log.error(f"Website download failed: {e}")
|
|
||||||
|
|
||||||
return tool_calls
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Document Management Endpoints
|
# Document Management Endpoints
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@ -702,12 +753,7 @@ async def download_website(request: WebsiteDownloadRequest):
|
|||||||
|
|
||||||
@app.post("/v1/documents/url")
|
@app.post("/v1/documents/url")
|
||||||
async def add_document_from_url(request: dict):
|
async def add_document_from_url(request: dict):
|
||||||
"""
|
"""Add a document from URL to the knowledge base."""
|
||||||
Add a document from URL to the knowledge base.
|
|
||||||
|
|
||||||
NOTE: For websites, prefer using /v1/documents/website instead
|
|
||||||
as it downloads the entire site and provides better context.
|
|
||||||
"""
|
|
||||||
if not state.rag_system:
|
if not state.rag_system:
|
||||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||||
|
|
||||||
@ -758,11 +804,9 @@ async def get_site_info(url: str):
|
|||||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# URL will be passed as path parameter, need to decode
|
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
decoded_url = unquote(url)
|
decoded_url = unquote(url)
|
||||||
|
|
||||||
# Add scheme if missing
|
|
||||||
if not decoded_url.startswith(("http://", "https://")):
|
if not decoded_url.startswith(("http://", "https://")):
|
||||||
decoded_url = "https://" + decoded_url
|
decoded_url = "https://" + decoded_url
|
||||||
|
|
||||||
@ -799,11 +843,9 @@ async def delete_site(url: str):
|
|||||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# URL will be passed as path parameter, need to decode
|
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
decoded_url = unquote(url)
|
decoded_url = unquote(url)
|
||||||
|
|
||||||
# Add scheme if missing
|
|
||||||
if not decoded_url.startswith(("http://", "https://")):
|
if not decoded_url.startswith(("http://", "https://")):
|
||||||
decoded_url = "https://" + decoded_url
|
decoded_url = "https://" + decoded_url
|
||||||
|
|
||||||
@ -848,7 +890,7 @@ async def root():
|
|||||||
return {
|
return {
|
||||||
"name": "DocRAG API",
|
"name": "DocRAG API",
|
||||||
"version": "1.0.0",
|
"version": "1.0.0",
|
||||||
"description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash",
|
"description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash. Auto-downloads and analyzes websites when users ask about them.",
|
||||||
"endpoints": {
|
"endpoints": {
|
||||||
"chat": "/v1/chat/completions",
|
"chat": "/v1/chat/completions",
|
||||||
"models": "/v1/models",
|
"models": "/v1/models",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user