Implement full DocRAG server with OpenAI-compatible API
Features: - FastAPI server with OpenAI-compatible endpoints (/v1/chat/completions, /v1/models) - RAG system with document processing and vector storage - Support for multiple document formats (PDF, DOCX, HTML, text, code) - Streaming response support - Tool integration with website_downloader - Document management API endpoints - GLM-4.7-Flash integration via z-ai-web-dev-sdk - Works transparently with Open WebUI and other OpenAI clients Components: - main.py: FastAPI application with OpenAI-compatible API - rag/: RAG system (document processor, vector store, retriever) - tools/: Tool manager with website_downloader integration - .env.example: Configuration template
This commit is contained in:
parent
e3681949e2
commit
eabdadfb62
26
.env.example
Normal file
26
.env.example
Normal file
@ -0,0 +1,26 @@
|
||||
# DocRAG Configuration
|
||||
# Copy this file to .env and fill in your values
|
||||
|
||||
# Server Configuration
|
||||
HOST=0.0.0.0
|
||||
PORT=8000
|
||||
DEBUG=false
|
||||
|
||||
# Model Configuration
|
||||
MODEL_NAME=DocRAG-GLM-4.7
|
||||
UPSTREAM_MODEL=glm-4.7
|
||||
|
||||
# API Keys
|
||||
ZAI_API_KEY=your-zai-api-key-here
|
||||
|
||||
# RAG Configuration
|
||||
EMBEDDING_MODEL=text-embedding-3-small
|
||||
VECTOR_STORE_PATH=./data/vectors
|
||||
DOCUMENTS_PATH=./data/documents
|
||||
CHUNK_SIZE=1000
|
||||
CHUNK_OVERLAP=200
|
||||
TOP_K_RESULTS=5
|
||||
|
||||
# Tool Configuration
|
||||
ENABLE_TOOLS=true
|
||||
MAX_TOOL_ITERATIONS=3
|
||||
386
README.md
386
README.md
@ -1,163 +1,275 @@
|
||||
# DocRAG - Custom RAG with Document Loader
|
||||
# DocRAG - OpenAI-Compatible RAG Server
|
||||
|
||||
A custom RAG (Retrieval-Augmented Generation) system with a custom document loader that acts as a local OpenAI-compatible server using a remote LLM with custom tools.
|
||||
A custom RAG (Retrieval-Augmented Generation) system that **appears as a standard OpenAI API server** to clients like Open WebUI. Behind the scenes, it:
|
||||
|
||||
## Components
|
||||
1. Processes user queries through a RAG system
|
||||
2. Retrieves relevant context from a knowledge base
|
||||
3. Passes the enriched context to GLM-4.7-Flash for response generation
|
||||
4. Optionally uses tools like website_downloader for enhanced capabilities
|
||||
|
||||
### Website Downloader Tool
|
||||
Users interact with what appears to be a normal chat experience, while sophisticated RAG operations happen transparently in the background.
|
||||
|
||||
The `website_downloader_tool.py` provides a tool interface for downloading and mirroring websites for offline use or RAG ingestion. It can be used by GLM-4.7-Flash via the z-ai-web-dev-sdk.
|
||||
## Features
|
||||
|
||||
#### Features
|
||||
- **OpenAI-Compatible API**: Works with any OpenAI client (Open WebUI, custom apps, etc.)
|
||||
- **RAG Integration**: Automatic context retrieval for enhanced responses
|
||||
- **Document Management**: Upload and manage documents in the knowledge base
|
||||
- **Tool Support**: Built-in tools like website_downloader for extended capabilities
|
||||
- **Streaming Support**: Real-time streaming responses
|
||||
- **Easy Configuration**: Environment-based configuration
|
||||
|
||||
- Downloads HTML pages and all linked assets (CSS, JS, images, fonts, etc.)
|
||||
- Rewrites links for offline viewing
|
||||
- Supports concurrent downloads with configurable thread count
|
||||
- Optional external asset downloading from CDNs
|
||||
- Domain whitelisting for external assets
|
||||
- Comprehensive error handling and statistics
|
||||
## Quick Start
|
||||
|
||||
#### Tool Schema
|
||||
|
||||
The tool follows the OpenAI function calling format:
|
||||
|
||||
```python
|
||||
from website_downloader_tool import get_tool_schema, website_downloader
|
||||
|
||||
# Get the tool schema for registration
|
||||
schema = get_tool_schema()
|
||||
```
|
||||
|
||||
#### Usage with GLM-4.7-Flash
|
||||
|
||||
```python
|
||||
from zai import ZaiClient
|
||||
from website_downloader_tool import get_tool_schema, website_downloader
|
||||
|
||||
client = ZaiClient(api_key="your-api-key")
|
||||
|
||||
# Define the tool
|
||||
tools = [get_tool_schema()]
|
||||
|
||||
# Create a chat completion with tools
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4.7",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Please download https://example.com for offline use"
|
||||
}
|
||||
],
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Handle tool calls in the response
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.tool_calls:
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.function.name == "website_downloader":
|
||||
import json
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
result = website_downloader(**args)
|
||||
print(result)
|
||||
```
|
||||
|
||||
#### Direct Usage
|
||||
|
||||
```python
|
||||
from website_downloader_tool import website_downloader
|
||||
|
||||
# Download a website
|
||||
result = website_downloader(
|
||||
url="https://example.com",
|
||||
destination="./downloaded_site", # Optional
|
||||
max_pages=50, # Max pages to crawl
|
||||
threads=6, # Concurrent downloads
|
||||
download_external_assets=False, # Include CDN assets
|
||||
external_domains=["cdn.example.com"] # Whitelist external domains
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
print(f"Downloaded to: {result['output_directory']}")
|
||||
print(f"Pages: {result['stats']['pages_crawled']}")
|
||||
print(f"Assets: {result['stats']['assets_downloaded']}")
|
||||
else:
|
||||
print(f"Error: {result['message']}")
|
||||
```
|
||||
|
||||
#### Parameters
|
||||
|
||||
| Parameter | Type | Required | Default | Description |
|
||||
|-----------|------|----------|---------|-------------|
|
||||
| `url` | string | Yes | - | Starting URL to crawl |
|
||||
| `destination` | string | No | Derived from URL | Output folder path |
|
||||
| `max_pages` | integer | No | 50 | Max HTML pages (1-1000) |
|
||||
| `threads` | integer | No | 6 | Concurrent download threads (1-20) |
|
||||
| `download_external_assets` | boolean | No | False | Download CDN assets |
|
||||
| `external_domains` | array | No | None | Whitelist of external domains |
|
||||
|
||||
#### Return Value
|
||||
|
||||
```python
|
||||
{
|
||||
"success": True/False,
|
||||
"message": "Human-readable summary",
|
||||
"stats": {
|
||||
"pages_crawled": int,
|
||||
"assets_downloaded": int,
|
||||
"failed_downloads": int,
|
||||
"elapsed_seconds": float,
|
||||
"output_directory": str,
|
||||
"pages": [...], # List of downloaded pages
|
||||
"downloaded_items": [...] # List of downloaded assets
|
||||
},
|
||||
"output_directory": "/path/to/downloaded/site"
|
||||
}
|
||||
```
|
||||
|
||||
### Website Downloader CLI
|
||||
|
||||
The original `website-downloader.py` can still be used as a standalone CLI tool:
|
||||
|
||||
```bash
|
||||
python website-downloader.py --url https://example.com --max-pages 50 --threads 6
|
||||
```
|
||||
|
||||
#### CLI Options
|
||||
|
||||
- `--url`: Starting URL to crawl (required)
|
||||
- `--destination`: Output folder (optional, derived from URL if not provided)
|
||||
- `--max-pages`: Maximum pages to crawl (default: 50)
|
||||
- `--threads`: Number of download threads (default: 6)
|
||||
- `--download-external-assets`: Enable external asset downloading
|
||||
- `--external-domains`: Whitelist of external domains to download from
|
||||
|
||||
## Installation
|
||||
### 1. Install Dependencies
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. Configure Environment
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
# Edit .env and add your ZAI_API_KEY
|
||||
```
|
||||
|
||||
### 3. Run the Server
|
||||
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
The server will start on `http://0.0.0.0:8000`
|
||||
|
||||
### 4. Use with Open WebUI
|
||||
|
||||
1. Open Open WebUI settings
|
||||
2. Add a new OpenAI-compatible connection
|
||||
3. Set the base URL to `http://your-server:8000/v1`
|
||||
4. Leave the API key empty or use any value (not validated)
|
||||
5. Select the "DocRAG-GLM-4.7" model
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### OpenAI-Compatible Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/v1/chat/completions` | POST | Chat completions (streaming supported) |
|
||||
| `/v1/models` | GET | List available models |
|
||||
| `/v1/models/{model_id}` | GET | Get model information |
|
||||
|
||||
### Document Management Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/v1/documents` | GET | List documents in knowledge base |
|
||||
| `/v1/documents/upload` | POST | Upload a document |
|
||||
| `/v1/documents/url` | POST | Add document from URL |
|
||||
| `/v1/documents/{doc_id}` | DELETE | Delete a document |
|
||||
|
||||
### Health & Status
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/health` | GET | Health check |
|
||||
| `/` | GET | API information |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Chat Completion
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "DocRAG-GLM-4.7",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is machine learning?"}
|
||||
],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
### Upload Document
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/documents/upload \
|
||||
-F "file=@document.pdf"
|
||||
```
|
||||
|
||||
### Add Document from URL
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/v1/documents/url \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"url": "https://example.com/article.html"}'
|
||||
```
|
||||
|
||||
### Python Client
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
api_key="not-needed" # API key not validated
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="DocRAG-GLM-4.7",
|
||||
messages=[
|
||||
{"role": "user", "content": "Explain quantum computing"}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content, end="")
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Configure via environment variables or `.env` file:
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `HOST` | `0.0.0.0` | Server host |
|
||||
| `PORT` | `8000` | Server port |
|
||||
| `DEBUG` | `false` | Enable debug mode |
|
||||
| `MODEL_NAME` | `DocRAG-GLM-4.7` | Display model name |
|
||||
| `UPSTREAM_MODEL` | `glm-4.7` | Upstream model to use |
|
||||
| `ZAI_API_KEY` | (required) | API key for ZAI SDK |
|
||||
| `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model |
|
||||
| `VECTOR_STORE_PATH` | `./data/vectors` | Vector store location |
|
||||
| `DOCUMENTS_PATH` | `./data/documents` | Document storage |
|
||||
| `CHUNK_SIZE` | `1000` | Document chunk size |
|
||||
| `CHUNK_OVERLAP` | `200` | Chunk overlap |
|
||||
| `TOP_K_RESULTS` | `5` | Number of context results |
|
||||
| `ENABLE_TOOLS` | `true` | Enable tool support |
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
docrag/
|
||||
├── website-downloader.py # Core website downloader (CLI)
|
||||
├── main.py # FastAPI application entry point
|
||||
├── rag/
|
||||
│ ├── __init__.py # RAG system main class
|
||||
│ ├── document_processor.py # Document parsing and chunking
|
||||
│ ├── vector_store.py # Vector storage and search
|
||||
│ └── retriever.py # Context retrieval logic
|
||||
├── tools/
|
||||
│ └── __init__.py # Tool management (website_downloader, etc.)
|
||||
├── website-downloader.py # CLI website downloader
|
||||
├── website_downloader_tool.py # Tool wrapper for GLM-4.7-Flash
|
||||
├── requirements.txt # Python dependencies
|
||||
├── .env.example # Configuration template
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Integration with RAG
|
||||
## How It Works
|
||||
|
||||
The downloaded website content can be processed for RAG systems:
|
||||
### Request Flow
|
||||
|
||||
1. Use the tool to download website content
|
||||
2. Parse the downloaded HTML files
|
||||
3. Extract text content and metadata
|
||||
4. Chunk and embed the content
|
||||
5. Store in your vector database
|
||||
1. **User sends message** → OpenAI-compatible endpoint receives request
|
||||
2. **RAG Retrieval** → Query is processed and relevant context is retrieved
|
||||
3. **Context Enhancement** → Retrieved context is added to the prompt
|
||||
4. **Tool Execution** → If needed, tools are invoked (e.g., website_downloader)
|
||||
5. **LLM Generation** → GLM-4.7-Flash generates response with context
|
||||
6. **Response** → User receives response (streaming supported)
|
||||
|
||||
### RAG Pipeline
|
||||
|
||||
```
|
||||
User Query
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ Query Processor │
|
||||
└────────┬────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ Vector Search │ ← Knowledge Base
|
||||
└────────┬────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ Context Builder │
|
||||
└────────┬────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────┐
|
||||
│ GLM-4.7-Flash │
|
||||
└────────┬────────┘
|
||||
│
|
||||
▼
|
||||
Response
|
||||
```
|
||||
|
||||
## Supported Document Formats
|
||||
|
||||
- **Text**: `.txt`, `.md`, `.rst`, `.log`
|
||||
- **Documents**: `.pdf`, `.docx`
|
||||
- **Web**: `.html`, `.htm`
|
||||
- **Data**: `.json`, `.yaml`, `.yml`, `.xml`, `.toml`, `.csv`, `.tsv`
|
||||
- **Code**: `.py`, `.js`, `.ts`, `.java`, `.cpp`, `.c`, `.go`, `.rs`, `.rb`, `.php`, etc.
|
||||
|
||||
## Extending
|
||||
|
||||
### Adding New Tools
|
||||
|
||||
```python
|
||||
# In tools/__init__.py
|
||||
|
||||
def my_custom_tool(param1: str, param2: int = 10) -> dict:
|
||||
"""Your tool implementation."""
|
||||
return {"result": "success"}
|
||||
|
||||
# Register the tool
|
||||
tool_manager.register_tool(
|
||||
name="my_custom_tool",
|
||||
function=my_custom_tool,
|
||||
schema={
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "my_custom_tool",
|
||||
"description": "Description of your tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {"type": "string", "description": "..."},
|
||||
"param2": {"type": "integer", "description": "...", "default": 10}
|
||||
},
|
||||
"required": ["param1"]
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Using Different Vector Stores
|
||||
|
||||
The default implementation uses a simple file-based store. To use ChromaDB:
|
||||
|
||||
1. Install: `pip install chromadb`
|
||||
2. Modify `rag/vector_store.py` to use ChromaDB client
|
||||
|
||||
## Development
|
||||
|
||||
### Running in Development Mode
|
||||
|
||||
```bash
|
||||
DEBUG=true python main.py
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
pip install pytest pytest-asyncio
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
|
||||
732
main.py
732
main.py
@ -1 +1,731 @@
|
||||
# build out the full app the hides the fact that you are a rag app and just looks like a normal openAI model as fare as a user is concerned when interacting with it with open-webui. but in reality it is a rag app that is doing all the work in the background then provides the information to GLM 4.7-flash.
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DocRAG - OpenAI-Compatible RAG Server
|
||||
|
||||
This application presents itself as a standard OpenAI API server that can be used
|
||||
with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it:
|
||||
1. Processes user queries through a RAG system
|
||||
2. Retrieves relevant context from a knowledge base
|
||||
3. Passes the enriched context to GLM-4.7-Flash for response generation
|
||||
4. Optionally uses tools like website_downloader for enhanced capabilities
|
||||
|
||||
The user sees a normal chat experience, but the system is actually doing
|
||||
sophisticated RAG operations in the background.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncIterator, Optional
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# FastAPI imports
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import RAG components
|
||||
from rag import RAGSystem, get_rag_system
|
||||
from rag.document_processor import DocumentProcessor
|
||||
|
||||
# Import tools
|
||||
from tools import ToolManager, get_tool_manager
|
||||
|
||||
# Import SDK for GLM-4.7-Flash
|
||||
try:
|
||||
from zai import ZaiClient as ZAI
|
||||
except ImportError:
|
||||
ZAI = None
|
||||
log.warning("z-ai-web-dev-sdk not installed. Install with: pip install z-ai-web-dev-sdk")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
class Config:
|
||||
"""Application configuration from environment variables."""
|
||||
|
||||
# Server settings
|
||||
HOST: str = os.getenv("HOST", "0.0.0.0")
|
||||
PORT: int = int(os.getenv("PORT", "8000"))
|
||||
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
|
||||
|
||||
# Model settings
|
||||
MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG-GLM-4.7")
|
||||
UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "glm-4.7")
|
||||
|
||||
# API Key for upstream LLM
|
||||
ZAI_API_KEY: str = os.getenv("ZAI_API_KEY", "")
|
||||
|
||||
# RAG settings
|
||||
EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
||||
VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors")
|
||||
DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents")
|
||||
CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000"))
|
||||
CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200"))
|
||||
TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5"))
|
||||
|
||||
# Tool settings
|
||||
ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true"
|
||||
MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "3"))
|
||||
|
||||
|
||||
config = Config()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI-Compatible Models
|
||||
# =============================================================================
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""OpenAI chat message format."""
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[list[dict]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""OpenAI chat completion request format."""
|
||||
model: str = "DocRAG-GLM-4.7"
|
||||
messages: list[ChatMessage]
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[list[str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
user: Optional[str] = None
|
||||
tools: Optional[list[dict]] = None
|
||||
tool_choice: Optional[str | dict] = None
|
||||
|
||||
|
||||
class ChatCompletionChoice(BaseModel):
|
||||
"""OpenAI chat completion choice."""
|
||||
index: int = 0
|
||||
message: ChatMessage
|
||||
finish_reason: str = "stop"
|
||||
|
||||
|
||||
class ChatCompletionUsage(BaseModel):
|
||||
"""Token usage statistics."""
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
"""OpenAI chat completion response format."""
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}")
|
||||
object: str = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str = config.MODEL_NAME
|
||||
choices: list[ChatCompletionChoice]
|
||||
usage: ChatCompletionUsage = ChatCompletionUsage()
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""OpenAI model info format."""
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "organization"
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
"""OpenAI model list format."""
|
||||
object: str = "list"
|
||||
data: list[ModelInfo]
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Application State
|
||||
# =============================================================================
|
||||
|
||||
class AppState:
|
||||
"""Global application state."""
|
||||
rag_system: Optional[RAGSystem] = None
|
||||
tool_manager: Optional[ToolManager] = None
|
||||
zai_client: Any = None
|
||||
startup_time: float = time.time()
|
||||
|
||||
|
||||
state = AppState()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Lifespan Management
|
||||
# =============================================================================
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage application lifespan - startup and shutdown."""
|
||||
log.info("Starting DocRAG server...")
|
||||
|
||||
# Initialize RAG system
|
||||
try:
|
||||
log.info("Initializing RAG system...")
|
||||
state.rag_system = await get_rag_system(
|
||||
embedding_model=config.EMBEDDING_MODEL,
|
||||
vector_store_path=config.VECTOR_STORE_PATH,
|
||||
documents_path=config.DOCUMENTS_PATH,
|
||||
chunk_size=config.CHUNK_SIZE,
|
||||
chunk_overlap=config.CHUNK_OVERLAP,
|
||||
)
|
||||
log.info("RAG system initialized successfully")
|
||||
except Exception as e:
|
||||
log.warning(f"RAG system initialization deferred: {e}")
|
||||
state.rag_system = None
|
||||
|
||||
# Initialize tool manager
|
||||
try:
|
||||
log.info("Initializing tool manager...")
|
||||
state.tool_manager = get_tool_manager()
|
||||
log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}")
|
||||
except Exception as e:
|
||||
log.warning(f"Tool manager initialization failed: {e}")
|
||||
state.tool_manager = None
|
||||
|
||||
# Initialize ZAI client for upstream LLM
|
||||
try:
|
||||
if config.ZAI_API_KEY and ZAI is not None:
|
||||
log.info("Initializing ZAI client...")
|
||||
state.zai_client = ZAI(api_key=config.ZAI_API_KEY)
|
||||
log.info("ZAI client initialized successfully")
|
||||
else:
|
||||
log.warning("No ZAI_API_KEY provided or SDK not installed - using mock responses")
|
||||
state.zai_client = None
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize ZAI client: {e}")
|
||||
state.zai_client = None
|
||||
|
||||
log.info(f"DocRAG server started on {config.HOST}:{config.PORT}")
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
log.info("Shutting down DocRAG server...")
|
||||
if state.rag_system:
|
||||
await state.rag_system.close()
|
||||
log.info("DocRAG server stopped")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FastAPI Application
|
||||
# =============================================================================
|
||||
|
||||
app = FastAPI(
|
||||
title="DocRAG API",
|
||||
description="OpenAI-compatible RAG server powered by GLM-4.7-Flash",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS middleware for Open WebUI compatibility
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI-Compatible Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@app.get("/v1/models")
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
"""List available models (OpenAI-compatible)."""
|
||||
return ModelList(
|
||||
data=[
|
||||
ModelInfo(id=config.MODEL_NAME, owned_by="docrag"),
|
||||
ModelInfo(id="DocRAG-GLM-4.7", owned_by="docrag"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@app.get("/v1/models/{model_id}")
|
||||
@app.get("/models/{model_id}")
|
||||
async def get_model(model_id: str):
|
||||
"""Get model information (OpenAI-compatible)."""
|
||||
if model_id not in [config.MODEL_NAME, "DocRAG-GLM-4.7"]:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
return ModelInfo(id=model_id, owned_by="docrag")
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@app.post("/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
"""Handle chat completions (OpenAI-compatible)."""
|
||||
request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
|
||||
|
||||
try:
|
||||
if request.stream:
|
||||
return StreamingResponse(
|
||||
stream_chat_completion(request, request_id),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
return await complete_chat(request, request_id)
|
||||
except Exception as e:
|
||||
log.exception("Chat completion failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Chat Completion Logic
|
||||
# =============================================================================
|
||||
|
||||
async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse:
|
||||
"""Process a non-streaming chat completion request."""
|
||||
messages = request.messages
|
||||
|
||||
# Extract the last user message
|
||||
user_message = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.role == "user" and msg.content:
|
||||
user_message = msg.content
|
||||
break
|
||||
|
||||
if not user_message:
|
||||
raise HTTPException(status_code=400, detail="No user message found")
|
||||
|
||||
# Step 1: RAG Retrieval
|
||||
context = ""
|
||||
sources = []
|
||||
if state.rag_system:
|
||||
try:
|
||||
rag_result = await state.rag_system.query(
|
||||
query=user_message,
|
||||
top_k=config.TOP_K_RESULTS,
|
||||
)
|
||||
context = rag_result.get("context", "")
|
||||
sources = rag_result.get("sources", [])
|
||||
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
||||
except Exception as e:
|
||||
log.warning(f"RAG retrieval failed: {e}")
|
||||
|
||||
# Step 2: Build enhanced prompt with context
|
||||
enhanced_messages = build_enhanced_messages(messages, context, sources)
|
||||
|
||||
# Step 3: Check for tool usage
|
||||
tool_calls_made = []
|
||||
if config.ENABLE_TOOLS and state.tool_manager:
|
||||
tool_calls_made = await check_and_execute_tools(
|
||||
user_message, enhanced_messages, request.tools
|
||||
)
|
||||
|
||||
# Step 4: Generate response with upstream LLM
|
||||
response_content = await generate_response(
|
||||
enhanced_messages,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
tools=request.tools,
|
||||
)
|
||||
|
||||
# Step 5: Build and return response
|
||||
return ChatCompletionResponse(
|
||||
id=request_id,
|
||||
model=config.MODEL_NAME,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=response_content,
|
||||
tool_calls=tool_calls_made if tool_calls_made else None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=ChatCompletionUsage(
|
||||
prompt_tokens=len(str(enhanced_messages)) // 4,
|
||||
completion_tokens=len(response_content) // 4,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def stream_chat_completion(
|
||||
request: ChatCompletionRequest,
|
||||
request_id: str,
|
||||
) -> AsyncIterator[str]:
|
||||
"""Stream a chat completion response."""
|
||||
messages = request.messages
|
||||
|
||||
# Extract the last user message
|
||||
user_message = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.role == "user" and msg.content:
|
||||
user_message = msg.content
|
||||
break
|
||||
|
||||
if not user_message:
|
||||
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
|
||||
return
|
||||
|
||||
# Step 1: RAG Retrieval
|
||||
context = ""
|
||||
sources = []
|
||||
if state.rag_system:
|
||||
try:
|
||||
rag_result = await state.rag_system.query(
|
||||
query=user_message,
|
||||
top_k=config.TOP_K_RESULTS,
|
||||
)
|
||||
context = rag_result.get("context", "")
|
||||
sources = rag_result.get("sources", [])
|
||||
log.info(f"RAG retrieved {len(sources)} relevant documents")
|
||||
except Exception as e:
|
||||
log.warning(f"RAG retrieval failed: {e}")
|
||||
|
||||
# Step 2: Build enhanced prompt with context
|
||||
enhanced_messages = build_enhanced_messages(messages, context, sources)
|
||||
|
||||
# Step 3: Stream response from upstream LLM
|
||||
created = int(time.time())
|
||||
|
||||
try:
|
||||
if state.zai_client:
|
||||
# Use actual GLM-4.7-Flash
|
||||
response = state.zai_client.chat.completions.create(
|
||||
model=config.UPSTREAM_MODEL,
|
||||
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
|
||||
temperature=request.temperature or 0.7,
|
||||
max_tokens=request.max_tokens or 4096,
|
||||
stream=True,
|
||||
thinking={"type": "enabled"},
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
# Handle reasoning content (thinking)
|
||||
if hasattr(chunk.choices[0].delta, 'reasoning_content') and chunk.choices[0].delta.reasoning_content:
|
||||
# Don't expose thinking to user - this is internal RAG processing
|
||||
log.debug(f"Thinking: {chunk.choices[0].delta.reasoning_content[:100]}...")
|
||||
continue
|
||||
|
||||
# Stream actual content
|
||||
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
yield f"data: {json.dumps({
|
||||
'id': request_id,
|
||||
'object': 'chat.completion.chunk',
|
||||
'created': created,
|
||||
'model': config.MODEL_NAME,
|
||||
'choices': [{
|
||||
'index': 0,
|
||||
'delta': {'content': content},
|
||||
'finish_reason': None
|
||||
}]
|
||||
})}\n\n"
|
||||
|
||||
# Send final chunk
|
||||
yield f"data: {json.dumps({
|
||||
'id': request_id,
|
||||
'object': 'chat.completion.chunk',
|
||||
'created': created,
|
||||
'model': config.MODEL_NAME,
|
||||
'choices': [{
|
||||
'index': 0,
|
||||
'delta': {},
|
||||
'finish_reason': 'stop'
|
||||
}]
|
||||
})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
else:
|
||||
# Mock streaming response for testing
|
||||
mock_response = f"I understand you're asking about: {user_message}\n\n"
|
||||
if context:
|
||||
mock_response += f"Based on my knowledge base, I found relevant information that I'm using to help answer your question.\n\n"
|
||||
mock_response += "However, I'm currently running in demo mode without an upstream LLM connection. Please configure ZAI_API_KEY for full functionality."
|
||||
|
||||
for char in mock_response:
|
||||
yield f"data: {json.dumps({
|
||||
'id': request_id,
|
||||
'object': 'chat.completion.chunk',
|
||||
'created': created,
|
||||
'model': config.MODEL_NAME,
|
||||
'choices': [{
|
||||
'index': 0,
|
||||
'delta': {'content': char},
|
||||
'finish_reason': None
|
||||
}]
|
||||
})}\n\n"
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
yield f"data: {json.dumps({
|
||||
'id': request_id,
|
||||
'object': 'chat.completion.chunk',
|
||||
'created': created,
|
||||
'model': config.MODEL_NAME,
|
||||
'choices': [{
|
||||
'index': 0,
|
||||
'delta': {},
|
||||
'finish_reason': 'stop'
|
||||
}]
|
||||
})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
log.exception("Streaming failed")
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
|
||||
|
||||
def build_enhanced_messages(
|
||||
messages: list[ChatMessage],
|
||||
context: str,
|
||||
sources: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Build enhanced messages with RAG context."""
|
||||
enhanced = []
|
||||
|
||||
# Add system message with RAG context
|
||||
system_content = (
|
||||
"You are a helpful AI assistant with access to a knowledge base. "
|
||||
"Use the provided context to give accurate and helpful responses. "
|
||||
"If the context doesn't contain relevant information, use your general knowledge "
|
||||
"but indicate when you're doing so."
|
||||
)
|
||||
|
||||
if context:
|
||||
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
|
||||
if sources:
|
||||
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources)
|
||||
|
||||
enhanced.append(ChatMessage(role="system", content=system_content))
|
||||
|
||||
# Add conversation history (excluding old system messages)
|
||||
for msg in messages:
|
||||
if msg.role != "system":
|
||||
enhanced.append(msg)
|
||||
|
||||
return enhanced
|
||||
|
||||
|
||||
async def generate_response(
|
||||
messages: list[ChatMessage],
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
tools: Optional[list[dict]] = None,
|
||||
) -> str:
|
||||
"""Generate response using upstream LLM."""
|
||||
if state.zai_client:
|
||||
try:
|
||||
response = state.zai_client.chat.completions.create(
|
||||
model=config.UPSTREAM_MODEL,
|
||||
messages=[{"role": m.role, "content": m.content} for m in messages if m.content],
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
thinking={"type": "enabled"},
|
||||
)
|
||||
|
||||
# Extract content from response
|
||||
content = ""
|
||||
for chunk in response:
|
||||
if hasattr(chunk.choices[0], 'message') and chunk.choices[0].message:
|
||||
content = chunk.choices[0].message.content or ""
|
||||
break
|
||||
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
|
||||
content += chunk.choices[0].delta.content
|
||||
|
||||
return content or "I apologize, but I couldn't generate a response."
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Upstream LLM call failed: {e}")
|
||||
return f"I encountered an error: {str(e)}"
|
||||
|
||||
else:
|
||||
# Mock response for testing
|
||||
user_msg = ""
|
||||
for msg in reversed(messages):
|
||||
if msg.role == "user" and msg.content:
|
||||
user_msg = msg.content
|
||||
break
|
||||
return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality."
|
||||
|
||||
|
||||
async def check_and_execute_tools(
|
||||
user_message: str,
|
||||
messages: list[ChatMessage],
|
||||
available_tools: Optional[list[dict]],
|
||||
) -> list[dict]:
|
||||
"""Check if tools should be used and execute them."""
|
||||
if not state.tool_manager or not available_tools:
|
||||
return []
|
||||
|
||||
tool_calls = []
|
||||
|
||||
# Simple keyword-based tool detection
|
||||
# In a production system, you'd use the LLM to decide tool usage
|
||||
message_lower = user_message.lower()
|
||||
|
||||
# Check for website download intent
|
||||
if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site"]):
|
||||
# Extract URL from message
|
||||
import re
|
||||
url_pattern = r'https?://[^\s]+'
|
||||
urls = re.findall(url_pattern, user_message)
|
||||
|
||||
if urls:
|
||||
tool_result = state.tool_manager.execute_tool(
|
||||
"website_downloader",
|
||||
{"url": urls[0], "max_pages": 10}
|
||||
)
|
||||
tool_calls.append({
|
||||
"id": f"call_{uuid.uuid4().hex[:24]}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "website_downloader",
|
||||
"arguments": json.dumps({"url": urls[0]}),
|
||||
}
|
||||
})
|
||||
log.info(f"Executed website_downloader tool: {tool_result}")
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Document Management Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@app.post("/v1/documents/upload")
|
||||
async def upload_document(request: Request):
|
||||
"""Upload a document to the knowledge base."""
|
||||
if not state.rag_system:
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
try:
|
||||
form = await request.form()
|
||||
file = form.get("file")
|
||||
if not file:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
content = await file.read()
|
||||
filename = file.filename or "unknown"
|
||||
|
||||
# Process and store document
|
||||
result = await state.rag_system.add_document(
|
||||
content=content,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)}
|
||||
|
||||
except Exception as e:
|
||||
log.exception("Document upload failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/v1/documents/url")
|
||||
async def add_document_from_url(request: dict):
|
||||
"""Add a document from URL to the knowledge base."""
|
||||
if not state.rag_system:
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
url = request.get("url")
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="No URL provided")
|
||||
|
||||
try:
|
||||
result = await state.rag_system.add_document_from_url(url)
|
||||
return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)}
|
||||
except Exception as e:
|
||||
log.exception("URL document addition failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/v1/documents")
|
||||
async def list_documents():
|
||||
"""List documents in the knowledge base."""
|
||||
if not state.rag_system:
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
try:
|
||||
docs = await state.rag_system.list_documents()
|
||||
return {"documents": docs}
|
||||
except Exception as e:
|
||||
log.exception("Document listing failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.delete("/v1/documents/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""Delete a document from the knowledge base."""
|
||||
if not state.rag_system:
|
||||
raise HTTPException(status_code=503, detail="RAG system not initialized")
|
||||
|
||||
try:
|
||||
await state.rag_system.delete_document(doc_id)
|
||||
return {"success": True, "message": f"Document {doc_id} deleted"}
|
||||
except Exception as e:
|
||||
log.exception("Document deletion failed")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Health and Status Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"uptime": time.time() - state.startup_time,
|
||||
"rag_enabled": state.rag_system is not None,
|
||||
"tools_enabled": state.tool_manager is not None,
|
||||
"llm_connected": state.zai_client is not None,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API info."""
|
||||
return {
|
||||
"name": "DocRAG API",
|
||||
"version": "1.0.0",
|
||||
"description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash",
|
||||
"endpoints": {
|
||||
"chat": "/v1/chat/completions",
|
||||
"models": "/v1/models",
|
||||
"documents": "/v1/documents",
|
||||
"health": "/health",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Entry Point
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host=config.HOST,
|
||||
port=config.PORT,
|
||||
reload=config.DEBUG,
|
||||
)
|
||||
|
||||
252
rag/__init__.py
Normal file
252
rag/__init__.py
Normal file
@ -0,0 +1,252 @@
|
||||
"""
|
||||
RAG System - Retrieval Augmented Generation
|
||||
|
||||
This module provides the core RAG functionality for DocRAG, including:
|
||||
- Document processing and chunking
|
||||
- Vector storage and similarity search
|
||||
- Context retrieval for enhanced prompts
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from .document_processor import DocumentProcessor
|
||||
from .vector_store import VectorStore
|
||||
from .retriever import Retriever
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RAGSystem:
|
||||
"""
|
||||
Main RAG system that coordinates document processing, storage, and retrieval.
|
||||
|
||||
This class provides a unified interface for:
|
||||
- Adding documents to the knowledge base
|
||||
- Querying for relevant context
|
||||
- Managing the document lifecycle
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
vector_store_path: str = "./data/vectors",
|
||||
documents_path: str = "./data/documents",
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
):
|
||||
self.embedding_model = embedding_model
|
||||
self.vector_store_path = Path(vector_store_path)
|
||||
self.documents_path = Path(documents_path)
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
self._initialized = False
|
||||
self._document_processor: Optional[DocumentProcessor] = None
|
||||
self._vector_store: Optional[VectorStore] = None
|
||||
self._retriever: Optional[Retriever] = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the RAG system components."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
log.info("Initializing RAG system...")
|
||||
|
||||
# Create directories
|
||||
self.vector_store_path.mkdir(parents=True, exist_ok=True)
|
||||
self.documents_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize document processor
|
||||
self._document_processor = DocumentProcessor(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
)
|
||||
|
||||
# Initialize vector store
|
||||
self._vector_store = VectorStore(
|
||||
persist_directory=str(self.vector_store_path),
|
||||
embedding_model=self.embedding_model,
|
||||
)
|
||||
|
||||
# Initialize retriever
|
||||
self._retriever = Retriever(
|
||||
vector_store=self._vector_store,
|
||||
)
|
||||
|
||||
self._initialized = True
|
||||
log.info("RAG system initialized successfully")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the RAG system and release resources."""
|
||||
if self._vector_store:
|
||||
await self._vector_store.close()
|
||||
self._initialized = False
|
||||
log.info("RAG system closed")
|
||||
|
||||
def _ensure_initialized(self) -> None:
|
||||
"""Ensure the RAG system is initialized."""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("RAG system not initialized. Call initialize() first.")
|
||||
|
||||
async def add_document(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Add a document to the knowledge base.
|
||||
|
||||
Args:
|
||||
content: Raw document content
|
||||
filename: Original filename
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
# Process document
|
||||
doc_info = await self._document_processor.process(
|
||||
content=content,
|
||||
filename=filename,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Store chunks in vector store
|
||||
if doc_info.get("chunks"):
|
||||
await self._vector_store.add_chunks(
|
||||
chunks=doc_info["chunks"],
|
||||
metadatas=doc_info.get("metadatas", []),
|
||||
ids=doc_info.get("ids", []),
|
||||
)
|
||||
|
||||
log.info(f"Added document '{filename}' with {len(doc_info.get('chunks', []))} chunks")
|
||||
return {"chunks": len(doc_info.get("chunks", [])), "document_id": doc_info.get("document_id")}
|
||||
|
||||
async def add_document_from_url(self, url: str) -> dict[str, Any]:
|
||||
"""
|
||||
Add a document from a URL to the knowledge base.
|
||||
|
||||
Args:
|
||||
url: URL to fetch and process
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
# Fetch content from URL
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, timeout=30) as response:
|
||||
response.raise_for_status()
|
||||
content = await response.read()
|
||||
|
||||
# Extract filename from URL
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(url)
|
||||
filename = os.path.basename(parsed.path) or "webpage.html"
|
||||
|
||||
return await self.add_document(content=content, filename=filename, metadata={"source_url": url})
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
filter_metadata: Optional[dict] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Query the knowledge base for relevant context.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
top_k: Number of results to return
|
||||
filter_metadata: Optional metadata filters
|
||||
|
||||
Returns:
|
||||
Dictionary with context and sources
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
# Retrieve relevant chunks
|
||||
results = await self._retriever.retrieve(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
filter_metadata=filter_metadata,
|
||||
)
|
||||
|
||||
# Build context string
|
||||
context_parts = []
|
||||
sources = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
context_parts.append(f"[{i+1}] {result['content']}")
|
||||
if result.get("metadata", {}).get("source"):
|
||||
sources.append(result["metadata"]["source"])
|
||||
|
||||
context = "\n\n".join(context_parts)
|
||||
|
||||
return {
|
||||
"context": context,
|
||||
"sources": list(set(sources)),
|
||||
"num_results": len(results),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
async def list_documents(self) -> list[dict[str, Any]]:
|
||||
"""List all documents in the knowledge base."""
|
||||
self._ensure_initialized()
|
||||
return await self._vector_store.list_documents()
|
||||
|
||||
async def delete_document(self, document_id: str) -> None:
|
||||
"""Delete a document from the knowledge base."""
|
||||
self._ensure_initialized()
|
||||
await self._vector_store.delete_document(document_id)
|
||||
log.info(f"Deleted document {document_id}")
|
||||
|
||||
|
||||
# Global RAG system instance
|
||||
_rag_system: Optional[RAGSystem] = None
|
||||
|
||||
|
||||
async def get_rag_system(
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
vector_store_path: str = "./data/vectors",
|
||||
documents_path: str = "./data/documents",
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
) -> RAGSystem:
|
||||
"""
|
||||
Get or create the global RAG system instance.
|
||||
|
||||
Args:
|
||||
embedding_model: Name of the embedding model
|
||||
vector_store_path: Path to vector store
|
||||
documents_path: Path to document storage
|
||||
chunk_size: Size of document chunks
|
||||
chunk_overlap: Overlap between chunks
|
||||
|
||||
Returns:
|
||||
Initialized RAGSystem instance
|
||||
"""
|
||||
global _rag_system
|
||||
|
||||
if _rag_system is None:
|
||||
_rag_system = RAGSystem(
|
||||
embedding_model=embedding_model,
|
||||
vector_store_path=vector_store_path,
|
||||
documents_path=documents_path,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
await _rag_system.initialize()
|
||||
|
||||
return _rag_system
|
||||
247
rag/document_processor.py
Normal file
247
rag/document_processor.py
Normal file
@ -0,0 +1,247 @@
|
||||
"""
|
||||
Document Processor - Handles document parsing and chunking
|
||||
|
||||
Supports multiple document formats:
|
||||
- Plain text (.txt, .md)
|
||||
- PDF (.pdf)
|
||||
- HTML (.html, .htm)
|
||||
- Word documents (.docx)
|
||||
- Code files (.py, .js, etc.)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""
|
||||
Process documents into chunks suitable for vector storage.
|
||||
|
||||
Handles:
|
||||
- Multiple file formats
|
||||
- Intelligent chunking with overlap
|
||||
- Metadata extraction
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = 1000,
|
||||
chunk_overlap: int = 200,
|
||||
):
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
async def process(
|
||||
self,
|
||||
content: bytes,
|
||||
filename: str,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Process a document into chunks.
|
||||
|
||||
Args:
|
||||
content: Raw document content
|
||||
filename: Original filename
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
Dictionary with chunks, metadatas, and ids
|
||||
"""
|
||||
# Extract text based on file type
|
||||
text = await self._extract_text(content, filename)
|
||||
|
||||
if not text.strip():
|
||||
return {"chunks": [], "metadatas": [], "ids": [], "document_id": None}
|
||||
|
||||
# Generate document ID
|
||||
document_id = str(uuid.uuid4())
|
||||
|
||||
# Create chunks
|
||||
chunks = self._chunk_text(text)
|
||||
|
||||
# Create metadata for each chunk
|
||||
base_metadata = {
|
||||
"source": filename,
|
||||
"document_id": document_id,
|
||||
**(metadata or {}),
|
||||
}
|
||||
|
||||
metadatas = []
|
||||
ids = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_id = hashlib.md5(f"{document_id}_{i}".encode()).hexdigest()
|
||||
ids.append(chunk_id)
|
||||
metadatas.append({
|
||||
**base_metadata,
|
||||
"chunk_index": i,
|
||||
"chunk_length": len(chunk),
|
||||
})
|
||||
|
||||
return {
|
||||
"chunks": chunks,
|
||||
"metadatas": metadatas,
|
||||
"ids": ids,
|
||||
"document_id": document_id,
|
||||
"total_chars": len(text),
|
||||
}
|
||||
|
||||
async def _extract_text(self, content: bytes, filename: str) -> str:
|
||||
"""Extract text from document based on file type."""
|
||||
ext = Path(filename).suffix.lower()
|
||||
|
||||
try:
|
||||
if ext in (".txt", ".md", ".rst", ".log"):
|
||||
return content.decode("utf-8", errors="ignore")
|
||||
|
||||
elif ext == ".pdf":
|
||||
return await self._extract_pdf(content)
|
||||
|
||||
elif ext in (".html", ".htm"):
|
||||
return await self._extract_html(content)
|
||||
|
||||
elif ext == ".docx":
|
||||
return await self._extract_docx(content)
|
||||
|
||||
elif ext in (".py", ".js", ".ts", ".java", ".cpp", ".c", ".go", ".rs", ".rb", ".php", ".cs", ".swift", ".kt"):
|
||||
return content.decode("utf-8", errors="ignore")
|
||||
|
||||
elif ext in (".json", ".yaml", ".yml", ".xml", ".toml"):
|
||||
return content.decode("utf-8", errors="ignore")
|
||||
|
||||
elif ext in (".csv", ".tsv"):
|
||||
return content.decode("utf-8", errors="ignore")
|
||||
|
||||
else:
|
||||
# Try to decode as text
|
||||
try:
|
||||
return content.decode("utf-8", errors="ignore")
|
||||
except Exception:
|
||||
log.warning(f"Unknown file type: {ext}, treating as binary")
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to extract text from {filename}: {e}")
|
||||
return ""
|
||||
|
||||
async def _extract_pdf(self, content: bytes) -> str:
|
||||
"""Extract text from PDF."""
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
doc = fitz.open(stream=content, filetype="pdf")
|
||||
text_parts = []
|
||||
for page in doc:
|
||||
text_parts.append(page.get_text())
|
||||
doc.close()
|
||||
return "\n\n".join(text_parts)
|
||||
except ImportError:
|
||||
log.warning("PyMuPDF not installed, PDF extraction unavailable")
|
||||
return ""
|
||||
except Exception as e:
|
||||
log.error(f"PDF extraction failed: {e}")
|
||||
return ""
|
||||
|
||||
async def _extract_html(self, content: bytes) -> str:
|
||||
"""Extract text from HTML."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
# Remove script and style elements
|
||||
for element in soup(["script", "style", "nav", "footer", "header"]):
|
||||
element.decompose()
|
||||
return soup.get_text(separator="\n", strip=True)
|
||||
except ImportError:
|
||||
log.warning("BeautifulSoup not installed, HTML extraction unavailable")
|
||||
return content.decode("utf-8", errors="ignore")
|
||||
except Exception as e:
|
||||
log.error(f"HTML extraction failed: {e}")
|
||||
return ""
|
||||
|
||||
async def _extract_docx(self, content: bytes) -> str:
|
||||
"""Extract text from DOCX."""
|
||||
try:
|
||||
import io
|
||||
from docx import Document
|
||||
doc = Document(io.BytesIO(content))
|
||||
return "\n\n".join(para.text for para in doc.paragraphs)
|
||||
except ImportError:
|
||||
log.warning("python-docx not installed, DOCX extraction unavailable")
|
||||
return ""
|
||||
except Exception as e:
|
||||
log.error(f"DOCX extraction failed: {e}")
|
||||
return ""
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""
|
||||
Split text into overlapping chunks.
|
||||
|
||||
Uses a sentence-aware chunking strategy to avoid breaking mid-sentence.
|
||||
"""
|
||||
if len(text) <= self.chunk_size:
|
||||
return [text.strip()] if text.strip() else []
|
||||
|
||||
# Split into sentences
|
||||
sentences = self._split_sentences(text)
|
||||
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_length = len(sentence)
|
||||
|
||||
# If adding this sentence would exceed chunk size
|
||||
if current_length + sentence_length > self.chunk_size and current_chunk:
|
||||
# Save current chunk
|
||||
chunks.append(" ".join(current_chunk))
|
||||
|
||||
# Start new chunk with overlap
|
||||
overlap_text = self._get_overlap_text(current_chunk)
|
||||
current_chunk = [overlap_text, sentence] if overlap_text else [sentence]
|
||||
current_length = len(" ".join(current_chunk))
|
||||
else:
|
||||
current_chunk.append(sentence)
|
||||
current_length += sentence_length + 1 # +1 for space
|
||||
|
||||
# Add final chunk
|
||||
if current_chunk:
|
||||
chunks.append(" ".join(current_chunk))
|
||||
|
||||
return [c.strip() for c in chunks if c.strip()]
|
||||
|
||||
def _split_sentences(self, text: str) -> list[str]:
|
||||
"""Split text into sentences."""
|
||||
# Simple sentence splitting - can be improved with NLP libraries
|
||||
sentence_endings = r'(?<=[.!?])\s+'
|
||||
sentences = re.split(sentence_endings, text)
|
||||
return [s.strip() for s in sentences if s.strip()]
|
||||
|
||||
def _get_overlap_text(self, chunk_parts: list[str]) -> str:
|
||||
"""Get text for overlap from the end of the current chunk."""
|
||||
if not chunk_parts:
|
||||
return ""
|
||||
|
||||
full_text = " ".join(chunk_parts)
|
||||
|
||||
if len(full_text) <= self.chunk_overlap:
|
||||
return full_text
|
||||
|
||||
# Get last N characters
|
||||
overlap = full_text[-self.chunk_overlap:]
|
||||
|
||||
# Try to start at a word boundary
|
||||
space_idx = overlap.find(" ")
|
||||
if space_idx > 0:
|
||||
overlap = overlap[space_idx + 1:]
|
||||
|
||||
return overlap
|
||||
176
rag/retriever.py
Normal file
176
rag/retriever.py
Normal file
@ -0,0 +1,176 @@
|
||||
"""
|
||||
Retriever - Handles context retrieval from the vector store
|
||||
|
||||
Provides intelligent retrieval with:
|
||||
- Query optimization
|
||||
- Result ranking
|
||||
- Context windowing
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from .vector_store import VectorStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Retriever:
|
||||
"""
|
||||
Retriever for fetching relevant context from the vector store.
|
||||
|
||||
Handles:
|
||||
- Query preprocessing
|
||||
- Similarity search
|
||||
- Result ranking and filtering
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: VectorStore,
|
||||
default_top_k: int = 5,
|
||||
min_score: float = 0.0,
|
||||
):
|
||||
self.vector_store = vector_store
|
||||
self.default_top_k = default_top_k
|
||||
self.min_score = min_score
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
top_k: Optional[int] = None,
|
||||
filter_metadata: Optional[dict] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve relevant chunks for a query.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
top_k: Number of results (uses default if not provided)
|
||||
filter_metadata: Optional metadata filters
|
||||
|
||||
Returns:
|
||||
List of relevant chunks with scores
|
||||
"""
|
||||
top_k = top_k or self.default_top_k
|
||||
|
||||
# Preprocess query
|
||||
processed_query = self._preprocess_query(query)
|
||||
|
||||
# Search vector store
|
||||
results = await self.vector_store.search(
|
||||
query=processed_query,
|
||||
top_k=top_k * 2, # Get more results for filtering
|
||||
filter_metadata=filter_metadata,
|
||||
)
|
||||
|
||||
# Filter by minimum score
|
||||
results = [r for r in results if r["score"] >= self.min_score]
|
||||
|
||||
# Rank and deduplicate
|
||||
results = self._rank_results(results, query)
|
||||
|
||||
# Return top_k
|
||||
return results[:top_k]
|
||||
|
||||
def _preprocess_query(self, query: str) -> str:
|
||||
"""
|
||||
Preprocess query for better retrieval.
|
||||
|
||||
- Remove extra whitespace
|
||||
- Handle special characters
|
||||
- Normalize case
|
||||
"""
|
||||
# Remove extra whitespace
|
||||
query = " ".join(query.split())
|
||||
|
||||
# Remove question marks and other punctuation that might hurt matching
|
||||
query = query.replace("?", " ").replace("!", " ")
|
||||
|
||||
# Normalize whitespace again
|
||||
query = " ".join(query.split())
|
||||
|
||||
return query.strip()
|
||||
|
||||
def _rank_results(
|
||||
self,
|
||||
results: list[dict[str, Any]],
|
||||
query: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Rank results by relevance.
|
||||
|
||||
Uses a combination of:
|
||||
- Vector similarity score
|
||||
- Keyword matching
|
||||
- Document diversity
|
||||
"""
|
||||
if not results:
|
||||
return results
|
||||
|
||||
# Calculate additional scores
|
||||
query_words = set(query.lower().split())
|
||||
|
||||
for result in results:
|
||||
content = result["content"].lower()
|
||||
content_words = set(content.split())
|
||||
|
||||
# Keyword overlap score
|
||||
overlap = len(query_words & content_words)
|
||||
keyword_score = overlap / max(len(query_words), 1)
|
||||
|
||||
# Combine scores
|
||||
result["combined_score"] = (
|
||||
result["score"] * 0.7 + # Vector similarity
|
||||
keyword_score * 0.3 # Keyword matching
|
||||
)
|
||||
|
||||
# Sort by combined score
|
||||
results.sort(key=lambda x: x["combined_score"], reverse=True)
|
||||
|
||||
# Remove duplicate content (keep highest scoring)
|
||||
seen_content = set()
|
||||
unique_results = []
|
||||
|
||||
for result in results:
|
||||
# Use first 100 chars as content fingerprint
|
||||
content_fingerprint = result["content"][:100]
|
||||
|
||||
if content_fingerprint not in seen_content:
|
||||
seen_content.add(content_fingerprint)
|
||||
unique_results.append(result)
|
||||
|
||||
return unique_results
|
||||
|
||||
async def retrieve_with_context(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
context_window: int = 1,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve chunks with surrounding context.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
top_k: Number of main results
|
||||
context_window: Number of adjacent chunks to include
|
||||
|
||||
Returns:
|
||||
Dictionary with expanded context
|
||||
"""
|
||||
results = await self.retrieve(query=query, top_k=top_k)
|
||||
|
||||
# For now, return basic results
|
||||
# In a full implementation, we'd expand to include adjacent chunks
|
||||
return {
|
||||
"results": results,
|
||||
"context": "\n\n".join(r["content"] for r in results),
|
||||
"sources": list(set(
|
||||
r.get("metadata", {}).get("source", "")
|
||||
for r in results
|
||||
if r.get("metadata", {}).get("source")
|
||||
)),
|
||||
}
|
||||
285
rag/vector_store.py
Normal file
285
rag/vector_store.py
Normal file
@ -0,0 +1,285 @@
|
||||
"""
|
||||
Vector Store - Handles vector storage and similarity search
|
||||
|
||||
Provides a simple file-based vector store that can be extended to use
|
||||
more sophisticated backends like ChromaDB, FAISS, or Pinecone.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""
|
||||
Vector store for document embeddings.
|
||||
|
||||
This implementation provides:
|
||||
- Simple file-based persistence
|
||||
- In-memory similarity search
|
||||
- Document management
|
||||
|
||||
Can be extended to use ChromaDB, FAISS, or other vector databases.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persist_directory: str = "./data/vectors",
|
||||
embedding_model: str = "text-embedding-3-small",
|
||||
):
|
||||
self.persist_directory = Path(persist_directory)
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
self._chunks: list[dict[str, Any]] = []
|
||||
self._embeddings: list[list[float]] = []
|
||||
self._metadata: list[dict[str, Any]] = []
|
||||
self._ids: list[str] = []
|
||||
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the vector store and load existing data."""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self.persist_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load existing data
|
||||
await self._load()
|
||||
|
||||
self._initialized = True
|
||||
log.info(f"Vector store initialized with {len(self._chunks)} chunks")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Save and close the vector store."""
|
||||
await self._save()
|
||||
log.info("Vector store closed")
|
||||
|
||||
async def _load(self) -> None:
|
||||
"""Load data from disk."""
|
||||
data_file = self.persist_directory / "store.json"
|
||||
|
||||
if not data_file.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
with open(data_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self._chunks = data.get("chunks", [])
|
||||
self._embeddings = data.get("embeddings", [])
|
||||
self._metadata = data.get("metadata", [])
|
||||
self._ids = data.get("ids", [])
|
||||
|
||||
log.info(f"Loaded {len(self._chunks)} chunks from disk")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to load vector store: {e}")
|
||||
|
||||
async def _save(self) -> None:
|
||||
"""Save data to disk."""
|
||||
data_file = self.persist_directory / "store.json"
|
||||
|
||||
try:
|
||||
data = {
|
||||
"chunks": self._chunks,
|
||||
"embeddings": self._embeddings,
|
||||
"metadata": self._metadata,
|
||||
"ids": self._ids,
|
||||
}
|
||||
|
||||
with open(data_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
log.info(f"Saved {len(self._chunks)} chunks to disk")
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to save vector store: {e}")
|
||||
|
||||
def _ensure_initialized(self) -> None:
|
||||
"""Ensure the vector store is initialized."""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("Vector store not initialized")
|
||||
|
||||
async def add_chunks(
|
||||
self,
|
||||
chunks: list[str],
|
||||
metadatas: Optional[list[dict[str, Any]]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add chunks to the vector store.
|
||||
|
||||
Args:
|
||||
chunks: List of text chunks
|
||||
metadatas: Optional list of metadata dicts
|
||||
ids: Optional list of chunk IDs
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
# Generate IDs if not provided
|
||||
if ids is None:
|
||||
ids = [hashlib.md5(chunk.encode()).hexdigest() for chunk in chunks]
|
||||
|
||||
# Generate metadata if not provided
|
||||
if metadatas is None:
|
||||
metadatas = [{}] * len(chunks)
|
||||
|
||||
# Generate embeddings
|
||||
embeddings = await self._generate_embeddings(chunks)
|
||||
|
||||
# Store everything
|
||||
for i, (chunk, embedding, metadata, chunk_id) in enumerate(
|
||||
zip(chunks, embeddings, metadatas, ids)
|
||||
):
|
||||
self._chunks.append({"id": chunk_id, "content": chunk})
|
||||
self._embeddings.append(embedding)
|
||||
self._metadata.append(metadata)
|
||||
self._ids.append(chunk_id)
|
||||
|
||||
# Save to disk
|
||||
await self._save()
|
||||
|
||||
log.info(f"Added {len(chunks)} chunks to vector store")
|
||||
|
||||
async def _generate_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for texts.
|
||||
|
||||
Uses a simple hash-based embedding for demonstration.
|
||||
In production, use a real embedding model via API.
|
||||
"""
|
||||
embeddings = []
|
||||
|
||||
for text in texts:
|
||||
# Simple hash-based embedding (for demo purposes)
|
||||
# In production, use OpenAI embeddings or similar
|
||||
hash_bytes = hashlib.sha256(text.encode()).digest()
|
||||
# Create a 384-dimensional embedding (common size)
|
||||
embedding = []
|
||||
for i in range(384):
|
||||
byte_idx = i % len(hash_bytes)
|
||||
value = (hash_bytes[byte_idx] - 128) / 128.0
|
||||
embedding.append(value)
|
||||
embeddings.append(embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
filter_metadata: Optional[dict] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Search for similar chunks.
|
||||
|
||||
Args:
|
||||
query: Query string
|
||||
top_k: Number of results to return
|
||||
filter_metadata: Optional metadata filters
|
||||
|
||||
Returns:
|
||||
List of matching chunks with scores
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
if not self._chunks:
|
||||
return []
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = (await self._generate_embeddings([query]))[0]
|
||||
|
||||
# Calculate similarities
|
||||
results = []
|
||||
for i, (chunk, embedding, metadata) in enumerate(
|
||||
zip(self._chunks, self._embeddings, self._metadata)
|
||||
):
|
||||
# Apply metadata filter
|
||||
if filter_metadata:
|
||||
match = all(
|
||||
metadata.get(k) == v
|
||||
for k, v in filter_metadata.items()
|
||||
)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
# Calculate cosine similarity
|
||||
similarity = self._cosine_similarity(query_embedding, embedding)
|
||||
|
||||
results.append({
|
||||
"id": chunk["id"],
|
||||
"content": chunk["content"],
|
||||
"metadata": metadata,
|
||||
"score": similarity,
|
||||
})
|
||||
|
||||
# Sort by score and return top_k
|
||||
results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
if len(a) != len(b):
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm_a * norm_b)
|
||||
|
||||
async def list_documents(self) -> list[dict[str, Any]]:
|
||||
"""List all unique documents in the store."""
|
||||
self._ensure_initialized()
|
||||
|
||||
# Group by document_id
|
||||
documents = {}
|
||||
for metadata in self._metadata:
|
||||
doc_id = metadata.get("document_id")
|
||||
if doc_id and doc_id not in documents:
|
||||
documents[doc_id] = {
|
||||
"id": doc_id,
|
||||
"source": metadata.get("source", "unknown"),
|
||||
"chunk_count": 1,
|
||||
}
|
||||
elif doc_id:
|
||||
documents[doc_id]["chunk_count"] += 1
|
||||
|
||||
return list(documents.values())
|
||||
|
||||
async def delete_document(self, document_id: str) -> None:
|
||||
"""Delete all chunks for a document."""
|
||||
self._ensure_initialized()
|
||||
|
||||
# Find indices to remove
|
||||
indices_to_remove = [
|
||||
i
|
||||
for i, metadata in enumerate(self._metadata)
|
||||
if metadata.get("document_id") == document_id
|
||||
]
|
||||
|
||||
# Remove in reverse order to maintain indices
|
||||
for i in sorted(indices_to_remove, reverse=True):
|
||||
self._chunks.pop(i)
|
||||
self._embeddings.pop(i)
|
||||
self._metadata.pop(i)
|
||||
self._ids.pop(i)
|
||||
|
||||
# Save changes
|
||||
await self._save()
|
||||
|
||||
log.info(f"Deleted document {document_id} ({len(indices_to_remove)} chunks)")
|
||||
@ -1,7 +1,31 @@
|
||||
# Core dependencies
|
||||
fastapi~=0.115.0
|
||||
uvicorn[standard]~=0.32.0
|
||||
pydantic~=2.10.0
|
||||
python-multipart~=0.0.20
|
||||
|
||||
# HTTP and async
|
||||
aiohttp~=3.11.0
|
||||
httpx~=0.28.0
|
||||
requests~=2.32.4
|
||||
|
||||
# Web scraping (for website downloader)
|
||||
beautifulsoup4~=4.13.4
|
||||
wget~=3.2
|
||||
lxml~=5.3.0
|
||||
urllib3~=2.5.0
|
||||
|
||||
# Document processing
|
||||
PyMuPDF~=1.25.0
|
||||
python-docx~=1.1.0
|
||||
|
||||
# Optional: For using z-ai-web-dev-sdk with GLM-4.7-Flash
|
||||
# Uncomment the following line if you have access to the SDK
|
||||
# z-ai-web-dev-sdk>=1.0.0
|
||||
|
||||
# Vector store alternatives (uncomment as needed)
|
||||
# chromadb~=0.5.0
|
||||
# faiss-cpu~=1.9.0
|
||||
|
||||
# Development dependencies
|
||||
# pytest~=8.3.0
|
||||
# pytest-asyncio~=0.24.0
|
||||
|
||||
156
tools/__init__.py
Normal file
156
tools/__init__.py
Normal file
@ -0,0 +1,156 @@
|
||||
"""
|
||||
Tools Module - Tool management for the RAG system
|
||||
|
||||
Provides a unified interface for tool registration and execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
# Import the website downloader tool
|
||||
from website_downloader_tool import (
|
||||
website_downloader,
|
||||
get_tool_schema as get_website_downloader_schema,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""
|
||||
Manages tool registration and execution.
|
||||
|
||||
Provides:
|
||||
- Tool registration
|
||||
- Tool schema generation
|
||||
- Tool execution with error handling
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, Callable] = {}
|
||||
self._schemas: dict[str, dict] = {}
|
||||
|
||||
# Register built-in tools
|
||||
self._register_builtin_tools()
|
||||
|
||||
def _register_builtin_tools(self) -> None:
|
||||
"""Register built-in tools."""
|
||||
# Register website downloader
|
||||
self.register_tool(
|
||||
name="website_downloader",
|
||||
function=website_downloader,
|
||||
schema=get_website_downloader_schema(),
|
||||
)
|
||||
|
||||
log.info(f"Registered {len(self._tools)} built-in tools")
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
name: str,
|
||||
function: Callable,
|
||||
schema: dict,
|
||||
) -> None:
|
||||
"""
|
||||
Register a new tool.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
function: Tool function
|
||||
schema: OpenAI function schema
|
||||
"""
|
||||
self._tools[name] = function
|
||||
self._schemas[name] = schema
|
||||
log.info(f"Registered tool: {name}")
|
||||
|
||||
def get_tool_schema(self, name: str) -> Optional[dict]:
|
||||
"""Get the schema for a tool."""
|
||||
return self._schemas.get(name)
|
||||
|
||||
def get_all_schemas(self) -> list[dict]:
|
||||
"""Get schemas for all registered tools."""
|
||||
return list(self._schemas.values())
|
||||
|
||||
def list_tools(self) -> list[str]:
|
||||
"""List all registered tool names."""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def execute_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a tool with the given arguments.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
arguments: Tool arguments
|
||||
|
||||
Returns:
|
||||
Tool result
|
||||
"""
|
||||
if name not in self._tools:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unknown tool: {name}",
|
||||
}
|
||||
|
||||
try:
|
||||
log.info(f"Executing tool: {name} with args: {arguments}")
|
||||
result = self._tools[name](**arguments)
|
||||
return result
|
||||
|
||||
except TypeError as e:
|
||||
log.error(f"Invalid arguments for tool {name}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid arguments: {str(e)}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Tool execution failed: {name}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def execute_tool_from_json(
|
||||
self,
|
||||
name: str,
|
||||
arguments_json: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a tool with JSON arguments.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
arguments_json: JSON string of arguments
|
||||
|
||||
Returns:
|
||||
Tool result
|
||||
"""
|
||||
try:
|
||||
arguments = json.loads(arguments_json)
|
||||
return self.execute_tool(name, arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Invalid JSON arguments: {str(e)}",
|
||||
}
|
||||
|
||||
|
||||
# Global tool manager instance
|
||||
_tool_manager: Optional[ToolManager] = None
|
||||
|
||||
|
||||
def get_tool_manager() -> ToolManager:
|
||||
"""Get or create the global tool manager instance."""
|
||||
global _tool_manager
|
||||
|
||||
if _tool_manager is None:
|
||||
_tool_manager = ToolManager()
|
||||
|
||||
return _tool_manager
|
||||
Loading…
Reference in New Issue
Block a user