Features:
- RAG system now uses website_downloader_tool as primary content ingestion method
- download_and_ingest_website() method for complete website processing
- Stores page pointers (source_url, page_url, local_path) in vector store
- Site registry tracks all downloaded websites with metadata
- New API endpoints for website management:
- POST /v1/documents/website - Download and ingest a website
- GET /v1/documents/sites - List all downloaded sites
- GET /v1/documents/sites/{url} - Get site info
- DELETE /v1/documents/sites/{url} - Delete a site and its content
Changes:
- rag/__init__.py: Added download_and_ingest_website(), site registry
- rag/document_processor.py: Added extract_text_from_html() public method
- rag/vector_store.py: Added delete_by_source_url(), get_stats()
- main.py: New website endpoints, integrated tool with RAG system
876 lines
29 KiB
Python
876 lines
29 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
DocRAG - OpenAI-Compatible RAG Server
|
|
|
|
This application presents itself as a standard OpenAI API server that can be used
|
|
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
|
|
1. Processes user queries through a RAG system
|
|
2. Retrieves relevant context from a knowledge base
|
|
3. Passes the enriched context to GLM-4.7-Flash for response generation
|
|
4. Optionally uses tools like website_downloader for enhanced capabilities
|
|
|
|
The user sees a normal chat experience, but the system is actually doing
|
|
sophisticated RAG operations in the background.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, AsyncIterator, Optional
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
log = logging.getLogger(__name__)
|
|
|
|
# FastAPI imports
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import StreamingResponse, JSONResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
# Import RAG components
|
|
from rag import RAGSystem, get_rag_system
|
|
from rag.document_processor import DocumentProcessor
|
|
|
|
# Import tools
|
|
from tools import ToolManager, get_tool_manager
|
|
|
|
# Import SDK for GLM-4.7-Flash
|
|
try:
|
|
from zai import ZaiClient as ZAI
|
|
except ImportError:
|
|
ZAI = None
|
|
log.warning("z-ai-web-dev-sdk not installed. Install with: pip install z-ai-web-dev-sdk")
|
|
|
|
|
|
# =============================================================================
|
|
# Configuration
|
|
# =============================================================================
|
|
|
|
class Config:
|
|
"""Application configuration from environment variables."""
|
|
|
|
# Server settings
|
|
HOST: str = os.getenv("HOST", "0.0.0.0")
|
|
PORT: int = int(os.getenv("PORT", "8000"))
|
|
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
|
|
|
|
# Model settings
|
|
MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG-GLM-4.7")
|
|
UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "glm-4.7")
|
|
|
|
# API Key for upstream LLM
|
|
ZAI_API_KEY: str = os.getenv("ZAI_API_KEY", "")
|
|
|
|
# RAG settings
|
|
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
|
VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors")
|
|
DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents")
|
|
CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000"))
|
|
CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200"))
|
|
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
|
|
|
|
# Tool settings
|
|
ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true"
|
|
MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "3"))
|
|
|
|
|
|
config = Config()
|
|
|
|
|
|
# =============================================================================
|
|
# OpenAI-Compatible Models
|
|
# =============================================================================
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""OpenAI chat message format."""
|
|
role: str
|
|
content: Optional[str] = None
|
|
name: Optional[str] = None
|
|
tool_calls: Optional[list[dict]] = None
|
|
tool_call_id: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
"""OpenAI chat completion request format."""
|
|
model: str = "DocRAG-GLM-4.7"
|
|
messages: list[ChatMessage]
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 1.0
|
|
n: Optional[int] = 1
|
|
stream: Optional[bool] = False
|
|
stop: Optional[list[str]] = None
|
|
max_tokens: Optional[int] = None
|
|
presence_penalty: Optional[float] = 0.0
|
|
frequency_penalty: Optional[float] = 0.0
|
|
user: Optional[str] = None
|
|
tools: Optional[list[dict]] = None
|
|
tool_choice: Optional[str | dict] = None
|
|
|
|
|
|
class ChatCompletionChoice(BaseModel):
|
|
"""OpenAI chat completion choice."""
|
|
index: int = 0
|
|
message: ChatMessage
|
|
finish_reason: str = "stop"
|
|
|
|
|
|
class ChatCompletionUsage(BaseModel):
|
|
"""Token usage statistics."""
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
total_tokens: int = 0
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""OpenAI chat completion response format."""
|
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}")
|
|
object: str = "chat.completion"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
model: str = config.MODEL_NAME
|
|
choices: list[ChatCompletionChoice]
|
|
usage: ChatCompletionUsage = ChatCompletionUsage()
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""OpenAI model info format."""
|
|
id: str
|
|
object: str = "model"
|
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|
owned_by: str = "organization"
|
|
|
|
|
|
class ModelList(BaseModel):
|
|
"""OpenAI model list format."""
|
|
object: str = "list"
|
|
data: list[ModelInfo]
|
|
|
|
|
|
# =============================================================================
|
|
# Application State
|
|
# =============================================================================
|
|
|
|
class AppState:
|
|
"""Global application state."""
|
|
rag_system: Optional[RAGSystem] = None
|
|
tool_manager: Optional[ToolManager] = None
|
|
zai_client: Any = None
|
|
startup_time: float = time.time()
|
|
|
|
|
|
state = AppState()
|
|
|
|
|
|
# =============================================================================
|
|
# Lifespan Management
|
|
# =============================================================================
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manage application lifespan - startup and shutdown."""
|
|
log.info("Starting DocRAG server...")
|
|
|
|
# Initialize RAG system
|
|
try:
|
|
log.info("Initializing RAG system...")
|
|
state.rag_system = await get_rag_system(
|
|
embedding_model=config.EMBEDDING_MODEL,
|
|
vector_store_path=config.VECTOR_STORE_PATH,
|
|
documents_path=config.DOCUMENTS_PATH,
|
|
chunk_size=config.CHUNK_SIZE,
|
|
chunk_overlap=config.CHUNK_OVERLAP,
|
|
)
|
|
log.info("RAG system initialized successfully")
|
|
except Exception as e:
|
|
log.warning(f"RAG system initialization deferred: {e}")
|
|
state.rag_system = None
|
|
|
|
# Initialize tool manager
|
|
try:
|
|
log.info("Initializing tool manager...")
|
|
state.tool_manager = get_tool_manager()
|
|
log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}")
|
|
except Exception as e:
|
|
log.warning(f"Tool manager initialization failed: {e}")
|
|
state.tool_manager = None
|
|
|
|
# Initialize ZAI client for upstream LLM
|
|
try:
|
|
if config.ZAI_API_KEY and ZAI is not None:
|
|
log.info("Initializing ZAI client...")
|
|
state.zai_client = ZAI(api_key=config.ZAI_API_KEY)
|
|
log.info("ZAI client initialized successfully")
|
|
else:
|
|
log.warning("No ZAI_API_KEY provided or SDK not installed - using mock responses")
|
|
state.zai_client = None
|
|
except Exception as e:
|
|
log.error(f"Failed to initialize ZAI client: {e}")
|
|
state.zai_client = None
|
|
|
|
log.info(f"DocRAG server started on {config.HOST}:{config.PORT}")
|
|
|
|
yield
|
|
|
|
# Cleanup
|
|
log.info("Shutting down DocRAG server...")
|
|
if state.rag_system:
|
|
await state.rag_system.close()
|
|
log.info("DocRAG server stopped")
|
|
|
|
|
|
# =============================================================================
|
|
# FastAPI Application
|
|
# =============================================================================
|
|
|
|
app = FastAPI(
|
|
title="DocRAG API",
|
|
description="OpenAI-compatible RAG server powered by GLM-4.7-Flash",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# CORS middleware for Open WebUI compatibility
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# OpenAI-Compatible Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/v1/models")
|
|
@app.get("/models")
|
|
async def list_models():
|
|
"""List available models (OpenAI-compatible)."""
|
|
return ModelList(
|
|
data=[
|
|
ModelInfo(id=config.MODEL_NAME, owned_by="docrag"),
|
|
ModelInfo(id="DocRAG-GLM-4.7", owned_by="docrag"),
|
|
]
|
|
)
|
|
|
|
|
|
@app.get("/v1/models/{model_id}")
|
|
@app.get("/models/{model_id}")
|
|
async def get_model(model_id: str):
|
|
"""Get model information (OpenAI-compatible)."""
|
|
if model_id not in [config.MODEL_NAME, "DocRAG-GLM-4.7"]:
|
|
raise HTTPException(status_code=404, detail="Model not found")
|
|
return ModelInfo(id=model_id, owned_by="docrag")
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
@app.post("/chat/completions")
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
"""Handle chat completions (OpenAI-compatible)."""
|
|
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
|
|
|
try:
|
|
if request.stream:
|
|
return StreamingResponse(
|
|
stream_chat_completion(request, request_id),
|
|
media_type="text/event-stream",
|
|
)
|
|
else:
|
|
return await complete_chat(request, request_id)
|
|
except Exception as e:
|
|
log.exception("Chat completion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# Chat Completion Logic
|
|
# =============================================================================
|
|
|
|
async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse:
|
|
"""Process a non-streaming chat completion request."""
|
|
messages = request.messages
|
|
|
|
# Extract the last user message
|
|
user_message = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_message = msg.content
|
|
break
|
|
|
|
if not user_message:
|
|
raise HTTPException(status_code=400, detail="No user message found")
|
|
|
|
# Step 1: RAG Retrieval
|
|
context = ""
|
|
sources = []
|
|
if state.rag_system:
|
|
try:
|
|
rag_result = await state.rag_system.query(
|
|
query=user_message,
|
|
top_k=config.TOP_K_RESULTS,
|
|
)
|
|
context = rag_result.get("context", "")
|
|
sources = rag_result.get("sources", [])
|
|
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
|
except Exception as e:
|
|
log.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 2: Build enhanced prompt with context
|
|
enhanced_messages = build_enhanced_messages(messages, context, sources)
|
|
|
|
# Step 3: Check for tool usage
|
|
tool_calls_made = []
|
|
if config.ENABLE_TOOLS and state.tool_manager:
|
|
tool_calls_made = await check_and_execute_tools(
|
|
user_message, enhanced_messages, request.tools
|
|
)
|
|
|
|
# Step 4: Generate response with upstream LLM
|
|
response_content = await generate_response(
|
|
enhanced_messages,
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
tools=request.tools,
|
|
)
|
|
|
|
# Step 5: Build and return response
|
|
return ChatCompletionResponse(
|
|
id=request_id,
|
|
model=config.MODEL_NAME,
|
|
choices=[
|
|
ChatCompletionChoice(
|
|
message=ChatMessage(
|
|
role="assistant",
|
|
content=response_content,
|
|
tool_calls=tool_calls_made if tool_calls_made else None,
|
|
),
|
|
finish_reason="stop",
|
|
)
|
|
],
|
|
usage=ChatCompletionUsage(
|
|
prompt_tokens=len(str(enhanced_messages)) // 4,
|
|
completion_tokens=len(response_content) // 4,
|
|
),
|
|
)
|
|
|
|
|
|
async def stream_chat_completion(
|
|
request: ChatCompletionRequest,
|
|
request_id: str,
|
|
) -> AsyncIterator[str]:
|
|
"""Stream a chat completion response."""
|
|
messages = request.messages
|
|
|
|
# Extract the last user message
|
|
user_message = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_message = msg.content
|
|
break
|
|
|
|
if not user_message:
|
|
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
|
|
return
|
|
|
|
# Step 1: RAG Retrieval
|
|
context = ""
|
|
sources = []
|
|
if state.rag_system:
|
|
try:
|
|
rag_result = await state.rag_system.query(
|
|
query=user_message,
|
|
top_k=config.TOP_K_RESULTS,
|
|
)
|
|
context = rag_result.get("context", "")
|
|
sources = rag_result.get("sources", [])
|
|
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
|
except Exception as e:
|
|
log.warning(f"RAG retrieval failed: {e}")
|
|
|
|
# Step 2: Build enhanced prompt with context
|
|
enhanced_messages = build_enhanced_messages(messages, context, sources)
|
|
|
|
# Step 3: Stream response from upstream LLM
|
|
created = int(time.time())
|
|
|
|
try:
|
|
if state.zai_client:
|
|
# Use actual GLM-4.7-Flash
|
|
response = state.zai_client.chat.completions.create(
|
|
model=config.UPSTREAM_MODEL,
|
|
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
|
|
temperature=request.temperature or 0.7,
|
|
max_tokens=request.max_tokens or 4096,
|
|
stream=True,
|
|
thinking={"type": "enabled"},
|
|
)
|
|
|
|
for chunk in response:
|
|
# Handle reasoning content (thinking)
|
|
if hasattr(chunk.choices[0].delta, 'reasoning_content') and chunk.choices[0].delta.reasoning_content:
|
|
# Don't expose thinking to user - this is internal RAG processing
|
|
log.debug(f"Thinking: {chunk.choices[0].delta.reasoning_content[:100]}...")
|
|
continue
|
|
|
|
# Stream actual content
|
|
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
|
content = chunk.choices[0].delta.content
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {'content': content},
|
|
'finish_reason': None
|
|
}]
|
|
})}\n\n"
|
|
|
|
# Send final chunk
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {},
|
|
'finish_reason': 'stop'
|
|
}]
|
|
})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
else:
|
|
# Mock streaming response for testing
|
|
mock_response = f"I understand you're asking about: {user_message}\n\n"
|
|
if context:
|
|
mock_response += f"Based on my knowledge base, I found relevant information that I'm using to help answer your question.\n\n"
|
|
mock_response += "However, I'm currently running in demo mode without an upstream LLM connection. Please configure ZAI_API_KEY for full functionality."
|
|
|
|
for char in mock_response:
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {'content': char},
|
|
'finish_reason': None
|
|
}]
|
|
})}\n\n"
|
|
await asyncio.sleep(0.01)
|
|
|
|
yield f"data: {json.dumps({
|
|
'id': request_id,
|
|
'object': 'chat.completion.chunk',
|
|
'created': created,
|
|
'model': config.MODEL_NAME,
|
|
'choices': [{
|
|
'index': 0,
|
|
'delta': {},
|
|
'finish_reason': 'stop'
|
|
}]
|
|
})}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
except Exception as e:
|
|
log.exception("Streaming failed")
|
|
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
|
|
|
|
|
def build_enhanced_messages(
|
|
messages: list[ChatMessage],
|
|
context: str,
|
|
sources: list[str],
|
|
) -> list[ChatMessage]:
|
|
"""Build enhanced messages with RAG context."""
|
|
enhanced = []
|
|
|
|
# Add system message with RAG context
|
|
system_content = (
|
|
"You are a helpful AI assistant with access to a knowledge base. "
|
|
"Use the provided context to give accurate and helpful responses. "
|
|
"If the context doesn't contain relevant information, use your general knowledge "
|
|
"but indicate when you're doing so."
|
|
)
|
|
|
|
if context:
|
|
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
|
|
if sources:
|
|
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources)
|
|
|
|
enhanced.append(ChatMessage(role="system", content=system_content))
|
|
|
|
# Add conversation history (excluding old system messages)
|
|
for msg in messages:
|
|
if msg.role != "system":
|
|
enhanced.append(msg)
|
|
|
|
return enhanced
|
|
|
|
|
|
async def generate_response(
|
|
messages: list[ChatMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 4096,
|
|
tools: Optional[list[dict]] = None,
|
|
) -> str:
|
|
"""Generate response using upstream LLM."""
|
|
if state.zai_client:
|
|
try:
|
|
response = state.zai_client.chat.completions.create(
|
|
model=config.UPSTREAM_MODEL,
|
|
messages=[{"role": m.role, "content": m.content} for m in messages if m.content],
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
stream=False,
|
|
thinking={"type": "enabled"},
|
|
)
|
|
|
|
# Extract content from response
|
|
content = ""
|
|
for chunk in response:
|
|
if hasattr(chunk.choices[0], 'message') and chunk.choices[0].message:
|
|
content = chunk.choices[0].message.content or ""
|
|
break
|
|
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
|
content += chunk.choices[0].delta.content
|
|
|
|
return content or "I apologize, but I couldn't generate a response."
|
|
|
|
except Exception as e:
|
|
log.error(f"Upstream LLM call failed: {e}")
|
|
return f"I encountered an error: {str(e)}"
|
|
|
|
else:
|
|
# Mock response for testing
|
|
user_msg = ""
|
|
for msg in reversed(messages):
|
|
if msg.role == "user" and msg.content:
|
|
user_msg = msg.content
|
|
break
|
|
return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality."
|
|
|
|
|
|
async def check_and_execute_tools(
|
|
user_message: str,
|
|
messages: list[ChatMessage],
|
|
available_tools: Optional[list[dict]],
|
|
) -> list[dict]:
|
|
"""Check if tools should be used and execute them."""
|
|
if not state.tool_manager or not available_tools:
|
|
return []
|
|
|
|
tool_calls = []
|
|
|
|
# Simple keyword-based tool detection
|
|
# In a production system, you'd use the LLM to decide tool usage
|
|
message_lower = user_message.lower()
|
|
|
|
# Check for website download intent - use RAG system for full integration
|
|
if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site", "ingest site"]):
|
|
# Extract URL from message
|
|
import re
|
|
url_pattern = r'https?://[^\s]+'
|
|
urls = re.findall(url_pattern, user_message)
|
|
|
|
if urls and state.rag_system:
|
|
# Use RAG system's integrated website downloader
|
|
try:
|
|
result = await state.rag_system.download_and_ingest_website(
|
|
url=urls[0],
|
|
max_pages=20, # Reasonable default
|
|
)
|
|
|
|
tool_calls.append({
|
|
"id": f"call_{uuid.uuid4().hex[:24]}",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "website_downloader",
|
|
"arguments": json.dumps({
|
|
"url": urls[0],
|
|
"success": result.get("success"),
|
|
"chunks_ingested": result.get("total_chunks", 0),
|
|
}),
|
|
}
|
|
})
|
|
log.info(f"Downloaded and ingested website: {urls[0]} -> {result.get('total_chunks', 0)} chunks")
|
|
except Exception as e:
|
|
log.error(f"Website download failed: {e}")
|
|
|
|
return tool_calls
|
|
|
|
|
|
# =============================================================================
|
|
# Document Management Endpoints
|
|
# =============================================================================
|
|
|
|
@app.post("/v1/documents/upload")
|
|
async def upload_document(request: Request):
|
|
"""Upload a document to the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
form = await request.form()
|
|
file = form.get("file")
|
|
if not file:
|
|
raise HTTPException(status_code=400, detail="No file provided")
|
|
|
|
content = await file.read()
|
|
filename = file.filename or "unknown"
|
|
|
|
# Process and store document
|
|
result = await state.rag_system.add_document(
|
|
content=content,
|
|
filename=filename,
|
|
)
|
|
|
|
return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)}
|
|
|
|
except Exception as e:
|
|
log.exception("Document upload failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
class WebsiteDownloadRequest(BaseModel):
|
|
"""Request model for website download."""
|
|
url: str
|
|
max_pages: int = 50
|
|
threads: int = 6
|
|
download_external_assets: bool = False
|
|
external_domains: Optional[list[str]] = None
|
|
|
|
|
|
@app.post("/v1/documents/website")
|
|
async def download_website(request: WebsiteDownloadRequest):
|
|
"""
|
|
Download a website and ingest it into the knowledge base.
|
|
|
|
This is the PRIMARY way to add content to the RAG system.
|
|
Uses the website_downloader_tool to download and process websites.
|
|
"""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
log.info(f"Downloading website: {request.url}")
|
|
|
|
result = await state.rag_system.download_and_ingest_website(
|
|
url=request.url,
|
|
max_pages=request.max_pages,
|
|
threads=request.threads,
|
|
download_external_assets=request.download_external_assets,
|
|
external_domains=request.external_domains,
|
|
)
|
|
|
|
if result.get("success"):
|
|
return {
|
|
"success": True,
|
|
"message": f"Website downloaded and ingested: {request.url}",
|
|
"url": request.url,
|
|
"local_path": result.get("local_path"),
|
|
"pages_processed": result.get("pages_processed", 0),
|
|
"total_chunks": result.get("total_chunks", 0),
|
|
}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=result.get("message", "Website download failed")
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Website download failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/v1/documents/url")
|
|
async def add_document_from_url(request: dict):
|
|
"""
|
|
Add a document from URL to the knowledge base.
|
|
|
|
NOTE: For websites, prefer using /v1/documents/website instead
|
|
as it downloads the entire site and provides better context.
|
|
"""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
url = request.get("url")
|
|
if not url:
|
|
raise HTTPException(status_code=400, detail="No URL provided")
|
|
|
|
try:
|
|
result = await state.rag_system.add_document_from_url(url)
|
|
return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)}
|
|
except Exception as e:
|
|
log.exception("URL document addition failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents")
|
|
async def list_documents():
|
|
"""List documents in the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
docs = await state.rag_system.list_documents()
|
|
return {"documents": docs}
|
|
except Exception as e:
|
|
log.exception("Document listing failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents/sites")
|
|
async def list_downloaded_sites():
|
|
"""List all downloaded websites in the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
sites = await state.rag_system.list_downloaded_sites()
|
|
return {"sites": sites}
|
|
except Exception as e:
|
|
log.exception("Site listing failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.get("/v1/documents/sites/{url:path}")
|
|
async def get_site_info(url: str):
|
|
"""Get information about a specific downloaded site."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
# URL will be passed as path parameter, need to decode
|
|
from urllib.parse import unquote
|
|
decoded_url = unquote(url)
|
|
|
|
# Add scheme if missing
|
|
if not decoded_url.startswith(("http://", "https://")):
|
|
decoded_url = "https://" + decoded_url
|
|
|
|
site_info = state.rag_system.get_site_info(decoded_url)
|
|
if site_info:
|
|
return {"site": site_info}
|
|
else:
|
|
raise HTTPException(status_code=404, detail="Site not found")
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Site info retrieval failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/documents/{doc_id}")
|
|
async def delete_document(doc_id: str):
|
|
"""Delete a document from the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
await state.rag_system.delete_document(doc_id)
|
|
return {"success": True, "message": f"Document {doc_id} deleted"}
|
|
except Exception as e:
|
|
log.exception("Document deletion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.delete("/v1/documents/sites/{url:path}")
|
|
async def delete_site(url: str):
|
|
"""Delete a downloaded website and all its content from the knowledge base."""
|
|
if not state.rag_system:
|
|
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
|
|
|
try:
|
|
# URL will be passed as path parameter, need to decode
|
|
from urllib.parse import unquote
|
|
decoded_url = unquote(url)
|
|
|
|
# Add scheme if missing
|
|
if not decoded_url.startswith(("http://", "https://")):
|
|
decoded_url = "https://" + decoded_url
|
|
|
|
result = await state.rag_system.delete_site(decoded_url)
|
|
|
|
if result.get("success"):
|
|
return {
|
|
"success": True,
|
|
"message": f"Site {decoded_url} deleted",
|
|
"deleted_chunks": result.get("deleted_chunks", 0),
|
|
"deleted_path": result.get("deleted_path"),
|
|
}
|
|
else:
|
|
raise HTTPException(status_code=404, detail=result.get("message", "Site not found"))
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
log.exception("Site deletion failed")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# =============================================================================
|
|
# Health and Status Endpoints
|
|
# =============================================================================
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint."""
|
|
return {
|
|
"status": "healthy",
|
|
"uptime": time.time() - state.startup_time,
|
|
"rag_enabled": state.rag_system is not None,
|
|
"tools_enabled": state.tool_manager is not None,
|
|
"llm_connected": state.zai_client is not None,
|
|
}
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint with API info."""
|
|
return {
|
|
"name": "DocRAG API",
|
|
"version": "1.0.0",
|
|
"description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash",
|
|
"endpoints": {
|
|
"chat": "/v1/chat/completions",
|
|
"models": "/v1/models",
|
|
"documents": "/v1/documents",
|
|
"download_website": "/v1/documents/website",
|
|
"list_sites": "/v1/documents/sites",
|
|
"health": "/health",
|
|
},
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Main Entry Point
|
|
# =============================================================================
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"main:app",
|
|
host=config.HOST,
|
|
port=config.PORT,
|
|
reload=config.DEBUG,
|
|
)
|