Implement full DocRAG server with OpenAI-compatible API

Features:
- FastAPI server with OpenAI-compatible endpoints (/v1/chat/completions, /v1/models)
- RAG system with document processing and vector storage
- Support for multiple document formats (PDF, DOCX, HTML, text, code)
- Streaming response support
- Tool integration with website_downloader
- Document management API endpoints
- GLM-4.7-Flash integration via z-ai-web-dev-sdk
- Works transparently with Open WebUI and other OpenAI clients

Components:
- main.py: FastAPI application with OpenAI-compatible API
- rag/: RAG system (document processor, vector store, retriever)
- tools/: Tool manager with website_downloader integration
- .env.example: Configuration template
This commit is contained in:
Z User 2026-03-29 00:57:37 +00:00
parent e3681949e2
commit eabdadfb62
9 changed files with 2150 additions and 142 deletions

26
.env.example Normal file
View File

@ -0,0 +1,26 @@
# DocRAG Configuration
# Copy this file to .env and fill in your values
# Server Configuration
HOST=0.0.0.0
PORT=8000
DEBUG=false
# Model Configuration
MODEL_NAME=DocRAG-GLM-4.7
UPSTREAM_MODEL=glm-4.7
# API Keys
ZAI_API_KEY=your-zai-api-key-here
# RAG Configuration
EMBEDDING_MODEL=text-embedding-3-small
VECTOR_STORE_PATH=./data/vectors
DOCUMENTS_PATH=./data/documents
CHUNK_SIZE=1000
CHUNK_OVERLAP=200
TOP_K_RESULTS=5
# Tool Configuration
ENABLE_TOOLS=true
MAX_TOOL_ITERATIONS=3

386
README.md
View File

@ -1,163 +1,275 @@
# DocRAG - Custom RAG with Document Loader
# DocRAG - OpenAI-Compatible RAG Server
A custom RAG (Retrieval-Augmented Generation) system with a custom document loader that acts as a local OpenAI-compatible server using a remote LLM with custom tools.
A custom RAG (Retrieval-Augmented Generation) system that **appears as a standard OpenAI API server** to clients like Open WebUI. Behind the scenes, it:
## Components
1. Processes user queries through a RAG system
2. Retrieves relevant context from a knowledge base
3. Passes the enriched context to GLM-4.7-Flash for response generation
4. Optionally uses tools like website_downloader for enhanced capabilities
### Website Downloader Tool
Users interact with what appears to be a normal chat experience, while sophisticated RAG operations happen transparently in the background.
The `website_downloader_tool.py` provides a tool interface for downloading and mirroring websites for offline use or RAG ingestion. It can be used by GLM-4.7-Flash via the z-ai-web-dev-sdk.
## Features
#### Features
- **OpenAI-Compatible API**: Works with any OpenAI client (Open WebUI, custom apps, etc.)
- **RAG Integration**: Automatic context retrieval for enhanced responses
- **Document Management**: Upload and manage documents in the knowledge base
- **Tool Support**: Built-in tools like website_downloader for extended capabilities
- **Streaming Support**: Real-time streaming responses
- **Easy Configuration**: Environment-based configuration
- Downloads HTML pages and all linked assets (CSS, JS, images, fonts, etc.)
- Rewrites links for offline viewing
- Supports concurrent downloads with configurable thread count
- Optional external asset downloading from CDNs
- Domain whitelisting for external assets
- Comprehensive error handling and statistics
## Quick Start
#### Tool Schema
The tool follows the OpenAI function calling format:
```python
from website_downloader_tool import get_tool_schema, website_downloader
# Get the tool schema for registration
schema = get_tool_schema()
```
#### Usage with GLM-4.7-Flash
```python
from zai import ZaiClient
from website_downloader_tool import get_tool_schema, website_downloader
client = ZaiClient(api_key="your-api-key")
# Define the tool
tools = [get_tool_schema()]
# Create a chat completion with tools
response = client.chat.completions.create(
model="glm-4.7",
messages=[
{
"role": "user",
"content": "Please download https://example.com for offline use"
}
],
tools=tools,
stream=True,
)
# Handle tool calls in the response
for chunk in response:
if chunk.choices[0].delta.tool_calls:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.function.name == "website_downloader":
import json
args = json.loads(tool_call.function.arguments)
result = website_downloader(**args)
print(result)
```
#### Direct Usage
```python
from website_downloader_tool import website_downloader
# Download a website
result = website_downloader(
url="https://example.com",
destination="./downloaded_site", # Optional
max_pages=50, # Max pages to crawl
threads=6, # Concurrent downloads
download_external_assets=False, # Include CDN assets
external_domains=["cdn.example.com"] # Whitelist external domains
)
if result["success"]:
print(f"Downloaded to: {result['output_directory']}")
print(f"Pages: {result['stats']['pages_crawled']}")
print(f"Assets: {result['stats']['assets_downloaded']}")
else:
print(f"Error: {result['message']}")
```
#### Parameters
| Parameter | Type | Required | Default | Description |
|-----------|------|----------|---------|-------------|
| `url` | string | Yes | - | Starting URL to crawl |
| `destination` | string | No | Derived from URL | Output folder path |
| `max_pages` | integer | No | 50 | Max HTML pages (1-1000) |
| `threads` | integer | No | 6 | Concurrent download threads (1-20) |
| `download_external_assets` | boolean | No | False | Download CDN assets |
| `external_domains` | array | No | None | Whitelist of external domains |
#### Return Value
```python
{
"success": True/False,
"message": "Human-readable summary",
"stats": {
"pages_crawled": int,
"assets_downloaded": int,
"failed_downloads": int,
"elapsed_seconds": float,
"output_directory": str,
"pages": [...], # List of downloaded pages
"downloaded_items": [...] # List of downloaded assets
},
"output_directory": "/path/to/downloaded/site"
}
```
### Website Downloader CLI
The original `website-downloader.py` can still be used as a standalone CLI tool:
```bash
python website-downloader.py --url https://example.com --max-pages 50 --threads 6
```
#### CLI Options
- `--url`: Starting URL to crawl (required)
- `--destination`: Output folder (optional, derived from URL if not provided)
- `--max-pages`: Maximum pages to crawl (default: 50)
- `--threads`: Number of download threads (default: 6)
- `--download-external-assets`: Enable external asset downloading
- `--external-domains`: Whitelist of external domains to download from
## Installation
### 1. Install Dependencies
```bash
pip install -r requirements.txt
```
### 2. Configure Environment
```bash
cp .env.example .env
# Edit .env and add your ZAI_API_KEY
```
### 3. Run the Server
```bash
python main.py
```
The server will start on `http://0.0.0.0:8000`
### 4. Use with Open WebUI
1. Open Open WebUI settings
2. Add a new OpenAI-compatible connection
3. Set the base URL to `http://your-server:8000/v1`
4. Leave the API key empty or use any value (not validated)
5. Select the "DocRAG-GLM-4.7" model
## API Endpoints
### OpenAI-Compatible Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/v1/chat/completions` | POST | Chat completions (streaming supported) |
| `/v1/models` | GET | List available models |
| `/v1/models/{model_id}` | GET | Get model information |
### Document Management Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/v1/documents` | GET | List documents in knowledge base |
| `/v1/documents/upload` | POST | Upload a document |
| `/v1/documents/url` | POST | Add document from URL |
| `/v1/documents/{doc_id}` | DELETE | Delete a document |
### Health & Status
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/health` | GET | Health check |
| `/` | GET | API information |
## Usage Examples
### Chat Completion
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "DocRAG-GLM-4.7",
"messages": [
{"role": "user", "content": "What is machine learning?"}
],
"stream": false
}'
```
### Upload Document
```bash
curl -X POST http://localhost:8000/v1/documents/upload \
-F "file=@document.pdf"
```
### Add Document from URL
```bash
curl -X POST http://localhost:8000/v1/documents/url \
-H "Content-Type: application/json" \
-d '{"url": "https://example.com/article.html"}'
```
### Python Client
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="not-needed" # API key not validated
)
response = client.chat.completions.create(
model="DocRAG-GLM-4.7",
messages=[
{"role": "user", "content": "Explain quantum computing"}
],
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content:
print(chunk.choices[0].delta.content, end="")
```
## Configuration
Configure via environment variables or `.env` file:
| Variable | Default | Description |
|----------|---------|-------------|
| `HOST` | `0.0.0.0` | Server host |
| `PORT` | `8000` | Server port |
| `DEBUG` | `false` | Enable debug mode |
| `MODEL_NAME` | `DocRAG-GLM-4.7` | Display model name |
| `UPSTREAM_MODEL` | `glm-4.7` | Upstream model to use |
| `ZAI_API_KEY` | (required) | API key for ZAI SDK |
| `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model |
| `VECTOR_STORE_PATH` | `./data/vectors` | Vector store location |
| `DOCUMENTS_PATH` | `./data/documents` | Document storage |
| `CHUNK_SIZE` | `1000` | Document chunk size |
| `CHUNK_OVERLAP` | `200` | Chunk overlap |
| `TOP_K_RESULTS` | `5` | Number of context results |
| `ENABLE_TOOLS` | `true` | Enable tool support |
## Project Structure
```
docrag/
├── website-downloader.py # Core website downloader (CLI)
├── main.py # FastAPI application entry point
├── rag/
│ ├── __init__.py # RAG system main class
│ ├── document_processor.py # Document parsing and chunking
│ ├── vector_store.py # Vector storage and search
│ └── retriever.py # Context retrieval logic
├── tools/
│ └── __init__.py # Tool management (website_downloader, etc.)
├── website-downloader.py # CLI website downloader
├── website_downloader_tool.py # Tool wrapper for GLM-4.7-Flash
├── requirements.txt # Python dependencies
├── .env.example # Configuration template
└── README.md # This file
```
## Integration with RAG
## How It Works
The downloaded website content can be processed for RAG systems:
### Request Flow
1. Use the tool to download website content
2. Parse the downloaded HTML files
3. Extract text content and metadata
4. Chunk and embed the content
5. Store in your vector database
1. **User sends message** → OpenAI-compatible endpoint receives request
2. **RAG Retrieval** → Query is processed and relevant context is retrieved
3. **Context Enhancement** → Retrieved context is added to the prompt
4. **Tool Execution** → If needed, tools are invoked (e.g., website_downloader)
5. **LLM Generation** → GLM-4.7-Flash generates response with context
6. **Response** → User receives response (streaming supported)
### RAG Pipeline
```
User Query
┌─────────────────┐
│ Query Processor │
└────────┬────────┘
┌─────────────────┐
│ Vector Search │ ← Knowledge Base
└────────┬────────┘
┌─────────────────┐
│ Context Builder │
└────────┬────────┘
┌─────────────────┐
│ GLM-4.7-Flash │
└────────┬────────┘
Response
```
## Supported Document Formats
- **Text**: `.txt`, `.md`, `.rst`, `.log`
- **Documents**: `.pdf`, `.docx`
- **Web**: `.html`, `.htm`
- **Data**: `.json`, `.yaml`, `.yml`, `.xml`, `.toml`, `.csv`, `.tsv`
- **Code**: `.py`, `.js`, `.ts`, `.java`, `.cpp`, `.c`, `.go`, `.rs`, `.rb`, `.php`, etc.
## Extending
### Adding New Tools
```python
# In tools/__init__.py
def my_custom_tool(param1: str, param2: int = 10) -> dict:
"""Your tool implementation."""
return {"result": "success"}
# Register the tool
tool_manager.register_tool(
name="my_custom_tool",
function=my_custom_tool,
schema={
"type": "function",
"function": {
"name": "my_custom_tool",
"description": "Description of your tool",
"parameters": {
"type": "object",
"properties": {
"param1": {"type": "string", "description": "..."},
"param2": {"type": "integer", "description": "...", "default": 10}
},
"required": ["param1"]
}
}
}
)
```
### Using Different Vector Stores
The default implementation uses a simple file-based store. To use ChromaDB:
1. Install: `pip install chromadb`
2. Modify `rag/vector_store.py` to use ChromaDB client
## Development
### Running in Development Mode
```bash
DEBUG=true python main.py
```
### Running Tests
```bash
pip install pytest pytest-asyncio
pytest tests/
```
## License

732
main.py
View File

@ -1 +1,731 @@
# build out the full app the hides the fact that you are a rag app and just looks like a normal openAI model as fare as a user is concerned when interacting with it with open-webui. but in reality it is a rag app that is doing all the work in the background then provides the information to GLM 4.7-flash.
#!/usr/bin/env python3
"""
DocRAG - OpenAI-Compatible RAG Server
This application presents itself as a standard OpenAI API server that can be used
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
1. Processes user queries through a RAG system
2. Retrieves relevant context from a knowledge base
3. Passes the enriched context to GLM-4.7-Flash for response generation
4. Optionally uses tools like website_downloader for enhanced capabilities
The user sees a normal chat experience, but the system is actually doing
sophisticated RAG operations in the background.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import sys
import time
import uuid
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Optional
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)
# FastAPI imports
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel, Field
# Import RAG components
from rag import RAGSystem, get_rag_system
from rag.document_processor import DocumentProcessor
# Import tools
from tools import ToolManager, get_tool_manager
# Import SDK for GLM-4.7-Flash
try:
from zai import ZaiClient as ZAI
except ImportError:
ZAI = None
log.warning("z-ai-web-dev-sdk not installed. Install with: pip install z-ai-web-dev-sdk")
# =============================================================================
# Configuration
# =============================================================================
class Config:
"""Application configuration from environment variables."""
# Server settings
HOST: str = os.getenv("HOST", "0.0.0.0")
PORT: int = int(os.getenv("PORT", "8000"))
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
# Model settings
MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG-GLM-4.7")
UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "glm-4.7")
# API Key for upstream LLM
ZAI_API_KEY: str = os.getenv("ZAI_API_KEY", "")
# RAG settings
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors")
DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents")
CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000"))
CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200"))
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
# Tool settings
ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true"
MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "3"))
config = Config()
# =============================================================================
# OpenAI-Compatible Models
# =============================================================================
class ChatMessage(BaseModel):
"""OpenAI chat message format."""
role: str
content: Optional[str] = None
name: Optional[str] = None
tool_calls: Optional[list[dict]] = None
tool_call_id: Optional[str] = None
class ChatCompletionRequest(BaseModel):
"""OpenAI chat completion request format."""
model: str = "DocRAG-GLM-4.7"
messages: list[ChatMessage]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
stop: Optional[list[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
user: Optional[str] = None
tools: Optional[list[dict]] = None
tool_choice: Optional[str | dict] = None
class ChatCompletionChoice(BaseModel):
"""OpenAI chat completion choice."""
index: int = 0
message: ChatMessage
finish_reason: str = "stop"
class ChatCompletionUsage(BaseModel):
"""Token usage statistics."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class ChatCompletionResponse(BaseModel):
"""OpenAI chat completion response format."""
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str = config.MODEL_NAME
choices: list[ChatCompletionChoice]
usage: ChatCompletionUsage = ChatCompletionUsage()
class ModelInfo(BaseModel):
"""OpenAI model info format."""
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "organization"
class ModelList(BaseModel):
"""OpenAI model list format."""
object: str = "list"
data: list[ModelInfo]
# =============================================================================
# Application State
# =============================================================================
class AppState:
"""Global application state."""
rag_system: Optional[RAGSystem] = None
tool_manager: Optional[ToolManager] = None
zai_client: Any = None
startup_time: float = time.time()
state = AppState()
# =============================================================================
# Lifespan Management
# =============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan - startup and shutdown."""
log.info("Starting DocRAG server...")
# Initialize RAG system
try:
log.info("Initializing RAG system...")
state.rag_system = await get_rag_system(
embedding_model=config.EMBEDDING_MODEL,
vector_store_path=config.VECTOR_STORE_PATH,
documents_path=config.DOCUMENTS_PATH,
chunk_size=config.CHUNK_SIZE,
chunk_overlap=config.CHUNK_OVERLAP,
)
log.info("RAG system initialized successfully")
except Exception as e:
log.warning(f"RAG system initialization deferred: {e}")
state.rag_system = None
# Initialize tool manager
try:
log.info("Initializing tool manager...")
state.tool_manager = get_tool_manager()
log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}")
except Exception as e:
log.warning(f"Tool manager initialization failed: {e}")
state.tool_manager = None
# Initialize ZAI client for upstream LLM
try:
if config.ZAI_API_KEY and ZAI is not None:
log.info("Initializing ZAI client...")
state.zai_client = ZAI(api_key=config.ZAI_API_KEY)
log.info("ZAI client initialized successfully")
else:
log.warning("No ZAI_API_KEY provided or SDK not installed - using mock responses")
state.zai_client = None
except Exception as e:
log.error(f"Failed to initialize ZAI client: {e}")
state.zai_client = None
log.info(f"DocRAG server started on {config.HOST}:{config.PORT}")
yield
# Cleanup
log.info("Shutting down DocRAG server...")
if state.rag_system:
await state.rag_system.close()
log.info("DocRAG server stopped")
# =============================================================================
# FastAPI Application
# =============================================================================
app = FastAPI(
title="DocRAG API",
description="OpenAI-compatible RAG server powered by GLM-4.7-Flash",
version="1.0.0",
lifespan=lifespan,
)
# CORS middleware for Open WebUI compatibility
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =============================================================================
# OpenAI-Compatible Endpoints
# =============================================================================
@app.get("/v1/models")
@app.get("/models")
async def list_models():
"""List available models (OpenAI-compatible)."""
return ModelList(
data=[
ModelInfo(id=config.MODEL_NAME, owned_by="docrag"),
ModelInfo(id="DocRAG-GLM-4.7", owned_by="docrag"),
]
)
@app.get("/v1/models/{model_id}")
@app.get("/models/{model_id}")
async def get_model(model_id: str):
"""Get model information (OpenAI-compatible)."""
if model_id not in [config.MODEL_NAME, "DocRAG-GLM-4.7"]:
raise HTTPException(status_code=404, detail="Model not found")
return ModelInfo(id=model_id, owned_by="docrag")
@app.post("/v1/chat/completions")
@app.post("/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""Handle chat completions (OpenAI-compatible)."""
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
try:
if request.stream:
return StreamingResponse(
stream_chat_completion(request, request_id),
media_type="text/event-stream",
)
else:
return await complete_chat(request, request_id)
except Exception as e:
log.exception("Chat completion failed")
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Chat Completion Logic
# =============================================================================
async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse:
"""Process a non-streaming chat completion request."""
messages = request.messages
# Extract the last user message
user_message = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
user_message = msg.content
break
if not user_message:
raise HTTPException(status_code=400, detail="No user message found")
# Step 1: RAG Retrieval
context = ""
sources = []
if state.rag_system:
try:
rag_result = await state.rag_system.query(
query=user_message,
top_k=config.TOP_K_RESULTS,
)
context = rag_result.get("context", "")
sources = rag_result.get("sources", [])
log.info(f"RAG retrieved {len(sources)} relevant documents")
except Exception as e:
log.warning(f"RAG retrieval failed: {e}")
# Step 2: Build enhanced prompt with context
enhanced_messages = build_enhanced_messages(messages, context, sources)
# Step 3: Check for tool usage
tool_calls_made = []
if config.ENABLE_TOOLS and state.tool_manager:
tool_calls_made = await check_and_execute_tools(
user_message, enhanced_messages, request.tools
)
# Step 4: Generate response with upstream LLM
response_content = await generate_response(
enhanced_messages,
temperature=request.temperature,
max_tokens=request.max_tokens,
tools=request.tools,
)
# Step 5: Build and return response
return ChatCompletionResponse(
id=request_id,
model=config.MODEL_NAME,
choices=[
ChatCompletionChoice(
message=ChatMessage(
role="assistant",
content=response_content,
tool_calls=tool_calls_made if tool_calls_made else None,
),
finish_reason="stop",
)
],
usage=ChatCompletionUsage(
prompt_tokens=len(str(enhanced_messages)) // 4,
completion_tokens=len(response_content) // 4,
),
)
async def stream_chat_completion(
request: ChatCompletionRequest,
request_id: str,
) -> AsyncIterator[str]:
"""Stream a chat completion response."""
messages = request.messages
# Extract the last user message
user_message = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
user_message = msg.content
break
if not user_message:
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
return
# Step 1: RAG Retrieval
context = ""
sources = []
if state.rag_system:
try:
rag_result = await state.rag_system.query(
query=user_message,
top_k=config.TOP_K_RESULTS,
)
context = rag_result.get("context", "")
sources = rag_result.get("sources", [])
log.info(f"RAG retrieved {len(sources)} relevant documents")
except Exception as e:
log.warning(f"RAG retrieval failed: {e}")
# Step 2: Build enhanced prompt with context
enhanced_messages = build_enhanced_messages(messages, context, sources)
# Step 3: Stream response from upstream LLM
created = int(time.time())
try:
if state.zai_client:
# Use actual GLM-4.7-Flash
response = state.zai_client.chat.completions.create(
model=config.UPSTREAM_MODEL,
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
temperature=request.temperature or 0.7,
max_tokens=request.max_tokens or 4096,
stream=True,
thinking={"type": "enabled"},
)
for chunk in response:
# Handle reasoning content (thinking)
if hasattr(chunk.choices[0].delta, 'reasoning_content') and chunk.choices[0].delta.reasoning_content:
# Don't expose thinking to user - this is internal RAG processing
log.debug(f"Thinking: {chunk.choices[0].delta.reasoning_content[:100]}...")
continue
# Stream actual content
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': content},
'finish_reason': None
}]
})}\n\n"
# Send final chunk
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {},
'finish_reason': 'stop'
}]
})}\n\n"
yield "data: [DONE]\n\n"
else:
# Mock streaming response for testing
mock_response = f"I understand you're asking about: {user_message}\n\n"
if context:
mock_response += f"Based on my knowledge base, I found relevant information that I'm using to help answer your question.\n\n"
mock_response += "However, I'm currently running in demo mode without an upstream LLM connection. Please configure ZAI_API_KEY for full functionality."
for char in mock_response:
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': char},
'finish_reason': None
}]
})}\n\n"
await asyncio.sleep(0.01)
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {},
'finish_reason': 'stop'
}]
})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
log.exception("Streaming failed")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
def build_enhanced_messages(
messages: list[ChatMessage],
context: str,
sources: list[str],
) -> list[ChatMessage]:
"""Build enhanced messages with RAG context."""
enhanced = []
# Add system message with RAG context
system_content = (
"You are a helpful AI assistant with access to a knowledge base. "
"Use the provided context to give accurate and helpful responses. "
"If the context doesn't contain relevant information, use your general knowledge "
"but indicate when you're doing so."
)
if context:
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
if sources:
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources)
enhanced.append(ChatMessage(role="system", content=system_content))
# Add conversation history (excluding old system messages)
for msg in messages:
if msg.role != "system":
enhanced.append(msg)
return enhanced
async def generate_response(
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 4096,
tools: Optional[list[dict]] = None,
) -> str:
"""Generate response using upstream LLM."""
if state.zai_client:
try:
response = state.zai_client.chat.completions.create(
model=config.UPSTREAM_MODEL,
messages=[{"role": m.role, "content": m.content} for m in messages if m.content],
temperature=temperature,
max_tokens=max_tokens,
stream=False,
thinking={"type": "enabled"},
)
# Extract content from response
content = ""
for chunk in response:
if hasattr(chunk.choices[0], 'message') and chunk.choices[0].message:
content = chunk.choices[0].message.content or ""
break
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
content += chunk.choices[0].delta.content
return content or "I apologize, but I couldn't generate a response."
except Exception as e:
log.error(f"Upstream LLM call failed: {e}")
return f"I encountered an error: {str(e)}"
else:
# Mock response for testing
user_msg = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
user_msg = msg.content
break
return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality."
async def check_and_execute_tools(
user_message: str,
messages: list[ChatMessage],
available_tools: Optional[list[dict]],
) -> list[dict]:
"""Check if tools should be used and execute them."""
if not state.tool_manager or not available_tools:
return []
tool_calls = []
# Simple keyword-based tool detection
# In a production system, you'd use the LLM to decide tool usage
message_lower = user_message.lower()
# Check for website download intent
if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site"]):
# Extract URL from message
import re
url_pattern = r'https?://[^\s]+'
urls = re.findall(url_pattern, user_message)
if urls:
tool_result = state.tool_manager.execute_tool(
"website_downloader",
{"url": urls[0], "max_pages": 10}
)
tool_calls.append({
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": "website_downloader",
"arguments": json.dumps({"url": urls[0]}),
}
})
log.info(f"Executed website_downloader tool: {tool_result}")
return tool_calls
# =============================================================================
# Document Management Endpoints
# =============================================================================
@app.post("/v1/documents/upload")
async def upload_document(request: Request):
"""Upload a document to the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
form = await request.form()
file = form.get("file")
if not file:
raise HTTPException(status_code=400, detail="No file provided")
content = await file.read()
filename = file.filename or "unknown"
# Process and store document
result = await state.rag_system.add_document(
content=content,
filename=filename,
)
return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)}
except Exception as e:
log.exception("Document upload failed")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/documents/url")
async def add_document_from_url(request: dict):
"""Add a document from URL to the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
url = request.get("url")
if not url:
raise HTTPException(status_code=400, detail="No URL provided")
try:
result = await state.rag_system.add_document_from_url(url)
return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)}
except Exception as e:
log.exception("URL document addition failed")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/documents")
async def list_documents():
"""List documents in the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
docs = await state.rag_system.list_documents()
return {"documents": docs}
except Exception as e:
log.exception("Document listing failed")
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/v1/documents/{doc_id}")
async def delete_document(doc_id: str):
"""Delete a document from the knowledge base."""
if not state.rag_system:
raise HTTPException(status_code=503, detail="RAG system not initialized")
try:
await state.rag_system.delete_document(doc_id)
return {"success": True, "message": f"Document {doc_id} deleted"}
except Exception as e:
log.exception("Document deletion failed")
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Health and Status Endpoints
# =============================================================================
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"uptime": time.time() - state.startup_time,
"rag_enabled": state.rag_system is not None,
"tools_enabled": state.tool_manager is not None,
"llm_connected": state.zai_client is not None,
}
@app.get("/")
async def root():
"""Root endpoint with API info."""
return {
"name": "DocRAG API",
"version": "1.0.0",
"description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash",
"endpoints": {
"chat": "/v1/chat/completions",
"models": "/v1/models",
"documents": "/v1/documents",
"health": "/health",
},
}
# =============================================================================
# Main Entry Point
# =============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"main:app",
host=config.HOST,
port=config.PORT,
reload=config.DEBUG,
)

252
rag/__init__.py Normal file
View File

@ -0,0 +1,252 @@
"""
RAG System - Retrieval Augmented Generation
This module provides the core RAG functionality for DocRAG, including:
- Document processing and chunking
- Vector storage and similarity search
- Context retrieval for enhanced prompts
"""
from __future__ import annotations
import asyncio
import logging
import os
from pathlib import Path
from typing import Any, Optional
from .document_processor import DocumentProcessor
from .vector_store import VectorStore
from .retriever import Retriever
log = logging.getLogger(__name__)
class RAGSystem:
"""
Main RAG system that coordinates document processing, storage, and retrieval.
This class provides a unified interface for:
- Adding documents to the knowledge base
- Querying for relevant context
- Managing the document lifecycle
"""
def __init__(
self,
embedding_model: str = "text-embedding-3-small",
vector_store_path: str = "./data/vectors",
documents_path: str = "./data/documents",
chunk_size: int = 1000,
chunk_overlap: int = 200,
):
self.embedding_model = embedding_model
self.vector_store_path = Path(vector_store_path)
self.documents_path = Path(documents_path)
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self._initialized = False
self._document_processor: Optional[DocumentProcessor] = None
self._vector_store: Optional[VectorStore] = None
self._retriever: Optional[Retriever] = None
async def initialize(self) -> None:
"""Initialize the RAG system components."""
if self._initialized:
return
log.info("Initializing RAG system...")
# Create directories
self.vector_store_path.mkdir(parents=True, exist_ok=True)
self.documents_path.mkdir(parents=True, exist_ok=True)
# Initialize document processor
self._document_processor = DocumentProcessor(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
# Initialize vector store
self._vector_store = VectorStore(
persist_directory=str(self.vector_store_path),
embedding_model=self.embedding_model,
)
# Initialize retriever
self._retriever = Retriever(
vector_store=self._vector_store,
)
self._initialized = True
log.info("RAG system initialized successfully")
async def close(self) -> None:
"""Close the RAG system and release resources."""
if self._vector_store:
await self._vector_store.close()
self._initialized = False
log.info("RAG system closed")
def _ensure_initialized(self) -> None:
"""Ensure the RAG system is initialized."""
if not self._initialized:
raise RuntimeError("RAG system not initialized. Call initialize() first.")
async def add_document(
self,
content: bytes,
filename: str,
metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""
Add a document to the knowledge base.
Args:
content: Raw document content
filename: Original filename
metadata: Optional metadata
Returns:
Dictionary with processing results
"""
self._ensure_initialized()
# Process document
doc_info = await self._document_processor.process(
content=content,
filename=filename,
metadata=metadata,
)
# Store chunks in vector store
if doc_info.get("chunks"):
await self._vector_store.add_chunks(
chunks=doc_info["chunks"],
metadatas=doc_info.get("metadatas", []),
ids=doc_info.get("ids", []),
)
log.info(f"Added document '{filename}' with {len(doc_info.get('chunks', []))} chunks")
return {"chunks": len(doc_info.get("chunks", [])), "document_id": doc_info.get("document_id")}
async def add_document_from_url(self, url: str) -> dict[str, Any]:
"""
Add a document from a URL to the knowledge base.
Args:
url: URL to fetch and process
Returns:
Dictionary with processing results
"""
self._ensure_initialized()
# Fetch content from URL
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=30) as response:
response.raise_for_status()
content = await response.read()
# Extract filename from URL
from urllib.parse import urlparse
parsed = urlparse(url)
filename = os.path.basename(parsed.path) or "webpage.html"
return await self.add_document(content=content, filename=filename, metadata={"source_url": url})
async def query(
self,
query: str,
top_k: int = 5,
filter_metadata: Optional[dict] = None,
) -> dict[str, Any]:
"""
Query the knowledge base for relevant context.
Args:
query: Query string
top_k: Number of results to return
filter_metadata: Optional metadata filters
Returns:
Dictionary with context and sources
"""
self._ensure_initialized()
# Retrieve relevant chunks
results = await self._retriever.retrieve(
query=query,
top_k=top_k,
filter_metadata=filter_metadata,
)
# Build context string
context_parts = []
sources = []
for i, result in enumerate(results):
context_parts.append(f"[{i+1}] {result['content']}")
if result.get("metadata", {}).get("source"):
sources.append(result["metadata"]["source"])
context = "\n\n".join(context_parts)
return {
"context": context,
"sources": list(set(sources)),
"num_results": len(results),
"results": results,
}
async def list_documents(self) -> list[dict[str, Any]]:
"""List all documents in the knowledge base."""
self._ensure_initialized()
return await self._vector_store.list_documents()
async def delete_document(self, document_id: str) -> None:
"""Delete a document from the knowledge base."""
self._ensure_initialized()
await self._vector_store.delete_document(document_id)
log.info(f"Deleted document {document_id}")
# Global RAG system instance
_rag_system: Optional[RAGSystem] = None
async def get_rag_system(
embedding_model: str = "text-embedding-3-small",
vector_store_path: str = "./data/vectors",
documents_path: str = "./data/documents",
chunk_size: int = 1000,
chunk_overlap: int = 200,
) -> RAGSystem:
"""
Get or create the global RAG system instance.
Args:
embedding_model: Name of the embedding model
vector_store_path: Path to vector store
documents_path: Path to document storage
chunk_size: Size of document chunks
chunk_overlap: Overlap between chunks
Returns:
Initialized RAGSystem instance
"""
global _rag_system
if _rag_system is None:
_rag_system = RAGSystem(
embedding_model=embedding_model,
vector_store_path=vector_store_path,
documents_path=documents_path,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
await _rag_system.initialize()
return _rag_system

247
rag/document_processor.py Normal file
View File

@ -0,0 +1,247 @@
"""
Document Processor - Handles document parsing and chunking
Supports multiple document formats:
- Plain text (.txt, .md)
- PDF (.pdf)
- HTML (.html, .htm)
- Word documents (.docx)
- Code files (.py, .js, etc.)
"""
from __future__ import annotations
import hashlib
import logging
import os
import re
import uuid
from pathlib import Path
from typing import Any, Optional
log = logging.getLogger(__name__)
class DocumentProcessor:
"""
Process documents into chunks suitable for vector storage.
Handles:
- Multiple file formats
- Intelligent chunking with overlap
- Metadata extraction
"""
def __init__(
self,
chunk_size: int = 1000,
chunk_overlap: int = 200,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
async def process(
self,
content: bytes,
filename: str,
metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""
Process a document into chunks.
Args:
content: Raw document content
filename: Original filename
metadata: Optional additional metadata
Returns:
Dictionary with chunks, metadatas, and ids
"""
# Extract text based on file type
text = await self._extract_text(content, filename)
if not text.strip():
return {"chunks": [], "metadatas": [], "ids": [], "document_id": None}
# Generate document ID
document_id = str(uuid.uuid4())
# Create chunks
chunks = self._chunk_text(text)
# Create metadata for each chunk
base_metadata = {
"source": filename,
"document_id": document_id,
**(metadata or {}),
}
metadatas = []
ids = []
for i, chunk in enumerate(chunks):
chunk_id = hashlib.md5(f"{document_id}_{i}".encode()).hexdigest()
ids.append(chunk_id)
metadatas.append({
**base_metadata,
"chunk_index": i,
"chunk_length": len(chunk),
})
return {
"chunks": chunks,
"metadatas": metadatas,
"ids": ids,
"document_id": document_id,
"total_chars": len(text),
}
async def _extract_text(self, content: bytes, filename: str) -> str:
"""Extract text from document based on file type."""
ext = Path(filename).suffix.lower()
try:
if ext in (".txt", ".md", ".rst", ".log"):
return content.decode("utf-8", errors="ignore")
elif ext == ".pdf":
return await self._extract_pdf(content)
elif ext in (".html", ".htm"):
return await self._extract_html(content)
elif ext == ".docx":
return await self._extract_docx(content)
elif ext in (".py", ".js", ".ts", ".java", ".cpp", ".c", ".go", ".rs", ".rb", ".php", ".cs", ".swift", ".kt"):
return content.decode("utf-8", errors="ignore")
elif ext in (".json", ".yaml", ".yml", ".xml", ".toml"):
return content.decode("utf-8", errors="ignore")
elif ext in (".csv", ".tsv"):
return content.decode("utf-8", errors="ignore")
else:
# Try to decode as text
try:
return content.decode("utf-8", errors="ignore")
except Exception:
log.warning(f"Unknown file type: {ext}, treating as binary")
return ""
except Exception as e:
log.error(f"Failed to extract text from {filename}: {e}")
return ""
async def _extract_pdf(self, content: bytes) -> str:
"""Extract text from PDF."""
try:
import fitz # PyMuPDF
doc = fitz.open(stream=content, filetype="pdf")
text_parts = []
for page in doc:
text_parts.append(page.get_text())
doc.close()
return "\n\n".join(text_parts)
except ImportError:
log.warning("PyMuPDF not installed, PDF extraction unavailable")
return ""
except Exception as e:
log.error(f"PDF extraction failed: {e}")
return ""
async def _extract_html(self, content: bytes) -> str:
"""Extract text from HTML."""
try:
from bs4 import BeautifulSoup
soup = BeautifulSoup(content, "html.parser")
# Remove script and style elements
for element in soup(["script", "style", "nav", "footer", "header"]):
element.decompose()
return soup.get_text(separator="\n", strip=True)
except ImportError:
log.warning("BeautifulSoup not installed, HTML extraction unavailable")
return content.decode("utf-8", errors="ignore")
except Exception as e:
log.error(f"HTML extraction failed: {e}")
return ""
async def _extract_docx(self, content: bytes) -> str:
"""Extract text from DOCX."""
try:
import io
from docx import Document
doc = Document(io.BytesIO(content))
return "\n\n".join(para.text for para in doc.paragraphs)
except ImportError:
log.warning("python-docx not installed, DOCX extraction unavailable")
return ""
except Exception as e:
log.error(f"DOCX extraction failed: {e}")
return ""
def _chunk_text(self, text: str) -> list[str]:
"""
Split text into overlapping chunks.
Uses a sentence-aware chunking strategy to avoid breaking mid-sentence.
"""
if len(text) <= self.chunk_size:
return [text.strip()] if text.strip() else []
# Split into sentences
sentences = self._split_sentences(text)
chunks = []
current_chunk = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence)
# If adding this sentence would exceed chunk size
if current_length + sentence_length > self.chunk_size and current_chunk:
# Save current chunk
chunks.append(" ".join(current_chunk))
# Start new chunk with overlap
overlap_text = self._get_overlap_text(current_chunk)
current_chunk = [overlap_text, sentence] if overlap_text else [sentence]
current_length = len(" ".join(current_chunk))
else:
current_chunk.append(sentence)
current_length += sentence_length + 1 # +1 for space
# Add final chunk
if current_chunk:
chunks.append(" ".join(current_chunk))
return [c.strip() for c in chunks if c.strip()]
def _split_sentences(self, text: str) -> list[str]:
"""Split text into sentences."""
# Simple sentence splitting - can be improved with NLP libraries
sentence_endings = r'(?<=[.!?])\s+'
sentences = re.split(sentence_endings, text)
return [s.strip() for s in sentences if s.strip()]
def _get_overlap_text(self, chunk_parts: list[str]) -> str:
"""Get text for overlap from the end of the current chunk."""
if not chunk_parts:
return ""
full_text = " ".join(chunk_parts)
if len(full_text) <= self.chunk_overlap:
return full_text
# Get last N characters
overlap = full_text[-self.chunk_overlap:]
# Try to start at a word boundary
space_idx = overlap.find(" ")
if space_idx > 0:
overlap = overlap[space_idx + 1:]
return overlap

176
rag/retriever.py Normal file
View File

@ -0,0 +1,176 @@
"""
Retriever - Handles context retrieval from the vector store
Provides intelligent retrieval with:
- Query optimization
- Result ranking
- Context windowing
"""
from __future__ import annotations
import logging
from typing import Any, Optional
from .vector_store import VectorStore
log = logging.getLogger(__name__)
class Retriever:
"""
Retriever for fetching relevant context from the vector store.
Handles:
- Query preprocessing
- Similarity search
- Result ranking and filtering
"""
def __init__(
self,
vector_store: VectorStore,
default_top_k: int = 5,
min_score: float = 0.0,
):
self.vector_store = vector_store
self.default_top_k = default_top_k
self.min_score = min_score
async def retrieve(
self,
query: str,
top_k: Optional[int] = None,
filter_metadata: Optional[dict] = None,
) -> list[dict[str, Any]]:
"""
Retrieve relevant chunks for a query.
Args:
query: Query string
top_k: Number of results (uses default if not provided)
filter_metadata: Optional metadata filters
Returns:
List of relevant chunks with scores
"""
top_k = top_k or self.default_top_k
# Preprocess query
processed_query = self._preprocess_query(query)
# Search vector store
results = await self.vector_store.search(
query=processed_query,
top_k=top_k * 2, # Get more results for filtering
filter_metadata=filter_metadata,
)
# Filter by minimum score
results = [r for r in results if r["score"] >= self.min_score]
# Rank and deduplicate
results = self._rank_results(results, query)
# Return top_k
return results[:top_k]
def _preprocess_query(self, query: str) -> str:
"""
Preprocess query for better retrieval.
- Remove extra whitespace
- Handle special characters
- Normalize case
"""
# Remove extra whitespace
query = " ".join(query.split())
# Remove question marks and other punctuation that might hurt matching
query = query.replace("?", " ").replace("!", " ")
# Normalize whitespace again
query = " ".join(query.split())
return query.strip()
def _rank_results(
self,
results: list[dict[str, Any]],
query: str,
) -> list[dict[str, Any]]:
"""
Rank results by relevance.
Uses a combination of:
- Vector similarity score
- Keyword matching
- Document diversity
"""
if not results:
return results
# Calculate additional scores
query_words = set(query.lower().split())
for result in results:
content = result["content"].lower()
content_words = set(content.split())
# Keyword overlap score
overlap = len(query_words & content_words)
keyword_score = overlap / max(len(query_words), 1)
# Combine scores
result["combined_score"] = (
result["score"] * 0.7 + # Vector similarity
keyword_score * 0.3 # Keyword matching
)
# Sort by combined score
results.sort(key=lambda x: x["combined_score"], reverse=True)
# Remove duplicate content (keep highest scoring)
seen_content = set()
unique_results = []
for result in results:
# Use first 100 chars as content fingerprint
content_fingerprint = result["content"][:100]
if content_fingerprint not in seen_content:
seen_content.add(content_fingerprint)
unique_results.append(result)
return unique_results
async def retrieve_with_context(
self,
query: str,
top_k: int = 5,
context_window: int = 1,
) -> dict[str, Any]:
"""
Retrieve chunks with surrounding context.
Args:
query: Query string
top_k: Number of main results
context_window: Number of adjacent chunks to include
Returns:
Dictionary with expanded context
"""
results = await self.retrieve(query=query, top_k=top_k)
# For now, return basic results
# In a full implementation, we'd expand to include adjacent chunks
return {
"results": results,
"context": "\n\n".join(r["content"] for r in results),
"sources": list(set(
r.get("metadata", {}).get("source", "")
for r in results
if r.get("metadata", {}).get("source")
)),
}

285
rag/vector_store.py Normal file
View File

@ -0,0 +1,285 @@
"""
Vector Store - Handles vector storage and similarity search
Provides a simple file-based vector store that can be extended to use
more sophisticated backends like ChromaDB, FAISS, or Pinecone.
"""
from __future__ import annotations
import hashlib
import json
import logging
import os
from pathlib import Path
from typing import Any, Optional
log = logging.getLogger(__name__)
class VectorStore:
"""
Vector store for document embeddings.
This implementation provides:
- Simple file-based persistence
- In-memory similarity search
- Document management
Can be extended to use ChromaDB, FAISS, or other vector databases.
"""
def __init__(
self,
persist_directory: str = "./data/vectors",
embedding_model: str = "text-embedding-3-small",
):
self.persist_directory = Path(persist_directory)
self.embedding_model = embedding_model
self._chunks: list[dict[str, Any]] = []
self._embeddings: list[list[float]] = []
self._metadata: list[dict[str, Any]] = []
self._ids: list[str] = []
self._initialized = False
async def initialize(self) -> None:
"""Initialize the vector store and load existing data."""
if self._initialized:
return
self.persist_directory.mkdir(parents=True, exist_ok=True)
# Load existing data
await self._load()
self._initialized = True
log.info(f"Vector store initialized with {len(self._chunks)} chunks")
async def close(self) -> None:
"""Save and close the vector store."""
await self._save()
log.info("Vector store closed")
async def _load(self) -> None:
"""Load data from disk."""
data_file = self.persist_directory / "store.json"
if not data_file.exists():
return
try:
with open(data_file, "r", encoding="utf-8") as f:
data = json.load(f)
self._chunks = data.get("chunks", [])
self._embeddings = data.get("embeddings", [])
self._metadata = data.get("metadata", [])
self._ids = data.get("ids", [])
log.info(f"Loaded {len(self._chunks)} chunks from disk")
except Exception as e:
log.error(f"Failed to load vector store: {e}")
async def _save(self) -> None:
"""Save data to disk."""
data_file = self.persist_directory / "store.json"
try:
data = {
"chunks": self._chunks,
"embeddings": self._embeddings,
"metadata": self._metadata,
"ids": self._ids,
}
with open(data_file, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
log.info(f"Saved {len(self._chunks)} chunks to disk")
except Exception as e:
log.error(f"Failed to save vector store: {e}")
def _ensure_initialized(self) -> None:
"""Ensure the vector store is initialized."""
if not self._initialized:
raise RuntimeError("Vector store not initialized")
async def add_chunks(
self,
chunks: list[str],
metadatas: Optional[list[dict[str, Any]]] = None,
ids: Optional[list[str]] = None,
) -> None:
"""
Add chunks to the vector store.
Args:
chunks: List of text chunks
metadatas: Optional list of metadata dicts
ids: Optional list of chunk IDs
"""
self._ensure_initialized()
if not chunks:
return
# Generate IDs if not provided
if ids is None:
ids = [hashlib.md5(chunk.encode()).hexdigest() for chunk in chunks]
# Generate metadata if not provided
if metadatas is None:
metadatas = [{}] * len(chunks)
# Generate embeddings
embeddings = await self._generate_embeddings(chunks)
# Store everything
for i, (chunk, embedding, metadata, chunk_id) in enumerate(
zip(chunks, embeddings, metadatas, ids)
):
self._chunks.append({"id": chunk_id, "content": chunk})
self._embeddings.append(embedding)
self._metadata.append(metadata)
self._ids.append(chunk_id)
# Save to disk
await self._save()
log.info(f"Added {len(chunks)} chunks to vector store")
async def _generate_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
Generate embeddings for texts.
Uses a simple hash-based embedding for demonstration.
In production, use a real embedding model via API.
"""
embeddings = []
for text in texts:
# Simple hash-based embedding (for demo purposes)
# In production, use OpenAI embeddings or similar
hash_bytes = hashlib.sha256(text.encode()).digest()
# Create a 384-dimensional embedding (common size)
embedding = []
for i in range(384):
byte_idx = i % len(hash_bytes)
value = (hash_bytes[byte_idx] - 128) / 128.0
embedding.append(value)
embeddings.append(embedding)
return embeddings
async def search(
self,
query: str,
top_k: int = 5,
filter_metadata: Optional[dict] = None,
) -> list[dict[str, Any]]:
"""
Search for similar chunks.
Args:
query: Query string
top_k: Number of results to return
filter_metadata: Optional metadata filters
Returns:
List of matching chunks with scores
"""
self._ensure_initialized()
if not self._chunks:
return []
# Generate query embedding
query_embedding = (await self._generate_embeddings([query]))[0]
# Calculate similarities
results = []
for i, (chunk, embedding, metadata) in enumerate(
zip(self._chunks, self._embeddings, self._metadata)
):
# Apply metadata filter
if filter_metadata:
match = all(
metadata.get(k) == v
for k, v in filter_metadata.items()
)
if not match:
continue
# Calculate cosine similarity
similarity = self._cosine_similarity(query_embedding, embedding)
results.append({
"id": chunk["id"],
"content": chunk["content"],
"metadata": metadata,
"score": similarity,
})
# Sort by score and return top_k
results.sort(key=lambda x: x["score"], reverse=True)
return results[:top_k]
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
if len(a) != len(b):
return 0.0
dot_product = sum(x * y for x, y in zip(a, b))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot_product / (norm_a * norm_b)
async def list_documents(self) -> list[dict[str, Any]]:
"""List all unique documents in the store."""
self._ensure_initialized()
# Group by document_id
documents = {}
for metadata in self._metadata:
doc_id = metadata.get("document_id")
if doc_id and doc_id not in documents:
documents[doc_id] = {
"id": doc_id,
"source": metadata.get("source", "unknown"),
"chunk_count": 1,
}
elif doc_id:
documents[doc_id]["chunk_count"] += 1
return list(documents.values())
async def delete_document(self, document_id: str) -> None:
"""Delete all chunks for a document."""
self._ensure_initialized()
# Find indices to remove
indices_to_remove = [
i
for i, metadata in enumerate(self._metadata)
if metadata.get("document_id") == document_id
]
# Remove in reverse order to maintain indices
for i in sorted(indices_to_remove, reverse=True):
self._chunks.pop(i)
self._embeddings.pop(i)
self._metadata.pop(i)
self._ids.pop(i)
# Save changes
await self._save()
log.info(f"Deleted document {document_id} ({len(indices_to_remove)} chunks)")

View File

@ -1,7 +1,31 @@
# Core dependencies
fastapi~=0.115.0
uvicorn[standard]~=0.32.0
pydantic~=2.10.0
python-multipart~=0.0.20
# HTTP and async
aiohttp~=3.11.0
httpx~=0.28.0
requests~=2.32.4
# Web scraping (for website downloader)
beautifulsoup4~=4.13.4
wget~=3.2
lxml~=5.3.0
urllib3~=2.5.0
# Document processing
PyMuPDF~=1.25.0
python-docx~=1.1.0
# Optional: For using z-ai-web-dev-sdk with GLM-4.7-Flash
# Uncomment the following line if you have access to the SDK
# z-ai-web-dev-sdk>=1.0.0
# Vector store alternatives (uncomment as needed)
# chromadb~=0.5.0
# faiss-cpu~=1.9.0
# Development dependencies
# pytest~=8.3.0
# pytest-asyncio~=0.24.0

156
tools/__init__.py Normal file
View File

@ -0,0 +1,156 @@
"""
Tools Module - Tool management for the RAG system
Provides a unified interface for tool registration and execution.
"""
from __future__ import annotations
import json
import logging
from typing import Any, Callable, Optional
# Import the website downloader tool
from website_downloader_tool import (
website_downloader,
get_tool_schema as get_website_downloader_schema,
)
log = logging.getLogger(__name__)
class ToolManager:
"""
Manages tool registration and execution.
Provides:
- Tool registration
- Tool schema generation
- Tool execution with error handling
"""
def __init__(self):
self._tools: dict[str, Callable] = {}
self._schemas: dict[str, dict] = {}
# Register built-in tools
self._register_builtin_tools()
def _register_builtin_tools(self) -> None:
"""Register built-in tools."""
# Register website downloader
self.register_tool(
name="website_downloader",
function=website_downloader,
schema=get_website_downloader_schema(),
)
log.info(f"Registered {len(self._tools)} built-in tools")
def register_tool(
self,
name: str,
function: Callable,
schema: dict,
) -> None:
"""
Register a new tool.
Args:
name: Tool name
function: Tool function
schema: OpenAI function schema
"""
self._tools[name] = function
self._schemas[name] = schema
log.info(f"Registered tool: {name}")
def get_tool_schema(self, name: str) -> Optional[dict]:
"""Get the schema for a tool."""
return self._schemas.get(name)
def get_all_schemas(self) -> list[dict]:
"""Get schemas for all registered tools."""
return list(self._schemas.values())
def list_tools(self) -> list[str]:
"""List all registered tool names."""
return list(self._tools.keys())
def execute_tool(
self,
name: str,
arguments: dict[str, Any],
) -> dict[str, Any]:
"""
Execute a tool with the given arguments.
Args:
name: Tool name
arguments: Tool arguments
Returns:
Tool result
"""
if name not in self._tools:
return {
"success": False,
"error": f"Unknown tool: {name}",
}
try:
log.info(f"Executing tool: {name} with args: {arguments}")
result = self._tools[name](**arguments)
return result
except TypeError as e:
log.error(f"Invalid arguments for tool {name}: {e}")
return {
"success": False,
"error": f"Invalid arguments: {str(e)}",
}
except Exception as e:
log.exception(f"Tool execution failed: {name}")
return {
"success": False,
"error": str(e),
}
def execute_tool_from_json(
self,
name: str,
arguments_json: str,
) -> dict[str, Any]:
"""
Execute a tool with JSON arguments.
Args:
name: Tool name
arguments_json: JSON string of arguments
Returns:
Tool result
"""
try:
arguments = json.loads(arguments_json)
return self.execute_tool(name, arguments)
except json.JSONDecodeError as e:
return {
"success": False,
"error": f"Invalid JSON arguments: {str(e)}",
}
# Global tool manager instance
_tool_manager: Optional[ToolManager] = None
def get_tool_manager() -> ToolManager:
"""Get or create the global tool manager instance."""
global _tool_manager
if _tool_manager is None:
_tool_manager = ToolManager()
return _tool_manager