docrag/main.py
Z User 57228625fc Fix tool calling: switch to native OpenAI tools parameter
Problems fixed:
- 'Mega tool call': LLM outputting multiple tool calls that got bundled
  into one. Now uses native OpenAI tools parameter which handles multiple
  tool calls properly via message.tool_calls array.
- 'Returning nothing': _clean_tool_syntax was too aggressive, stripping
  the entire response. Now only strips code-fence-wrapped blocks.
- Tool results were appended to system message growing it unboundedly;
  now uses proper 'tool' role messages in conversation history.

Key changes:
- generate_response: passes tools/tool_choice to OpenAI API (native
  tool calling), with retry without tool_choice for unsupported models
- generate_response: handles multiple tool_calls per response natively
- generate_response: uses proper 'tool' role for results instead of
  appending to system message
- _parse_tool_calls (was _parse_tool_call): now returns a list, supports
  multiple tool calls, used as fallback for models without native tools
- _clean_tool_syntax: much less aggressive, only strips code-fence
  blocks, no longer removes bare JSON (was eating valid responses)
- System prompt: removed JSON format instructions (native tools handles
  format), simplified rules
2026-03-29 17:57:26 +00:00

1298 lines
47 KiB
Python
Executable File

#!/usr/bin/env python3
"""
DocRAG - OpenAI-Compatible RAG Server
This application presents itself as a standard OpenAI API server that can be used
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
1. Detects URLs in user messages and auto-downloads websites
2. Ingests website content into the RAG knowledge base
3. Retrieves relevant context from the knowledge base
4. Passes the enriched context to OpenRouter for response generation
The user sees a normal chat experience, but the system is actually doing
sophisticated RAG operations in the background.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import sys
import time
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, AsyncIterator, Optional
# Load environment variables from .env file (look in script directory)
from dotenv import load_dotenv
SCRIPT_DIR = Path(__file__).parent.resolve()
ENV_FILE = SCRIPT_DIR / ".env"
load_dotenv(ENV_FILE)
# Also try loading from current working directory
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)
# FastAPI imports
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field
# Import RAG components
from rag import RAGSystem, get_rag_system
from rag.document_processor import DocumentProcessor
# Import tools
from tools import ToolManager, get_tool_manager
# Import OpenAI client for OpenRouter
from openai import AsyncOpenAI
# =============================================================================
# Configuration
# =============================================================================
class Config:
"""Application configuration from environment variables."""
# Server settings
HOST: str = os.getenv("HOST", "0.0.0.0")
PORT: int = int(os.getenv("PORT", "8000"))
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
# Model settings
MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG")
UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "openrouter/free")
# OpenRouter API settings
OPENROUTER_API_KEY: str = os.getenv("OPENROUTER_API_KEY", "")
OPENROUTER_BASE_URL: str = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
# RAG settings
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors")
DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents")
CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000"))
CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200"))
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
# Tool settings
ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true"
MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "5"))
config = Config()
# =============================================================================
# OpenAI-Compatible Models
# =============================================================================
class ChatMessage(BaseModel):
"""OpenAI chat message format."""
role: str
content: Optional[str] = None
name: Optional[str] = None
tool_calls: Optional[list[dict]] = None
tool_call_id: Optional[str] = None
class ChatCompletionRequest(BaseModel):
"""OpenAI chat completion request format."""
model: str = "DocRAG"
messages: list[ChatMessage]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
stop: Optional[list[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
tools: Optional[list[dict]] = None
tool_choice: Optional[str | dict] = None
class ChatCompletionChoice(BaseModel):
"""OpenAI chat completion choice."""
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionUsage(BaseModel):
"""Token usage statistics."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class ChatCompletionResponse(BaseModel):
"""OpenAI chat completion response format."""
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str = config.MODEL_NAME
choices: list[ChatCompletionChoice]
usage: ChatCompletionUsage = ChatCompletionUsage()
class ModelInfo(BaseModel):
"""OpenAI model info format."""
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "organization"
class ModelList(BaseModel):
"""OpenAI model list format."""
object: str = "list"
data: list[ModelInfo]
# =============================================================================
# Application State
# =============================================================================
class AppState:
"""Global application state."""
rag_system: Optional[RAGSystem] = None
tool_manager: Optional[ToolManager] = None
llm_client: Optional[AsyncOpenAI] = None
startup_time: float = time.time()
state = AppState()
# =============================================================================
# Lifespan Management
# =============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan - startup and shutdown."""
log.info("Starting DocRAG server...")
# Initialize RAG system
try:
log.info("Initializing RAG system...")
state.rag_system = await get_rag_system(
embedding_model=config.EMBEDDING_MODEL,
vector_store_path=config.VECTOR_STORE_PATH,
documents_path=config.DOCUMENTS_PATH,
chunk_size=config.CHUNK_SIZE,
chunk_overlap=config.CHUNK_OVERLAP,
)
log.info("RAG system initialized successfully")
except Exception as e:
log.warning(f"RAG system initialization deferred: {e}")
state.rag_system = None
# Initialize tool manager
try:
log.info("Initializing tool manager...")
state.tool_manager = get_tool_manager()
log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}")
except Exception as e:
log.warning(f"Tool manager initialization failed: {e}")
state.tool_manager = None
# Initialize OpenRouter client for upstream LLM
# Debug: Show .env file status
log.info(f"Looking for .env file at: {ENV_FILE}")
log.info(f".env file exists: {ENV_FILE.exists()}")
api_key = config.OPENROUTER_API_KEY
if api_key:
key_preview = f"{api_key[:8]}...{api_key[-4:]}" if len(api_key) > 12 else "***"
log.info(f"OPENROUTER_API_KEY found: {key_preview}")
else:
log.warning("OPENROUTER_API_KEY not found in environment!")
log.warning(f"Checked .env at: {ENV_FILE}")
log.warning("Set OPENROUTER_API_KEY in .env file or as environment variable")
try:
if api_key:
log.info("Initializing OpenRouter client...")
# Create custom httpx client to avoid proxy issues
import httpx
http_client = httpx.AsyncClient(
timeout=httpx.Timeout(60.0, connect=10.0),
)
state.llm_client = AsyncOpenAI(
api_key=api_key,
base_url=config.OPENROUTER_BASE_URL,
http_client=http_client,
)
log.info(f"OpenRouter client initialized successfully (model: {config.UPSTREAM_MODEL})")
else:
log.warning("No OPENROUTER_API_KEY provided - using mock responses")
state.llm_client = None
except Exception as e:
log.error(f"Failed to initialize OpenRouter client: {e}")
state.llm_client = None
log.info(f"DocRAG server started on {config.HOST}:{config.PORT}")
yield
# Cleanup
log.info("Shutting down DocRAG server...")
if state.rag_system:
await state.rag_system.close()
log.info("DocRAG server stopped")
# =============================================================================
# FastAPI Application
# =============================================================================
app = FastAPI(
title="DocRAG API",
description="OpenAI-compatible RAG server powered by OpenRouter",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware for Open WebUI compatibility
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# OpenAI-Compatible Endpoints
# =============================================================================
@app.get("/v1/models")
@app.get("/models")
async def list_models():
"""List available models (OpenAI-compatible)."""
return ModelList(
data=[
ModelInfo(id=config.MODEL_NAME, owned_by="docrag"),
ModelInfo(id="DocRAG", owned_by="docrag"),
]
)
@app.get("/v1/models/{model_id}")
@app.get("/models/{model_id}")
async def get_model(model_id: str):
"""Get model information (OpenAI-compatible)."""
if model_id not in [config.MODEL_NAME, "DocRAG"]:
raise HTTPException(status_code=404, detail="Model not found")
return ModelInfo(id=model_id, owned_by="docrag")
@app.post("/v1/chat/completions")
@app.post("/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""Handle chat completions (OpenAI-compatible)."""
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
try:
if request.stream:
return StreamingResponse(
stream_chat_completion(request, request_id),
media_type="text/event-stream",
)
else:
return await complete_chat(request, request_id)
except Exception as e:
log.exception("Chat completion failed")
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# URL Detection and Website Download
# =============================================================================
def extract_urls_from_message(message: str) -> list[str]:
"""Extract URLs from a message, including domains without scheme."""
# Match full URLs
url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+'
full_urls = re.findall(url_pattern, message)
# Match domain names (with or without www)
domain_pattern = r'(?:^|[^\w/-])(?:www\.)?([a-zA-Z0-9][-a-zA-Z0-9]*\.[a-zA-Z]{2,}(?:\.[a-zA-Z]{2,})?)(?:/[^\s]*)?(?=[^\w/-]|$)'
domains = re.findall(domain_pattern, message)
urls = list(full_urls)
for domain in domains:
# Check if it's a valid domain (not just a word)
if '.' in domain and len(domain) > 4:
normalized = f"https://{domain}" if not domain.startswith(('http://', 'https://')) else domain
if normalized not in urls:
urls.append(normalized)
return urls
def should_download_website(message: str, urls: list[str]) -> bool:
"""Determine if the user wants to access content from a website."""
if not urls:
return False
message_lower = message.lower()
# Skip if this is an automated Open WebUI task (not a real user request)
automated_task_indicators = [
'### task:', 'generate a concise', 'generate 1-3',
'suggest 3-5 relevant follow-up', 'summarizing the chat history',
'categorizing the main themes', 'follow-up questions or prompts',
]
if any(indicator in message_lower for indicator in automated_task_indicators):
log.info("Skipping website download - appears to be an automated task")
return False
# Keywords indicating user wants website content
access_keywords = [
'go to', 'visit', 'check', 'look at', 'browse', 'open',
'what is on', 'tell me about', 'give me', 'show me', 'get',
'headlines', 'content', 'information from', 'from the', 'from',
'on the website', 'on the site', 'the website', 'the site', 'website',
'summarize', 'extract', 'read', 'analyze', 'find', 'search',
'what does', 'what\'s on', 'what is', 'tell me', 'about',
'news', 'articles', 'posts', 'page', 'pages',
]
# Check if message contains access intent
has_access_intent = any(kw in message_lower for kw in access_keywords)
# Also trigger if URL is directly mentioned with a question
has_question = '?' in message or any(qw in message_lower for qw in ['what', 'how', 'who', 'where', 'when', 'why'])
return (has_access_intent or has_question) and len(urls) > 0
async def download_website_if_needed(user_message: str) -> dict[str, Any]:
"""
Download website if user is asking about one.
Returns download info if successful.
"""
urls = extract_urls_from_message(user_message)
if not should_download_website(user_message, urls):
return {"downloaded": False, "reason": "No website access intent detected"}
if not state.rag_system:
return {"downloaded": False, "reason": "RAG system not initialized"}
for url in urls:
try:
# Check if site is already downloaded
site_info = state.rag_system.get_site_info(url)
if site_info:
log.info(f"Site already downloaded: {url} ({site_info.get('chunks_ingested', 0)} chunks)")
return {
"downloaded": True,
"url": url,
"chunks": site_info.get("chunks_ingested", 0),
"pages": site_info.get("pages_downloaded", 0),
"local_path": site_info.get("local_path"),
"cached": True,
}
log.info(f"Auto-downloading website: {url}")
result = await state.rag_system.download_and_ingest_website(
url=url,
max_pages=30,
)
if result.get("success"):
log.info(f"Successfully downloaded {url}: {result.get('total_chunks', 0)} chunks")
return {
"downloaded": True,
"url": url,
"chunks": result.get("total_chunks", 0),
"pages": result.get("pages_processed", 0),
"local_path": result.get("local_path"),
}
except Exception as e:
log.warning(f"Failed to download {url}: {e}")
return {"downloaded": False, "reason": "All download attempts failed"}
# =============================================================================
# Chat Completion Logic
# =============================================================================
async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse:
"""Process a non-streaming chat completion request."""
log.info(f"=== Starting complete_chat for request {request_id} ===")
messages = request.messages
# Extract the last user message
user_message = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
user_message = msg.content
break
if not user_message:
raise HTTPException(status_code=400, detail="No user message found")
log.info(f"User message: {user_message[:100]}...")
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
download_info = await download_website_if_needed(user_message)
if download_info.get("downloaded"):
log.info(f"Website auto-downloaded: {download_info.get('url')}")
# Step 2: RAG Retrieval (now includes newly downloaded content)
context = ""
sources = []
if state.rag_system:
try:
rag_result = await state.rag_system.query(
query=user_message,
top_k=config.TOP_K_RESULTS,
)
context = rag_result.get("context", "")
sources = rag_result.get("sources", [])
log.info(f"RAG retrieved {len(sources)} relevant documents")
except Exception as e:
log.warning(f"RAG retrieval failed: {e}")
# Step 3: Build enhanced prompt with context
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
# Step 4: Generate response with upstream LLM
log.info(f"Calling generate_response for request {request_id}")
response_content = await generate_response(
enhanced_messages,
temperature=request.temperature,
max_tokens=request.max_tokens,
)
log.info(f"=== Completed complete_chat for request {request_id} ===")
# Step 5: Build and return response
return ChatCompletionResponse(
id=request_id,
model=config.MODEL_NAME,
choices=[
ChatCompletionChoice(
message=ChatMessage(
role="assistant",
content=response_content,
),
finish_reason="stop",
)
],
usage=ChatCompletionUsage(
prompt_tokens=len(str(enhanced_messages)) // 4,
completion_tokens=len(response_content) // 4,
),
)
async def stream_chat_completion(
request: ChatCompletionRequest,
request_id: str,
) -> AsyncIterator[str]:
"""Stream a chat completion response."""
messages = request.messages
# Extract the last user message
user_message = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
user_message = msg.content
break
if not user_message:
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
return
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
download_info = await download_website_if_needed(user_message)
if download_info.get("downloaded"):
log.info(f"Website auto-downloaded: {download_info.get('url')}")
# Step 2: RAG Retrieval (now includes newly downloaded content)
context = ""
sources = []
if state.rag_system:
try:
rag_result = await state.rag_system.query(
query=user_message,
top_k=config.TOP_K_RESULTS,
)
context = rag_result.get("context", "")
sources = rag_result.get("sources", [])
log.info(f"RAG retrieved {len(sources)} relevant documents")
except Exception as e:
log.warning(f"RAG retrieval failed: {e}")
# Step 3: Build enhanced prompt with context
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
# Step 4: Stream response from upstream LLM
created = int(time.time())
try:
if state.llm_client:
# For streaming with tools, we need to handle tool calls first
# Then stream the final response
if state.tool_manager and config.ENABLE_TOOLS:
# Use non-streaming for tool calls, then stream the result
response_content = await generate_response(
enhanced_messages,
temperature=request.temperature or 0.7,
max_tokens=request.max_tokens or 4096,
)
# Stream the final response as a single chunk
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': response_content},
'finish_reason': None
}]
})}\n\n"
else:
# No tools - use regular streaming
stream = await state.llm_client.chat.completions.create(
model=config.UPSTREAM_MODEL,
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
temperature=request.temperature or 0.7,
max_tokens=request.max_tokens or 4096,
stream=True,
)
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': content},
'finish_reason': None
}]
})}\n\n"
# Send final chunk
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {},
'finish_reason': 'stop'
}]
})}\n\n"
yield "data: [DONE]\n\n"
else:
# Mock streaming response for testing
mock_response = f"I understand you're asking about: {user_message}\n\n"
if download_info.get("downloaded"):
mock_response += f"I have downloaded and analyzed {download_info.get('url')}.\n"
mock_response += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} chunks.\n\n"
if context:
mock_response += f"Based on my knowledge base, here's what I found:\n\n{context[:1000]}...\n\n"
mock_response += "\n\n[Demo mode - configure OPENROUTER_API_KEY for full LLM responses]"
for char in mock_response:
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': char},
'finish_reason': None
}]
})}\n\n"
await asyncio.sleep(0.01)
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {},
'finish_reason': 'stop'
}]
})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
log.exception("Streaming failed")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
def build_enhanced_messages(
messages: list[ChatMessage],
context: str,
sources: list[str],
download_info: dict = None,
tool_results: list[dict] = None,
) -> list[ChatMessage]:
"""Build enhanced messages with RAG context."""
enhanced = []
# Build tool descriptions for context
tool_descriptions = _build_tool_descriptions()
# Add system message with RAG context and tool instructions
system_content = """You are a helpful AI assistant with access to real-time data through various tools.
## AVAILABLE TOOLS
You have access to tools for getting real-time data. Use them whenever you need current information.
## IMPORTANT RULES
1. ALWAYS use your available tools to get CURRENT data - do NOT say you cannot access real-time data
2. When asked about stocks, crypto, weather, or news, you MUST use the appropriate tool
3. After receiving tool results, provide a helpful, natural-language response based on the data
4. Be concise and factual - report exact data from tools
"""
if download_info and download_info.get("downloaded"):
system_content += f"\n\n--- Website Access ---\n"
system_content += f"Downloaded website: {download_info.get('url')}\n"
system_content += f"Pages: {download_info.get('pages')}, Chunks: {download_info.get('chunks')}\n"
if context:
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
if sources:
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources[:10])
# Add previous tool results if any
if tool_results:
system_content += "\n\n--- PREVIOUS TOOL RESULTS ---\n"
for tr in tool_results:
system_content += f"\nTool: {tr['name']}\nResult: {json.dumps(tr['result'], indent=2)}\n"
enhanced.append(ChatMessage(role="system", content=system_content))
# Add conversation history (excluding old system messages)
for msg in messages:
if msg.role != "system":
enhanced.append(msg)
return enhanced
def _build_tool_descriptions() -> str:
"""Build a concise description of all available tools for the system prompt."""
if not state.tool_manager:
return "No tools available."
descriptions = []
for name, schema in state.tool_manager._schemas.items():
func = schema.get("function", {})
desc = func.get("description", "")[:100] # Truncate long descriptions
params = func.get("parameters", {}).get("properties", {})
required = func.get("parameters", {}).get("required", [])
# Build param list
param_strs = []
for pname, pinfo in params.items():
ptype = pinfo.get("type", "any")
preq = " (required)" if pname in required else ""
param_strs.append(f"{pname}: {ptype}{preq}")
params_str = ", ".join(param_strs) if param_strs else "none"
descriptions.append(f"- {name}({params_str}): {desc}")
return "\n".join(descriptions)
def _parse_tool_calls(content: str) -> list[dict]:
"""Parse tool calls from LLM response content (fallback for models without native tool support).
Returns a list of tool call dicts, each with 'name' and 'arguments' keys.
Supports multiple tool calls in a single response.
"""
tool_calls = []
def _extract_all_json_objects(text: str, start_key: str) -> list[dict]:
"""Extract ALL JSON objects containing start_key using brace counting."""
results = []
search_start = 0
while True:
idx = text.find(start_key, search_start)
if idx == -1:
break
# Walk backwards to find the opening { of this object
depth = 0
obj_start = -1
for i in range(idx, -1, -1):
if text[i] == '}':
depth += 1
elif text[i] == '{':
if depth == 0:
obj_start = i
break
depth -= 1
if obj_start == -1:
break
# Walk forwards to find the matching closing }
depth = 0
obj_end = -1
for i in range(obj_start, len(text)):
if text[i] == '{':
depth += 1
elif text[i] == '}':
depth -= 1
if depth == 0:
obj_end = i + 1
break
if obj_end == -1:
break
try:
obj = json.loads(text[obj_start:obj_end])
if obj and isinstance(obj, dict):
results.append(obj)
except json.JSONDecodeError:
pass
# Move past this object to find the next one
search_start = obj_end
return results
# Pattern 1: code fence blocks containing tool_call
fence_matches = re.findall(r'```\w*\s*(.*?)\s*```', content, re.DOTALL)
for block_text in fence_matches:
if '"tool_call"' in block_text:
objects = _extract_all_json_objects(block_text, '"tool_call"')
for obj in objects:
if "tool_call" in obj:
tc = obj["tool_call"]
if isinstance(tc, dict) and "name" in tc:
tool_calls.append(tc)
# Pattern 2: bare JSON {"tool_call": {...}} outside code fences
# Strip code fences first to avoid double-parsing
stripped = re.sub(r'```\w*\s*.*?\s*```', '', content, flags=re.DOTALL)
if '"tool_call"' in stripped:
objects = _extract_all_json_objects(stripped, '"tool_call"')
for obj in objects:
if "tool_call" in obj:
tc = obj["tool_call"]
if isinstance(tc, dict) and "name" in tc:
# Avoid duplicates
if not any(
existing.get("name") == tc.get("name") and
existing.get("arguments") == tc.get("arguments")
for existing in tool_calls
):
tool_calls.append(tc)
# Pattern 3: [USE: tool_name args] pattern
bracket_matches = re.findall(r'\[USE:\s*(\w+)\s*(?:args:\s*(\{.*?\}))?\s*\]', content, re.DOTALL)
for match in bracket_matches:
name = match[0]
args_str = match[1] or "{}"
try:
args = json.loads(args_str)
except json.JSONDecodeError:
args = {}
tool_calls.append({"name": name, "arguments": args})
return tool_calls
async def generate_response(
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 4096,
) -> str:
"""Generate response using upstream LLM via OpenRouter with native tool calling.
Uses OpenAI-compatible `tools` parameter for reliable tool calling.
Falls back to content-based parsing if the model doesn't support native tools.
"""
if not state.llm_client:
# Mock response for testing
user_msg = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
user_msg = msg.content
break
return f"Demo mode response. Your question: {user_msg[:100]}... Configure OPENROUTER_API_KEY for full functionality."
try:
# Convert messages to dict format
messages_dict = []
for m in messages:
if m.content:
messages_dict.append({"role": m.role, "content": m.content})
# Prepare native tool schemas for OpenAI API
native_tools = None
if state.tool_manager and config.ENABLE_TOOLS:
schemas = state.tool_manager.get_all_schemas()
if schemas:
native_tools = []
for schema in schemas:
if isinstance(schema, dict):
# Ensure correct OpenAI tools format
if schema.get("type") == "function" and "function" in schema:
native_tools.append(schema)
else:
# Wrap bare function schema
native_tools.append({
"type": "function",
"function": schema,
})
else:
log.warning(f"Skipping non-dict tool schema: {schema}")
if native_tools:
log.info(f"Passing {len(native_tools)} tools to LLM API")
else:
log.info("No native tools available, using content-only mode")
# Tool calling loop
max_iterations = config.MAX_TOOL_ITERATIONS
iteration = 0
while iteration < max_iterations:
iteration += 1
log.info(f"LLM call iteration {iteration}")
# Build API call parameters
api_params = {
"model": config.UPSTREAM_MODEL,
"messages": messages_dict,
"temperature": temperature,
"max_tokens": max_tokens,
}
if native_tools:
api_params["tools"] = native_tools
api_params["tool_choice"] = "auto"
# Call LLM (with retry without tool_choice if model doesn't support it)
try:
response = await state.llm_client.chat.completions.create(**api_params)
except Exception as api_err:
err_str = str(api_err).lower()
if "tool_choice" in err_str and native_tools:
log.warning(f"Model doesn't support tool_choice, retrying without it: {api_err}")
del api_params["tool_choice"]
response = await state.llm_client.chat.completions.create(**api_params)
else:
raise
if not response.choices:
log.warning("No choices in response")
return "I apologize, but I couldn't generate a response."
choice = response.choices[0]
message = choice.message
content = message.content or ""
finish_reason = choice.finish_reason or "stop"
log.info(f"LLM response: content_len={len(content)}, finish_reason={finish_reason}")
# --- Handle native tool calls (preferred path) ---
native_tool_calls = getattr(message, 'tool_calls', None)
if native_tool_calls:
log.info(f"Native tool calls detected: {len(native_tool_calls)}")
# Build assistant message with tool_calls for conversation history
assistant_msg = {
"role": "assistant",
"content": content if content else None,
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments or "{}",
},
}
for tc in native_tool_calls
],
}
messages_dict.append(assistant_msg)
# Execute each tool and add result messages
for tc in native_tool_calls:
tool_name = tc.function.name
try:
tool_args = json.loads(tc.function.arguments or "{}")
except json.JSONDecodeError:
log.warning(f"Failed to parse tool arguments for {tool_name}: {tc.function.arguments}")
tool_args = {}
log.info(f"Executing native tool: {tool_name} with args: {tool_args}")
if state.tool_manager:
result = await asyncio.to_thread(
state.tool_manager.execute_tool, tool_name, tool_args
)
else:
result = {"success": False, "error": "No tool manager available"}
log.info(f"Tool {tool_name} result: success={result.get('success', False)}")
# Add tool result using proper 'tool' role
messages_dict.append({
"role": "tool",
"tool_call_id": tc.id,
"content": json.dumps(result),
})
continue
# --- Fallback: parse tool calls from content (for models without native tool support) ---
content_tool_calls = _parse_tool_calls(content)
if content_tool_calls:
log.info(f"Content-based tool calls detected: {len(content_tool_calls)}")
# Add the assistant's raw response to conversation
messages_dict.append({"role": "assistant", "content": content})
for tool_call in content_tool_calls:
tool_name = tool_call.get("name")
tool_args = tool_call.get("arguments", {})
if not isinstance(tool_args, dict):
try:
tool_args = json.loads(tool_args)
except (json.JSONDecodeError, TypeError):
tool_args = {}
log.info(f"Executing content-based tool: {tool_name}")
if state.tool_manager:
result = await asyncio.to_thread(
state.tool_manager.execute_tool, tool_name, tool_args
)
else:
result = {"success": False, "error": "No tool manager available"}
log.info(f"Tool {tool_name} result: success={result.get('success', False)}")
# Feed result back as a user message
messages_dict.append({
"role": "user",
"content": f"--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data.",
})
continue
# --- No tool calls - return the final response ---
# Light cleanup: only strip code-fence-wrapped tool_call blocks
cleaned_content = _clean_tool_syntax(content)
log.info(f"Returning final response (len={len(cleaned_content)}, cleaned={len(cleaned_content) != len(content)})")
return cleaned_content or "I apologize, but I couldn't generate a response."
# Max iterations reached
log.warning(f"Max iterations ({max_iterations}) reached")
return "I reached the maximum number of tool calls. Please try a more specific question."
except Exception as e:
log.error(f"OpenRouter LLM call failed: {e}")
import traceback
log.error(traceback.format_exc())
return f"I encountered an error: {str(e)}"
def _clean_tool_syntax(content: str) -> str:
"""Remove tool call syntax from response if partially included.
Only strips code-fence-wrapped blocks containing tool_call.
Does NOT strip bare JSON to avoid accidentally removing valid content.
"""
# Remove ```json ... ``` blocks containing tool_call
def remove_code_block(m):
block = m.group(0)
inner = m.group(1)
if '"tool_call"' in inner:
return ''
return block
cleaned = re.sub(r'```\w*\s*(.*?)\s*```', remove_code_block, content, flags=re.DOTALL)
return cleaned.strip()
# =============================================================================
# Document Management Endpoints
# =============================================================================
@app.post("/v1/documents/upload")
async def upload_document(request: Request):
"""Upload a document to the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
form = await request.form()
file = form.get("file")
if not file:
raise HTTPException(status_code=400, detail="No file provided")
content = await file.read()
filename = file.filename or "unknown"
# Process and store document
result = await state.rag_system.add_document(
content=content,
filename=filename,
)
return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)}
except Exception as e:
log.exception("Document upload failed")
raise HTTPException(status_code=500, detail=str(e))
class WebsiteDownloadRequest(BaseModel):
"""Request model for website download."""
url: str
max_pages: int = 50
threads: int = 6
download_external_assets: bool = False
external_domains: Optional[list[str]] = None
@app.post("/v1/documents/website")
async def download_website(request: WebsiteDownloadRequest):
"""
Download a website and ingest it into the knowledge base.
This is the PRIMARY way to add content to the RAG system.
Uses the website_downloader_tool to download and process websites.
"""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
log.info(f"Downloading website: {request.url}")
result = await state.rag_system.download_and_ingest_website(
url=request.url,
max_pages=request.max_pages,
threads=request.threads,
download_external_assets=request.download_external_assets,
external_domains=request.external_domains,
)
if result.get("success"):
return {
"success": True,
"message": f"Website downloaded and ingested: {request.url}",
"url": request.url,
"local_path": result.get("local_path"),
"pages_processed": result.get("pages_processed", 0),
"total_chunks": result.get("total_chunks", 0),
}
else:
raise HTTPException(
status_code=500,
detail=result.get("message", "Website download failed")
)
except HTTPException:
raise
except Exception as e:
log.exception("Website download failed")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/documents/url")
async def add_document_from_url(request: dict):
"""Add a document from URL to the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
url = request.get("url")
if not url:
raise HTTPException(status_code=400, detail="No URL provided")
try:
result = await state.rag_system.add_document_from_url(url)
return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)}
except Exception as e:
log.exception("URL document addition failed")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/documents")
async def list_documents():
"""List documents in the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
docs = await state.rag_system.list_documents()
return {"documents": docs}
except Exception as e:
log.exception("Document listing failed")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/documents/sites")
async def list_downloaded_sites():
"""List all downloaded websites in the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
sites = await state.rag_system.list_downloaded_sites()
return {"sites": sites}
except Exception as e:
log.exception("Site listing failed")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/documents/sites/{url:path}")
async def get_site_info(url: str):
"""Get information about a specific downloaded site."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
from urllib.parse import unquote
decoded_url = unquote(url)
if not decoded_url.startswith(("http://", "https://")):
decoded_url = "https://" + decoded_url
site_info = state.rag_system.get_site_info(decoded_url)
if site_info:
return {"site": site_info}
else:
raise HTTPException(status_code=404, detail="Site not found")
except HTTPException:
raise
except Exception as e:
log.exception("Site info retrieval failed")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/v1/documents/{doc_id}")
async def delete_document(doc_id: str):
"""Delete a document from the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
await state.rag_system.delete_document(doc_id)
return {"success": True, "message": f"Document {doc_id} deleted"}
except Exception as e:
log.exception("Document deletion failed")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/v1/documents/sites/{url:path}")
async def delete_site(url: str):
"""Delete a downloaded website and all its content from the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
from urllib.parse import unquote
decoded_url = unquote(url)
if not decoded_url.startswith(("http://", "https://")):
decoded_url = "https://" + decoded_url
result = await state.rag_system.delete_site(decoded_url)
if result.get("success"):
return {
"success": True,
"message": f"Site {decoded_url} deleted",
"deleted_chunks": result.get("deleted_chunks", 0),
"deleted_path": result.get("deleted_path"),
}
else:
raise HTTPException(status_code=404, detail=result.get("message", "Site not found"))
except HTTPException:
raise
except Exception as e:
log.exception("Site deletion failed")
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Health and Status Endpoints
# =============================================================================
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"uptime": time.time() - state.startup_time,
"rag_enabled": state.rag_system is not None,
"tools_enabled": state.tool_manager is not None,
"llm_connected": state.llm_client is not None,
}
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"name": "DocRAG API",
"version": "1.0.0",
"description": "OpenAI-compatible RAG server powered by OpenRouter. Auto-downloads and analyzes websites when users ask about them.",
"endpoints": {
"chat": "/v1/chat/completions",
"models": "/v1/models",
"documents": "/v1/documents",
"download_website": "/v1/documents/website",
"list_sites": "/v1/documents/sites",
"health": "/health",
},
}
# =============================================================================
# Main Entry Point
# =============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=config.HOST,
port=config.PORT,
reload=config.DEBUG,
)