_select_tools() parses the user message with keyword matching: - News keywords → news_aggregate, news_get_top_stories, news_get_reddit - Finance/stock keywords → finance_get_stock_info/history (extracts ticker) - Crypto keywords → finance_get_crypto_price (extracts coin name), finance_get_top_cryptos - Weather keywords → weather_get_current/forecast/air_quality (extracts location) - Medical keywords → pubmed, fda, disease data, health topics - Science keywords → science_aggregate_search - Wikipedia keywords → wikipedia_search - Always: web_search + web_instant_answer as general fallback - URL in message → web_get_page_content Entity extractors: - _extract_ticker: maps known company names, handles $TICKER format - _extract_crypto: maps known crypto names to CoinGecko IDs - _extract_location: preposition-based + known locations (prefers longest match) - _extract_subject: strips question patterns, leading articles, trailing punctuation Flow remains: request → select tools → run in parallel → results into system prompt → 1 LLM call
1297 lines
47 KiB
Python
Executable File
1297 lines
47 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
DocRAG - OpenAI-Compatible RAG Server
|
|
|
|
This application presents itself as a standard OpenAI API server that can be used
|
|
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
|
|
1. Detects URLs in user messages and auto-downloads websites
|
|
2. Ingests website content into the RAG knowledge base
|
|
3. Retrieves relevant context from the knowledge base
|
|
4. Passes the enriched context to OpenRouter for response generation
|
|
|
|
The user sees a normal chat experience, but the system is actually doing
|
|
sophisticated RAG operations in the background.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from typing import Any, AsyncIterator, Optional
|
|
|
|
# Load environment variables from .env file (look in script directory)
|
|
from dotenv import load_dotenv
|
|
SCRIPT_DIR = Path(__file__).parent.resolve()
|
|
ENV_FILE = SCRIPT_DIR / ".env"
|
|
load_dotenv(ENV_FILE)
|
|
|
|
# Also try loading from current working directory
|
|
load_dotenv()
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
log = logging.getLogger(__name__)
|
|
|
|
# FastAPI imports
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Import RAG components
|
|
from rag import RAGSystem, get_rag_system
|
|
from rag.document_processor import DocumentProcessor
|
|
|
|
# Import tools
|
|
from tools import ToolManager, get_tool_manager
|
|
|
|
# Import OpenAI client for OpenRouter
|
|
from openai import AsyncOpenAI
|
|
|
|
|
|
# =============================================================================
|
|
# Configuration
|
|
# =============================================================================
|
|
|
|
class Config:
|
|
"""Application configuration from environment variables."""
|
|
|
|
# Server settings
|
|
HOST: str = os.getenv("HOST", "0.0.0.0")
|
|
PORT: int = int(os.getenv("PORT", "8000"))
|
|
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
|
|
|
|
# Model settings
|
|
MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG")
|
|
UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "openrouter/free")
|
|
|
|
# OpenRouter API settings
|
|
OPENROUTER_API_KEY: str = os.getenv("OPENROUTER_API_KEY", "")
|
|
OPENROUTER_BASE_URL: str = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
|
|
|
|
# RAG settings
|
|
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
|
VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors")
|
|
DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents")
|
|
CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000"))
|
|
CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200"))
|
|
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
|
|
|
|
# Tool settings
|
|
ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true"
|
|
MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "5"))
|
|
|
|
|
|
config = Config()
|
|
|
|
|
|
# =============================================================================
|
|
# OpenAI-Compatible Models
|
|
# =============================================================================
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""OpenAI chat message format."""
|
|
role: str
|
|
content: Optional[str] = None
|
|
name: Optional[str] = None
|
|
tool_calls: Optional[list[dict]] = None
|
|
tool_call_id: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
"""OpenAI chat completion request format."""
|
|
model: str = "DocRAG"
|
|
messages: list[ChatMessage]
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 1.0
|
|
n: Optional[int] = 1
|
|
stream: Optional[bool] = False
|
|
stop: Optional[list[str]] = None
|
|
max_tokens: Optional[int] = None
|
|
presence_penalty: Optional[float] = 0.0
|
|
frequency_penalty: Optional[float] = 0.0
|
|
user: Optional[str] = None
|
|
tools: Optional[list[dict]] = None
|
|
tool_choice: Optional[str | dict] = None
|
|
|
|
|
|
class ChatCompletionChoice(BaseModel):
|
|
"""OpenAI chat completion choice."""
|
|
index: int = 0
|
|
message: ChatMessage
|
|
finish_reason: str = "stop"
|
|
|
|
|
|
class ChatCompletionUsage(BaseModel):
|
|
"""Token usage statistics."""
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
total_tokens: int = 0
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""OpenAI chat completion response format."""
|
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}")
|
|
object: str = "chat.completion"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
model: str = config.MODEL_NAME
|
|
choices: list[ChatCompletionChoice]
|
|
usage: ChatCompletionUsage = ChatCompletionUsage()
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""OpenAI model info format."""
|
|
id: str
|
|
object: str = "model"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
owned_by: str = "organization"
|
|
|
|
|
|
class ModelList(BaseModel):
|
|
"""OpenAI model list format."""
|
|
object: str = "list"
|
|
data: list[ModelInfo]
|
|
|
|
|
|
# =============================================================================
|
|
# Application State
|
|
# =============================================================================
|
|
|
|
class AppState:
|
|
"""Global application state."""
|
|
rag_system: Optional[RAGSystem] = None
|
|
tool_manager: Optional[ToolManager] = None
|
|
llm_client: Optional[AsyncOpenAI] = None
|
|
startup_time: float = time.time()
|
|
|
|
|
|
state = AppState()
|
|
|
|
|
|
# =============================================================================
|
|
# Lifespan Management
|
|
# =============================================================================
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manage application lifespan - startup and shutdown."""
|
|
log.info("Starting DocRAG server...")
|
|
|
|
# Initialize RAG system
|
|
try:
|
|
log.info("Initializing RAG system...")
|
|
state.rag_system = await get_rag_system(
|
|
embedding_model=config.EMBEDDING_MODEL,
|
|
vector_store_path=config.VECTOR_STORE_PATH,
|
|
documents_path=config.DOCUMENTS_PATH,
|
|
chunk_size=config.CHUNK_SIZE,
|
|
chunk_overlap=config.CHUNK_OVERLAP,
|
|
)
|
|
log.info("RAG system initialized successfully")
|
|
except Exception as e:
|
|
log.warning(f"RAG system initialization deferred: {e}")
|
|
state.rag_system = None
|
|
|
|
# Initialize tool manager
|
|
try:
|
|
log.info("Initializing tool manager...")
|
|
state.tool_manager = get_tool_manager()
|
|
log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}")
|
|
except Exception as e:
|
|
log.warning(f"Tool manager initialization failed: {e}")
|
|
state.tool_manager = None
|
|
|
|
# Initialize OpenRouter client for upstream LLM
|
|
# Debug: Show .env file status
|
|
log.info(f"Looking for .env file at: {ENV_FILE}")
|
|
log.info(f".env file exists: {ENV_FILE.exists()}")
|
|
|
|
api_key = config.OPENROUTER_API_KEY
|
|
if api_key:
|
|
key_preview = f"{api_key[:8]}...{api_key[-4:]}" if len(api_key) > 12 else "***"
|
|
log.info(f"OPENROUTER_API_KEY found: {key_preview}")
|
|
else:
|
|
log.warning("OPENROUTER_API_KEY not found in environment!")
|
|
log.warning(f"Checked .env at: {ENV_FILE}")
|
|
log.warning("Set OPENROUTER_API_KEY in .env file or as environment variable")
|
|
|
|
try:
|
|
if api_key:
|
|
log.info("Initializing OpenRouter client...")
|
|
# Create custom httpx client to avoid proxy issues
|
|
import httpx
|
|
http_client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(60.0, connect=10.0),
|
|
)
|
|
state.llm_client = AsyncOpenAI(
|
|
api_key=api_key,
|
|
base_url=config.OPENROUTER_BASE_URL,
|
|
http_client=http_client,
|
|
)
|
|
log.info(f"OpenRouter client initialized successfully (model: {config.UPSTREAM_MODEL})")
|
|
else:
|
|
log.warning("No OPENROUTER_API_KEY provided - using mock responses")
|
|
state.llm_client = None
|
|
except Exception as e:
|
|
log.error(f"Failed to initialize OpenRouter client: {e}")
|
|
state.llm_client = None
|
|
|
|
log.info(f"DocRAG server started on {config.HOST}:{config.PORT}")
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
log.info("Shutting down DocRAG server...")
|
|
if state.rag_system:
|
|
await state.rag_system.close()
|
|
log.info("DocRAG server stopped")
|
|
|
|
|
|
# =============================================================================
|
|
# FastAPI Application
|
|
# =============================================================================
|
|
|
|
app = FastAPI(
|
|
title="DocRAG API",
|
|
description="OpenAI-compatible RAG server powered by OpenRouter",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware for Open WebUI compatibility
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# OpenAI-Compatible Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/v1/models")
|
|
@app.get("/models")
|
|
async def list_models():
|
|
"""List available models (OpenAI-compatible)."""
|
|
return ModelList(
|
|
data=[
|
|
ModelInfo(id=config.MODEL_NAME, owned_by="docrag"),
|
|
ModelInfo(id="DocRAG", owned_by="docrag"),
|
|
]
|
|
)
|
|
|
|
|
|
@app.get("/v1/models/{model_id}")
|
|
@app.get("/models/{model_id}")
|
|
async def get_model(model_id: str):
|
|
"""Get model information (OpenAI-compatible)."""
|
|
if model_id not in [config.MODEL_NAME, "DocRAG"]:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
return ModelInfo(id=model_id, owned_by="docrag")
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@app.post("/chat/completions")
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
"""Handle chat completions (OpenAI-compatible)."""
|
|
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
|
|
try:
|
|
if request.stream:
|
|
return StreamingResponse(
|
|
stream_chat_completion(request, request_id),
|
|
media_type="text/event-stream",
|
|
)
|
|
else:
|
|
return await complete_chat(request, request_id)
|
|
except Exception as e:
|
|
log.exception("Chat completion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# URL Detection and Website Download
|
|
# =============================================================================
|
|
|
|
def extract_urls_from_message(message: str) -> list[str]:
|
|
"""Extract URLs from a message, including domains without scheme."""
|
|
# Match full URLs
|
|
url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+'
|
|
full_urls = re.findall(url_pattern, message)
|
|
|
|
# Match domain names (with or without www)
|
|
domain_pattern = r'(?:^|[^\w/-])(?:www\.)?([a-zA-Z0-9][-a-zA-Z0-9]*\.[a-zA-Z]{2,}(?:\.[a-zA-Z]{2,})?)(?:/[^\s]*)?(?=[^\w/-]|$)'
|
|
domains = re.findall(domain_pattern, message)
|
|
|
|
urls = list(full_urls)
|
|
for domain in domains:
|
|
# Check if it's a valid domain (not just a word)
|
|
if '.' in domain and len(domain) > 4:
|
|
normalized = f"https://{domain}" if not domain.startswith(('http://', 'https://')) else domain
|
|
if normalized not in urls:
|
|
urls.append(normalized)
|
|
|
|
return urls
|
|
|
|
|
|
def should_download_website(message: str, urls: list[str]) -> bool:
|
|
"""Determine if the user wants to access content from a website."""
|
|
if not urls:
|
|
return False
|
|
|
|
message_lower = message.lower()
|
|
|
|
# Skip if this is an automated Open WebUI task (not a real user request)
|
|
automated_task_indicators = [
|
|
'### task:', 'generate a concise', 'generate 1-3',
|
|
'suggest 3-5 relevant follow-up', 'summarizing the chat history',
|
|
'categorizing the main themes', 'follow-up questions or prompts',
|
|
]
|
|
if any(indicator in message_lower for indicator in automated_task_indicators):
|
|
log.info("Skipping website download - appears to be an automated task")
|
|
return False
|
|
|
|
# Keywords indicating user wants website content
|
|
access_keywords = [
|
|
'go to', 'visit', 'check', 'look at', 'browse', 'open',
|
|
'what is on', 'tell me about', 'give me', 'show me', 'get',
|
|
'headlines', 'content', 'information from', 'from the', 'from',
|
|
'on the website', 'on the site', 'the website', 'the site', 'website',
|
|
'summarize', 'extract', 'read', 'analyze', 'find', 'search',
|
|
'what does', 'what\'s on', 'what is', 'tell me', 'about',
|
|
'news', 'articles', 'posts', 'page', 'pages',
|
|
]
|
|
|
|
# Check if message contains access intent
|
|
has_access_intent = any(kw in message_lower for kw in access_keywords)
|
|
|
|
# Also trigger if URL is directly mentioned with a question
|
|
has_question = '?' in message or any(qw in message_lower for qw in ['what', 'how', 'who', 'where', 'when', 'why'])
|
|
|
|
return (has_access_intent or has_question) and len(urls) > 0
|
|
|
|
|
|
async def download_website_if_needed(user_message: str) -> dict[str, Any]:
|
|
"""
|
|
Download website if user is asking about one.
|
|
Returns download info if successful.
|
|
"""
|
|
urls = extract_urls_from_message(user_message)
|
|
|
|
if not should_download_website(user_message, urls):
|
|
return {"downloaded": False, "reason": "No website access intent detected"}
|
|
|
|
if not state.rag_system:
|
|
return {"downloaded": False, "reason": "RAG system not initialized"}
|
|
|
|
for url in urls:
|
|
try:
|
|
# Check if site is already downloaded
|
|
site_info = state.rag_system.get_site_info(url)
|
|
if site_info:
|
|
log.info(f"Site already downloaded: {url} ({site_info.get('chunks_ingested', 0)} chunks)")
|
|
return {
|
|
"downloaded": True,
|
|
"url": url,
|
|
"chunks": site_info.get("chunks_ingested", 0),
|
|
"pages": site_info.get("pages_downloaded", 0),
|
|
"local_path": site_info.get("local_path"),
|
|
"cached": True,
|
|
}
|
|
|
|
log.info(f"Auto-downloading website: {url}")
|
|
result = await state.rag_system.download_and_ingest_website(
|
|
url=url,
|
|
max_pages=30,
|
|
)
|
|
|
|
if result.get("success"):
|
|
log.info(f"Successfully downloaded {url}: {result.get('total_chunks', 0)} chunks")
|
|
return {
|
|
"downloaded": True,
|
|
"url": url,
|
|
"chunks": result.get("total_chunks", 0),
|
|
"pages": result.get("pages_processed", 0),
|
|
"local_path": result.get("local_path"),
|
|
}
|
|
except Exception as e:
|
|
log.warning(f"Failed to download {url}: {e}")
|
|
|
|
return {"downloaded": False, "reason": "All download attempts failed"}
|
|
|
|
|
|
# =============================================================================
|
|
# Chat Completion Logic
|
|
# =============================================================================
|
|
|
|
async def _run_all_tools(user_message: str) -> list[dict]:
|
|
"""Select relevant tools via local keyword matching, then run them all in parallel."""
|
|
if not state.tool_manager:
|
|
return []
|
|
|
|
# Step 1: Determine which tools to call
|
|
tool_calls = _select_tools(user_message)
|
|
if not tool_calls:
|
|
log.info("No tools selected by keyword parser")
|
|
return []
|
|
|
|
log.info(f"Selected {len(tool_calls)} tools: {[tc['name'] for tc in tool_calls]}")
|
|
|
|
# Step 2: Execute them all in parallel
|
|
async def _run_one(name: str, kwargs: dict):
|
|
try:
|
|
result = await asyncio.wait_for(
|
|
asyncio.to_thread(state.tool_manager.execute_tool, name, kwargs),
|
|
timeout=30,
|
|
)
|
|
return {"name": name, "success": True, "result": result}
|
|
except asyncio.TimeoutError:
|
|
return {"name": name, "success": False, "error": "timeout"}
|
|
except Exception as e:
|
|
return {"name": name, "success": False, "error": str(e)}
|
|
|
|
results = await asyncio.gather(*[_run_one(tc["name"], tc["kwargs"]) for tc in tool_calls])
|
|
successes = [r for r in results if r["success"]]
|
|
log.info(f"Tool execution complete: {len(successes)}/{len(results)} succeeded")
|
|
return results
|
|
|
|
|
|
def _select_tools(user_message: str) -> list[dict]:
|
|
"""Parse the user message and determine which tools to call.
|
|
|
|
Uses keyword/category matching. Returns list of {"name": str, "kwargs": dict}.
|
|
"""
|
|
msg = user_message.lower()
|
|
tools = []
|
|
|
|
# --- Extract useful entities from the message ---
|
|
location = _extract_location(user_message)
|
|
ticker = _extract_ticker(user_message)
|
|
crypto = _extract_crypto(user_message)
|
|
url = _extract_url(user_message)
|
|
subject = _extract_subject(user_message) # the main topic/query
|
|
|
|
# --- News tools ---
|
|
if any(kw in msg for kw in ["news", "headline", "headlines", "current event", "current events",
|
|
"breaking", "trending", "reddit", "hacker", "story", "stories"]):
|
|
tools.append({"name": "news_aggregate", "kwargs": {"query": subject}})
|
|
tools.append({"name": "news_get_top_stories", "kwargs": {}})
|
|
tools.append({"name": "news_get_reddit", "kwargs": {"subreddit": "news"}})
|
|
if any(kw in msg for kw in ["reddit", "subreddit"]):
|
|
tools.append({"name": "news_search_reddit", "kwargs": {"query": subject}})
|
|
|
|
# --- Finance / stock tools ---
|
|
if any(kw in msg for kw in ["stock", "share", "price", "market", "nasdaq", "nyse",
|
|
"dow", "s&p", "dividend", "portfolio", "ipo"]):
|
|
if ticker:
|
|
tools.append({"name": "finance_get_stock_info", "kwargs": {"symbol": ticker}})
|
|
tools.append({"name": "finance_get_stock_history", "kwargs": {"symbol": ticker}})
|
|
else:
|
|
# Try to extract any potential ticker from message
|
|
for word in user_message.upper().split():
|
|
word = word.strip(",$.!?;:")
|
|
if 1 <= len(word) <= 5 and word.isalpha():
|
|
tools.append({"name": "finance_get_stock_info", "kwargs": {"symbol": word}})
|
|
break
|
|
|
|
# --- Crypto tools ---
|
|
if any(kw in msg for kw in ["crypto", "bitcoin", "btc", "ethereum", "eth", "solana",
|
|
"dogecoin", "memecoin", "altcoin", "blockchain", "coin", "token"]):
|
|
if crypto:
|
|
tools.append({"name": "finance_get_crypto_price", "kwargs": {"coin_id": crypto}})
|
|
tools.append({"name": "finance_get_top_cryptos", "kwargs": {}})
|
|
|
|
# --- Weather tools ---
|
|
if any(kw in msg for kw in ["weather", "temperature", "forecast", "rain", "snow",
|
|
"wind", "humid", "sunny", "cloudy", "storm", "air quality",
|
|
"aqi", "pollution", "uv index"]):
|
|
if location:
|
|
tools.append({"name": "weather_get_current", "kwargs": {"location": location}})
|
|
tools.append({"name": "weather_get_forecast", "kwargs": {"location": location}})
|
|
tools.append({"name": "weather_get_air_quality", "kwargs": {"location": location}})
|
|
else:
|
|
tools.append({"name": "weather_get_current", "kwargs": {"location": subject}})
|
|
tools.append({"name": "weather_get_forecast", "kwargs": {"location": subject}})
|
|
|
|
# --- Medical tools ---
|
|
if any(kw in msg for kw in ["medical", "health", "disease", "symptom", "drug", "medication",
|
|
"covid", "vaccine", "fda", "hospital", "pubmed", "clinical",
|
|
"treatment", "diagnosis", "doctor", "patient"]):
|
|
tools.append({"name": "medical_search_pubmed", "kwargs": {"query": subject}})
|
|
tools.append({"name": "medical_search_fda", "kwargs": {"query": subject}})
|
|
tools.append({"name": "medical_get_disease_data", "kwargs": {"disease": subject}})
|
|
tools.append({"name": "medical_get_health_topics", "kwargs": {"topic": subject}})
|
|
|
|
# --- Science / research tools ---
|
|
if any(kw in msg for kw in ["research", "paper", "study", "arxiv", "academic", "journal",
|
|
"scholar", "citation", "peer-review", "scientific", "thesis",
|
|
"experiment", "theory", "physics", "math"]):
|
|
tools.append({"name": "science_aggregate_search", "kwargs": {"query": subject}})
|
|
|
|
# --- Wikipedia ---
|
|
if any(kw in msg for kw in ["wikipedia", "wiki", "who is", "what is", "history of",
|
|
"explain", "definition", "meaning of", "tell me about"]):
|
|
tools.append({"name": "wikipedia_search", "kwargs": {"query": subject}})
|
|
|
|
# --- Web search (always include as general fallback) ---
|
|
tools.append({"name": "web_search", "kwargs": {"query": subject}})
|
|
tools.append({"name": "web_instant_answer", "kwargs": {"query": subject}})
|
|
|
|
# --- URL extraction ---
|
|
if url:
|
|
tools.append({"name": "web_get_page_content", "kwargs": {"url": url}})
|
|
|
|
# Deduplicate by name
|
|
seen = set()
|
|
unique = []
|
|
for tc in tools:
|
|
if tc["name"] not in seen:
|
|
seen.add(tc["name"])
|
|
unique.append(tc)
|
|
|
|
return unique
|
|
|
|
|
|
# --- Entity extractors ---
|
|
|
|
# Common stock tickers
|
|
_KNOWN_TICKERS = {
|
|
"aapl": "AAPL", "apple": "AAPL",
|
|
"googl": "GOOGL", "google": "GOOGL",
|
|
"msft": "MSFT", "microsoft": "MSFT",
|
|
"amzn": "AMZN", "amazon": "AMZN",
|
|
"tsla": "TSLA", "tesla": "TSLA",
|
|
"meta": "META", "facebook": "META",
|
|
"nvda": "NVDA", "nvidia": "NVDA",
|
|
"netflix": "NFLX", "nflx": "NFLX",
|
|
"amd": "AMD", "intel": "INTC",
|
|
"disney": "DIS", "jpmorgan": "JPM",
|
|
"ba": "BA", "boeing": "BA",
|
|
"walmart": "WMT", "wmt": "WMT",
|
|
"pfizer": "PFE", "pfe": "PFE",
|
|
"nio": "NIO", "pltr": "PLTR", "palantir": "PLTR",
|
|
"coin": "COIN", "coinbase": "COIN",
|
|
"roku": "ROKU", "spotify": "SPOT", "shopify": "SHOP",
|
|
}
|
|
|
|
# Common crypto names
|
|
_KNOWN_CRYPTO = {
|
|
"bitcoin": "bitcoin", "btc": "bitcoin",
|
|
"ethereum": "ethereum", "eth": "ethereum",
|
|
"solana": "solana", "sol": "solana",
|
|
"dogecoin": "dogecoin", "doge": "dogecoin",
|
|
"ripple": "ripple", "xrp": "ripple",
|
|
"cardano": "cardano", "ada": "cardano",
|
|
"polkadot": "polkadot", "dot": "polkadot",
|
|
"litecoin": "litecoin", "ltc": "litecoin",
|
|
"chainlink": "chainlink", "link": "chainlink",
|
|
"avalanche": "avalanche", "avax": "avalanche",
|
|
"polygon": "polygon", "matic": "polygon",
|
|
"shiba": "shiba-inu", "shib": "shiba-inu",
|
|
"tron": "tron", "trx": "tron",
|
|
"usdt": "tether", "tether": "tether",
|
|
}
|
|
|
|
# US states and common cities for location extraction
|
|
_KNOWN_LOCATIONS = [
|
|
"alabama", "alaska", "arizona", "arkansas", "california", "colorado",
|
|
"connecticut", "delaware", "florida", "georgia", "hawaii", "idaho",
|
|
"illinois", "indiana", "iowa", "kansas", "kentucky", "louisiana",
|
|
"maine", "maryland", "massachusetts", "michigan", "minnesota",
|
|
"mississippi", "missouri", "montana", "nebraska", "nevada",
|
|
"new hampshire", "new jersey", "new mexico", "new york", "north carolina",
|
|
"north dakota", "ohio", "oklahoma", "oregon", "pennsylvania",
|
|
"rhode island", "south carolina", "south dakota", "tennessee", "texas",
|
|
"utah", "vermont", "virginia", "washington", "west virginia",
|
|
"wisconsin", "wyoming", "oroville", "chico", "redding", "sacramento",
|
|
"los angeles", "san francisco", "san diego", "new york city", "chicago",
|
|
"houston", "phoenix", "dallas", "austin", "seattle", "portland",
|
|
"denver", "miami", "boston", "atlanta", "london", "paris", "tokyo",
|
|
"berlin", "sydney", "toronto", "vancouver", "melbourne",
|
|
]
|
|
|
|
_LOCATION_PREPOSITIONS = ["in", "at", "for", "near", "around", "outside", "from"]
|
|
|
|
|
|
def _extract_ticker(message: str) -> str:
|
|
"""Extract a stock ticker from the message."""
|
|
words = message.upper().split()
|
|
for i, word in enumerate(words):
|
|
clean = word.strip(",$.!?;:\"'")
|
|
# Check known names
|
|
if message.lower().split()[i] in _KNOWN_TICKERS:
|
|
return _KNOWN_TICKERS[message.lower().split()[i]]
|
|
# Check $TICKER format
|
|
if clean.startswith("$") and 1 <= len(clean[1:]) <= 5 and clean[1:].isalpha():
|
|
return clean[1:]
|
|
# Check raw uppercase ticker (1-5 alpha chars)
|
|
if 2 <= len(clean) <= 5 and clean.isalpha() and clean.isupper():
|
|
return clean
|
|
return ""
|
|
|
|
|
|
def _extract_crypto(message: str) -> str:
|
|
"""Extract a cryptocurrency name from the message."""
|
|
msg_lower = message.lower()
|
|
for name, coin_id in _KNOWN_CRYPTO.items():
|
|
if name in msg_lower:
|
|
return coin_id
|
|
return ""
|
|
|
|
|
|
def _extract_location(message: str) -> str:
|
|
"""Extract a location from the message."""
|
|
msg_lower = message.lower()
|
|
words = msg_lower.split()
|
|
|
|
# Try preposition-based extraction first (more specific)
|
|
best_match = ""
|
|
for i, word in enumerate(words):
|
|
if word in _LOCATION_PREPOSITIONS and i + 1 < len(words):
|
|
candidate_words = []
|
|
for j in range(i + 1, min(i + 6, len(words))):
|
|
if words[j] in _LOCATION_PREPOSITIONS or words[j] in [",", ".", "?", "!"]:
|
|
break
|
|
candidate_words.append(words[j].strip(",$.!?;:\"'"))
|
|
candidate = " ".join(candidate_words)
|
|
if not candidate:
|
|
continue
|
|
# Find the longest known location that matches within the candidate
|
|
matches = sorted(
|
|
[loc for loc in _KNOWN_LOCATIONS if loc in candidate],
|
|
key=len, reverse=True
|
|
)
|
|
if matches:
|
|
best_match = matches[0].title()
|
|
elif len(candidate) <= 4:
|
|
best_match = candidate.title()
|
|
if best_match:
|
|
return best_match
|
|
|
|
# Fallback: check for known locations appearing anywhere (prefer longest)
|
|
matches = sorted(
|
|
[loc for loc in _KNOWN_LOCATIONS if loc in msg_lower],
|
|
key=len, reverse=True
|
|
)
|
|
if matches:
|
|
return matches[0].title()
|
|
|
|
return ""
|
|
|
|
|
|
def _extract_url(message: str) -> str:
|
|
"""Extract a URL from the message."""
|
|
match = re.search(r'https?://[^\s<>"{}|\\^`\[\]]+', message)
|
|
return match.group(0) if match else ""
|
|
|
|
|
|
def _extract_subject(message: str) -> str:
|
|
"""Extract the main subject/query from the user message.
|
|
|
|
Strips common question patterns to get the core topic.
|
|
"""
|
|
subject = message.strip()
|
|
|
|
# Remove question starters (longest first to avoid partial matches)
|
|
starters = [
|
|
"give me all the", "give me all", "give me",
|
|
"tell me about", "tell me",
|
|
"what is the", "what is",
|
|
"what are the", "what are",
|
|
"what's the", "what's",
|
|
"how is the", "how is",
|
|
"how are the", "how are",
|
|
"how do", "how does",
|
|
"how much", "how many",
|
|
"can you", "could you", "would you",
|
|
"please ", "i need", "i want",
|
|
"show me", "get me",
|
|
"find me", "find",
|
|
"search for", "look up", "lookup",
|
|
"check the", "check",
|
|
"what's happening",
|
|
"explain", "describe", "summarize",
|
|
]
|
|
msg_lower = subject.lower()
|
|
for starter in starters:
|
|
if msg_lower.startswith(starter):
|
|
subject = subject[len(starter):].strip()
|
|
msg_lower = subject.lower()
|
|
break
|
|
|
|
# Strip leading "the", "a", "an"
|
|
for article in ["the ", "a ", "an "]:
|
|
if msg_lower.startswith(article):
|
|
subject = subject[len(article):]
|
|
msg_lower = subject.lower()
|
|
break
|
|
|
|
# Strip trailing punctuation
|
|
subject = subject.rstrip("?.!;, ")
|
|
|
|
# If still long, take first meaningful chunk
|
|
if len(subject) > 200:
|
|
subject = subject[:200]
|
|
|
|
return subject or message.strip()
|
|
|
|
|
|
def _build_tool_results_text(tool_results: list[dict]) -> str:
|
|
"""Build a text block of all tool results for the system prompt."""
|
|
if not tool_results:
|
|
return ""
|
|
|
|
parts = []
|
|
for tr in tool_results:
|
|
name = tr["name"]
|
|
if tr["success"]:
|
|
result_data = tr.get("result", {})
|
|
# Truncate large results to keep prompt manageable
|
|
result_str = json.dumps(result_data, ensure_ascii=False)
|
|
if len(result_str) > 3000:
|
|
result_str = result_str[:3000] + '..." [TRUNCATED]'
|
|
parts.append(f"### {name}\n{result_str}")
|
|
else:
|
|
parts.append(f"### {name}\n[ERROR: {tr.get('error', 'unknown')}]")
|
|
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse:
|
|
"""Process a non-streaming chat completion request."""
|
|
log.info(f"=== Starting complete_chat for request {request_id} ===")
|
|
messages = request.messages
|
|
|
|
# Extract the last user message
|
|
user_message = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_message = msg.content
|
|
break
|
|
|
|
if not user_message:
|
|
raise HTTPException(status_code=400, detail="No user message found")
|
|
|
|
log.info(f"User message: {user_message[:100]}...")
|
|
|
|
# Step 1: Download website if user is asking about one
|
|
download_info = await download_website_if_needed(user_message)
|
|
if download_info.get("downloaded"):
|
|
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
|
|
|
# Step 2: RAG Retrieval
|
|
context = ""
|
|
sources = []
|
|
if state.rag_system:
|
|
try:
|
|
rag_result = await state.rag_system.query(
|
|
query=user_message,
|
|
top_k=config.TOP_K_RESULTS,
|
|
)
|
|
context = rag_result.get("context", "")
|
|
sources = rag_result.get("sources", [])
|
|
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
|
except Exception as e:
|
|
log.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 3: Run ALL tools in parallel (no LLM needed)
|
|
tool_results = []
|
|
if state.tool_manager and config.ENABLE_TOOLS:
|
|
tool_results = await _run_all_tools(user_message)
|
|
|
|
# Step 4: Build system prompt with tool results as context
|
|
tool_results_text = _build_tool_results_text(tool_results)
|
|
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info, tool_results_text)
|
|
|
|
# Step 5: ONE LLM call
|
|
log.info(f"Calling LLM (single call) for request {request_id}")
|
|
response_content = await call_llm(
|
|
enhanced_messages,
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
)
|
|
log.info(f"=== Completed complete_chat for request {request_id} ===")
|
|
|
|
return ChatCompletionResponse(
|
|
id=request_id,
|
|
model=config.MODEL_NAME,
|
|
choices=[
|
|
ChatCompletionChoice(
|
|
message=ChatMessage(
|
|
role="assistant",
|
|
content=response_content,
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
usage=ChatCompletionUsage(
|
|
prompt_tokens=len(str(enhanced_messages)) // 4,
|
|
completion_tokens=len(response_content) // 4,
|
|
),
|
|
)
|
|
|
|
|
|
async def stream_chat_completion(
|
|
request: ChatCompletionRequest,
|
|
request_id: str,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream a chat completion response."""
|
|
messages = request.messages
|
|
|
|
# Extract the last user message
|
|
user_message = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_message = msg.content
|
|
break
|
|
|
|
if not user_message:
|
|
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
|
|
return
|
|
|
|
# Step 1: Download website if user is asking about one
|
|
download_info = await download_website_if_needed(user_message)
|
|
if download_info.get("downloaded"):
|
|
log.info(f"Website auto-downloaded: {download_info.get('url')}")
|
|
|
|
# Step 2: RAG Retrieval
|
|
context = ""
|
|
sources = []
|
|
if state.rag_system:
|
|
try:
|
|
rag_result = await state.rag_system.query(
|
|
query=user_message,
|
|
top_k=config.TOP_K_RESULTS,
|
|
)
|
|
context = rag_result.get("context", "")
|
|
sources = rag_result.get("sources", [])
|
|
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
|
except Exception as e:
|
|
log.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 3: Run ALL tools in parallel (no LLM needed)
|
|
tool_results = []
|
|
if state.tool_manager and config.ENABLE_TOOLS:
|
|
tool_results = await _run_all_tools(user_message)
|
|
|
|
# Step 4: Build system prompt with tool results as context
|
|
tool_results_text = _build_tool_results_text(tool_results)
|
|
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info, tool_results_text)
|
|
|
|
# Step 5: ONE LLM call (stream the result)
|
|
created = int(time.time())
|
|
|
|
try:
|
|
if state.llm_client:
|
|
stream = await state.llm_client.chat.completions.create(
|
|
model=config.UPSTREAM_MODEL,
|
|
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
|
|
temperature=request.temperature or 0.7,
|
|
max_tokens=request.max_tokens or 4096,
|
|
stream=True,
|
|
)
|
|
|
|
async for chunk in stream:
|
|
if chunk.choices and chunk.choices[0].delta.content:
|
|
content = chunk.choices[0].delta.content
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {'content': content},
|
|
'finish_reason': None
|
|
}]
|
|
})}\n\n"
|
|
|
|
# Final chunk
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {},
|
|
'finish_reason': 'stop'
|
|
}]
|
|
})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
else:
|
|
# Mock streaming response
|
|
mock_response = f"I understand you're asking about: {user_message}\n\n"
|
|
if tool_results_text:
|
|
mock_response += f"I gathered data from {len(tool_results)} tools.\n\n"
|
|
if context:
|
|
mock_response += f"Knowledge base context:\n{context[:1000]}...\n\n"
|
|
mock_response += "\n\n[Demo mode - configure OPENROUTER_API_KEY for full LLM responses]"
|
|
|
|
for char in mock_response:
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {'content': char},
|
|
'finish_reason': None
|
|
}]
|
|
})}\n\n"
|
|
await asyncio.sleep(0.01)
|
|
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {},
|
|
'finish_reason': 'stop'
|
|
}]
|
|
})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
except Exception as e:
|
|
log.exception("Streaming failed")
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
|
|
def build_enhanced_messages(
|
|
messages: list[ChatMessage],
|
|
context: str,
|
|
sources: list[str],
|
|
download_info: dict = None,
|
|
tool_results_text: str = "",
|
|
) -> list[ChatMessage]:
|
|
"""Build enhanced messages with RAG context and tool results in system prompt."""
|
|
enhanced = []
|
|
|
|
system_content = "You are a helpful AI assistant with access to real-time data.\n"
|
|
|
|
if tool_results_text:
|
|
system_content += f"\n## REAL-TIME DATA (from tools)\n{tool_results_text}\n"
|
|
|
|
if download_info and download_info.get("downloaded"):
|
|
system_content += f"\n--- Website Access ---\n"
|
|
system_content += f"Downloaded website: {download_info.get('url')}\n"
|
|
system_content += f"Pages: {download_info.get('pages')}, Chunks: {download_info.get('chunks')}\n"
|
|
|
|
if context:
|
|
system_content += f"\n## Relevant Context from Knowledge Base\n{context}\n"
|
|
if sources:
|
|
system_content += f"\nSources:\n" + "\n".join(f"- {s}" for s in sources[:10])
|
|
|
|
system_content += "\n\n## INSTRUCTIONS\nUse the data above to answer the user's question. Be concise and factual."
|
|
|
|
enhanced.append(ChatMessage(role="system", content=system_content))
|
|
|
|
# Add conversation history (excluding old system messages)
|
|
for msg in messages:
|
|
if msg.role != "system":
|
|
enhanced.append(msg)
|
|
|
|
return enhanced
|
|
|
|
|
|
async def call_llm(
|
|
messages: list[ChatMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 4096,
|
|
) -> str:
|
|
"""Single LLM call. No tool logic, just prompt in → response out."""
|
|
if not state.llm_client:
|
|
user_msg = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_msg = msg.content
|
|
break
|
|
return f"Demo mode. Your question: {user_msg[:100]}... Configure OPENROUTER_API_KEY for full functionality."
|
|
|
|
try:
|
|
messages_dict = [{"role": m.role, "content": m.content} for m in messages if m.content]
|
|
response = await state.llm_client.chat.completions.create(
|
|
model=config.UPSTREAM_MODEL,
|
|
messages=messages_dict,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
if not response.choices:
|
|
log.warning("No choices in LLM response")
|
|
return "I apologize, but I couldn't generate a response."
|
|
content = response.choices[0].message.content or ""
|
|
return content or "I apologize, but I couldn't generate a response."
|
|
except Exception as e:
|
|
log.error(f"LLM call failed: {e}")
|
|
return f"I encountered an error: {str(e)}"
|
|
|
|
|
|
# =============================================================================
|
|
# Document Management Endpoints
|
|
# =============================================================================
|
|
|
|
@app.post("/v1/documents/upload")
|
|
async def upload_document(request: Request):
|
|
"""Upload a document to the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
form = await request.form()
|
|
file = form.get("file")
|
|
if not file:
|
|
raise HTTPException(status_code=400, detail="No file provided")
|
|
|
|
content = await file.read()
|
|
filename = file.filename or "unknown"
|
|
|
|
# Process and store document
|
|
result = await state.rag_system.add_document(
|
|
content=content,
|
|
filename=filename,
|
|
)
|
|
|
|
return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)}
|
|
|
|
except Exception as e:
|
|
log.exception("Document upload failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
class WebsiteDownloadRequest(BaseModel):
|
|
"""Request model for website download."""
|
|
url: str
|
|
max_pages: int = 50
|
|
threads: int = 6
|
|
download_external_assets: bool = False
|
|
external_domains: Optional[list[str]] = None
|
|
|
|
|
|
@app.post("/v1/documents/website")
|
|
async def download_website(request: WebsiteDownloadRequest):
|
|
"""
|
|
Download a website and ingest it into the knowledge base.
|
|
|
|
This is the PRIMARY way to add content to the RAG system.
|
|
Uses the website_downloader_tool to download and process websites.
|
|
"""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
log.info(f"Downloading website: {request.url}")
|
|
|
|
result = await state.rag_system.download_and_ingest_website(
|
|
url=request.url,
|
|
max_pages=request.max_pages,
|
|
threads=request.threads,
|
|
download_external_assets=request.download_external_assets,
|
|
external_domains=request.external_domains,
|
|
)
|
|
|
|
if result.get("success"):
|
|
return {
|
|
"success": True,
|
|
"message": f"Website downloaded and ingested: {request.url}",
|
|
"url": request.url,
|
|
"local_path": result.get("local_path"),
|
|
"pages_processed": result.get("pages_processed", 0),
|
|
"total_chunks": result.get("total_chunks", 0),
|
|
}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=result.get("message", "Website download failed")
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Website download failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/v1/documents/url")
|
|
async def add_document_from_url(request: dict):
|
|
"""Add a document from URL to the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
url = request.get("url")
|
|
if not url:
|
|
raise HTTPException(status_code=400, detail="No URL provided")
|
|
|
|
try:
|
|
result = await state.rag_system.add_document_from_url(url)
|
|
return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)}
|
|
except Exception as e:
|
|
log.exception("URL document addition failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents")
|
|
async def list_documents():
|
|
"""List documents in the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
docs = await state.rag_system.list_documents()
|
|
return {"documents": docs}
|
|
except Exception as e:
|
|
log.exception("Document listing failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents/sites")
|
|
async def list_downloaded_sites():
|
|
"""List all downloaded websites in the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
sites = await state.rag_system.list_downloaded_sites()
|
|
return {"sites": sites}
|
|
except Exception as e:
|
|
log.exception("Site listing failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents/sites/{url:path}")
|
|
async def get_site_info(url: str):
|
|
"""Get information about a specific downloaded site."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
from urllib.parse import unquote
|
|
decoded_url = unquote(url)
|
|
|
|
if not decoded_url.startswith(("http://", "https://")):
|
|
decoded_url = "https://" + decoded_url
|
|
|
|
site_info = state.rag_system.get_site_info(decoded_url)
|
|
if site_info:
|
|
return {"site": site_info}
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Site not found")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Site info retrieval failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/documents/{doc_id}")
|
|
async def delete_document(doc_id: str):
|
|
"""Delete a document from the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
await state.rag_system.delete_document(doc_id)
|
|
return {"success": True, "message": f"Document {doc_id} deleted"}
|
|
except Exception as e:
|
|
log.exception("Document deletion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/documents/sites/{url:path}")
|
|
async def delete_site(url: str):
|
|
"""Delete a downloaded website and all its content from the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
from urllib.parse import unquote
|
|
decoded_url = unquote(url)
|
|
|
|
if not decoded_url.startswith(("http://", "https://")):
|
|
decoded_url = "https://" + decoded_url
|
|
|
|
result = await state.rag_system.delete_site(decoded_url)
|
|
|
|
if result.get("success"):
|
|
return {
|
|
"success": True,
|
|
"message": f"Site {decoded_url} deleted",
|
|
"deleted_chunks": result.get("deleted_chunks", 0),
|
|
"deleted_path": result.get("deleted_path"),
|
|
}
|
|
else:
|
|
raise HTTPException(status_code=404, detail=result.get("message", "Site not found"))
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Site deletion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# Health and Status Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {
|
|
"status": "healthy",
|
|
"uptime": time.time() - state.startup_time,
|
|
"rag_enabled": state.rag_system is not None,
|
|
"tools_enabled": state.tool_manager is not None,
|
|
"llm_connected": state.llm_client is not None,
|
|
}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint with API info."""
|
|
return {
|
|
"name": "DocRAG API",
|
|
"version": "1.0.0",
|
|
"description": "OpenAI-compatible RAG server powered by OpenRouter. Auto-downloads and analyzes websites when users ask about them.",
|
|
"endpoints": {
|
|
"chat": "/v1/chat/completions",
|
|
"models": "/v1/models",
|
|
"documents": "/v1/documents",
|
|
"download_website": "/v1/documents/website",
|
|
"list_sites": "/v1/documents/sites",
|
|
"health": "/health",
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Main Entry Point
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=config.HOST,
|
|
port=config.PORT,
|
|
reload=config.DEBUG,
|
|
)
|