Key changes: - Add URL extraction and detection functions - Download websites BEFORE RAG retrieval (not after) - Expand trigger keywords to include common phrases like 'go to', 'headlines', etc. - Update system prompt to tell LLM it CAN access websites - Improve streaming response handling Now when user asks 'go to orovillemr.com and give me the headlines': 1. System detects URL and access intent 2. Downloads and ingests website content 3. RAG retrieves relevant content 4. LLM generates response with actual website content
918 lines
32 KiB
Python
918 lines
32 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
DocRAG - OpenAI-Compatible RAG Server
|
|
|
|
This application presents itself as a standard OpenAI API server that can be used
|
|
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
|
|
1. Detects URLs in user messages and auto-downloads websites
|
|
2. Ingests website content into the RAG knowledge base
|
|
3. Retrieves relevant context from the knowledge base
|
|
4. Passes the enriched context to GLM-4.7-Flash for response generation
|
|
|
|
The user sees a normal chat experience, but the system is actually doing
|
|
sophisticated RAG operations in the background.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, AsyncIterator, Optional
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
log = logging.getLogger(__name__)
|
|
|
|
# FastAPI imports
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Import RAG components
|
|
from rag import RAGSystem, get_rag_system
|
|
from rag.document_processor import DocumentProcessor
|
|
|
|
# Import tools
|
|
from tools import ToolManager, get_tool_manager
|
|
|
|
# Import SDK for GLM-4.7-Flash
|
|
try:
|
|
from zai import ZaiClient as ZAI
|
|
except ImportError:
|
|
ZAI = None
|
|
log.warning("z-ai-web-dev-sdk not installed. Install with: pip install z-ai-web-dev-sdk")
|
|
|
|
|
|
# =============================================================================
|
|
# Configuration
|
|
# =============================================================================
|
|
|
|
class Config:
|
|
"""Application configuration from environment variables."""
|
|
|
|
# Server settings
|
|
HOST: str = os.getenv("HOST", "0.0.0.0")
|
|
PORT: int = int(os.getenv("PORT", "8000"))
|
|
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
|
|
|
|
# Model settings
|
|
MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG-GLM-4.7")
|
|
UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "glm-4.7")
|
|
|
|
# API Key for upstream LLM
|
|
ZAI_API_KEY: str = os.getenv("ZAI_API_KEY", "")
|
|
|
|
# RAG settings
|
|
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
|
VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors")
|
|
DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents")
|
|
CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000"))
|
|
CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200"))
|
|
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
|
|
|
|
# Tool settings
|
|
ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true"
|
|
MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "3"))
|
|
|
|
|
|
config = Config()
|
|
|
|
|
|
# =============================================================================
|
|
# OpenAI-Compatible Models
|
|
# =============================================================================
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""OpenAI chat message format."""
|
|
role: str
|
|
content: Optional[str] = None
|
|
name: Optional[str] = None
|
|
tool_calls: Optional[list[dict]] = None
|
|
tool_call_id: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
"""OpenAI chat completion request format."""
|
|
model: str = "DocRAG-GLM-4.7"
|
|
messages: list[ChatMessage]
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 1.0
|
|
n: Optional[int] = 1
|
|
stream: Optional[bool] = False
|
|
stop: Optional[list[str]] = None
|
|
max_tokens: Optional[int] = None
|
|
presence_penalty: Optional[float] = 0.0
|
|
frequency_penalty: Optional[float] = 0.0
|
|
user: Optional[str] = None
|
|
tools: Optional[list[dict]] = None
|
|
tool_choice: Optional[str | dict] = None
|
|
|
|
|
|
class ChatCompletionChoice(BaseModel):
|
|
"""OpenAI chat completion choice."""
|
|
index: int = 0
|
|
message: ChatMessage
|
|
finish_reason: str = "stop"
|
|
|
|
|
|
class ChatCompletionUsage(BaseModel):
|
|
"""Token usage statistics."""
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
total_tokens: int = 0
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""OpenAI chat completion response format."""
|
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}")
|
|
object: str = "chat.completion"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
model: str = config.MODEL_NAME
|
|
choices: list[ChatCompletionChoice]
|
|
usage: ChatCompletionUsage = ChatCompletionUsage()
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""OpenAI model info format."""
|
|
id: str
|
|
object: str = "model"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
owned_by: str = "organization"
|
|
|
|
|
|
class ModelList(BaseModel):
|
|
"""OpenAI model list format."""
|
|
object: str = "list"
|
|
data: list[ModelInfo]
|
|
|
|
|
|
# =============================================================================
|
|
# Application State
|
|
# =============================================================================
|
|
|
|
class AppState:
|
|
"""Global application state."""
|
|
rag_system: Optional[RAGSystem] = None
|
|
tool_manager: Optional[ToolManager] = None
|
|
zai_client: Any = None
|
|
startup_time: float = time.time()
|
|
|
|
|
|
state = AppState()
|
|
|
|
|
|
# =============================================================================
|
|
# Lifespan Management
|
|
# =============================================================================
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manage application lifespan - startup and shutdown."""
|
|
log.info("Starting DocRAG server...")
|
|
|
|
# Initialize RAG system
|
|
try:
|
|
log.info("Initializing RAG system...")
|
|
state.rag_system = await get_rag_system(
|
|
embedding_model=config.EMBEDDING_MODEL,
|
|
vector_store_path=config.VECTOR_STORE_PATH,
|
|
documents_path=config.DOCUMENTS_PATH,
|
|
chunk_size=config.CHUNK_SIZE,
|
|
chunk_overlap=config.CHUNK_OVERLAP,
|
|
)
|
|
log.info("RAG system initialized successfully")
|
|
except Exception as e:
|
|
log.warning(f"RAG system initialization deferred: {e}")
|
|
state.rag_system = None
|
|
|
|
# Initialize tool manager
|
|
try:
|
|
log.info("Initializing tool manager...")
|
|
state.tool_manager = get_tool_manager()
|
|
log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}")
|
|
except Exception as e:
|
|
log.warning(f"Tool manager initialization failed: {e}")
|
|
state.tool_manager = None
|
|
|
|
# Initialize ZAI client for upstream LLM
|
|
try:
|
|
if config.ZAI_API_KEY and ZAI is not None:
|
|
log.info("Initializing ZAI client...")
|
|
state.zai_client = ZAI(api_key=config.ZAI_API_KEY)
|
|
log.info("ZAI client initialized successfully")
|
|
else:
|
|
log.warning("No ZAI_API_KEY provided or SDK not installed - using mock responses")
|
|
state.zai_client = None
|
|
except Exception as e:
|
|
log.error(f"Failed to initialize ZAI client: {e}")
|
|
state.zai_client = None
|
|
|
|
log.info(f"DocRAG server started on {config.HOST}:{config.PORT}")
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
log.info("Shutting down DocRAG server...")
|
|
if state.rag_system:
|
|
await state.rag_system.close()
|
|
log.info("DocRAG server stopped")
|
|
|
|
|
|
# =============================================================================
|
|
# FastAPI Application
|
|
# =============================================================================
|
|
|
|
app = FastAPI(
|
|
title="DocRAG API",
|
|
description="OpenAI-compatible RAG server powered by GLM-4.7-Flash",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware for Open WebUI compatibility
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# OpenAI-Compatible Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/v1/models")
|
|
@app.get("/models")
|
|
async def list_models():
|
|
"""List available models (OpenAI-compatible)."""
|
|
return ModelList(
|
|
data=[
|
|
ModelInfo(id=config.MODEL_NAME, owned_by="docrag"),
|
|
ModelInfo(id="DocRAG-GLM-4.7", owned_by="docrag"),
|
|
]
|
|
)
|
|
|
|
|
|
@app.get("/v1/models/{model_id}")
|
|
@app.get("/models/{model_id}")
|
|
async def get_model(model_id: str):
|
|
"""Get model information (OpenAI-compatible)."""
|
|
if model_id not in [config.MODEL_NAME, "DocRAG-GLM-4.7"]:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
return ModelInfo(id=model_id, owned_by="docrag")
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@app.post("/chat/completions")
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
"""Handle chat completions (OpenAI-compatible)."""
|
|
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
|
|
try:
|
|
if request.stream:
|
|
return StreamingResponse(
|
|
stream_chat_completion(request, request_id),
|
|
media_type="text/event-stream",
|
|
)
|
|
else:
|
|
return await complete_chat(request, request_id)
|
|
except Exception as e:
|
|
log.exception("Chat completion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# URL Detection and Website Download
|
|
# =============================================================================
|
|
|
|
def extract_urls_from_message(message: str) -> list[str]:
|
|
"""Extract URLs from a message, including domains without scheme."""
|
|
# Match full URLs
|
|
url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+'
|
|
full_urls = re.findall(url_pattern, message)
|
|
|
|
# Match domain names (with or without www)
|
|
domain_pattern = r'(?:^|[^\w/-])(?:www\.)?([a-zA-Z0-9][-a-zA-Z0-9]*\.[a-zA-Z]{2,}(?:\.[a-zA-Z]{2,})?)(?:/[^\s]*)?(?=[^\w/-]|$)'
|
|
domains = re.findall(domain_pattern, message)
|
|
|
|
urls = list(full_urls)
|
|
for domain in domains:
|
|
# Check if it's a valid domain (not just a word)
|
|
if '.' in domain and len(domain) > 4:
|
|
normalized = f"https://{domain}" if not domain.startswith(('http://', 'https://')) else domain
|
|
if normalized not in urls:
|
|
urls.append(normalized)
|
|
|
|
return urls
|
|
|
|
|
|
def should_download_website(message: str, urls: list[str]) -> bool:
|
|
"""Determine if the user wants to access content from a website."""
|
|
if not urls:
|
|
return False
|
|
|
|
message_lower = message.lower()
|
|
|
|
# Keywords indicating user wants website content
|
|
access_keywords = [
|
|
'go to', 'visit', 'check', 'look at', 'browse', 'open',
|
|
'what is on', 'tell me about', 'give me', 'show me', 'get',
|
|
'headlines', 'content', 'information from', 'from the', 'from',
|
|
'on the website', 'on the site', 'the website', 'the site', 'website',
|
|
'summarize', 'extract', 'read', 'analyze', 'find', 'search',
|
|
'what does', 'what\'s on', 'what is', 'tell me', 'about',
|
|
'news', 'articles', 'posts', 'page', 'pages',
|
|
]
|
|
|
|
# Check if message contains access intent
|
|
has_access_intent = any(kw in message_lower for kw in access_keywords)
|
|
|
|
# Also trigger if URL is directly mentioned with a question
|
|
has_question = '?' in message or any(qw in message_lower for qw in ['what', 'how', 'who', 'where', 'when', 'why'])
|
|
|
|
return (has_access_intent or has_question) and len(urls) > 0
|
|
|
|
|
|
async def download_website_if_needed(user_message: str) -> dict[str, Any]:
|
|
"""
|
|
Download website if user is asking about one.
|
|
Returns download info if successful.
|
|
"""
|
|
urls = extract_urls_from_message(user_message)
|
|
|
|
if not should_download_website(user_message, urls):
|
|
return {"downloaded": False, "reason": "No website access intent detected"}
|
|
|
|
if not state.rag_system:
|
|
return {"downloaded": False, "reason": "RAG system not initialized"}
|
|
|
|
for url in urls:
|
|
try:
|
|
log.info(f"Auto-downloading website: {url}")
|
|
result = await state.rag_system.download_and_ingest_website(
|
|
url=url,
|
|
max_pages=30,
|
|
)
|
|
|
|
if result.get("success"):
|
|
log.info(f"Successfully downloaded {url}: {result.get('total_chunks', 0)} chunks")
|
|
return {
|
|
"downloaded": True,
|
|
"url": url,
|
|
"chunks": result.get("total_chunks", 0),
|
|
"pages": result.get("pages_processed", 0),
|
|
"local_path": result.get("local_path"),
|
|
}
|
|
except Exception as e:
|
|
log.warning(f"Failed to download {url}: {e}")
|
|
|
|
return {"downloaded": False, "reason": "All download attempts failed"}
|
|
|
|
|
|
# =============================================================================
|
|
# Chat Completion Logic
|
|
# =============================================================================
|
|
|
|
async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse:
|
|
"""Process a non-streaming chat completion request."""
|
|
messages = request.messages
|
|
|
|
# Extract the last user message
|
|
user_message = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_message = msg.content
|
|
break
|
|
|
|
if not user_message:
|
|
raise HTTPException(status_code=400, detail="No user message found")
|
|
|
|
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
|
|
download_info = await download_website_if_needed(user_message)
|
|
if download_info.get("downloaded"):
|
|
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
|
|
|
# Step 2: RAG Retrieval (now includes newly downloaded content)
|
|
context = ""
|
|
sources = []
|
|
if state.rag_system:
|
|
try:
|
|
rag_result = await state.rag_system.query(
|
|
query=user_message,
|
|
top_k=config.TOP_K_RESULTS,
|
|
)
|
|
context = rag_result.get("context", "")
|
|
sources = rag_result.get("sources", [])
|
|
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
|
except Exception as e:
|
|
log.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 3: Build enhanced prompt with context
|
|
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
|
|
|
|
# Step 4: Generate response with upstream LLM
|
|
response_content = await generate_response(
|
|
enhanced_messages,
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
)
|
|
|
|
# Step 5: Build and return response
|
|
return ChatCompletionResponse(
|
|
id=request_id,
|
|
model=config.MODEL_NAME,
|
|
choices=[
|
|
ChatCompletionChoice(
|
|
message=ChatMessage(
|
|
role="assistant",
|
|
content=response_content,
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
usage=ChatCompletionUsage(
|
|
prompt_tokens=len(str(enhanced_messages)) // 4,
|
|
completion_tokens=len(response_content) // 4,
|
|
),
|
|
)
|
|
|
|
|
|
async def stream_chat_completion(
|
|
request: ChatCompletionRequest,
|
|
request_id: str,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream a chat completion response."""
|
|
messages = request.messages
|
|
|
|
# Extract the last user message
|
|
user_message = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_message = msg.content
|
|
break
|
|
|
|
if not user_message:
|
|
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
|
|
return
|
|
|
|
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
|
|
download_info = await download_website_if_needed(user_message)
|
|
if download_info.get("downloaded"):
|
|
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
|
|
|
# Step 2: RAG Retrieval (now includes newly downloaded content)
|
|
context = ""
|
|
sources = []
|
|
if state.rag_system:
|
|
try:
|
|
rag_result = await state.rag_system.query(
|
|
query=user_message,
|
|
top_k=config.TOP_K_RESULTS,
|
|
)
|
|
context = rag_result.get("context", "")
|
|
sources = rag_result.get("sources", [])
|
|
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
|
except Exception as e:
|
|
log.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 3: Build enhanced prompt with context
|
|
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
|
|
|
|
# Step 4: Stream response from upstream LLM
|
|
created = int(time.time())
|
|
|
|
try:
|
|
if state.zai_client:
|
|
# Use actual GLM-4.7-Flash
|
|
response = state.zai_client.chat.completions.create(
|
|
model=config.UPSTREAM_MODEL,
|
|
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
|
|
temperature=request.temperature or 0.7,
|
|
max_tokens=request.max_tokens or 4096,
|
|
stream=True,
|
|
thinking={"type": "enabled"},
|
|
)
|
|
|
|
for chunk in response:
|
|
# Handle reasoning content (thinking)
|
|
if hasattr(chunk.choices[0].delta, 'reasoning_content') and chunk.choices[0].delta.reasoning_content:
|
|
# Don't expose thinking to user - this is internal RAG processing
|
|
log.debug(f"Thinking: {chunk.choices[0].delta.reasoning_content[:100]}...")
|
|
continue
|
|
|
|
# Stream actual content
|
|
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
|
content = chunk.choices[0].delta.content
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {'content': content},
|
|
'finish_reason': None
|
|
}]
|
|
})}\n\n"
|
|
|
|
# Send final chunk
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {},
|
|
'finish_reason': 'stop'
|
|
}]
|
|
})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
else:
|
|
# Mock streaming response for testing
|
|
mock_response = f"I understand you're asking about: {user_message}\n\n"
|
|
if download_info.get("downloaded"):
|
|
mock_response += f"I have downloaded and analyzed {download_info.get('url')}.\n"
|
|
mock_response += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} chunks.\n\n"
|
|
if context:
|
|
mock_response += f"Based on my knowledge base, here's what I found:\n\n{context[:1000]}...\n\n"
|
|
mock_response += "\n\n[Demo mode - configure ZAI_API_KEY for full LLM responses]"
|
|
|
|
for char in mock_response:
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {'content': char},
|
|
'finish_reason': None
|
|
}]
|
|
})}\n\n"
|
|
await asyncio.sleep(0.01)
|
|
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {},
|
|
'finish_reason': 'stop'
|
|
}]
|
|
})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
except Exception as e:
|
|
log.exception("Streaming failed")
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
|
|
def build_enhanced_messages(
|
|
messages: list[ChatMessage],
|
|
context: str,
|
|
sources: list[str],
|
|
download_info: dict = None,
|
|
) -> list[ChatMessage]:
|
|
"""Build enhanced messages with RAG context."""
|
|
enhanced = []
|
|
|
|
# Add system message with RAG context
|
|
system_content = (
|
|
"You are a helpful AI assistant with the ability to access and analyze websites on-demand. "
|
|
"When a user asks about a website, you can download and analyze its content directly. "
|
|
"Use the provided context from the knowledge base to give accurate and helpful responses. "
|
|
"If context from a website is provided, use it to answer the user's question directly with specific information. "
|
|
"Be helpful, detailed, and provide the specific information the user is asking for (headlines, summaries, etc.)."
|
|
)
|
|
|
|
if download_info and download_info.get("downloaded"):
|
|
system_content += f"\n\n--- Website Access ---\n"
|
|
system_content += f"I have successfully downloaded and analyzed the website: {download_info.get('url')}\n"
|
|
system_content += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} text chunks.\n"
|
|
system_content += "The context below contains the actual content from this website. Use it to answer the user's question."
|
|
|
|
if context:
|
|
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
|
|
if sources:
|
|
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources[:10])
|
|
|
|
enhanced.append(ChatMessage(role="system", content=system_content))
|
|
|
|
# Add conversation history (excluding old system messages)
|
|
for msg in messages:
|
|
if msg.role != "system":
|
|
enhanced.append(msg)
|
|
|
|
return enhanced
|
|
|
|
|
|
async def generate_response(
|
|
messages: list[ChatMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 4096,
|
|
) -> str:
|
|
"""Generate response using upstream LLM."""
|
|
if state.zai_client:
|
|
try:
|
|
response = state.zai_client.chat.completions.create(
|
|
model=config.UPSTREAM_MODEL,
|
|
messages=[{"role": m.role, "content": m.content} for m in messages if m.content],
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
stream=False,
|
|
thinking={"type": "enabled"},
|
|
)
|
|
|
|
# Extract content from response
|
|
content = ""
|
|
for chunk in response:
|
|
if hasattr(chunk.choices[0], 'message') and chunk.choices[0].message:
|
|
content = chunk.choices[0].message.content or ""
|
|
break
|
|
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
|
content += chunk.choices[0].delta.content
|
|
|
|
return content or "I apologize, but I couldn't generate a response."
|
|
|
|
except Exception as e:
|
|
log.error(f"Upstream LLM call failed: {e}")
|
|
return f"I encountered an error: {str(e)}"
|
|
|
|
else:
|
|
# Mock response for testing
|
|
user_msg = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_msg = msg.content
|
|
break
|
|
return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality."
|
|
|
|
|
|
# =============================================================================
|
|
# Document Management Endpoints
|
|
# =============================================================================
|
|
|
|
@app.post("/v1/documents/upload")
|
|
async def upload_document(request: Request):
|
|
"""Upload a document to the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
form = await request.form()
|
|
file = form.get("file")
|
|
if not file:
|
|
raise HTTPException(status_code=400, detail="No file provided")
|
|
|
|
content = await file.read()
|
|
filename = file.filename or "unknown"
|
|
|
|
# Process and store document
|
|
result = await state.rag_system.add_document(
|
|
content=content,
|
|
filename=filename,
|
|
)
|
|
|
|
return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)}
|
|
|
|
except Exception as e:
|
|
log.exception("Document upload failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
class WebsiteDownloadRequest(BaseModel):
|
|
"""Request model for website download."""
|
|
url: str
|
|
max_pages: int = 50
|
|
threads: int = 6
|
|
download_external_assets: bool = False
|
|
external_domains: Optional[list[str]] = None
|
|
|
|
|
|
@app.post("/v1/documents/website")
|
|
async def download_website(request: WebsiteDownloadRequest):
|
|
"""
|
|
Download a website and ingest it into the knowledge base.
|
|
|
|
This is the PRIMARY way to add content to the RAG system.
|
|
Uses the website_downloader_tool to download and process websites.
|
|
"""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
log.info(f"Downloading website: {request.url}")
|
|
|
|
result = await state.rag_system.download_and_ingest_website(
|
|
url=request.url,
|
|
max_pages=request.max_pages,
|
|
threads=request.threads,
|
|
download_external_assets=request.download_external_assets,
|
|
external_domains=request.external_domains,
|
|
)
|
|
|
|
if result.get("success"):
|
|
return {
|
|
"success": True,
|
|
"message": f"Website downloaded and ingested: {request.url}",
|
|
"url": request.url,
|
|
"local_path": result.get("local_path"),
|
|
"pages_processed": result.get("pages_processed", 0),
|
|
"total_chunks": result.get("total_chunks", 0),
|
|
}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=result.get("message", "Website download failed")
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Website download failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/v1/documents/url")
|
|
async def add_document_from_url(request: dict):
|
|
"""Add a document from URL to the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
url = request.get("url")
|
|
if not url:
|
|
raise HTTPException(status_code=400, detail="No URL provided")
|
|
|
|
try:
|
|
result = await state.rag_system.add_document_from_url(url)
|
|
return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)}
|
|
except Exception as e:
|
|
log.exception("URL document addition failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents")
|
|
async def list_documents():
|
|
"""List documents in the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
docs = await state.rag_system.list_documents()
|
|
return {"documents": docs}
|
|
except Exception as e:
|
|
log.exception("Document listing failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents/sites")
|
|
async def list_downloaded_sites():
|
|
"""List all downloaded websites in the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
sites = await state.rag_system.list_downloaded_sites()
|
|
return {"sites": sites}
|
|
except Exception as e:
|
|
log.exception("Site listing failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents/sites/{url:path}")
|
|
async def get_site_info(url: str):
|
|
"""Get information about a specific downloaded site."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
from urllib.parse import unquote
|
|
decoded_url = unquote(url)
|
|
|
|
if not decoded_url.startswith(("http://", "https://")):
|
|
decoded_url = "https://" + decoded_url
|
|
|
|
site_info = state.rag_system.get_site_info(decoded_url)
|
|
if site_info:
|
|
return {"site": site_info}
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Site not found")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Site info retrieval failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/documents/{doc_id}")
|
|
async def delete_document(doc_id: str):
|
|
"""Delete a document from the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
await state.rag_system.delete_document(doc_id)
|
|
return {"success": True, "message": f"Document {doc_id} deleted"}
|
|
except Exception as e:
|
|
log.exception("Document deletion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/documents/sites/{url:path}")
|
|
async def delete_site(url: str):
|
|
"""Delete a downloaded website and all its content from the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
from urllib.parse import unquote
|
|
decoded_url = unquote(url)
|
|
|
|
if not decoded_url.startswith(("http://", "https://")):
|
|
decoded_url = "https://" + decoded_url
|
|
|
|
result = await state.rag_system.delete_site(decoded_url)
|
|
|
|
if result.get("success"):
|
|
return {
|
|
"success": True,
|
|
"message": f"Site {decoded_url} deleted",
|
|
"deleted_chunks": result.get("deleted_chunks", 0),
|
|
"deleted_path": result.get("deleted_path"),
|
|
}
|
|
else:
|
|
raise HTTPException(status_code=404, detail=result.get("message", "Site not found"))
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Site deletion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# Health and Status Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {
|
|
"status": "healthy",
|
|
"uptime": time.time() - state.startup_time,
|
|
"rag_enabled": state.rag_system is not None,
|
|
"tools_enabled": state.tool_manager is not None,
|
|
"llm_connected": state.zai_client is not None,
|
|
}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint with API info."""
|
|
return {
|
|
"name": "DocRAG API",
|
|
"version": "1.0.0",
|
|
"description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash. Auto-downloads and analyzes websites when users ask about them.",
|
|
"endpoints": {
|
|
"chat": "/v1/chat/completions",
|
|
"models": "/v1/models",
|
|
"documents": "/v1/documents",
|
|
"download_website": "/v1/documents/website",
|
|
"list_sites": "/v1/documents/sites",
|
|
"health": "/health",
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Main Entry Point
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=config.HOST,
|
|
port=config.PORT,
|
|
reload=config.DEBUG,
|
|
)
|