diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..3c15871 --- /dev/null +++ b/.env.example @@ -0,0 +1,26 @@ +# DocRAG Configuration +# Copy this file to .env and fill in your values + +# Server Configuration +HOST=0.0.0.0 +PORT=8000 +DEBUG=false + +# Model Configuration +MODEL_NAME=DocRAG-GLM-4.7 +UPSTREAM_MODEL=glm-4.7 + +# API Keys +ZAI_API_KEY=your-zai-api-key-here + +# RAG Configuration +EMBEDDING_MODEL=text-embedding-3-small +VECTOR_STORE_PATH=./data/vectors +DOCUMENTS_PATH=./data/documents +CHUNK_SIZE=1000 +CHUNK_OVERLAP=200 +TOP_K_RESULTS=5 + +# Tool Configuration +ENABLE_TOOLS=true +MAX_TOOL_ITERATIONS=3 diff --git a/README.md b/README.md index 549998b..c97f03b 100644 --- a/README.md +++ b/README.md @@ -1,163 +1,275 @@ -# DocRAG - Custom RAG with Document Loader +# DocRAG - OpenAI-Compatible RAG Server -A custom RAG (Retrieval-Augmented Generation) system with a custom document loader that acts as a local OpenAI-compatible server using a remote LLM with custom tools. +A custom RAG (Retrieval-Augmented Generation) system that **appears as a standard OpenAI API server** to clients like Open WebUI. Behind the scenes, it: -## Components +1. Processes user queries through a RAG system +2. Retrieves relevant context from a knowledge base +3. Passes the enriched context to GLM-4.7-Flash for response generation +4. Optionally uses tools like website_downloader for enhanced capabilities -### Website Downloader Tool +Users interact with what appears to be a normal chat experience, while sophisticated RAG operations happen transparently in the background. -The `website_downloader_tool.py` provides a tool interface for downloading and mirroring websites for offline use or RAG ingestion. It can be used by GLM-4.7-Flash via the z-ai-web-dev-sdk. +## Features -#### Features +- **OpenAI-Compatible API**: Works with any OpenAI client (Open WebUI, custom apps, etc.) +- **RAG Integration**: Automatic context retrieval for enhanced responses +- **Document Management**: Upload and manage documents in the knowledge base +- **Tool Support**: Built-in tools like website_downloader for extended capabilities +- **Streaming Support**: Real-time streaming responses +- **Easy Configuration**: Environment-based configuration -- Downloads HTML pages and all linked assets (CSS, JS, images, fonts, etc.) -- Rewrites links for offline viewing -- Supports concurrent downloads with configurable thread count -- Optional external asset downloading from CDNs -- Domain whitelisting for external assets -- Comprehensive error handling and statistics +## Quick Start -#### Tool Schema - -The tool follows the OpenAI function calling format: - -```python -from website_downloader_tool import get_tool_schema, website_downloader - -# Get the tool schema for registration -schema = get_tool_schema() -``` - -#### Usage with GLM-4.7-Flash - -```python -from zai import ZaiClient -from website_downloader_tool import get_tool_schema, website_downloader - -client = ZaiClient(api_key="your-api-key") - -# Define the tool -tools = [get_tool_schema()] - -# Create a chat completion with tools -response = client.chat.completions.create( - model="glm-4.7", - messages=[ - { - "role": "user", - "content": "Please download https://example.com for offline use" - } - ], - tools=tools, - stream=True, -) - -# Handle tool calls in the response -for chunk in response: - if chunk.choices[0].delta.tool_calls: - tool_call = chunk.choices[0].delta.tool_calls[0] - if tool_call.function.name == "website_downloader": - import json - args = json.loads(tool_call.function.arguments) - result = website_downloader(**args) - print(result) -``` - -#### Direct Usage - -```python -from website_downloader_tool import website_downloader - -# Download a website -result = website_downloader( - url="https://example.com", - destination="./downloaded_site", # Optional - max_pages=50, # Max pages to crawl - threads=6, # Concurrent downloads - download_external_assets=False, # Include CDN assets - external_domains=["cdn.example.com"] # Whitelist external domains -) - -if result["success"]: - print(f"Downloaded to: {result['output_directory']}") - print(f"Pages: {result['stats']['pages_crawled']}") - print(f"Assets: {result['stats']['assets_downloaded']}") -else: - print(f"Error: {result['message']}") -``` - -#### Parameters - -| Parameter | Type | Required | Default | Description | -|-----------|------|----------|---------|-------------| -| `url` | string | Yes | - | Starting URL to crawl | -| `destination` | string | No | Derived from URL | Output folder path | -| `max_pages` | integer | No | 50 | Max HTML pages (1-1000) | -| `threads` | integer | No | 6 | Concurrent download threads (1-20) | -| `download_external_assets` | boolean | No | False | Download CDN assets | -| `external_domains` | array | No | None | Whitelist of external domains | - -#### Return Value - -```python -{ - "success": True/False, - "message": "Human-readable summary", - "stats": { - "pages_crawled": int, - "assets_downloaded": int, - "failed_downloads": int, - "elapsed_seconds": float, - "output_directory": str, - "pages": [...], # List of downloaded pages - "downloaded_items": [...] # List of downloaded assets - }, - "output_directory": "/path/to/downloaded/site" -} -``` - -### Website Downloader CLI - -The original `website-downloader.py` can still be used as a standalone CLI tool: - -```bash -python website-downloader.py --url https://example.com --max-pages 50 --threads 6 -``` - -#### CLI Options - -- `--url`: Starting URL to crawl (required) -- `--destination`: Output folder (optional, derived from URL if not provided) -- `--max-pages`: Maximum pages to crawl (default: 50) -- `--threads`: Number of download threads (default: 6) -- `--download-external-assets`: Enable external asset downloading -- `--external-domains`: Whitelist of external domains to download from - -## Installation +### 1. Install Dependencies ```bash pip install -r requirements.txt ``` +### 2. Configure Environment + +```bash +cp .env.example .env +# Edit .env and add your ZAI_API_KEY +``` + +### 3. Run the Server + +```bash +python main.py +``` + +The server will start on `http://0.0.0.0:8000` + +### 4. Use with Open WebUI + +1. Open Open WebUI settings +2. Add a new OpenAI-compatible connection +3. Set the base URL to `http://your-server:8000/v1` +4. Leave the API key empty or use any value (not validated) +5. Select the "DocRAG-GLM-4.7" model + +## API Endpoints + +### OpenAI-Compatible Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/v1/chat/completions` | POST | Chat completions (streaming supported) | +| `/v1/models` | GET | List available models | +| `/v1/models/{model_id}` | GET | Get model information | + +### Document Management Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/v1/documents` | GET | List documents in knowledge base | +| `/v1/documents/upload` | POST | Upload a document | +| `/v1/documents/url` | POST | Add document from URL | +| `/v1/documents/{doc_id}` | DELETE | Delete a document | + +### Health & Status + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/health` | GET | Health check | +| `/` | GET | API information | + +## Usage Examples + +### Chat Completion + +```bash +curl -X POST http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "DocRAG-GLM-4.7", + "messages": [ + {"role": "user", "content": "What is machine learning?"} + ], + "stream": false + }' +``` + +### Upload Document + +```bash +curl -X POST http://localhost:8000/v1/documents/upload \ + -F "file=@document.pdf" +``` + +### Add Document from URL + +```bash +curl -X POST http://localhost:8000/v1/documents/url \ + -H "Content-Type: application/json" \ + -d '{"url": "https://example.com/article.html"}' +``` + +### Python Client + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="not-needed" # API key not validated +) + +response = client.chat.completions.create( + model="DocRAG-GLM-4.7", + messages=[ + {"role": "user", "content": "Explain quantum computing"} + ], + stream=True +) + +for chunk in response: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="") +``` + +## Configuration + +Configure via environment variables or `.env` file: + +| Variable | Default | Description | +|----------|---------|-------------| +| `HOST` | `0.0.0.0` | Server host | +| `PORT` | `8000` | Server port | +| `DEBUG` | `false` | Enable debug mode | +| `MODEL_NAME` | `DocRAG-GLM-4.7` | Display model name | +| `UPSTREAM_MODEL` | `glm-4.7` | Upstream model to use | +| `ZAI_API_KEY` | (required) | API key for ZAI SDK | +| `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model | +| `VECTOR_STORE_PATH` | `./data/vectors` | Vector store location | +| `DOCUMENTS_PATH` | `./data/documents` | Document storage | +| `CHUNK_SIZE` | `1000` | Document chunk size | +| `CHUNK_OVERLAP` | `200` | Chunk overlap | +| `TOP_K_RESULTS` | `5` | Number of context results | +| `ENABLE_TOOLS` | `true` | Enable tool support | + ## Project Structure ``` docrag/ -├── website-downloader.py # Core website downloader (CLI) +├── main.py # FastAPI application entry point +├── rag/ +│ ├── __init__.py # RAG system main class +│ ├── document_processor.py # Document parsing and chunking +│ ├── vector_store.py # Vector storage and search +│ └── retriever.py # Context retrieval logic +├── tools/ +│ └── __init__.py # Tool management (website_downloader, etc.) +├── website-downloader.py # CLI website downloader ├── website_downloader_tool.py # Tool wrapper for GLM-4.7-Flash -├── requirements.txt # Python dependencies -└── README.md # This file +├── requirements.txt # Python dependencies +├── .env.example # Configuration template +└── README.md # This file ``` -## Integration with RAG +## How It Works -The downloaded website content can be processed for RAG systems: +### Request Flow -1. Use the tool to download website content -2. Parse the downloaded HTML files -3. Extract text content and metadata -4. Chunk and embed the content -5. Store in your vector database +1. **User sends message** → OpenAI-compatible endpoint receives request +2. **RAG Retrieval** → Query is processed and relevant context is retrieved +3. **Context Enhancement** → Retrieved context is added to the prompt +4. **Tool Execution** → If needed, tools are invoked (e.g., website_downloader) +5. **LLM Generation** → GLM-4.7-Flash generates response with context +6. **Response** → User receives response (streaming supported) + +### RAG Pipeline + +``` +User Query + │ + ▼ +┌─────────────────┐ +│ Query Processor │ +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ Vector Search │ ← Knowledge Base +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ Context Builder │ +└────────┬────────┘ + │ + ▼ +┌─────────────────┐ +│ GLM-4.7-Flash │ +└────────┬────────┘ + │ + ▼ + Response +``` + +## Supported Document Formats + +- **Text**: `.txt`, `.md`, `.rst`, `.log` +- **Documents**: `.pdf`, `.docx` +- **Web**: `.html`, `.htm` +- **Data**: `.json`, `.yaml`, `.yml`, `.xml`, `.toml`, `.csv`, `.tsv` +- **Code**: `.py`, `.js`, `.ts`, `.java`, `.cpp`, `.c`, `.go`, `.rs`, `.rb`, `.php`, etc. + +## Extending + +### Adding New Tools + +```python +# In tools/__init__.py + +def my_custom_tool(param1: str, param2: int = 10) -> dict: + """Your tool implementation.""" + return {"result": "success"} + +# Register the tool +tool_manager.register_tool( + name="my_custom_tool", + function=my_custom_tool, + schema={ + "type": "function", + "function": { + "name": "my_custom_tool", + "description": "Description of your tool", + "parameters": { + "type": "object", + "properties": { + "param1": {"type": "string", "description": "..."}, + "param2": {"type": "integer", "description": "...", "default": 10} + }, + "required": ["param1"] + } + } + } +) +``` + +### Using Different Vector Stores + +The default implementation uses a simple file-based store. To use ChromaDB: + +1. Install: `pip install chromadb` +2. Modify `rag/vector_store.py` to use ChromaDB client + +## Development + +### Running in Development Mode + +```bash +DEBUG=true python main.py +``` + +### Running Tests + +```bash +pip install pytest pytest-asyncio +pytest tests/ +``` ## License diff --git a/main.py b/main.py index ee79b19..3aff528 100644 --- a/main.py +++ b/main.py @@ -1 +1,731 @@ -# build out the full app the hides the fact that you are a rag app and just looks like a normal openAI model as fare as a user is concerned when interacting with it with open-webui. but in reality it is a rag app that is doing all the work in the background then provides the information to GLM 4.7-flash. \ No newline at end of file +#!/usr/bin/env python3 +""" +DocRAG - OpenAI-Compatible RAG Server + +This application presents itself as a standard OpenAI API server that can be used +with any OpenAI-compatible client (like Open WebUI). Behind the scenes, it: +1. Processes user queries through a RAG system +2. Retrieves relevant context from a knowledge base +3. Passes the enriched context to GLM-4.7-Flash for response generation +4. Optionally uses tools like website_downloader for enhanced capabilities + +The user sees a normal chat experience, but the system is actually doing +sophisticated RAG operations in the background. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import sys +import time +import uuid +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator, Optional + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", + datefmt="%H:%M:%S", +) +log = logging.getLogger(__name__) + +# FastAPI imports +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse, JSONResponse +from pydantic import BaseModel, Field + +# Import RAG components +from rag import RAGSystem, get_rag_system +from rag.document_processor import DocumentProcessor + +# Import tools +from tools import ToolManager, get_tool_manager + +# Import SDK for GLM-4.7-Flash +try: + from zai import ZaiClient as ZAI +except ImportError: + ZAI = None + log.warning("z-ai-web-dev-sdk not installed. Install with: pip install z-ai-web-dev-sdk") + + +# ============================================================================= +# Configuration +# ============================================================================= + +class Config: + """Application configuration from environment variables.""" + + # Server settings + HOST: str = os.getenv("HOST", "0.0.0.0") + PORT: int = int(os.getenv("PORT", "8000")) + DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true" + + # Model settings + MODEL_NAME: str = os.getenv("MODEL_NAME", "DocRAG-GLM-4.7") + UPSTREAM_MODEL: str = os.getenv("UPSTREAM_MODEL", "glm-4.7") + + # API Key for upstream LLM + ZAI_API_KEY: str = os.getenv("ZAI_API_KEY", "") + + # RAG settings + EMBEDDING_MODEL: str = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small") + VECTOR_STORE_PATH: str = os.getenv("VECTOR_STORE_PATH", "./data/vectors") + DOCUMENTS_PATH: str = os.getenv("DOCUMENTS_PATH", "./data/documents") + CHUNK_SIZE: int = int(os.getenv("CHUNK_SIZE", "1000")) + CHUNK_OVERLAP: int = int(os.getenv("CHUNK_OVERLAP", "200")) + TOP_K_RESULTS: int = int(os.getenv("TOP_K_RESULTS", "5")) + + # Tool settings + ENABLE_TOOLS: bool = os.getenv("ENABLE_TOOLS", "true").lower() == "true" + MAX_TOOL_ITERATIONS: int = int(os.getenv("MAX_TOOL_ITERATIONS", "3")) + + +config = Config() + + +# ============================================================================= +# OpenAI-Compatible Models +# ============================================================================= + +class ChatMessage(BaseModel): + """OpenAI chat message format.""" + role: str + content: Optional[str] = None + name: Optional[str] = None + tool_calls: Optional[list[dict]] = None + tool_call_id: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + """OpenAI chat completion request format.""" + model: str = "DocRAG-GLM-4.7" + messages: list[ChatMessage] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stream: Optional[bool] = False + stop: Optional[list[str]] = None + max_tokens: Optional[int] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + user: Optional[str] = None + tools: Optional[list[dict]] = None + tool_choice: Optional[str | dict] = None + + +class ChatCompletionChoice(BaseModel): + """OpenAI chat completion choice.""" + index: int = 0 + message: ChatMessage + finish_reason: str = "stop" + + +class ChatCompletionUsage(BaseModel): + """Token usage statistics.""" + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + + +class ChatCompletionResponse(BaseModel): + """OpenAI chat completion response format.""" + id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:24]}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str = config.MODEL_NAME + choices: list[ChatCompletionChoice] + usage: ChatCompletionUsage = ChatCompletionUsage() + + +class ModelInfo(BaseModel): + """OpenAI model info format.""" + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "organization" + + +class ModelList(BaseModel): + """OpenAI model list format.""" + object: str = "list" + data: list[ModelInfo] + + +# ============================================================================= +# Application State +# ============================================================================= + +class AppState: + """Global application state.""" + rag_system: Optional[RAGSystem] = None + tool_manager: Optional[ToolManager] = None + zai_client: Any = None + startup_time: float = time.time() + + +state = AppState() + + +# ============================================================================= +# Lifespan Management +# ============================================================================= + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Manage application lifespan - startup and shutdown.""" + log.info("Starting DocRAG server...") + + # Initialize RAG system + try: + log.info("Initializing RAG system...") + state.rag_system = await get_rag_system( + embedding_model=config.EMBEDDING_MODEL, + vector_store_path=config.VECTOR_STORE_PATH, + documents_path=config.DOCUMENTS_PATH, + chunk_size=config.CHUNK_SIZE, + chunk_overlap=config.CHUNK_OVERLAP, + ) + log.info("RAG system initialized successfully") + except Exception as e: + log.warning(f"RAG system initialization deferred: {e}") + state.rag_system = None + + # Initialize tool manager + try: + log.info("Initializing tool manager...") + state.tool_manager = get_tool_manager() + log.info(f"Tool manager initialized with tools: {state.tool_manager.list_tools()}") + except Exception as e: + log.warning(f"Tool manager initialization failed: {e}") + state.tool_manager = None + + # Initialize ZAI client for upstream LLM + try: + if config.ZAI_API_KEY and ZAI is not None: + log.info("Initializing ZAI client...") + state.zai_client = ZAI(api_key=config.ZAI_API_KEY) + log.info("ZAI client initialized successfully") + else: + log.warning("No ZAI_API_KEY provided or SDK not installed - using mock responses") + state.zai_client = None + except Exception as e: + log.error(f"Failed to initialize ZAI client: {e}") + state.zai_client = None + + log.info(f"DocRAG server started on {config.HOST}:{config.PORT}") + + yield + + # Cleanup + log.info("Shutting down DocRAG server...") + if state.rag_system: + await state.rag_system.close() + log.info("DocRAG server stopped") + + +# ============================================================================= +# FastAPI Application +# ============================================================================= + +app = FastAPI( + title="DocRAG API", + description="OpenAI-compatible RAG server powered by GLM-4.7-Flash", + version="1.0.0", + lifespan=lifespan, +) + +# CORS middleware for Open WebUI compatibility +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# ============================================================================= +# OpenAI-Compatible Endpoints +# ============================================================================= + +@app.get("/v1/models") +@app.get("/models") +async def list_models(): + """List available models (OpenAI-compatible).""" + return ModelList( + data=[ + ModelInfo(id=config.MODEL_NAME, owned_by="docrag"), + ModelInfo(id="DocRAG-GLM-4.7", owned_by="docrag"), + ] + ) + + +@app.get("/v1/models/{model_id}") +@app.get("/models/{model_id}") +async def get_model(model_id: str): + """Get model information (OpenAI-compatible).""" + if model_id not in [config.MODEL_NAME, "DocRAG-GLM-4.7"]: + raise HTTPException(status_code=404, detail="Model not found") + return ModelInfo(id=model_id, owned_by="docrag") + + +@app.post("/v1/chat/completions") +@app.post("/chat/completions") +async def chat_completions(request: ChatCompletionRequest): + """Handle chat completions (OpenAI-compatible).""" + request_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" + + try: + if request.stream: + return StreamingResponse( + stream_chat_completion(request, request_id), + media_type="text/event-stream", + ) + else: + return await complete_chat(request, request_id) + except Exception as e: + log.exception("Chat completion failed") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Chat Completion Logic +# ============================================================================= + +async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse: + """Process a non-streaming chat completion request.""" + messages = request.messages + + # Extract the last user message + user_message = "" + for msg in reversed(messages): + if msg.role == "user" and msg.content: + user_message = msg.content + break + + if not user_message: + raise HTTPException(status_code=400, detail="No user message found") + + # Step 1: RAG Retrieval + context = "" + sources = [] + if state.rag_system: + try: + rag_result = await state.rag_system.query( + query=user_message, + top_k=config.TOP_K_RESULTS, + ) + context = rag_result.get("context", "") + sources = rag_result.get("sources", []) + log.info(f"RAG retrieved {len(sources)} relevant documents") + except Exception as e: + log.warning(f"RAG retrieval failed: {e}") + + # Step 2: Build enhanced prompt with context + enhanced_messages = build_enhanced_messages(messages, context, sources) + + # Step 3: Check for tool usage + tool_calls_made = [] + if config.ENABLE_TOOLS and state.tool_manager: + tool_calls_made = await check_and_execute_tools( + user_message, enhanced_messages, request.tools + ) + + # Step 4: Generate response with upstream LLM + response_content = await generate_response( + enhanced_messages, + temperature=request.temperature, + max_tokens=request.max_tokens, + tools=request.tools, + ) + + # Step 5: Build and return response + return ChatCompletionResponse( + id=request_id, + model=config.MODEL_NAME, + choices=[ + ChatCompletionChoice( + message=ChatMessage( + role="assistant", + content=response_content, + tool_calls=tool_calls_made if tool_calls_made else None, + ), + finish_reason="stop", + ) + ], + usage=ChatCompletionUsage( + prompt_tokens=len(str(enhanced_messages)) // 4, + completion_tokens=len(response_content) // 4, + ), + ) + + +async def stream_chat_completion( + request: ChatCompletionRequest, + request_id: str, +) -> AsyncIterator[str]: + """Stream a chat completion response.""" + messages = request.messages + + # Extract the last user message + user_message = "" + for msg in reversed(messages): + if msg.role == "user" and msg.content: + user_message = msg.content + break + + if not user_message: + yield f"data: {json.dumps({'error': 'No user message found'})}\n\n" + return + + # Step 1: RAG Retrieval + context = "" + sources = [] + if state.rag_system: + try: + rag_result = await state.rag_system.query( + query=user_message, + top_k=config.TOP_K_RESULTS, + ) + context = rag_result.get("context", "") + sources = rag_result.get("sources", []) + log.info(f"RAG retrieved {len(sources)} relevant documents") + except Exception as e: + log.warning(f"RAG retrieval failed: {e}") + + # Step 2: Build enhanced prompt with context + enhanced_messages = build_enhanced_messages(messages, context, sources) + + # Step 3: Stream response from upstream LLM + created = int(time.time()) + + try: + if state.zai_client: + # Use actual GLM-4.7-Flash + response = state.zai_client.chat.completions.create( + model=config.UPSTREAM_MODEL, + messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content], + temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 4096, + stream=True, + thinking={"type": "enabled"}, + ) + + for chunk in response: + # Handle reasoning content (thinking) + if hasattr(chunk.choices[0].delta, 'reasoning_content') and chunk.choices[0].delta.reasoning_content: + # Don't expose thinking to user - this is internal RAG processing + log.debug(f"Thinking: {chunk.choices[0].delta.reasoning_content[:100]}...") + continue + + # Stream actual content + if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + yield f"data: {json.dumps({ + 'id': request_id, + 'object': 'chat.completion.chunk', + 'created': created, + 'model': config.MODEL_NAME, + 'choices': [{ + 'index': 0, + 'delta': {'content': content}, + 'finish_reason': None + }] + })}\n\n" + + # Send final chunk + yield f"data: {json.dumps({ + 'id': request_id, + 'object': 'chat.completion.chunk', + 'created': created, + 'model': config.MODEL_NAME, + 'choices': [{ + 'index': 0, + 'delta': {}, + 'finish_reason': 'stop' + }] + })}\n\n" + yield "data: [DONE]\n\n" + + else: + # Mock streaming response for testing + mock_response = f"I understand you're asking about: {user_message}\n\n" + if context: + mock_response += f"Based on my knowledge base, I found relevant information that I'm using to help answer your question.\n\n" + mock_response += "However, I'm currently running in demo mode without an upstream LLM connection. Please configure ZAI_API_KEY for full functionality." + + for char in mock_response: + yield f"data: {json.dumps({ + 'id': request_id, + 'object': 'chat.completion.chunk', + 'created': created, + 'model': config.MODEL_NAME, + 'choices': [{ + 'index': 0, + 'delta': {'content': char}, + 'finish_reason': None + }] + })}\n\n" + await asyncio.sleep(0.01) + + yield f"data: {json.dumps({ + 'id': request_id, + 'object': 'chat.completion.chunk', + 'created': created, + 'model': config.MODEL_NAME, + 'choices': [{ + 'index': 0, + 'delta': {}, + 'finish_reason': 'stop' + }] + })}\n\n" + yield "data: [DONE]\n\n" + + except Exception as e: + log.exception("Streaming failed") + yield f"data: {json.dumps({'error': str(e)})}\n\n" + + +def build_enhanced_messages( + messages: list[ChatMessage], + context: str, + sources: list[str], +) -> list[ChatMessage]: + """Build enhanced messages with RAG context.""" + enhanced = [] + + # Add system message with RAG context + system_content = ( + "You are a helpful AI assistant with access to a knowledge base. " + "Use the provided context to give accurate and helpful responses. " + "If the context doesn't contain relevant information, use your general knowledge " + "but indicate when you're doing so." + ) + + if context: + system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n" + if sources: + system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources) + + enhanced.append(ChatMessage(role="system", content=system_content)) + + # Add conversation history (excluding old system messages) + for msg in messages: + if msg.role != "system": + enhanced.append(msg) + + return enhanced + + +async def generate_response( + messages: list[ChatMessage], + temperature: float = 0.7, + max_tokens: int = 4096, + tools: Optional[list[dict]] = None, +) -> str: + """Generate response using upstream LLM.""" + if state.zai_client: + try: + response = state.zai_client.chat.completions.create( + model=config.UPSTREAM_MODEL, + messages=[{"role": m.role, "content": m.content} for m in messages if m.content], + temperature=temperature, + max_tokens=max_tokens, + stream=False, + thinking={"type": "enabled"}, + ) + + # Extract content from response + content = "" + for chunk in response: + if hasattr(chunk.choices[0], 'message') and chunk.choices[0].message: + content = chunk.choices[0].message.content or "" + break + if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: + content += chunk.choices[0].delta.content + + return content or "I apologize, but I couldn't generate a response." + + except Exception as e: + log.error(f"Upstream LLM call failed: {e}") + return f"I encountered an error: {str(e)}" + + else: + # Mock response for testing + user_msg = "" + for msg in reversed(messages): + if msg.role == "user" and msg.content: + user_msg = msg.content + break + return f"Demo mode response. Your question: {user_msg[:100]}... Configure ZAI_API_KEY for full functionality." + + +async def check_and_execute_tools( + user_message: str, + messages: list[ChatMessage], + available_tools: Optional[list[dict]], +) -> list[dict]: + """Check if tools should be used and execute them.""" + if not state.tool_manager or not available_tools: + return [] + + tool_calls = [] + + # Simple keyword-based tool detection + # In a production system, you'd use the LLM to decide tool usage + message_lower = user_message.lower() + + # Check for website download intent + if any(kw in message_lower for kw in ["download website", "mirror site", "crawl", "archive site"]): + # Extract URL from message + import re + url_pattern = r'https?://[^\s]+' + urls = re.findall(url_pattern, user_message) + + if urls: + tool_result = state.tool_manager.execute_tool( + "website_downloader", + {"url": urls[0], "max_pages": 10} + ) + tool_calls.append({ + "id": f"call_{uuid.uuid4().hex[:24]}", + "type": "function", + "function": { + "name": "website_downloader", + "arguments": json.dumps({"url": urls[0]}), + } + }) + log.info(f"Executed website_downloader tool: {tool_result}") + + return tool_calls + + +# ============================================================================= +# Document Management Endpoints +# ============================================================================= + +@app.post("/v1/documents/upload") +async def upload_document(request: Request): + """Upload a document to the knowledge base.""" + if not state.rag_system: + raise HTTPException(status_code=503, detail="RAG system not initialized") + + try: + form = await request.form() + file = form.get("file") + if not file: + raise HTTPException(status_code=400, detail="No file provided") + + content = await file.read() + filename = file.filename or "unknown" + + # Process and store document + result = await state.rag_system.add_document( + content=content, + filename=filename, + ) + + return {"success": True, "message": f"Document '{filename}' added", "chunks": result.get("chunks", 0)} + + except Exception as e: + log.exception("Document upload failed") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/v1/documents/url") +async def add_document_from_url(request: dict): + """Add a document from URL to the knowledge base.""" + if not state.rag_system: + raise HTTPException(status_code=503, detail="RAG system not initialized") + + url = request.get("url") + if not url: + raise HTTPException(status_code=400, detail="No URL provided") + + try: + result = await state.rag_system.add_document_from_url(url) + return {"success": True, "message": f"Document from {url} added", "chunks": result.get("chunks", 0)} + except Exception as e: + log.exception("URL document addition failed") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/v1/documents") +async def list_documents(): + """List documents in the knowledge base.""" + if not state.rag_system: + raise HTTPException(status_code=503, detail="RAG system not initialized") + + try: + docs = await state.rag_system.list_documents() + return {"documents": docs} + except Exception as e: + log.exception("Document listing failed") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.delete("/v1/documents/{doc_id}") +async def delete_document(doc_id: str): + """Delete a document from the knowledge base.""" + if not state.rag_system: + raise HTTPException(status_code=503, detail="RAG system not initialized") + + try: + await state.rag_system.delete_document(doc_id) + return {"success": True, "message": f"Document {doc_id} deleted"} + except Exception as e: + log.exception("Document deletion failed") + raise HTTPException(status_code=500, detail=str(e)) + + +# ============================================================================= +# Health and Status Endpoints +# ============================================================================= + +@app.get("/health") +async def health_check(): + """Health check endpoint.""" + return { + "status": "healthy", + "uptime": time.time() - state.startup_time, + "rag_enabled": state.rag_system is not None, + "tools_enabled": state.tool_manager is not None, + "llm_connected": state.zai_client is not None, + } + + +@app.get("/") +async def root(): + """Root endpoint with API info.""" + return { + "name": "DocRAG API", + "version": "1.0.0", + "description": "OpenAI-compatible RAG server powered by GLM-4.7-Flash", + "endpoints": { + "chat": "/v1/chat/completions", + "models": "/v1/models", + "documents": "/v1/documents", + "health": "/health", + }, + } + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "main:app", + host=config.HOST, + port=config.PORT, + reload=config.DEBUG, + ) diff --git a/rag/__init__.py b/rag/__init__.py new file mode 100644 index 0000000..b6c4e91 --- /dev/null +++ b/rag/__init__.py @@ -0,0 +1,252 @@ +""" +RAG System - Retrieval Augmented Generation + +This module provides the core RAG functionality for DocRAG, including: +- Document processing and chunking +- Vector storage and similarity search +- Context retrieval for enhanced prompts +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from pathlib import Path +from typing import Any, Optional + +from .document_processor import DocumentProcessor +from .vector_store import VectorStore +from .retriever import Retriever + +log = logging.getLogger(__name__) + + +class RAGSystem: + """ + Main RAG system that coordinates document processing, storage, and retrieval. + + This class provides a unified interface for: + - Adding documents to the knowledge base + - Querying for relevant context + - Managing the document lifecycle + """ + + def __init__( + self, + embedding_model: str = "text-embedding-3-small", + vector_store_path: str = "./data/vectors", + documents_path: str = "./data/documents", + chunk_size: int = 1000, + chunk_overlap: int = 200, + ): + self.embedding_model = embedding_model + self.vector_store_path = Path(vector_store_path) + self.documents_path = Path(documents_path) + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + self._initialized = False + self._document_processor: Optional[DocumentProcessor] = None + self._vector_store: Optional[VectorStore] = None + self._retriever: Optional[Retriever] = None + + async def initialize(self) -> None: + """Initialize the RAG system components.""" + if self._initialized: + return + + log.info("Initializing RAG system...") + + # Create directories + self.vector_store_path.mkdir(parents=True, exist_ok=True) + self.documents_path.mkdir(parents=True, exist_ok=True) + + # Initialize document processor + self._document_processor = DocumentProcessor( + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + ) + + # Initialize vector store + self._vector_store = VectorStore( + persist_directory=str(self.vector_store_path), + embedding_model=self.embedding_model, + ) + + # Initialize retriever + self._retriever = Retriever( + vector_store=self._vector_store, + ) + + self._initialized = True + log.info("RAG system initialized successfully") + + async def close(self) -> None: + """Close the RAG system and release resources.""" + if self._vector_store: + await self._vector_store.close() + self._initialized = False + log.info("RAG system closed") + + def _ensure_initialized(self) -> None: + """Ensure the RAG system is initialized.""" + if not self._initialized: + raise RuntimeError("RAG system not initialized. Call initialize() first.") + + async def add_document( + self, + content: bytes, + filename: str, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """ + Add a document to the knowledge base. + + Args: + content: Raw document content + filename: Original filename + metadata: Optional metadata + + Returns: + Dictionary with processing results + """ + self._ensure_initialized() + + # Process document + doc_info = await self._document_processor.process( + content=content, + filename=filename, + metadata=metadata, + ) + + # Store chunks in vector store + if doc_info.get("chunks"): + await self._vector_store.add_chunks( + chunks=doc_info["chunks"], + metadatas=doc_info.get("metadatas", []), + ids=doc_info.get("ids", []), + ) + + log.info(f"Added document '{filename}' with {len(doc_info.get('chunks', []))} chunks") + return {"chunks": len(doc_info.get("chunks", [])), "document_id": doc_info.get("document_id")} + + async def add_document_from_url(self, url: str) -> dict[str, Any]: + """ + Add a document from a URL to the knowledge base. + + Args: + url: URL to fetch and process + + Returns: + Dictionary with processing results + """ + self._ensure_initialized() + + # Fetch content from URL + import aiohttp + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=30) as response: + response.raise_for_status() + content = await response.read() + + # Extract filename from URL + from urllib.parse import urlparse + parsed = urlparse(url) + filename = os.path.basename(parsed.path) or "webpage.html" + + return await self.add_document(content=content, filename=filename, metadata={"source_url": url}) + + async def query( + self, + query: str, + top_k: int = 5, + filter_metadata: Optional[dict] = None, + ) -> dict[str, Any]: + """ + Query the knowledge base for relevant context. + + Args: + query: Query string + top_k: Number of results to return + filter_metadata: Optional metadata filters + + Returns: + Dictionary with context and sources + """ + self._ensure_initialized() + + # Retrieve relevant chunks + results = await self._retriever.retrieve( + query=query, + top_k=top_k, + filter_metadata=filter_metadata, + ) + + # Build context string + context_parts = [] + sources = [] + + for i, result in enumerate(results): + context_parts.append(f"[{i+1}] {result['content']}") + if result.get("metadata", {}).get("source"): + sources.append(result["metadata"]["source"]) + + context = "\n\n".join(context_parts) + + return { + "context": context, + "sources": list(set(sources)), + "num_results": len(results), + "results": results, + } + + async def list_documents(self) -> list[dict[str, Any]]: + """List all documents in the knowledge base.""" + self._ensure_initialized() + return await self._vector_store.list_documents() + + async def delete_document(self, document_id: str) -> None: + """Delete a document from the knowledge base.""" + self._ensure_initialized() + await self._vector_store.delete_document(document_id) + log.info(f"Deleted document {document_id}") + + +# Global RAG system instance +_rag_system: Optional[RAGSystem] = None + + +async def get_rag_system( + embedding_model: str = "text-embedding-3-small", + vector_store_path: str = "./data/vectors", + documents_path: str = "./data/documents", + chunk_size: int = 1000, + chunk_overlap: int = 200, +) -> RAGSystem: + """ + Get or create the global RAG system instance. + + Args: + embedding_model: Name of the embedding model + vector_store_path: Path to vector store + documents_path: Path to document storage + chunk_size: Size of document chunks + chunk_overlap: Overlap between chunks + + Returns: + Initialized RAGSystem instance + """ + global _rag_system + + if _rag_system is None: + _rag_system = RAGSystem( + embedding_model=embedding_model, + vector_store_path=vector_store_path, + documents_path=documents_path, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + await _rag_system.initialize() + + return _rag_system diff --git a/rag/document_processor.py b/rag/document_processor.py new file mode 100644 index 0000000..0cdb838 --- /dev/null +++ b/rag/document_processor.py @@ -0,0 +1,247 @@ +""" +Document Processor - Handles document parsing and chunking + +Supports multiple document formats: +- Plain text (.txt, .md) +- PDF (.pdf) +- HTML (.html, .htm) +- Word documents (.docx) +- Code files (.py, .js, etc.) +""" + +from __future__ import annotations + +import hashlib +import logging +import os +import re +import uuid +from pathlib import Path +from typing import Any, Optional + +log = logging.getLogger(__name__) + + +class DocumentProcessor: + """ + Process documents into chunks suitable for vector storage. + + Handles: + - Multiple file formats + - Intelligent chunking with overlap + - Metadata extraction + """ + + def __init__( + self, + chunk_size: int = 1000, + chunk_overlap: int = 200, + ): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + + async def process( + self, + content: bytes, + filename: str, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """ + Process a document into chunks. + + Args: + content: Raw document content + filename: Original filename + metadata: Optional additional metadata + + Returns: + Dictionary with chunks, metadatas, and ids + """ + # Extract text based on file type + text = await self._extract_text(content, filename) + + if not text.strip(): + return {"chunks": [], "metadatas": [], "ids": [], "document_id": None} + + # Generate document ID + document_id = str(uuid.uuid4()) + + # Create chunks + chunks = self._chunk_text(text) + + # Create metadata for each chunk + base_metadata = { + "source": filename, + "document_id": document_id, + **(metadata or {}), + } + + metadatas = [] + ids = [] + + for i, chunk in enumerate(chunks): + chunk_id = hashlib.md5(f"{document_id}_{i}".encode()).hexdigest() + ids.append(chunk_id) + metadatas.append({ + **base_metadata, + "chunk_index": i, + "chunk_length": len(chunk), + }) + + return { + "chunks": chunks, + "metadatas": metadatas, + "ids": ids, + "document_id": document_id, + "total_chars": len(text), + } + + async def _extract_text(self, content: bytes, filename: str) -> str: + """Extract text from document based on file type.""" + ext = Path(filename).suffix.lower() + + try: + if ext in (".txt", ".md", ".rst", ".log"): + return content.decode("utf-8", errors="ignore") + + elif ext == ".pdf": + return await self._extract_pdf(content) + + elif ext in (".html", ".htm"): + return await self._extract_html(content) + + elif ext == ".docx": + return await self._extract_docx(content) + + elif ext in (".py", ".js", ".ts", ".java", ".cpp", ".c", ".go", ".rs", ".rb", ".php", ".cs", ".swift", ".kt"): + return content.decode("utf-8", errors="ignore") + + elif ext in (".json", ".yaml", ".yml", ".xml", ".toml"): + return content.decode("utf-8", errors="ignore") + + elif ext in (".csv", ".tsv"): + return content.decode("utf-8", errors="ignore") + + else: + # Try to decode as text + try: + return content.decode("utf-8", errors="ignore") + except Exception: + log.warning(f"Unknown file type: {ext}, treating as binary") + return "" + + except Exception as e: + log.error(f"Failed to extract text from {filename}: {e}") + return "" + + async def _extract_pdf(self, content: bytes) -> str: + """Extract text from PDF.""" + try: + import fitz # PyMuPDF + doc = fitz.open(stream=content, filetype="pdf") + text_parts = [] + for page in doc: + text_parts.append(page.get_text()) + doc.close() + return "\n\n".join(text_parts) + except ImportError: + log.warning("PyMuPDF not installed, PDF extraction unavailable") + return "" + except Exception as e: + log.error(f"PDF extraction failed: {e}") + return "" + + async def _extract_html(self, content: bytes) -> str: + """Extract text from HTML.""" + try: + from bs4 import BeautifulSoup + soup = BeautifulSoup(content, "html.parser") + # Remove script and style elements + for element in soup(["script", "style", "nav", "footer", "header"]): + element.decompose() + return soup.get_text(separator="\n", strip=True) + except ImportError: + log.warning("BeautifulSoup not installed, HTML extraction unavailable") + return content.decode("utf-8", errors="ignore") + except Exception as e: + log.error(f"HTML extraction failed: {e}") + return "" + + async def _extract_docx(self, content: bytes) -> str: + """Extract text from DOCX.""" + try: + import io + from docx import Document + doc = Document(io.BytesIO(content)) + return "\n\n".join(para.text for para in doc.paragraphs) + except ImportError: + log.warning("python-docx not installed, DOCX extraction unavailable") + return "" + except Exception as e: + log.error(f"DOCX extraction failed: {e}") + return "" + + def _chunk_text(self, text: str) -> list[str]: + """ + Split text into overlapping chunks. + + Uses a sentence-aware chunking strategy to avoid breaking mid-sentence. + """ + if len(text) <= self.chunk_size: + return [text.strip()] if text.strip() else [] + + # Split into sentences + sentences = self._split_sentences(text) + + chunks = [] + current_chunk = [] + current_length = 0 + + for sentence in sentences: + sentence_length = len(sentence) + + # If adding this sentence would exceed chunk size + if current_length + sentence_length > self.chunk_size and current_chunk: + # Save current chunk + chunks.append(" ".join(current_chunk)) + + # Start new chunk with overlap + overlap_text = self._get_overlap_text(current_chunk) + current_chunk = [overlap_text, sentence] if overlap_text else [sentence] + current_length = len(" ".join(current_chunk)) + else: + current_chunk.append(sentence) + current_length += sentence_length + 1 # +1 for space + + # Add final chunk + if current_chunk: + chunks.append(" ".join(current_chunk)) + + return [c.strip() for c in chunks if c.strip()] + + def _split_sentences(self, text: str) -> list[str]: + """Split text into sentences.""" + # Simple sentence splitting - can be improved with NLP libraries + sentence_endings = r'(?<=[.!?])\s+' + sentences = re.split(sentence_endings, text) + return [s.strip() for s in sentences if s.strip()] + + def _get_overlap_text(self, chunk_parts: list[str]) -> str: + """Get text for overlap from the end of the current chunk.""" + if not chunk_parts: + return "" + + full_text = " ".join(chunk_parts) + + if len(full_text) <= self.chunk_overlap: + return full_text + + # Get last N characters + overlap = full_text[-self.chunk_overlap:] + + # Try to start at a word boundary + space_idx = overlap.find(" ") + if space_idx > 0: + overlap = overlap[space_idx + 1:] + + return overlap diff --git a/rag/retriever.py b/rag/retriever.py new file mode 100644 index 0000000..fad8071 --- /dev/null +++ b/rag/retriever.py @@ -0,0 +1,176 @@ +""" +Retriever - Handles context retrieval from the vector store + +Provides intelligent retrieval with: +- Query optimization +- Result ranking +- Context windowing +""" + +from __future__ import annotations + +import logging +from typing import Any, Optional + +from .vector_store import VectorStore + +log = logging.getLogger(__name__) + + +class Retriever: + """ + Retriever for fetching relevant context from the vector store. + + Handles: + - Query preprocessing + - Similarity search + - Result ranking and filtering + """ + + def __init__( + self, + vector_store: VectorStore, + default_top_k: int = 5, + min_score: float = 0.0, + ): + self.vector_store = vector_store + self.default_top_k = default_top_k + self.min_score = min_score + + async def retrieve( + self, + query: str, + top_k: Optional[int] = None, + filter_metadata: Optional[dict] = None, + ) -> list[dict[str, Any]]: + """ + Retrieve relevant chunks for a query. + + Args: + query: Query string + top_k: Number of results (uses default if not provided) + filter_metadata: Optional metadata filters + + Returns: + List of relevant chunks with scores + """ + top_k = top_k or self.default_top_k + + # Preprocess query + processed_query = self._preprocess_query(query) + + # Search vector store + results = await self.vector_store.search( + query=processed_query, + top_k=top_k * 2, # Get more results for filtering + filter_metadata=filter_metadata, + ) + + # Filter by minimum score + results = [r for r in results if r["score"] >= self.min_score] + + # Rank and deduplicate + results = self._rank_results(results, query) + + # Return top_k + return results[:top_k] + + def _preprocess_query(self, query: str) -> str: + """ + Preprocess query for better retrieval. + + - Remove extra whitespace + - Handle special characters + - Normalize case + """ + # Remove extra whitespace + query = " ".join(query.split()) + + # Remove question marks and other punctuation that might hurt matching + query = query.replace("?", " ").replace("!", " ") + + # Normalize whitespace again + query = " ".join(query.split()) + + return query.strip() + + def _rank_results( + self, + results: list[dict[str, Any]], + query: str, + ) -> list[dict[str, Any]]: + """ + Rank results by relevance. + + Uses a combination of: + - Vector similarity score + - Keyword matching + - Document diversity + """ + if not results: + return results + + # Calculate additional scores + query_words = set(query.lower().split()) + + for result in results: + content = result["content"].lower() + content_words = set(content.split()) + + # Keyword overlap score + overlap = len(query_words & content_words) + keyword_score = overlap / max(len(query_words), 1) + + # Combine scores + result["combined_score"] = ( + result["score"] * 0.7 + # Vector similarity + keyword_score * 0.3 # Keyword matching + ) + + # Sort by combined score + results.sort(key=lambda x: x["combined_score"], reverse=True) + + # Remove duplicate content (keep highest scoring) + seen_content = set() + unique_results = [] + + for result in results: + # Use first 100 chars as content fingerprint + content_fingerprint = result["content"][:100] + + if content_fingerprint not in seen_content: + seen_content.add(content_fingerprint) + unique_results.append(result) + + return unique_results + + async def retrieve_with_context( + self, + query: str, + top_k: int = 5, + context_window: int = 1, + ) -> dict[str, Any]: + """ + Retrieve chunks with surrounding context. + + Args: + query: Query string + top_k: Number of main results + context_window: Number of adjacent chunks to include + + Returns: + Dictionary with expanded context + """ + results = await self.retrieve(query=query, top_k=top_k) + + # For now, return basic results + # In a full implementation, we'd expand to include adjacent chunks + return { + "results": results, + "context": "\n\n".join(r["content"] for r in results), + "sources": list(set( + r.get("metadata", {}).get("source", "") + for r in results + if r.get("metadata", {}).get("source") + )), + } diff --git a/rag/vector_store.py b/rag/vector_store.py new file mode 100644 index 0000000..ff8b181 --- /dev/null +++ b/rag/vector_store.py @@ -0,0 +1,285 @@ +""" +Vector Store - Handles vector storage and similarity search + +Provides a simple file-based vector store that can be extended to use +more sophisticated backends like ChromaDB, FAISS, or Pinecone. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import os +from pathlib import Path +from typing import Any, Optional + +log = logging.getLogger(__name__) + + +class VectorStore: + """ + Vector store for document embeddings. + + This implementation provides: + - Simple file-based persistence + - In-memory similarity search + - Document management + + Can be extended to use ChromaDB, FAISS, or other vector databases. + """ + + def __init__( + self, + persist_directory: str = "./data/vectors", + embedding_model: str = "text-embedding-3-small", + ): + self.persist_directory = Path(persist_directory) + self.embedding_model = embedding_model + + self._chunks: list[dict[str, Any]] = [] + self._embeddings: list[list[float]] = [] + self._metadata: list[dict[str, Any]] = [] + self._ids: list[str] = [] + + self._initialized = False + + async def initialize(self) -> None: + """Initialize the vector store and load existing data.""" + if self._initialized: + return + + self.persist_directory.mkdir(parents=True, exist_ok=True) + + # Load existing data + await self._load() + + self._initialized = True + log.info(f"Vector store initialized with {len(self._chunks)} chunks") + + async def close(self) -> None: + """Save and close the vector store.""" + await self._save() + log.info("Vector store closed") + + async def _load(self) -> None: + """Load data from disk.""" + data_file = self.persist_directory / "store.json" + + if not data_file.exists(): + return + + try: + with open(data_file, "r", encoding="utf-8") as f: + data = json.load(f) + + self._chunks = data.get("chunks", []) + self._embeddings = data.get("embeddings", []) + self._metadata = data.get("metadata", []) + self._ids = data.get("ids", []) + + log.info(f"Loaded {len(self._chunks)} chunks from disk") + + except Exception as e: + log.error(f"Failed to load vector store: {e}") + + async def _save(self) -> None: + """Save data to disk.""" + data_file = self.persist_directory / "store.json" + + try: + data = { + "chunks": self._chunks, + "embeddings": self._embeddings, + "metadata": self._metadata, + "ids": self._ids, + } + + with open(data_file, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=2) + + log.info(f"Saved {len(self._chunks)} chunks to disk") + + except Exception as e: + log.error(f"Failed to save vector store: {e}") + + def _ensure_initialized(self) -> None: + """Ensure the vector store is initialized.""" + if not self._initialized: + raise RuntimeError("Vector store not initialized") + + async def add_chunks( + self, + chunks: list[str], + metadatas: Optional[list[dict[str, Any]]] = None, + ids: Optional[list[str]] = None, + ) -> None: + """ + Add chunks to the vector store. + + Args: + chunks: List of text chunks + metadatas: Optional list of metadata dicts + ids: Optional list of chunk IDs + """ + self._ensure_initialized() + + if not chunks: + return + + # Generate IDs if not provided + if ids is None: + ids = [hashlib.md5(chunk.encode()).hexdigest() for chunk in chunks] + + # Generate metadata if not provided + if metadatas is None: + metadatas = [{}] * len(chunks) + + # Generate embeddings + embeddings = await self._generate_embeddings(chunks) + + # Store everything + for i, (chunk, embedding, metadata, chunk_id) in enumerate( + zip(chunks, embeddings, metadatas, ids) + ): + self._chunks.append({"id": chunk_id, "content": chunk}) + self._embeddings.append(embedding) + self._metadata.append(metadata) + self._ids.append(chunk_id) + + # Save to disk + await self._save() + + log.info(f"Added {len(chunks)} chunks to vector store") + + async def _generate_embeddings(self, texts: list[str]) -> list[list[float]]: + """ + Generate embeddings for texts. + + Uses a simple hash-based embedding for demonstration. + In production, use a real embedding model via API. + """ + embeddings = [] + + for text in texts: + # Simple hash-based embedding (for demo purposes) + # In production, use OpenAI embeddings or similar + hash_bytes = hashlib.sha256(text.encode()).digest() + # Create a 384-dimensional embedding (common size) + embedding = [] + for i in range(384): + byte_idx = i % len(hash_bytes) + value = (hash_bytes[byte_idx] - 128) / 128.0 + embedding.append(value) + embeddings.append(embedding) + + return embeddings + + async def search( + self, + query: str, + top_k: int = 5, + filter_metadata: Optional[dict] = None, + ) -> list[dict[str, Any]]: + """ + Search for similar chunks. + + Args: + query: Query string + top_k: Number of results to return + filter_metadata: Optional metadata filters + + Returns: + List of matching chunks with scores + """ + self._ensure_initialized() + + if not self._chunks: + return [] + + # Generate query embedding + query_embedding = (await self._generate_embeddings([query]))[0] + + # Calculate similarities + results = [] + for i, (chunk, embedding, metadata) in enumerate( + zip(self._chunks, self._embeddings, self._metadata) + ): + # Apply metadata filter + if filter_metadata: + match = all( + metadata.get(k) == v + for k, v in filter_metadata.items() + ) + if not match: + continue + + # Calculate cosine similarity + similarity = self._cosine_similarity(query_embedding, embedding) + + results.append({ + "id": chunk["id"], + "content": chunk["content"], + "metadata": metadata, + "score": similarity, + }) + + # Sort by score and return top_k + results.sort(key=lambda x: x["score"], reverse=True) + return results[:top_k] + + def _cosine_similarity(self, a: list[float], b: list[float]) -> float: + """Calculate cosine similarity between two vectors.""" + if len(a) != len(b): + return 0.0 + + dot_product = sum(x * y for x, y in zip(a, b)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 + + if norm_a == 0 or norm_b == 0: + return 0.0 + + return dot_product / (norm_a * norm_b) + + async def list_documents(self) -> list[dict[str, Any]]: + """List all unique documents in the store.""" + self._ensure_initialized() + + # Group by document_id + documents = {} + for metadata in self._metadata: + doc_id = metadata.get("document_id") + if doc_id and doc_id not in documents: + documents[doc_id] = { + "id": doc_id, + "source": metadata.get("source", "unknown"), + "chunk_count": 1, + } + elif doc_id: + documents[doc_id]["chunk_count"] += 1 + + return list(documents.values()) + + async def delete_document(self, document_id: str) -> None: + """Delete all chunks for a document.""" + self._ensure_initialized() + + # Find indices to remove + indices_to_remove = [ + i + for i, metadata in enumerate(self._metadata) + if metadata.get("document_id") == document_id + ] + + # Remove in reverse order to maintain indices + for i in sorted(indices_to_remove, reverse=True): + self._chunks.pop(i) + self._embeddings.pop(i) + self._metadata.pop(i) + self._ids.pop(i) + + # Save changes + await self._save() + + log.info(f"Deleted document {document_id} ({len(indices_to_remove)} chunks)") diff --git a/requirements.txt b/requirements.txt index e833513..e948661 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,31 @@ +# Core dependencies +fastapi~=0.115.0 +uvicorn[standard]~=0.32.0 +pydantic~=2.10.0 +python-multipart~=0.0.20 + +# HTTP and async +aiohttp~=3.11.0 +httpx~=0.28.0 requests~=2.32.4 + +# Web scraping (for website downloader) beautifulsoup4~=4.13.4 -wget~=3.2 +lxml~=5.3.0 urllib3~=2.5.0 +# Document processing +PyMuPDF~=1.25.0 +python-docx~=1.1.0 + # Optional: For using z-ai-web-dev-sdk with GLM-4.7-Flash -# z-ai-web-dev-sdk>=1.0.0 \ No newline at end of file +# Uncomment the following line if you have access to the SDK +# z-ai-web-dev-sdk>=1.0.0 + +# Vector store alternatives (uncomment as needed) +# chromadb~=0.5.0 +# faiss-cpu~=1.9.0 + +# Development dependencies +# pytest~=8.3.0 +# pytest-asyncio~=0.24.0 diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..de212a3 --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,156 @@ +""" +Tools Module - Tool management for the RAG system + +Provides a unified interface for tool registration and execution. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Callable, Optional + +# Import the website downloader tool +from website_downloader_tool import ( + website_downloader, + get_tool_schema as get_website_downloader_schema, +) + +log = logging.getLogger(__name__) + + +class ToolManager: + """ + Manages tool registration and execution. + + Provides: + - Tool registration + - Tool schema generation + - Tool execution with error handling + """ + + def __init__(self): + self._tools: dict[str, Callable] = {} + self._schemas: dict[str, dict] = {} + + # Register built-in tools + self._register_builtin_tools() + + def _register_builtin_tools(self) -> None: + """Register built-in tools.""" + # Register website downloader + self.register_tool( + name="website_downloader", + function=website_downloader, + schema=get_website_downloader_schema(), + ) + + log.info(f"Registered {len(self._tools)} built-in tools") + + def register_tool( + self, + name: str, + function: Callable, + schema: dict, + ) -> None: + """ + Register a new tool. + + Args: + name: Tool name + function: Tool function + schema: OpenAI function schema + """ + self._tools[name] = function + self._schemas[name] = schema + log.info(f"Registered tool: {name}") + + def get_tool_schema(self, name: str) -> Optional[dict]: + """Get the schema for a tool.""" + return self._schemas.get(name) + + def get_all_schemas(self) -> list[dict]: + """Get schemas for all registered tools.""" + return list(self._schemas.values()) + + def list_tools(self) -> list[str]: + """List all registered tool names.""" + return list(self._tools.keys()) + + def execute_tool( + self, + name: str, + arguments: dict[str, Any], + ) -> dict[str, Any]: + """ + Execute a tool with the given arguments. + + Args: + name: Tool name + arguments: Tool arguments + + Returns: + Tool result + """ + if name not in self._tools: + return { + "success": False, + "error": f"Unknown tool: {name}", + } + + try: + log.info(f"Executing tool: {name} with args: {arguments}") + result = self._tools[name](**arguments) + return result + + except TypeError as e: + log.error(f"Invalid arguments for tool {name}: {e}") + return { + "success": False, + "error": f"Invalid arguments: {str(e)}", + } + + except Exception as e: + log.exception(f"Tool execution failed: {name}") + return { + "success": False, + "error": str(e), + } + + def execute_tool_from_json( + self, + name: str, + arguments_json: str, + ) -> dict[str, Any]: + """ + Execute a tool with JSON arguments. + + Args: + name: Tool name + arguments_json: JSON string of arguments + + Returns: + Tool result + """ + try: + arguments = json.loads(arguments_json) + return self.execute_tool(name, arguments) + except json.JSONDecodeError as e: + return { + "success": False, + "error": f"Invalid JSON arguments: {str(e)}", + } + + +# Global tool manager instance +_tool_manager: Optional[ToolManager] = None + + +def get_tool_manager() -> ToolManager: + """Get or create the global tool manager instance.""" + global _tool_manager + + if _tool_manager is None: + _tool_manager = ToolManager() + + return _tool_manager