- Replace z-ai-web-dev-sdk with openai SDK - Add OPENROUTER_API_KEY and OPENROUTER_BASE_URL config - Update AsyncOpenAI client for OpenRouter - Update generate_response and stream_chat_completion - Update .env.example with OpenRouter settings
908 lines
31 KiB
Python
908 lines
31 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
DocRAG - OpenAI-Compatible RAG Server
|
|
|
|
This application presents itself as a standard OpenAI API server that can be used
|
|
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
|
|
1. Detects URLs in user messages and auto-downloads websites
|
|
2. Ingests website content into the RAG knowledge base
|
|
3. Retrieves relevant context from the knowledge base
|
|
4. Passes the enriched context to GLM-4.7-Flash for response generation
|
|
|
|
The user sees a normal chat experience, but the system is actually doing
|
|
sophisticated RAG operations in the background.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, AsyncIterator, Optional
|
|
|
|
# Load environment variables from .env file
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
log = logging.getLogger(__name__)
|
|
|
|
# FastAPI imports
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Import RAG components
|
|
from rag import RAGSystem, get_rag_system
|
|
from rag.document_processor import DocumentProcessor
|
|
|
|
# Import tools
|
|
from tools import ToolManager, get_tool_manager
|
|
|
|
# Import OpenAI client for OpenRouter
|
|
from openai import AsyncOpenAI
|
|
|
|
|
|
# =============================================================================
|
|
# Configuration
|
|
# =============================================================================
|
|
|
|
class Config:
|
|
"""Application configuration from environment variables."""
|
|
|
|
# Server settings
|
|
HOST: str = os.getenv("HOST", "0.0.0.0")
|
|
PORT: int = int(os.getenv("PORT", "8000"))
|
|
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
|
|
|
|
# Model settings
|
|
MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG")
|
|
UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "openrouter/free")
|
|
|
|
# OpenRouter API settings
|
|
OPENROUTER_API_KEY: str = os.getenv("OPENROUTER_API_KEY", "")
|
|
OPENROUTER_BASE_URL: str = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
|
|
|
|
# RAG settings
|
|
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
|
VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors")
|
|
DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents")
|
|
CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000"))
|
|
CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200"))
|
|
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
|
|
|
|
# Tool settings
|
|
ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true"
|
|
MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "3"))
|
|
|
|
|
|
config = Config()
|
|
|
|
|
|
# =============================================================================
|
|
# OpenAI-Compatible Models
|
|
# =============================================================================
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""OpenAI chat message format."""
|
|
role: str
|
|
content: Optional[str] = None
|
|
name: Optional[str] = None
|
|
tool_calls: Optional[list[dict]] = None
|
|
tool_call_id: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
"""OpenAI chat completion request format."""
|
|
model: str = "DocRAG"
|
|
messages: list[ChatMessage]
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 1.0
|
|
n: Optional[int] = 1
|
|
stream: Optional[bool] = False
|
|
stop: Optional[list[str]] = None
|
|
max_tokens: Optional[int] = None
|
|
presence_penalty: Optional[float] = 0.0
|
|
frequency_penalty: Optional[float] = 0.0
|
|
user: Optional[str] = None
|
|
tools: Optional[list[dict]] = None
|
|
tool_choice: Optional[str | dict] = None
|
|
|
|
|
|
class ChatCompletionChoice(BaseModel):
|
|
"""OpenAI chat completion choice."""
|
|
index: int = 0
|
|
message: ChatMessage
|
|
finish_reason: str = "stop"
|
|
|
|
|
|
class ChatCompletionUsage(BaseModel):
|
|
"""Token usage statistics."""
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
total_tokens: int = 0
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""OpenAI chat completion response format."""
|
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}")
|
|
object: str = "chat.completion"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
model: str = config.MODEL_NAME
|
|
choices: list[ChatCompletionChoice]
|
|
usage: ChatCompletionUsage = ChatCompletionUsage()
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""OpenAI model info format."""
|
|
id: str
|
|
object: str = "model"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
owned_by: str = "organization"
|
|
|
|
|
|
class ModelList(BaseModel):
|
|
"""OpenAI model list format."""
|
|
object: str = "list"
|
|
data: list[ModelInfo]
|
|
|
|
|
|
# =============================================================================
|
|
# Application State
|
|
# =============================================================================
|
|
|
|
class AppState:
|
|
"""Global application state."""
|
|
rag_system: Optional[RAGSystem] = None
|
|
tool_manager: Optional[ToolManager] = None
|
|
llm_client: Optional[AsyncOpenAI] = None
|
|
startup_time: float = time.time()
|
|
|
|
|
|
state = AppState()
|
|
|
|
|
|
# =============================================================================
|
|
# Lifespan Management
|
|
# =============================================================================
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manage application lifespan - startup and shutdown."""
|
|
log.info("Starting DocRAG server...")
|
|
|
|
# Initialize RAG system
|
|
try:
|
|
log.info("Initializing RAG system...")
|
|
state.rag_system = await get_rag_system(
|
|
embedding_model=config.EMBEDDING_MODEL,
|
|
vector_store_path=config.VECTOR_STORE_PATH,
|
|
documents_path=config.DOCUMENTS_PATH,
|
|
chunk_size=config.CHUNK_SIZE,
|
|
chunk_overlap=config.CHUNK_OVERLAP,
|
|
)
|
|
log.info("RAG system initialized successfully")
|
|
except Exception as e:
|
|
log.warning(f"RAG system initialization deferred: {e}")
|
|
state.rag_system = None
|
|
|
|
# Initialize tool manager
|
|
try:
|
|
log.info("Initializing tool manager...")
|
|
state.tool_manager = get_tool_manager()
|
|
log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}")
|
|
except Exception as e:
|
|
log.warning(f"Tool manager initialization failed: {e}")
|
|
state.tool_manager = None
|
|
|
|
# Initialize OpenRouter client for upstream LLM
|
|
try:
|
|
if config.OPENROUTER_API_KEY:
|
|
log.info("Initializing OpenRouter client...")
|
|
state.llm_client = AsyncOpenAI(
|
|
api_key=config.OPENROUTER_API_KEY,
|
|
base_url=config.OPENROUTER_BASE_URL,
|
|
)
|
|
log.info(f"OpenRouter client initialized successfully (model: {config.UPSTREAM_MODEL})")
|
|
else:
|
|
log.warning("No OPENROUTER_API_KEY provided - using mock responses")
|
|
state.llm_client = None
|
|
except Exception as e:
|
|
log.error(f"Failed to initialize OpenRouter client: {e}")
|
|
state.llm_client = None
|
|
|
|
log.info(f"DocRAG server started on {config.HOST}:{config.PORT}")
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
log.info("Shutting down DocRAG server...")
|
|
if state.rag_system:
|
|
await state.rag_system.close()
|
|
log.info("DocRAG server stopped")
|
|
|
|
|
|
# =============================================================================
|
|
# FastAPI Application
|
|
# =============================================================================
|
|
|
|
app = FastAPI(
|
|
title="DocRAG API",
|
|
description="OpenAI-compatible RAG server powered by OpenRouter",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware for Open WebUI compatibility
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# OpenAI-Compatible Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/v1/models")
|
|
@app.get("/models")
|
|
async def list_models():
|
|
"""List available models (OpenAI-compatible)."""
|
|
return ModelList(
|
|
data=[
|
|
ModelInfo(id=config.MODEL_NAME, owned_by="docrag"),
|
|
ModelInfo(id="DocRAG", owned_by="docrag"),
|
|
]
|
|
)
|
|
|
|
|
|
@app.get("/v1/models/{model_id}")
|
|
@app.get("/models/{model_id}")
|
|
async def get_model(model_id: str):
|
|
"""Get model information (OpenAI-compatible)."""
|
|
if model_id not in [config.MODEL_NAME, "DocRAG"]:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
return ModelInfo(id=model_id, owned_by="docrag")
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@app.post("/chat/completions")
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
"""Handle chat completions (OpenAI-compatible)."""
|
|
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
|
|
try:
|
|
if request.stream:
|
|
return StreamingResponse(
|
|
stream_chat_completion(request, request_id),
|
|
media_type="text/event-stream",
|
|
)
|
|
else:
|
|
return await complete_chat(request, request_id)
|
|
except Exception as e:
|
|
log.exception("Chat completion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# URL Detection and Website Download
|
|
# =============================================================================
|
|
|
|
def extract_urls_from_message(message: str) -> list[str]:
|
|
"""Extract URLs from a message, including domains without scheme."""
|
|
# Match full URLs
|
|
url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+'
|
|
full_urls = re.findall(url_pattern, message)
|
|
|
|
# Match domain names (with or without www)
|
|
domain_pattern = r'(?:^|[^\w/-])(?:www\.)?([a-zA-Z0-9][-a-zA-Z0-9]*\.[a-zA-Z]{2,}(?:\.[a-zA-Z]{2,})?)(?:/[^\s]*)?(?=[^\w/-]|$)'
|
|
domains = re.findall(domain_pattern, message)
|
|
|
|
urls = list(full_urls)
|
|
for domain in domains:
|
|
# Check if it's a valid domain (not just a word)
|
|
if '.' in domain and len(domain) > 4:
|
|
normalized = f"https://{domain}" if not domain.startswith(('http://', 'https://')) else domain
|
|
if normalized not in urls:
|
|
urls.append(normalized)
|
|
|
|
return urls
|
|
|
|
|
|
def should_download_website(message: str, urls: list[str]) -> bool:
|
|
"""Determine if the user wants to access content from a website."""
|
|
if not urls:
|
|
return False
|
|
|
|
message_lower = message.lower()
|
|
|
|
# Keywords indicating user wants website content
|
|
access_keywords = [
|
|
'go to', 'visit', 'check', 'look at', 'browse', 'open',
|
|
'what is on', 'tell me about', 'give me', 'show me', 'get',
|
|
'headlines', 'content', 'information from', 'from the', 'from',
|
|
'on the website', 'on the site', 'the website', 'the site', 'website',
|
|
'summarize', 'extract', 'read', 'analyze', 'find', 'search',
|
|
'what does', 'what\'s on', 'what is', 'tell me', 'about',
|
|
'news', 'articles', 'posts', 'page', 'pages',
|
|
]
|
|
|
|
# Check if message contains access intent
|
|
has_access_intent = any(kw in message_lower for kw in access_keywords)
|
|
|
|
# Also trigger if URL is directly mentioned with a question
|
|
has_question = '?' in message or any(qw in message_lower for qw in ['what', 'how', 'who', 'where', 'when', 'why'])
|
|
|
|
return (has_access_intent or has_question) and len(urls) > 0
|
|
|
|
|
|
async def download_website_if_needed(user_message: str) -> dict[str, Any]:
|
|
"""
|
|
Download website if user is asking about one.
|
|
Returns download info if successful.
|
|
"""
|
|
urls = extract_urls_from_message(user_message)
|
|
|
|
if not should_download_website(user_message, urls):
|
|
return {"downloaded": False, "reason": "No website access intent detected"}
|
|
|
|
if not state.rag_system:
|
|
return {"downloaded": False, "reason": "RAG system not initialized"}
|
|
|
|
for url in urls:
|
|
try:
|
|
log.info(f"Auto-downloading website: {url}")
|
|
result = await state.rag_system.download_and_ingest_website(
|
|
url=url,
|
|
max_pages=30,
|
|
)
|
|
|
|
if result.get("success"):
|
|
log.info(f"Successfully downloaded {url}: {result.get('total_chunks', 0)} chunks")
|
|
return {
|
|
"downloaded": True,
|
|
"url": url,
|
|
"chunks": result.get("total_chunks", 0),
|
|
"pages": result.get("pages_processed", 0),
|
|
"local_path": result.get("local_path"),
|
|
}
|
|
except Exception as e:
|
|
log.warning(f"Failed to download {url}: {e}")
|
|
|
|
return {"downloaded": False, "reason": "All download attempts failed"}
|
|
|
|
|
|
# =============================================================================
|
|
# Chat Completion Logic
|
|
# =============================================================================
|
|
|
|
async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse:
|
|
"""Process a non-streaming chat completion request."""
|
|
messages = request.messages
|
|
|
|
# Extract the last user message
|
|
user_message = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_message = msg.content
|
|
break
|
|
|
|
if not user_message:
|
|
raise HTTPException(status_code=400, detail="No user message found")
|
|
|
|
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
|
|
download_info = await download_website_if_needed(user_message)
|
|
if download_info.get("downloaded"):
|
|
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
|
|
|
# Step 2: RAG Retrieval (now includes newly downloaded content)
|
|
context = ""
|
|
sources = []
|
|
if state.rag_system:
|
|
try:
|
|
rag_result = await state.rag_system.query(
|
|
query=user_message,
|
|
top_k=config.TOP_K_RESULTS,
|
|
)
|
|
context = rag_result.get("context", "")
|
|
sources = rag_result.get("sources", [])
|
|
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
|
except Exception as e:
|
|
log.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 3: Build enhanced prompt with context
|
|
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
|
|
|
|
# Step 4: Generate response with upstream LLM
|
|
response_content = await generate_response(
|
|
enhanced_messages,
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
)
|
|
|
|
# Step 5: Build and return response
|
|
return ChatCompletionResponse(
|
|
id=request_id,
|
|
model=config.MODEL_NAME,
|
|
choices=[
|
|
ChatCompletionChoice(
|
|
message=ChatMessage(
|
|
role="assistant",
|
|
content=response_content,
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
usage=ChatCompletionUsage(
|
|
prompt_tokens=len(str(enhanced_messages)) // 4,
|
|
completion_tokens=len(response_content) // 4,
|
|
),
|
|
)
|
|
|
|
|
|
async def stream_chat_completion(
|
|
request: ChatCompletionRequest,
|
|
request_id: str,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream a chat completion response."""
|
|
messages = request.messages
|
|
|
|
# Extract the last user message
|
|
user_message = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_message = msg.content
|
|
break
|
|
|
|
if not user_message:
|
|
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
|
|
return
|
|
|
|
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
|
|
download_info = await download_website_if_needed(user_message)
|
|
if download_info.get("downloaded"):
|
|
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
|
|
|
# Step 2: RAG Retrieval (now includes newly downloaded content)
|
|
context = ""
|
|
sources = []
|
|
if state.rag_system:
|
|
try:
|
|
rag_result = await state.rag_system.query(
|
|
query=user_message,
|
|
top_k=config.TOP_K_RESULTS,
|
|
)
|
|
context = rag_result.get("context", "")
|
|
sources = rag_result.get("sources", [])
|
|
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
|
except Exception as e:
|
|
log.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 3: Build enhanced prompt with context
|
|
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
|
|
|
|
# Step 4: Stream response from upstream LLM
|
|
created = int(time.time())
|
|
|
|
try:
|
|
if state.llm_client:
|
|
# Use OpenRouter with streaming
|
|
stream = await state.llm_client.chat.completions.create(
|
|
model=config.UPSTREAM_MODEL,
|
|
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
|
|
temperature=request.temperature or 0.7,
|
|
max_tokens=request.max_tokens or 4096,
|
|
stream=True,
|
|
)
|
|
|
|
async for chunk in stream:
|
|
if chunk.choices and chunk.choices[0].delta.content:
|
|
content = chunk.choices[0].delta.content
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {'content': content},
|
|
'finish_reason': None
|
|
}]
|
|
})}\n\n"
|
|
|
|
# Send final chunk
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {},
|
|
'finish_reason': 'stop'
|
|
}]
|
|
})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
else:
|
|
# Mock streaming response for testing
|
|
mock_response = f"I understand you're asking about: {user_message}\n\n"
|
|
if download_info.get("downloaded"):
|
|
mock_response += f"I have downloaded and analyzed {download_info.get('url')}.\n"
|
|
mock_response += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} chunks.\n\n"
|
|
if context:
|
|
mock_response += f"Based on my knowledge base, here's what I found:\n\n{context[:1000]}...\n\n"
|
|
mock_response += "\n\n[Demo mode - configure OPENROUTER_API_KEY for full LLM responses]"
|
|
|
|
for char in mock_response:
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {'content': char},
|
|
'finish_reason': None
|
|
}]
|
|
})}\n\n"
|
|
await asyncio.sleep(0.01)
|
|
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {},
|
|
'finish_reason': 'stop'
|
|
}]
|
|
})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
except Exception as e:
|
|
log.exception("Streaming failed")
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
|
|
def build_enhanced_messages(
|
|
messages: list[ChatMessage],
|
|
context: str,
|
|
sources: list[str],
|
|
download_info: dict = None,
|
|
) -> list[ChatMessage]:
|
|
"""Build enhanced messages with RAG context."""
|
|
enhanced = []
|
|
|
|
# Add system message with RAG context
|
|
system_content = (
|
|
"You are a helpful AI assistant with the ability to access and analyze websites on-demand. "
|
|
"When a user asks about a website, you can download and analyze its content directly. "
|
|
"Use the provided context from the knowledge base to give accurate and helpful responses. "
|
|
"If context from a website is provided, use it to answer the user's question directly with specific information. "
|
|
"Be helpful, detailed, and provide the specific information the user is asking for (headlines, summaries, etc.)."
|
|
)
|
|
|
|
if download_info and download_info.get("downloaded"):
|
|
system_content += f"\n\n--- Website Access ---\n"
|
|
system_content += f"I have successfully downloaded and analyzed the website: {download_info.get('url')}\n"
|
|
system_content += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} text chunks.\n"
|
|
system_content += "The context below contains the actual content from this website. Use it to answer the user's question."
|
|
|
|
if context:
|
|
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
|
|
if sources:
|
|
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources[:10])
|
|
|
|
enhanced.append(ChatMessage(role="system", content=system_content))
|
|
|
|
# Add conversation history (excluding old system messages)
|
|
for msg in messages:
|
|
if msg.role != "system":
|
|
enhanced.append(msg)
|
|
|
|
return enhanced
|
|
|
|
|
|
async def generate_response(
|
|
messages: list[ChatMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 4096,
|
|
) -> str:
|
|
"""Generate response using upstream LLM via OpenRouter."""
|
|
if state.llm_client:
|
|
try:
|
|
response = await state.llm_client.chat.completions.create(
|
|
model=config.UPSTREAM_MODEL,
|
|
messages=[{"role": m.role, "content": m.content} for m in messages if m.content],
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
|
|
# Extract content from response
|
|
if response.choices:
|
|
message_content = response.choices[0].message.content
|
|
return message_content or "I apologize, but I couldn't generate a response."
|
|
|
|
return "I apologize, but I couldn't generate a response."
|
|
|
|
except Exception as e:
|
|
log.error(f"OpenRouter LLM call failed: {e}")
|
|
return f"I encountered an error: {str(e)}"
|
|
|
|
else:
|
|
# Mock response for testing
|
|
user_msg = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_msg = msg.content
|
|
break
|
|
return f"Demo mode response. Your question: {user_msg[:100]}... Configure OPENROUTER_API_KEY for full functionality."
|
|
|
|
|
|
# =============================================================================
|
|
# Document Management Endpoints
|
|
# =============================================================================
|
|
|
|
@app.post("/v1/documents/upload")
|
|
async def upload_document(request: Request):
|
|
"""Upload a document to the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
form = await request.form()
|
|
file = form.get("file")
|
|
if not file:
|
|
raise HTTPException(status_code=400, detail="No file provided")
|
|
|
|
content = await file.read()
|
|
filename = file.filename or "unknown"
|
|
|
|
# Process and store document
|
|
result = await state.rag_system.add_document(
|
|
content=content,
|
|
filename=filename,
|
|
)
|
|
|
|
return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)}
|
|
|
|
except Exception as e:
|
|
log.exception("Document upload failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
class WebsiteDownloadRequest(BaseModel):
|
|
"""Request model for website download."""
|
|
url: str
|
|
max_pages: int = 50
|
|
threads: int = 6
|
|
download_external_assets: bool = False
|
|
external_domains: Optional[list[str]] = None
|
|
|
|
|
|
@app.post("/v1/documents/website")
|
|
async def download_website(request: WebsiteDownloadRequest):
|
|
"""
|
|
Download a website and ingest it into the knowledge base.
|
|
|
|
This is the PRIMARY way to add content to the RAG system.
|
|
Uses the website_downloader_tool to download and process websites.
|
|
"""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
log.info(f"Downloading website: {request.url}")
|
|
|
|
result = await state.rag_system.download_and_ingest_website(
|
|
url=request.url,
|
|
max_pages=request.max_pages,
|
|
threads=request.threads,
|
|
download_external_assets=request.download_external_assets,
|
|
external_domains=request.external_domains,
|
|
)
|
|
|
|
if result.get("success"):
|
|
return {
|
|
"success": True,
|
|
"message": f"Website downloaded and ingested: {request.url}",
|
|
"url": request.url,
|
|
"local_path": result.get("local_path"),
|
|
"pages_processed": result.get("pages_processed", 0),
|
|
"total_chunks": result.get("total_chunks", 0),
|
|
}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=result.get("message", "Website download failed")
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Website download failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/v1/documents/url")
|
|
async def add_document_from_url(request: dict):
|
|
"""Add a document from URL to the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
url = request.get("url")
|
|
if not url:
|
|
raise HTTPException(status_code=400, detail="No URL provided")
|
|
|
|
try:
|
|
result = await state.rag_system.add_document_from_url(url)
|
|
return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)}
|
|
except Exception as e:
|
|
log.exception("URL document addition failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents")
|
|
async def list_documents():
|
|
"""List documents in the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
docs = await state.rag_system.list_documents()
|
|
return {"documents": docs}
|
|
except Exception as e:
|
|
log.exception("Document listing failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents/sites")
|
|
async def list_downloaded_sites():
|
|
"""List all downloaded websites in the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
sites = await state.rag_system.list_downloaded_sites()
|
|
return {"sites": sites}
|
|
except Exception as e:
|
|
log.exception("Site listing failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents/sites/{url:path}")
|
|
async def get_site_info(url: str):
|
|
"""Get information about a specific downloaded site."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
from urllib.parse import unquote
|
|
decoded_url = unquote(url)
|
|
|
|
if not decoded_url.startswith(("http://", "https://")):
|
|
decoded_url = "https://" + decoded_url
|
|
|
|
site_info = state.rag_system.get_site_info(decoded_url)
|
|
if site_info:
|
|
return {"site": site_info}
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Site not found")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Site info retrieval failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/documents/{doc_id}")
|
|
async def delete_document(doc_id: str):
|
|
"""Delete a document from the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
await state.rag_system.delete_document(doc_id)
|
|
return {"success": True, "message": f"Document {doc_id} deleted"}
|
|
except Exception as e:
|
|
log.exception("Document deletion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/documents/sites/{url:path}")
|
|
async def delete_site(url: str):
|
|
"""Delete a downloaded website and all its content from the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
from urllib.parse import unquote
|
|
decoded_url = unquote(url)
|
|
|
|
if not decoded_url.startswith(("http://", "https://")):
|
|
decoded_url = "https://" + decoded_url
|
|
|
|
result = await state.rag_system.delete_site(decoded_url)
|
|
|
|
if result.get("success"):
|
|
return {
|
|
"success": True,
|
|
"message": f"Site {decoded_url} deleted",
|
|
"deleted_chunks": result.get("deleted_chunks", 0),
|
|
"deleted_path": result.get("deleted_path"),
|
|
}
|
|
else:
|
|
raise HTTPException(status_code=404, detail=result.get("message", "Site not found"))
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Site deletion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# Health and Status Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {
|
|
"status": "healthy",
|
|
"uptime": time.time() - state.startup_time,
|
|
"rag_enabled": state.rag_system is not None,
|
|
"tools_enabled": state.tool_manager is not None,
|
|
"llm_connected": state.llm_client is not None,
|
|
}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint with API info."""
|
|
return {
|
|
"name": "DocRAG API",
|
|
"version": "1.0.0",
|
|
"description": "OpenAI-compatible RAG server powered by OpenRouter. Auto-downloads and analyzes websites when users ask about them.",
|
|
"endpoints": {
|
|
"chat": "/v1/chat/completions",
|
|
"models": "/v1/models",
|
|
"documents": "/v1/documents",
|
|
"download_website": "/v1/documents/website",
|
|
"list_sites": "/v1/documents/sites",
|
|
"health": "/health",
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Main Entry Point
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=config.HOST,
|
|
port=config.PORT,
|
|
reload=config.DEBUG,
|
|
)
|