#!/usr/bin/env python3 """ DocRAG - OpenAI-Compatible RAG Server This application presents itself as a standard OpenAI API server that can be used with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it: 1. Processes user queries through a RAG system 2. Retrieves relevant context from a knowledge base 3. Passes the enriched context to GLM-4.7-Flash for response generation 4. Optionally uses tools like website_downloader for enhanced capabilities The user sees a normal chat experience, but the system is actually doing sophisticated RAG operations in the background. """ from __future__ import annotations import asyncio import json import logging import os import sys import time import uuid from contextlib import asynccontextmanager from typing import Any, AsyncIterator, Optional # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", datefmt="%H:%M:%S", ) log = logging.getLogger(__name__) # FastAPI imports from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel, Field # Import RAG components from rag import RAGSystem, get_rag_system from rag.document_processor import DocumentProcessor # Import tools from tools import ToolManager, get_tool_manager # Import SDK for GLM-4.7-Flash try: from zai import ZaiClient as ZAI except ImportError: ZAI = None log.warning("z-ai-web-dev-sdk not installed. Install with: pip install z-ai-web-dev-sdk") # ============================================================================= # Configuration # ============================================================================= class Config: """Application configuration from environment variables.""" # Server settings HOST: str = os.getenv("HOST", "0.0.0.0") PORT: int = int(os.getenv("PORT", "8000")) DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true" # Model settings MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG-GLM-4.7") UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "glm-4.7") # API Key for upstream LLM ZAI_API_KEY: str = os.getenv("ZAI_API_KEY", "") # RAG settings EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors") DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents") CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000")) CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200")) TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5")) # Tool settings ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true" MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "3")) config = Config() # ============================================================================= # OpenAI-Compatible Models # ============================================================================= class ChatMessage(BaseModel): """OpenAI chat message format.""" role: str content: Optional[str] = None name: Optional[str] = None tool_calls: Optional[list[dict]] = None tool_call_id: Optional[str] = None class ChatCompletionRequest(BaseModel): """OpenAI chat completion request format.""" model: str = "DocRAG-GLM-4.7" messages: list[ChatMessage] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 n: Optional[int] = 1 stream: Optional[bool] = False stop: Optional[list[str]] = None max_tokens: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 user: Optional[str] = None tools: Optional[list[dict]] = None tool_choice: Optional[str | dict] = None class ChatCompletionChoice(BaseModel): """OpenAI chat completion choice.""" index: int = 0 message: ChatMessage finish_reason: str = "stop" class ChatCompletionUsage(BaseModel): """Token usage statistics.""" prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 class ChatCompletionResponse(BaseModel): """OpenAI chat completion response format.""" id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}") object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str = config.MODEL_NAME choices: list[ChatCompletionChoice] usage: ChatCompletionUsage = ChatCompletionUsage() class ModelInfo(BaseModel): """OpenAI model info format.""" id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "organization" class ModelList(BaseModel): """OpenAI model list format.""" object: str = "list" data: list[ModelInfo] # ============================================================================= # Application State # ============================================================================= class AppState: """Global application state.""" rag_system: Optional[RAGSystem] = None tool_manager: Optional[ToolManager] = None zai_client: Any = None startup_time: float = time.time() state = AppState() # ============================================================================= # Lifespan Management # ============================================================================= @asynccontextmanager async def lifespan(app: FastAPI): """Manage application lifespan - startup and shutdown.""" log.info("Starting DocRAG server...") # Initialize RAG system try: log.info("Initializing RAG system...") state.rag_system = await get_rag_system( embedding_model=config.EMBEDDING_MODEL, vector_store_path=config.VECTOR_STORE_PATH, documents_path=config.DOCUMENTS_PATH, chunk_size=config.CHUNK_SIZE, chunk_overlap=config.CHUNK_OVERLAP, ) log.info("RAG system initialized successfully") except Exception as e: log.warning(f"RAG system initialization deferred: {e}") state.rag_system = None # Initialize tool manager try: log.info("Initializing tool manager...") state.tool_manager = get_tool_manager() log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}") except Exception as e: log.warning(f"Tool manager initialization failed: {e}") state.tool_manager = None # Initialize ZAI client for upstream LLM try: if config.ZAI_API_KEY and ZAI is not None: log.info("Initializing ZAI client...") state.zai_client = ZAI(api_key=config.ZAI_API_KEY) log.info("ZAI client initialized successfully") else: log.warning("No ZAI_API_KEY provided or SDK not installed - using mock responses") state.zai_client = None except Exception as e: log.error(f"Failed to initialize ZAI client: {e}") state.zai_client = None log.info(f"DocRAG server started on {config.HOST}:{config.PORT}") yield # Cleanup log.info("Shutting down DocRAG server...") if state.rag_system: await state.rag_system.close() log.info("DocRAG server stopped") # ============================================================================= # FastAPI Application # ============================================================================= app = FastAPI( title="DocRAG API", description="OpenAI-compatible RAG server powered by GLM-4.7-Flash", version="1.0.0", lifespan=lifespan, ) # CORS middleware for Open WebUI compatibility app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================================================================= # OpenAI-Compatible Endpoints # ============================================================================= @app.get("/v1/models") @app.get("/models") async def list_models(): """List available models (OpenAI-compatible).""" return ModelList( data=[ ModelInfo(id=config.MODEL_NAME, owned_by="docrag"), ModelInfo(id="DocRAG-GLM-4.7", owned_by="docrag"), ] ) @app.get("/v1/models/{model_id}") @app.get("/models/{model_id}") async def get_model(model_id: str): """Get model information (OpenAI-compatible).""" if model_id not in [config.MODEL_NAME, "DocRAG-GLM-4.7"]: raise HTTPException(status_code=404, detail="Model not found") return ModelInfo(id=model_id, owned_by="docrag") @app.post("/v1/chat/completions") @app.post("/chat/completions") async def chat_completions(request: ChatCompletionRequest): """Handle chat completions (OpenAI-compatible).""" request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" try: if request.stream: return StreamingResponse( stream_chat_completion(request, request_id), media_type="text/event-stream", ) else: return await complete_chat(request, request_id) except Exception as e: log.exception("Chat completion failed") raise HTTPException(status_code=500, detail=str(e)) # ============================================================================= # Chat Completion Logic # ============================================================================= async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse: """Process a non-streaming chat completion request.""" messages = request.messages # Extract the last user message user_message = "" for msg in reversed(messages): if msg.role == "user" and msg.content: user_message = msg.content break if not user_message: raise HTTPException(status_code=400, detail="No user message found") # Step 1: RAG Retrieval context = "" sources = [] if state.rag_system: try: rag_result = await state.rag_system.query( query=user_message, top_k=config.TOP_K_RESULTS, ) context = rag_result.get("context", "") sources = rag_result.get("sources", []) log.info(f"RAG retrieved {len(sources)} relevant documents") except Exception as e: log.warning(f"RAG retrieval failed: {e}") # Step 2: Build enhanced prompt with context enhanced_messages = build_enhanced_messages(messages, context, sources) # Step 3: Check for tool usage tool_calls_made = [] if config.ENABLE_TOOLS and state.tool_manager: tool_calls_made = await check_and_execute_tools( user_message, enhanced_messages, request.tools ) # Step 4: Generate response with upstream LLM response_content = await generate_response( enhanced_messages, temperature=request.temperature, max_tokens=request.max_tokens, tools=request.tools, ) # Step 5: Build and return response return ChatCompletionResponse( id=request_id, model=config.MODEL_NAME, choices=[ ChatCompletionChoice( message=ChatMessage( role="assistant", content=response_content, tool_calls=tool_calls_made if tool_calls_made else None, ), finish_reason="stop", ) ], usage=ChatCompletionUsage( prompt_tokens=len(str(enhanced_messages)) // 4, completion_tokens=len(response_content) // 4, ), ) async def stream_chat_completion( request: ChatCompletionRequest, request_id: str, ) -> AsyncIterator[str]: """Stream a chat completion response.""" messages = request.messages # Extract the last user message user_message = "" for msg in reversed(messages): if msg.role == "user" and msg.content: user_message = msg.content break if not user_message: yield f"data: {json.dumps({'error': 'No user message found'})}\n\n" return # Step 1: RAG Retrieval context = "" sources = [] if state.rag_system: try: rag_result = await state.rag_system.query( query=user_message, top_k=config.TOP_K_RESULTS, ) context = rag_result.get("context", "") sources = rag_result.get("sources", []) log.info(f"RAG retrieved {len(sources)} relevant documents") except Exception as e: log.warning(f"RAG retrieval failed: {e}") # Step 2: Build enhanced prompt with context enhanced_messages = build_enhanced_messages(messages, context, sources) # Step 3: Stream response from upstream LLM created = int(time.time()) try: if state.zai_client: # Use actual GLM-4.7-Flash response = state.zai_client.chat.completions.create( model=config.UPSTREAM_MODEL, messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content], temperature=request.temperature or 0.7, max_tokens=request.max_tokens or 4096, stream=True, thinking={"type": "enabled"}, ) for chunk in response: # Handle reasoning content (thinking) if hasattr(chunk.choices[0].delta, 'reasoning_content') and chunk.choices[0].delta.reasoning_content: # Don't expose thinking to user - this is internal RAG processing log.debug(f"Thinking: {chunk.choices[0].delta.reasoning_content[:100]}...") continue # Stream actual content if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content yield f"data: {json.dumps({ 'id': request_id, 'object': 'chat.completion.chunk', 'created': created, 'model': config.MODEL_NAME, 'choices': [{ 'index': 0, 'delta': {'content': content}, 'finish_reason': None }] })}\n\n" # Send final chunk yield f"data: {json.dumps({ 'id': request_id, 'object': 'chat.completion.chunk', 'created': created, 'model': config.MODEL_NAME, 'choices': [{ 'index': 0, 'delta': {}, 'finish_reason': 'stop' }] })}\n\n" yield "data: [DONE]\n\n" else: # Mock streaming response for testing mock_response = f"I understand you're asking about: {user_message}\n\n" if context: mock_response += f"Based on my knowledge base, I found relevant information that I'm using to help answer your question.\n\n" mock_response += "However, I'm currently running in demo mode without an upstream LLM connection. Please configure ZAI_API_KEY for full functionality." for char in mock_response: yield f"data: {json.dumps({ 'id': request_id, 'object': 'chat.completion.chunk', 'created': created, 'model': config.MODEL_NAME, 'choices': [{ 'index': 0, 'delta': {'content': char}, 'finish_reason': None }] })}\n\n" await asyncio.sleep(0.01) yield f"data: {json.dumps({ 'id': request_id, 'object': 'chat.completion.chunk', 'created': created, 'model': config.MODEL_NAME, 'choices': [{ 'index': 0, 'delta': {}, 'finish_reason': 'stop' }] })}\n\n" yield "data: [DONE]\n\n" except Exception as e: log.exception("Streaming failed") yield f"data: {json.dumps({'error': str(e)})}\n\n" def build_enhanced_messages( messages: list[ChatMessage], context: str, sources: list[str], ) -> list[ChatMessage]: """Build enhanced messages with RAG context.""" enhanced = [] # Add system message with RAG context system_content = ( "You are a helpful AI assistant with access to a knowledge base. " "Use the provided context to give accurate and helpful responses. " "If the context doesn't contain relevant information, use your general knowledge " "but indicate when you're doing so." ) if context: system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n" if sources: system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources) enhanced.append(ChatMessage(role="system", content=system_content)) # Add conversation history (excluding old system messages) for msg in messages: if msg.role != "system": enhanced.append(msg) return enhanced async def generate_response( messages: list[ChatMessage], temperature: float = 0.7, max_tokens: int = 4096, tools: Optional[list[dict]] = None, ) -> str: """Generate response using upstream LLM.""" if state.zai_client: try: response = state.zai_client.chat.completions.create( model=config.UPSTREAM_MODEL, messages=[{"role": m.role, "content": m.content} for m in messages if m.content], temperature=temperature, max_tokens=max_tokens, stream=False, thinking={"type": "enabled"}, ) # Extract content from response content = "" for chunk in response: if hasattr(chunk.choices[0], 'message') and chunk.choices[0].message: content = chunk.choices[0].message.content or "" break if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: content += chunk.choices[0].delta.content return content or "I apologize, but I couldn't generate a response." except Exception as e: log.error(f"Upstream LLM call failed: {e}") return f"I encountered an error: {str(e)}" else: # Mock response for testing user_msg = "" for msg in reversed(messages): if msg.role == "user" and msg.content: user_msg = msg.content break return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality." async def check_and_execute_tools( user_message: str, messages: list[ChatMessage], available_tools: Optional[list[dict]], ) -> list[dict]: """Check if tools should be used and execute them.""" if not state.tool_manager or not available_tools: return [] tool_calls = [] # Simple keyword-based tool detection # In a production system, you'd use the LLM to decide tool usage message_lower = user_message.lower() # Check for website download intent if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site"]): # Extract URL from message import re url_pattern = r'https?://[^\s]+' urls = re.findall(url_pattern, user_message) if urls: tool_result = state.tool_manager.execute_tool( "website_downloader", {"url": urls[0], "max_pages": 10} ) tool_calls.append({ "id": f"call_{uuid.uuid4().hex[:24]}", "type": "function", "function": { "name": "website_downloader", "arguments": json.dumps({"url": urls[0]}), } }) log.info(f"Executed website_downloader tool: {tool_result}") return tool_calls # ============================================================================= # Document Management Endpoints # ============================================================================= @app.post("/v1/documents/upload") async def upload_document(request: Request): """Upload a document to the knowledge base.""" if not state.rag_system: raise HTTPException(status_code=503, detail="RAG system not initialized") try: form = await request.form() file = form.get("file") if not file: raise HTTPException(status_code=400, detail="No file provided") content = await file.read() filename = file.filename or "unknown" # Process and store document result = await state.rag_system.add_document( content=content, filename=filename, ) return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)} except Exception as e: log.exception("Document upload failed") raise HTTPException(status_code=500, detail=str(e)) @app.post("/v1/documents/url") async def add_document_from_url(request: dict): """Add a document from URL to the knowledge base.""" if not state.rag_system: raise HTTPException(status_code=503, detail="RAG system not initialized") url = request.get("url") if not url: raise HTTPException(status_code=400, detail="No URL provided") try: result = await state.rag_system.add_document_from_url(url) return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)} except Exception as e: log.exception("URL document addition failed") raise HTTPException(status_code=500, detail=str(e)) @app.get("/v1/documents") async def list_documents(): """List documents in the knowledge base.""" if not state.rag_system: raise HTTPException(status_code=503, detail="RAG system not initialized") try: docs = await state.rag_system.list_documents() return {"documents": docs} except Exception as e: log.exception("Document listing failed") raise HTTPException(status_code=500, detail=str(e)) @app.delete("/v1/documents/{doc_id}") async def delete_document(doc_id: str): """Delete a document from the knowledge base.""" if not state.rag_system: raise HTTPException(status_code=503, detail="RAG system not initialized") try: await state.rag_system.delete_document(doc_id) return {"success": True, "message": f"Document {doc_id} deleted"} except Exception as e: log.exception("Document deletion failed") raise HTTPException(status_code=500, detail=str(e)) # ============================================================================= # Health and Status Endpoints # ============================================================================= @app.get("/health") async def health_check(): """Health check endpoint.""" return { "status": "healthy", "uptime": time.time() - state.startup_time, "rag_enabled": state.rag_system is not None, "tools_enabled": state.tool_manager is not None, "llm_connected": state.zai_client is not None, } @app.get("/") async def root(): """Root endpoint with API info.""" return { "name": "DocRAG API", "version": "1.0.0", "description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash", "endpoints": { "chat": "/v1/chat/completions", "models": "/v1/models", "documents": "/v1/documents", "health": "/health", }, } # ============================================================================= # Main Entry Point # ============================================================================= if __name__ == "__main__": import uvicorn uvicorn.run( "main:app", host=config.HOST, port=config.PORT, reload=config.DEBUG, )