docrag/main.py
Z User 6aecc4b231 Integrate website_downloader_tool into RAG system
Features:
- RAG system now uses website_downloader_tool as primary content ingestion method
- download_and_ingest_website() method for complete website processing
- Stores page pointers (source_url, page_url, local_path) in vector store
- Site registry tracks all downloaded websites with metadata
- New API endpoints for website management:
  - POST /v1/documents/website - Download and ingest a website
  - GET /v1/documents/sites - List all downloaded sites
  - GET /v1/documents/sites/{url} - Get site info
  - DELETE /v1/documents/sites/{url} - Delete a site and its content

Changes:
- rag/__init__.py: Added download_and_ingest_website(), site registry
- rag/document_processor.py: Added extract_text_from_html() public method
- rag/vector_store.py: Added delete_by_source_url(), get_stats()
- main.py: New website endpoints, integrated tool with RAG system
2026-03-29 02:36:59 +00:00

876 lines
29 KiB
Python

#!/usr/bin/env python3
"""
DocRAG - OpenAI-Compatible RAG Server
This application presents itself as a standard OpenAI API server that can be used
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
1. Processes user queries through a RAG system
2. Retrieves relevant context from a knowledge base
3. Passes the enriched context to GLM-4.7-Flash for response generation
4. Optionally uses tools like website_downloader for enhanced capabilities
The user sees a normal chat experience, but the system is actually doing
sophisticated RAG operations in the background.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import sys
import time
import uuid
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Optional
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)
# FastAPI imports
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field
# Import RAG components
from rag import RAGSystem, get_rag_system
from rag.document_processor import DocumentProcessor
# Import tools
from tools import ToolManager, get_tool_manager
# Import SDK for GLM-4.7-Flash
try:
from zai import ZaiClient as ZAI
except ImportError:
ZAI = None
log.warning("z-ai-web-dev-sdk not installed. Install with: pip install z-ai-web-dev-sdk")
# =============================================================================
# Configuration
# =============================================================================
class Config:
"""Application configuration from environment variables."""
# Server settings
HOST: str = os.getenv("HOST", "0.0.0.0")
PORT: int = int(os.getenv("PORT", "8000"))
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
# Model settings
MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG-GLM-4.7")
UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "glm-4.7")
# API Key for upstream LLM
ZAI_API_KEY: str = os.getenv("ZAI_API_KEY", "")
# RAG settings
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors")
DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents")
CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000"))
CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200"))
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
# Tool settings
ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true"
MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "3"))
config = Config()
# =============================================================================
# OpenAI-Compatible Models
# =============================================================================
class ChatMessage(BaseModel):
"""OpenAI chat message format."""
role: str
content: Optional[str] = None
name: Optional[str] = None
tool_calls: Optional[list[dict]] = None
tool_call_id: Optional[str] = None
class ChatCompletionRequest(BaseModel):
"""OpenAI chat completion request format."""
model: str = "DocRAG-GLM-4.7"
messages: list[ChatMessage]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
stop: Optional[list[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
tools: Optional[list[dict]] = None
tool_choice: Optional[str | dict] = None
class ChatCompletionChoice(BaseModel):
"""OpenAI chat completion choice."""
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionUsage(BaseModel):
"""Token usage statistics."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class ChatCompletionResponse(BaseModel):
"""OpenAI chat completion response format."""
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str = config.MODEL_NAME
choices: list[ChatCompletionChoice]
usage: ChatCompletionUsage = ChatCompletionUsage()
class ModelInfo(BaseModel):
"""OpenAI model info format."""
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "organization"
class ModelList(BaseModel):
"""OpenAI model list format."""
object: str = "list"
data: list[ModelInfo]
# =============================================================================
# Application State
# =============================================================================
class AppState:
"""Global application state."""
rag_system: Optional[RAGSystem] = None
tool_manager: Optional[ToolManager] = None
zai_client: Any = None
startup_time: float = time.time()
state = AppState()
# =============================================================================
# Lifespan Management
# =============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan - startup and shutdown."""
log.info("Starting DocRAG server...")
# Initialize RAG system
try:
log.info("Initializing RAG system...")
state.rag_system = await get_rag_system(
embedding_model=config.EMBEDDING_MODEL,
vector_store_path=config.VECTOR_STORE_PATH,
documents_path=config.DOCUMENTS_PATH,
chunk_size=config.CHUNK_SIZE,
chunk_overlap=config.CHUNK_OVERLAP,
)
log.info("RAG system initialized successfully")
except Exception as e:
log.warning(f"RAG system initialization deferred: {e}")
state.rag_system = None
# Initialize tool manager
try:
log.info("Initializing tool manager...")
state.tool_manager = get_tool_manager()
log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}")
except Exception as e:
log.warning(f"Tool manager initialization failed: {e}")
state.tool_manager = None
# Initialize ZAI client for upstream LLM
try:
if config.ZAI_API_KEY and ZAI is not None:
log.info("Initializing ZAI client...")
state.zai_client = ZAI(api_key=config.ZAI_API_KEY)
log.info("ZAI client initialized successfully")
else:
log.warning("No ZAI_API_KEY provided or SDK not installed - using mock responses")
state.zai_client = None
except Exception as e:
log.error(f"Failed to initialize ZAI client: {e}")
state.zai_client = None
log.info(f"DocRAG server started on {config.HOST}:{config.PORT}")
yield
# Cleanup
log.info("Shutting down DocRAG server...")
if state.rag_system:
await state.rag_system.close()
log.info("DocRAG server stopped")
# =============================================================================
# FastAPI Application
# =============================================================================
app = FastAPI(
title="DocRAG API",
description="OpenAI-compatible RAG server powered by GLM-4.7-Flash",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware for Open WebUI compatibility
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# OpenAI-Compatible Endpoints
# =============================================================================
@app.get("/v1/models")
@app.get("/models")
async def list_models():
"""List available models (OpenAI-compatible)."""
return ModelList(
data=[
ModelInfo(id=config.MODEL_NAME, owned_by="docrag"),
ModelInfo(id="DocRAG-GLM-4.7", owned_by="docrag"),
]
)
@app.get("/v1/models/{model_id}")
@app.get("/models/{model_id}")
async def get_model(model_id: str):
"""Get model information (OpenAI-compatible)."""
if model_id not in [config.MODEL_NAME, "DocRAG-GLM-4.7"]:
raise HTTPException(status_code=404, detail="Model not found")
return ModelInfo(id=model_id, owned_by="docrag")
@app.post("/v1/chat/completions")
@app.post("/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""Handle chat completions (OpenAI-compatible)."""
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
try:
if request.stream:
return StreamingResponse(
stream_chat_completion(request, request_id),
media_type="text/event-stream",
)
else:
return await complete_chat(request, request_id)
except Exception as e:
log.exception("Chat completion failed")
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Chat Completion Logic
# =============================================================================
async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse:
"""Process a non-streaming chat completion request."""
messages = request.messages
# Extract the last user message
user_message = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
user_message = msg.content
break
if not user_message:
raise HTTPException(status_code=400, detail="No user message found")
# Step 1: RAG Retrieval
context = ""
sources = []
if state.rag_system:
try:
rag_result = await state.rag_system.query(
query=user_message,
top_k=config.TOP_K_RESULTS,
)
context = rag_result.get("context", "")
sources = rag_result.get("sources", [])
log.info(f"RAG retrieved {len(sources)} relevant documents")
except Exception as e:
log.warning(f"RAG retrieval failed: {e}")
# Step 2: Build enhanced prompt with context
enhanced_messages = build_enhanced_messages(messages, context, sources)
# Step 3: Check for tool usage
tool_calls_made = []
if config.ENABLE_TOOLS and state.tool_manager:
tool_calls_made = await check_and_execute_tools(
user_message, enhanced_messages, request.tools
)
# Step 4: Generate response with upstream LLM
response_content = await generate_response(
enhanced_messages,
temperature=request.temperature,
max_tokens=request.max_tokens,
tools=request.tools,
)
# Step 5: Build and return response
return ChatCompletionResponse(
id=request_id,
model=config.MODEL_NAME,
choices=[
ChatCompletionChoice(
message=ChatMessage(
role="assistant",
content=response_content,
tool_calls=tool_calls_made if tool_calls_made else None,
),
finish_reason="stop",
)
],
usage=ChatCompletionUsage(
prompt_tokens=len(str(enhanced_messages)) // 4,
completion_tokens=len(response_content) // 4,
),
)
async def stream_chat_completion(
request: ChatCompletionRequest,
request_id: str,
) -> AsyncIterator[str]:
"""Stream a chat completion response."""
messages = request.messages
# Extract the last user message
user_message = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
user_message = msg.content
break
if not user_message:
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
return
# Step 1: RAG Retrieval
context = ""
sources = []
if state.rag_system:
try:
rag_result = await state.rag_system.query(
query=user_message,
top_k=config.TOP_K_RESULTS,
)
context = rag_result.get("context", "")
sources = rag_result.get("sources", [])
log.info(f"RAG retrieved {len(sources)} relevant documents")
except Exception as e:
log.warning(f"RAG retrieval failed: {e}")
# Step 2: Build enhanced prompt with context
enhanced_messages = build_enhanced_messages(messages, context, sources)
# Step 3: Stream response from upstream LLM
created = int(time.time())
try:
if state.zai_client:
# Use actual GLM-4.7-Flash
response = state.zai_client.chat.completions.create(
model=config.UPSTREAM_MODEL,
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
temperature=request.temperature or 0.7,
max_tokens=request.max_tokens or 4096,
stream=True,
thinking={"type": "enabled"},
)
for chunk in response:
# Handle reasoning content (thinking)
if hasattr(chunk.choices[0].delta, 'reasoning_content') and chunk.choices[0].delta.reasoning_content:
# Don't expose thinking to user - this is internal RAG processing
log.debug(f"Thinking: {chunk.choices[0].delta.reasoning_content[:100]}...")
continue
# Stream actual content
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': content},
'finish_reason': None
}]
})}\n\n"
# Send final chunk
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {},
'finish_reason': 'stop'
}]
})}\n\n"
yield "data: [DONE]\n\n"
else:
# Mock streaming response for testing
mock_response = f"I understand you're asking about: {user_message}\n\n"
if context:
mock_response += f"Based on my knowledge base, I found relevant information that I'm using to help answer your question.\n\n"
mock_response += "However, I'm currently running in demo mode without an upstream LLM connection. Please configure ZAI_API_KEY for full functionality."
for char in mock_response:
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': char},
'finish_reason': None
}]
})}\n\n"
await asyncio.sleep(0.01)
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {},
'finish_reason': 'stop'
}]
})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
log.exception("Streaming failed")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
def build_enhanced_messages(
messages: list[ChatMessage],
context: str,
sources: list[str],
) -> list[ChatMessage]:
"""Build enhanced messages with RAG context."""
enhanced = []
# Add system message with RAG context
system_content = (
"You are a helpful AI assistant with access to a knowledge base. "
"Use the provided context to give accurate and helpful responses. "
"If the context doesn't contain relevant information, use your general knowledge "
"but indicate when you're doing so."
)
if context:
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
if sources:
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources)
enhanced.append(ChatMessage(role="system", content=system_content))
# Add conversation history (excluding old system messages)
for msg in messages:
if msg.role != "system":
enhanced.append(msg)
return enhanced
async def generate_response(
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 4096,
tools: Optional[list[dict]] = None,
) -> str:
"""Generate response using upstream LLM."""
if state.zai_client:
try:
response = state.zai_client.chat.completions.create(
model=config.UPSTREAM_MODEL,
messages=[{"role": m.role, "content": m.content} for m in messages if m.content],
temperature=temperature,
max_tokens=max_tokens,
stream=False,
thinking={"type": "enabled"},
)
# Extract content from response
content = ""
for chunk in response:
if hasattr(chunk.choices[0], 'message') and chunk.choices[0].message:
content = chunk.choices[0].message.content or ""
break
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
return content or "I apologize, but I couldn't generate a response."
except Exception as e:
log.error(f"Upstream LLM call failed: {e}")
return f"I encountered an error: {str(e)}"
else:
# Mock response for testing
user_msg = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
user_msg = msg.content
break
return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality."
async def check_and_execute_tools(
user_message: str,
messages: list[ChatMessage],
available_tools: Optional[list[dict]],
) -> list[dict]:
"""Check if tools should be used and execute them."""
if not state.tool_manager or not available_tools:
return []
tool_calls = []
# Simple keyword-based tool detection
# In a production system, you'd use the LLM to decide tool usage
message_lower = user_message.lower()
# Check for website download intent - use RAG system for full integration
if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site", "ingest site"]):
# Extract URL from message
import re
url_pattern = r'https?://[^\s]+'
urls = re.findall(url_pattern, user_message)
if urls and state.rag_system:
# Use RAG system's integrated website downloader
try:
result = await state.rag_system.download_and_ingest_website(
url=urls[0],
max_pages=20, # Reasonable default
)
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": "website_downloader",
"arguments": json.dumps({
"url": urls[0],
"success": result.get("success"),
"chunks_ingested": result.get("total_chunks", 0),
}),
}
})
log.info(f"Downloaded and ingested website: {urls[0]} -> {result.get('total_chunks', 0)} chunks")
except Exception as e:
log.error(f"Website download failed: {e}")
return tool_calls
# =============================================================================
# Document Management Endpoints
# =============================================================================
@app.post("/v1/documents/upload")
async def upload_document(request: Request):
"""Upload a document to the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
form = await request.form()
file = form.get("file")
if not file:
raise HTTPException(status_code=400, detail="No file provided")
content = await file.read()
filename = file.filename or "unknown"
# Process and store document
result = await state.rag_system.add_document(
content=content,
filename=filename,
)
return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)}
except Exception as e:
log.exception("Document upload failed")
raise HTTPException(status_code=500, detail=str(e))
class WebsiteDownloadRequest(BaseModel):
"""Request model for website download."""
url: str
max_pages: int = 50
threads: int = 6
download_external_assets: bool = False
external_domains: Optional[list[str]] = None
@app.post("/v1/documents/website")
async def download_website(request: WebsiteDownloadRequest):
"""
Download a website and ingest it into the knowledge base.
This is the PRIMARY way to add content to the RAG system.
Uses the website_downloader_tool to download and process websites.
"""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
log.info(f"Downloading website: {request.url}")
result = await state.rag_system.download_and_ingest_website(
url=request.url,
max_pages=request.max_pages,
threads=request.threads,
download_external_assets=request.download_external_assets,
external_domains=request.external_domains,
)
if result.get("success"):
return {
"success": True,
"message": f"Website downloaded and ingested: {request.url}",
"url": request.url,
"local_path": result.get("local_path"),
"pages_processed": result.get("pages_processed", 0),
"total_chunks": result.get("total_chunks", 0),
}
else:
raise HTTPException(
status_code=500,
detail=result.get("message", "Website download failed")
)
except HTTPException:
raise
except Exception as e:
log.exception("Website download failed")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/documents/url")
async def add_document_from_url(request: dict):
"""
Add a document from URL to the knowledge base.
NOTE: For websites, prefer using /v1/documents/website instead
as it downloads the entire site and provides better context.
"""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
url = request.get("url")
if not url:
raise HTTPException(status_code=400, detail="No URL provided")
try:
result = await state.rag_system.add_document_from_url(url)
return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)}
except Exception as e:
log.exception("URL document addition failed")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/documents")
async def list_documents():
"""List documents in the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
docs = await state.rag_system.list_documents()
return {"documents": docs}
except Exception as e:
log.exception("Document listing failed")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/documents/sites")
async def list_downloaded_sites():
"""List all downloaded websites in the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
sites = await state.rag_system.list_downloaded_sites()
return {"sites": sites}
except Exception as e:
log.exception("Site listing failed")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/documents/sites/{url:path}")
async def get_site_info(url: str):
"""Get information about a specific downloaded site."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
# URL will be passed as path parameter, need to decode
from urllib.parse import unquote
decoded_url = unquote(url)
# Add scheme if missing
if not decoded_url.startswith(("http://", "https://")):
decoded_url = "https://" + decoded_url
site_info = state.rag_system.get_site_info(decoded_url)
if site_info:
return {"site": site_info}
else:
raise HTTPException(status_code=404, detail="Site not found")
except HTTPException:
raise
except Exception as e:
log.exception("Site info retrieval failed")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/v1/documents/{doc_id}")
async def delete_document(doc_id: str):
"""Delete a document from the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
await state.rag_system.delete_document(doc_id)
return {"success": True, "message": f"Document {doc_id} deleted"}
except Exception as e:
log.exception("Document deletion failed")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/v1/documents/sites/{url:path}")
async def delete_site(url: str):
"""Delete a downloaded website and all its content from the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
# URL will be passed as path parameter, need to decode
from urllib.parse import unquote
decoded_url = unquote(url)
# Add scheme if missing
if not decoded_url.startswith(("http://", "https://")):
decoded_url = "https://" + decoded_url
result = await state.rag_system.delete_site(decoded_url)
if result.get("success"):
return {
"success": True,
"message": f"Site {decoded_url} deleted",
"deleted_chunks": result.get("deleted_chunks", 0),
"deleted_path": result.get("deleted_path"),
}
else:
raise HTTPException(status_code=404, detail=result.get("message", "Site not found"))
except HTTPException:
raise
except Exception as e:
log.exception("Site deletion failed")
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Health and Status Endpoints
# =============================================================================
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"uptime": time.time() - state.startup_time,
"rag_enabled": state.rag_system is not None,
"tools_enabled": state.tool_manager is not None,
"llm_connected": state.zai_client is not None,
}
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"name": "DocRAG API",
"version": "1.0.0",
"description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash",
"endpoints": {
"chat": "/v1/chat/completions",
"models": "/v1/models",
"documents": "/v1/documents",
"download_website": "/v1/documents/website",
"list_sites": "/v1/documents/sites",
"health": "/health",
},
}
# =============================================================================
# Main Entry Point
# =============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=config.HOST,
port=config.PORT,
reload=config.DEBUG,
)