No LLM needed for tool selection. Flow is now: Request → run ALL tools in parallel → results into system prompt → 1 LLM call - _run_all_tools: fires every tool concurrently (30s timeout each) - No required args: run with schema defaults - Query-like required args (query, topic, title, etc): use user message - Specific args (symbol, url, pmid): skip (can't guess) - _build_tool_results_text: formats all results into system prompt - build_enhanced_messages: system prompt now has real-time data section - call_llm: dead simple, just prompt → response (replaces generate_response) - Removed: generate_response, _parse_tool_calls, _clean_tool_syntax, _build_tool_descriptions (all dead code now) - Streaming path: same flow, runs tools then streams the LLM response - Both streaming and non-streaming use identical tool pipeline
1045 lines
36 KiB
Python
Executable File
1045 lines
36 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
DocRAG - OpenAI-Compatible RAG Server
|
|
|
|
This application presents itself as a standard OpenAI API server that can be used
|
|
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
|
|
1. Detects URLs in user messages and auto-downloads websites
|
|
2. Ingests website content into the RAG knowledge base
|
|
3. Retrieves relevant context from the knowledge base
|
|
4. Passes the enriched context to OpenRouter for response generation
|
|
|
|
The user sees a normal chat experience, but the system is actually doing
|
|
sophisticated RAG operations in the background.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from typing import Any, AsyncIterator, Optional
|
|
|
|
# Load environment variables from .env file (look in script directory)
|
|
from dotenv import load_dotenv
|
|
SCRIPT_DIR = Path(__file__).parent.resolve()
|
|
ENV_FILE = SCRIPT_DIR / ".env"
|
|
load_dotenv(ENV_FILE)
|
|
|
|
# Also try loading from current working directory
|
|
load_dotenv()
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
log = logging.getLogger(__name__)
|
|
|
|
# FastAPI imports
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Import RAG components
|
|
from rag import RAGSystem, get_rag_system
|
|
from rag.document_processor import DocumentProcessor
|
|
|
|
# Import tools
|
|
from tools import ToolManager, get_tool_manager
|
|
|
|
# Import OpenAI client for OpenRouter
|
|
from openai import AsyncOpenAI
|
|
|
|
|
|
# =============================================================================
|
|
# Configuration
|
|
# =============================================================================
|
|
|
|
class Config:
|
|
"""Application configuration from environment variables."""
|
|
|
|
# Server settings
|
|
HOST: str = os.getenv("HOST", "0.0.0.0")
|
|
PORT: int = int(os.getenv("PORT", "8000"))
|
|
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
|
|
|
|
# Model settings
|
|
MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG")
|
|
UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "openrouter/free")
|
|
|
|
# OpenRouter API settings
|
|
OPENROUTER_API_KEY: str = os.getenv("OPENROUTER_API_KEY", "")
|
|
OPENROUTER_BASE_URL: str = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
|
|
|
|
# RAG settings
|
|
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
|
VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors")
|
|
DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents")
|
|
CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000"))
|
|
CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200"))
|
|
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
|
|
|
|
# Tool settings
|
|
ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true"
|
|
MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "5"))
|
|
|
|
|
|
config = Config()
|
|
|
|
|
|
# =============================================================================
|
|
# OpenAI-Compatible Models
|
|
# =============================================================================
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""OpenAI chat message format."""
|
|
role: str
|
|
content: Optional[str] = None
|
|
name: Optional[str] = None
|
|
tool_calls: Optional[list[dict]] = None
|
|
tool_call_id: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
"""OpenAI chat completion request format."""
|
|
model: str = "DocRAG"
|
|
messages: list[ChatMessage]
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 1.0
|
|
n: Optional[int] = 1
|
|
stream: Optional[bool] = False
|
|
stop: Optional[list[str]] = None
|
|
max_tokens: Optional[int] = None
|
|
presence_penalty: Optional[float] = 0.0
|
|
frequency_penalty: Optional[float] = 0.0
|
|
user: Optional[str] = None
|
|
tools: Optional[list[dict]] = None
|
|
tool_choice: Optional[str | dict] = None
|
|
|
|
|
|
class ChatCompletionChoice(BaseModel):
|
|
"""OpenAI chat completion choice."""
|
|
index: int = 0
|
|
message: ChatMessage
|
|
finish_reason: str = "stop"
|
|
|
|
|
|
class ChatCompletionUsage(BaseModel):
|
|
"""Token usage statistics."""
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
total_tokens: int = 0
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""OpenAI chat completion response format."""
|
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}")
|
|
object: str = "chat.completion"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
model: str = config.MODEL_NAME
|
|
choices: list[ChatCompletionChoice]
|
|
usage: ChatCompletionUsage = ChatCompletionUsage()
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""OpenAI model info format."""
|
|
id: str
|
|
object: str = "model"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
owned_by: str = "organization"
|
|
|
|
|
|
class ModelList(BaseModel):
|
|
"""OpenAI model list format."""
|
|
object: str = "list"
|
|
data: list[ModelInfo]
|
|
|
|
|
|
# =============================================================================
|
|
# Application State
|
|
# =============================================================================
|
|
|
|
class AppState:
|
|
"""Global application state."""
|
|
rag_system: Optional[RAGSystem] = None
|
|
tool_manager: Optional[ToolManager] = None
|
|
llm_client: Optional[AsyncOpenAI] = None
|
|
startup_time: float = time.time()
|
|
|
|
|
|
state = AppState()
|
|
|
|
|
|
# =============================================================================
|
|
# Lifespan Management
|
|
# =============================================================================
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manage application lifespan - startup and shutdown."""
|
|
log.info("Starting DocRAG server...")
|
|
|
|
# Initialize RAG system
|
|
try:
|
|
log.info("Initializing RAG system...")
|
|
state.rag_system = await get_rag_system(
|
|
embedding_model=config.EMBEDDING_MODEL,
|
|
vector_store_path=config.VECTOR_STORE_PATH,
|
|
documents_path=config.DOCUMENTS_PATH,
|
|
chunk_size=config.CHUNK_SIZE,
|
|
chunk_overlap=config.CHUNK_OVERLAP,
|
|
)
|
|
log.info("RAG system initialized successfully")
|
|
except Exception as e:
|
|
log.warning(f"RAG system initialization deferred: {e}")
|
|
state.rag_system = None
|
|
|
|
# Initialize tool manager
|
|
try:
|
|
log.info("Initializing tool manager...")
|
|
state.tool_manager = get_tool_manager()
|
|
log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}")
|
|
except Exception as e:
|
|
log.warning(f"Tool manager initialization failed: {e}")
|
|
state.tool_manager = None
|
|
|
|
# Initialize OpenRouter client for upstream LLM
|
|
# Debug: Show .env file status
|
|
log.info(f"Looking for .env file at: {ENV_FILE}")
|
|
log.info(f".env file exists: {ENV_FILE.exists()}")
|
|
|
|
api_key = config.OPENROUTER_API_KEY
|
|
if api_key:
|
|
key_preview = f"{api_key[:8]}...{api_key[-4:]}" if len(api_key) > 12 else "***"
|
|
log.info(f"OPENROUTER_API_KEY found: {key_preview}")
|
|
else:
|
|
log.warning("OPENROUTER_API_KEY not found in environment!")
|
|
log.warning(f"Checked .env at: {ENV_FILE}")
|
|
log.warning("Set OPENROUTER_API_KEY in .env file or as environment variable")
|
|
|
|
try:
|
|
if api_key:
|
|
log.info("Initializing OpenRouter client...")
|
|
# Create custom httpx client to avoid proxy issues
|
|
import httpx
|
|
http_client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(60.0, connect=10.0),
|
|
)
|
|
state.llm_client = AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url=config.OPENROUTER_BASE_URL,
|
|
http_client=http_client,
|
|
)
|
|
log.info(f"OpenRouter client initialized successfully (model: {config.UPSTREAM_MODEL})")
|
|
else:
|
|
log.warning("No OPENROUTER_API_KEY provided - using mock responses")
|
|
state.llm_client = None
|
|
except Exception as e:
|
|
log.error(f"Failed to initialize OpenRouter client: {e}")
|
|
state.llm_client = None
|
|
|
|
log.info(f"DocRAG server started on {config.HOST}:{config.PORT}")
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
log.info("Shutting down DocRAG server...")
|
|
if state.rag_system:
|
|
await state.rag_system.close()
|
|
log.info("DocRAG server stopped")
|
|
|
|
|
|
# =============================================================================
|
|
# FastAPI Application
|
|
# =============================================================================
|
|
|
|
app = FastAPI(
|
|
title="DocRAG API",
|
|
description="OpenAI-compatible RAG server powered by OpenRouter",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware for Open WebUI compatibility
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# OpenAI-Compatible Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/v1/models")
|
|
@app.get("/models")
|
|
async def list_models():
|
|
"""List available models (OpenAI-compatible)."""
|
|
return ModelList(
|
|
data=[
|
|
ModelInfo(id=config.MODEL_NAME, owned_by="docrag"),
|
|
ModelInfo(id="DocRAG", owned_by="docrag"),
|
|
]
|
|
)
|
|
|
|
|
|
@app.get("/v1/models/{model_id}")
|
|
@app.get("/models/{model_id}")
|
|
async def get_model(model_id: str):
|
|
"""Get model information (OpenAI-compatible)."""
|
|
if model_id not in [config.MODEL_NAME, "DocRAG"]:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
return ModelInfo(id=model_id, owned_by="docrag")
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@app.post("/chat/completions")
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
"""Handle chat completions (OpenAI-compatible)."""
|
|
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
|
|
try:
|
|
if request.stream:
|
|
return StreamingResponse(
|
|
stream_chat_completion(request, request_id),
|
|
media_type="text/event-stream",
|
|
)
|
|
else:
|
|
return await complete_chat(request, request_id)
|
|
except Exception as e:
|
|
log.exception("Chat completion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# URL Detection and Website Download
|
|
# =============================================================================
|
|
|
|
def extract_urls_from_message(message: str) -> list[str]:
|
|
"""Extract URLs from a message, including domains without scheme."""
|
|
# Match full URLs
|
|
url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+'
|
|
full_urls = re.findall(url_pattern, message)
|
|
|
|
# Match domain names (with or without www)
|
|
domain_pattern = r'(?:^|[^\w/-])(?:www\.)?([a-zA-Z0-9][-a-zA-Z0-9]*\.[a-zA-Z]{2,}(?:\.[a-zA-Z]{2,})?)(?:/[^\s]*)?(?=[^\w/-]|$)'
|
|
domains = re.findall(domain_pattern, message)
|
|
|
|
urls = list(full_urls)
|
|
for domain in domains:
|
|
# Check if it's a valid domain (not just a word)
|
|
if '.' in domain and len(domain) > 4:
|
|
normalized = f"https://{domain}" if not domain.startswith(('http://', 'https://')) else domain
|
|
if normalized not in urls:
|
|
urls.append(normalized)
|
|
|
|
return urls
|
|
|
|
|
|
def should_download_website(message: str, urls: list[str]) -> bool:
|
|
"""Determine if the user wants to access content from a website."""
|
|
if not urls:
|
|
return False
|
|
|
|
message_lower = message.lower()
|
|
|
|
# Skip if this is an automated Open WebUI task (not a real user request)
|
|
automated_task_indicators = [
|
|
'### task:', 'generate a concise', 'generate 1-3',
|
|
'suggest 3-5 relevant follow-up', 'summarizing the chat history',
|
|
'categorizing the main themes', 'follow-up questions or prompts',
|
|
]
|
|
if any(indicator in message_lower for indicator in automated_task_indicators):
|
|
log.info("Skipping website download - appears to be an automated task")
|
|
return False
|
|
|
|
# Keywords indicating user wants website content
|
|
access_keywords = [
|
|
'go to', 'visit', 'check', 'look at', 'browse', 'open',
|
|
'what is on', 'tell me about', 'give me', 'show me', 'get',
|
|
'headlines', 'content', 'information from', 'from the', 'from',
|
|
'on the website', 'on the site', 'the website', 'the site', 'website',
|
|
'summarize', 'extract', 'read', 'analyze', 'find', 'search',
|
|
'what does', 'what\'s on', 'what is', 'tell me', 'about',
|
|
'news', 'articles', 'posts', 'page', 'pages',
|
|
]
|
|
|
|
# Check if message contains access intent
|
|
has_access_intent = any(kw in message_lower for kw in access_keywords)
|
|
|
|
# Also trigger if URL is directly mentioned with a question
|
|
has_question = '?' in message or any(qw in message_lower for qw in ['what', 'how', 'who', 'where', 'when', 'why'])
|
|
|
|
return (has_access_intent or has_question) and len(urls) > 0
|
|
|
|
|
|
async def download_website_if_needed(user_message: str) -> dict[str, Any]:
|
|
"""
|
|
Download website if user is asking about one.
|
|
Returns download info if successful.
|
|
"""
|
|
urls = extract_urls_from_message(user_message)
|
|
|
|
if not should_download_website(user_message, urls):
|
|
return {"downloaded": False, "reason": "No website access intent detected"}
|
|
|
|
if not state.rag_system:
|
|
return {"downloaded": False, "reason": "RAG system not initialized"}
|
|
|
|
for url in urls:
|
|
try:
|
|
# Check if site is already downloaded
|
|
site_info = state.rag_system.get_site_info(url)
|
|
if site_info:
|
|
log.info(f"Site already downloaded: {url} ({site_info.get('chunks_ingested', 0)} chunks)")
|
|
return {
|
|
"downloaded": True,
|
|
"url": url,
|
|
"chunks": site_info.get("chunks_ingested", 0),
|
|
"pages": site_info.get("pages_downloaded", 0),
|
|
"local_path": site_info.get("local_path"),
|
|
"cached": True,
|
|
}
|
|
|
|
log.info(f"Auto-downloading website: {url}")
|
|
result = await state.rag_system.download_and_ingest_website(
|
|
url=url,
|
|
max_pages=30,
|
|
)
|
|
|
|
if result.get("success"):
|
|
log.info(f"Successfully downloaded {url}: {result.get('total_chunks', 0)} chunks")
|
|
return {
|
|
"downloaded": True,
|
|
"url": url,
|
|
"chunks": result.get("total_chunks", 0),
|
|
"pages": result.get("pages_processed", 0),
|
|
"local_path": result.get("local_path"),
|
|
}
|
|
except Exception as e:
|
|
log.warning(f"Failed to download {url}: {e}")
|
|
|
|
return {"downloaded": False, "reason": "All download attempts failed"}
|
|
|
|
|
|
# =============================================================================
|
|
# Chat Completion Logic
|
|
# =============================================================================
|
|
|
|
async def _run_all_tools(user_message: str) -> list[dict]:
|
|
"""Run ALL tools in parallel. No LLM involved.
|
|
|
|
- Tools with no required args: run with defaults.
|
|
- Tools with required args: use the user message as the query argument.
|
|
- Each tool gets a timeout so slow ones don't block.
|
|
"""
|
|
if not state.tool_manager:
|
|
return []
|
|
|
|
async def _run_one(name: str, kwargs: dict):
|
|
try:
|
|
result = await asyncio.wait_for(
|
|
asyncio.to_thread(state.tool_manager.execute_tool, name, kwargs),
|
|
timeout=30,
|
|
)
|
|
return {"name": name, "success": True, "result": result}
|
|
except asyncio.TimeoutError:
|
|
return {"name": name, "success": False, "error": "timeout"}
|
|
except Exception as e:
|
|
return {"name": name, "success": False, "error": str(e)}
|
|
|
|
tasks = []
|
|
for name, schema in state.tool_manager._schemas.items():
|
|
func_schema = schema.get("function", {})
|
|
params = func_schema.get("parameters", {})
|
|
required = set(params.get("required", []))
|
|
props = params.get("properties", {})
|
|
|
|
# Build kwargs: defaults from schema, then fill required from user message
|
|
kwargs = {}
|
|
for pname, pinfo in props.items():
|
|
if "default" in pinfo:
|
|
kwargs[pname] = pinfo["default"]
|
|
|
|
for pname in required:
|
|
if pname not in kwargs:
|
|
# Heuristic: use user_message for common query-like params
|
|
if pname in ("query", "q", "search", "search_query", "topic", "title", "question"):
|
|
kwargs[pname] = user_message
|
|
# Use user_message for specific ID fields that look queryable
|
|
elif pname in ("disease",):
|
|
kwargs[pname] = user_message
|
|
# Skip tools that need specific args we can't guess (symbol, pmid, paper_id, url, etc.)
|
|
else:
|
|
kwargs = None
|
|
break
|
|
|
|
if kwargs is not None:
|
|
tasks.append(_run_one(name, kwargs))
|
|
else:
|
|
log.debug(f"Skipping tool {name}: can't fill required param from user message")
|
|
|
|
log.info(f"Running {len(tasks)} tools in parallel...")
|
|
results = await asyncio.gather(*tasks)
|
|
successes = [r for r in results if r["success"]]
|
|
log.info(f"Tool execution complete: {len(successes)}/{len(results)} succeeded")
|
|
return results
|
|
|
|
|
|
def _build_tool_results_text(tool_results: list[dict]) -> str:
|
|
"""Build a text block of all tool results for the system prompt."""
|
|
if not tool_results:
|
|
return ""
|
|
|
|
parts = []
|
|
for tr in tool_results:
|
|
name = tr["name"]
|
|
if tr["success"]:
|
|
result_data = tr.get("result", {})
|
|
# Truncate large results to keep prompt manageable
|
|
result_str = json.dumps(result_data, ensure_ascii=False)
|
|
if len(result_str) > 3000:
|
|
result_str = result_str[:3000] + '..." [TRUNCATED]'
|
|
parts.append(f"### {name}\n{result_str}")
|
|
else:
|
|
parts.append(f"### {name}\n[ERROR: {tr.get('error', 'unknown')}]")
|
|
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse:
|
|
"""Process a non-streaming chat completion request."""
|
|
log.info(f"=== Starting complete_chat for request {request_id} ===")
|
|
messages = request.messages
|
|
|
|
# Extract the last user message
|
|
user_message = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_message = msg.content
|
|
break
|
|
|
|
if not user_message:
|
|
raise HTTPException(status_code=400, detail="No user message found")
|
|
|
|
log.info(f"User message: {user_message[:100]}...")
|
|
|
|
# Step 1: Download website if user is asking about one
|
|
download_info = await download_website_if_needed(user_message)
|
|
if download_info.get("downloaded"):
|
|
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
|
|
|
# Step 2: RAG Retrieval
|
|
context = ""
|
|
sources = []
|
|
if state.rag_system:
|
|
try:
|
|
rag_result = await state.rag_system.query(
|
|
query=user_message,
|
|
top_k=config.TOP_K_RESULTS,
|
|
)
|
|
context = rag_result.get("context", "")
|
|
sources = rag_result.get("sources", [])
|
|
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
|
except Exception as e:
|
|
log.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 3: Run ALL tools in parallel (no LLM needed)
|
|
tool_results = []
|
|
if state.tool_manager and config.ENABLE_TOOLS:
|
|
tool_results = await _run_all_tools(user_message)
|
|
|
|
# Step 4: Build system prompt with tool results as context
|
|
tool_results_text = _build_tool_results_text(tool_results)
|
|
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info, tool_results_text)
|
|
|
|
# Step 5: ONE LLM call
|
|
log.info(f"Calling LLM (single call) for request {request_id}")
|
|
response_content = await call_llm(
|
|
enhanced_messages,
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
)
|
|
log.info(f"=== Completed complete_chat for request {request_id} ===")
|
|
|
|
return ChatCompletionResponse(
|
|
id=request_id,
|
|
model=config.MODEL_NAME,
|
|
choices=[
|
|
ChatCompletionChoice(
|
|
message=ChatMessage(
|
|
role="assistant",
|
|
content=response_content,
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
usage=ChatCompletionUsage(
|
|
prompt_tokens=len(str(enhanced_messages)) // 4,
|
|
completion_tokens=len(response_content) // 4,
|
|
),
|
|
)
|
|
|
|
|
|
async def stream_chat_completion(
|
|
request: ChatCompletionRequest,
|
|
request_id: str,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream a chat completion response."""
|
|
messages = request.messages
|
|
|
|
# Extract the last user message
|
|
user_message = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_message = msg.content
|
|
break
|
|
|
|
if not user_message:
|
|
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
|
|
return
|
|
|
|
# Step 1: Download website if user is asking about one
|
|
download_info = await download_website_if_needed(user_message)
|
|
if download_info.get("downloaded"):
|
|
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
|
|
|
# Step 2: RAG Retrieval
|
|
context = ""
|
|
sources = []
|
|
if state.rag_system:
|
|
try:
|
|
rag_result = await state.rag_system.query(
|
|
query=user_message,
|
|
top_k=config.TOP_K_RESULTS,
|
|
)
|
|
context = rag_result.get("context", "")
|
|
sources = rag_result.get("sources", [])
|
|
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
|
except Exception as e:
|
|
log.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 3: Run ALL tools in parallel (no LLM needed)
|
|
tool_results = []
|
|
if state.tool_manager and config.ENABLE_TOOLS:
|
|
tool_results = await _run_all_tools(user_message)
|
|
|
|
# Step 4: Build system prompt with tool results as context
|
|
tool_results_text = _build_tool_results_text(tool_results)
|
|
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info, tool_results_text)
|
|
|
|
# Step 5: ONE LLM call (stream the result)
|
|
created = int(time.time())
|
|
|
|
try:
|
|
if state.llm_client:
|
|
stream = await state.llm_client.chat.completions.create(
|
|
model=config.UPSTREAM_MODEL,
|
|
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
|
|
temperature=request.temperature or 0.7,
|
|
max_tokens=request.max_tokens or 4096,
|
|
stream=True,
|
|
)
|
|
|
|
async for chunk in stream:
|
|
if chunk.choices and chunk.choices[0].delta.content:
|
|
content = chunk.choices[0].delta.content
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {'content': content},
|
|
'finish_reason': None
|
|
}]
|
|
})}\n\n"
|
|
|
|
# Final chunk
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {},
|
|
'finish_reason': 'stop'
|
|
}]
|
|
})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
else:
|
|
# Mock streaming response
|
|
mock_response = f"I understand you're asking about: {user_message}\n\n"
|
|
if tool_results_text:
|
|
mock_response += f"I gathered data from {len(tool_results)} tools.\n\n"
|
|
if context:
|
|
mock_response += f"Knowledge base context:\n{context[:1000]}...\n\n"
|
|
mock_response += "\n\n[Demo mode - configure OPENROUTER_API_KEY for full LLM responses]"
|
|
|
|
for char in mock_response:
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {'content': char},
|
|
'finish_reason': None
|
|
}]
|
|
})}\n\n"
|
|
await asyncio.sleep(0.01)
|
|
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {},
|
|
'finish_reason': 'stop'
|
|
}]
|
|
})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
except Exception as e:
|
|
log.exception("Streaming failed")
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
|
|
def build_enhanced_messages(
|
|
messages: list[ChatMessage],
|
|
context: str,
|
|
sources: list[str],
|
|
download_info: dict = None,
|
|
tool_results_text: str = "",
|
|
) -> list[ChatMessage]:
|
|
"""Build enhanced messages with RAG context and tool results in system prompt."""
|
|
enhanced = []
|
|
|
|
system_content = "You are a helpful AI assistant with access to real-time data.\n"
|
|
|
|
if tool_results_text:
|
|
system_content += f"\n## REAL-TIME DATA (from tools)\n{tool_results_text}\n"
|
|
|
|
if download_info and download_info.get("downloaded"):
|
|
system_content += f"\n--- Website Access ---\n"
|
|
system_content += f"Downloaded website: {download_info.get('url')}\n"
|
|
system_content += f"Pages: {download_info.get('pages')}, Chunks: {download_info.get('chunks')}\n"
|
|
|
|
if context:
|
|
system_content += f"\n## Relevant Context from Knowledge Base\n{context}\n"
|
|
if sources:
|
|
system_content += f"\nSources:\n" + "\n".join(f"- {s}" for s in sources[:10])
|
|
|
|
system_content += "\n\n## INSTRUCTIONS\nUse the data above to answer the user's question. Be concise and factual."
|
|
|
|
enhanced.append(ChatMessage(role="system", content=system_content))
|
|
|
|
# Add conversation history (excluding old system messages)
|
|
for msg in messages:
|
|
if msg.role != "system":
|
|
enhanced.append(msg)
|
|
|
|
return enhanced
|
|
|
|
|
|
async def call_llm(
|
|
messages: list[ChatMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 4096,
|
|
) -> str:
|
|
"""Single LLM call. No tool logic, just prompt in → response out."""
|
|
if not state.llm_client:
|
|
user_msg = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_msg = msg.content
|
|
break
|
|
return f"Demo mode. Your question: {user_msg[:100]}... Configure OPENROUTER_API_KEY for full functionality."
|
|
|
|
try:
|
|
messages_dict = [{"role": m.role, "content": m.content} for m in messages if m.content]
|
|
response = await state.llm_client.chat.completions.create(
|
|
model=config.UPSTREAM_MODEL,
|
|
messages=messages_dict,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
if not response.choices:
|
|
log.warning("No choices in LLM response")
|
|
return "I apologize, but I couldn't generate a response."
|
|
content = response.choices[0].message.content or ""
|
|
return content or "I apologize, but I couldn't generate a response."
|
|
except Exception as e:
|
|
log.error(f"LLM call failed: {e}")
|
|
return f"I encountered an error: {str(e)}"
|
|
|
|
|
|
# =============================================================================
|
|
# Document Management Endpoints
|
|
# =============================================================================
|
|
|
|
@app.post("/v1/documents/upload")
|
|
async def upload_document(request: Request):
|
|
"""Upload a document to the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
form = await request.form()
|
|
file = form.get("file")
|
|
if not file:
|
|
raise HTTPException(status_code=400, detail="No file provided")
|
|
|
|
content = await file.read()
|
|
filename = file.filename or "unknown"
|
|
|
|
# Process and store document
|
|
result = await state.rag_system.add_document(
|
|
content=content,
|
|
filename=filename,
|
|
)
|
|
|
|
return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)}
|
|
|
|
except Exception as e:
|
|
log.exception("Document upload failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
class WebsiteDownloadRequest(BaseModel):
|
|
"""Request model for website download."""
|
|
url: str
|
|
max_pages: int = 50
|
|
threads: int = 6
|
|
download_external_assets: bool = False
|
|
external_domains: Optional[list[str]] = None
|
|
|
|
|
|
@app.post("/v1/documents/website")
|
|
async def download_website(request: WebsiteDownloadRequest):
|
|
"""
|
|
Download a website and ingest it into the knowledge base.
|
|
|
|
This is the PRIMARY way to add content to the RAG system.
|
|
Uses the website_downloader_tool to download and process websites.
|
|
"""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
log.info(f"Downloading website: {request.url}")
|
|
|
|
result = await state.rag_system.download_and_ingest_website(
|
|
url=request.url,
|
|
max_pages=request.max_pages,
|
|
threads=request.threads,
|
|
download_external_assets=request.download_external_assets,
|
|
external_domains=request.external_domains,
|
|
)
|
|
|
|
if result.get("success"):
|
|
return {
|
|
"success": True,
|
|
"message": f"Website downloaded and ingested: {request.url}",
|
|
"url": request.url,
|
|
"local_path": result.get("local_path"),
|
|
"pages_processed": result.get("pages_processed", 0),
|
|
"total_chunks": result.get("total_chunks", 0),
|
|
}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=result.get("message", "Website download failed")
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Website download failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/v1/documents/url")
|
|
async def add_document_from_url(request: dict):
|
|
"""Add a document from URL to the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
url = request.get("url")
|
|
if not url:
|
|
raise HTTPException(status_code=400, detail="No URL provided")
|
|
|
|
try:
|
|
result = await state.rag_system.add_document_from_url(url)
|
|
return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)}
|
|
except Exception as e:
|
|
log.exception("URL document addition failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents")
|
|
async def list_documents():
|
|
"""List documents in the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
docs = await state.rag_system.list_documents()
|
|
return {"documents": docs}
|
|
except Exception as e:
|
|
log.exception("Document listing failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents/sites")
|
|
async def list_downloaded_sites():
|
|
"""List all downloaded websites in the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
sites = await state.rag_system.list_downloaded_sites()
|
|
return {"sites": sites}
|
|
except Exception as e:
|
|
log.exception("Site listing failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents/sites/{url:path}")
|
|
async def get_site_info(url: str):
|
|
"""Get information about a specific downloaded site."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
from urllib.parse import unquote
|
|
decoded_url = unquote(url)
|
|
|
|
if not decoded_url.startswith(("http://", "https://")):
|
|
decoded_url = "https://" + decoded_url
|
|
|
|
site_info = state.rag_system.get_site_info(decoded_url)
|
|
if site_info:
|
|
return {"site": site_info}
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Site not found")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Site info retrieval failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/documents/{doc_id}")
|
|
async def delete_document(doc_id: str):
|
|
"""Delete a document from the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
await state.rag_system.delete_document(doc_id)
|
|
return {"success": True, "message": f"Document {doc_id} deleted"}
|
|
except Exception as e:
|
|
log.exception("Document deletion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/documents/sites/{url:path}")
|
|
async def delete_site(url: str):
|
|
"""Delete a downloaded website and all its content from the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
from urllib.parse import unquote
|
|
decoded_url = unquote(url)
|
|
|
|
if not decoded_url.startswith(("http://", "https://")):
|
|
decoded_url = "https://" + decoded_url
|
|
|
|
result = await state.rag_system.delete_site(decoded_url)
|
|
|
|
if result.get("success"):
|
|
return {
|
|
"success": True,
|
|
"message": f"Site {decoded_url} deleted",
|
|
"deleted_chunks": result.get("deleted_chunks", 0),
|
|
"deleted_path": result.get("deleted_path"),
|
|
}
|
|
else:
|
|
raise HTTPException(status_code=404, detail=result.get("message", "Site not found"))
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Site deletion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# Health and Status Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {
|
|
"status": "healthy",
|
|
"uptime": time.time() - state.startup_time,
|
|
"rag_enabled": state.rag_system is not None,
|
|
"tools_enabled": state.tool_manager is not None,
|
|
"llm_connected": state.llm_client is not None,
|
|
}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint with API info."""
|
|
return {
|
|
"name": "DocRAG API",
|
|
"version": "1.0.0",
|
|
"description": "OpenAI-compatible RAG server powered by OpenRouter. Auto-downloads and analyzes websites when users ask about them.",
|
|
"endpoints": {
|
|
"chat": "/v1/chat/completions",
|
|
"models": "/v1/models",
|
|
"documents": "/v1/documents",
|
|
"download_website": "/v1/documents/website",
|
|
"list_sites": "/v1/documents/sites",
|
|
"health": "/health",
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Main Entry Point
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=config.HOST,
|
|
port=config.PORT,
|
|
reload=config.DEBUG,
|
|
)
|