""" OpenAI-Compatible API Routes Implements /v1/chat/completions, /v1/models, and /v1/embeddings """ import json import time import uuid from typing import Optional, List, AsyncGenerator from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field from loguru import logger from config import settings from core.orchestrator import Orchestrator from rag.store import RAGStore router = APIRouter() # ============================================================================ # Request/Response Models (OpenAI Compatible) # ============================================================================ class ChatMessage(BaseModel): """OpenAI chat message format.""" role: str content: Optional[str] = None name: Optional[str] = None tool_calls: Optional[List[dict]] = None tool_call_id: Optional[str] = None class ChatCompletionRequest(BaseModel): """OpenAI chat completion request format.""" model: str = "moxie" messages: List[ChatMessage] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 max_tokens: Optional[int] = None stream: Optional[bool] = False tools: Optional[List[dict]] = None tool_choice: Optional[str] = "auto" frequency_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0 stop: Optional[List[str]] = None class ChatCompletionChoice(BaseModel): """OpenAI chat completion choice.""" index: int message: ChatMessage finish_reason: str class ChatCompletionUsage(BaseModel): """Token usage information.""" prompt_tokens: int completion_tokens: int total_tokens: int class ChatCompletionResponse(BaseModel): """OpenAI chat completion response.""" id: str object: str = "chat.completion" created: int model: str choices: List[ChatCompletionChoice] usage: ChatCompletionUsage class ModelInfo(BaseModel): """OpenAI model info format.""" id: str object: str = "model" created: int owned_by: str = "moxie" class ModelsResponse(BaseModel): """OpenAI models list response.""" object: str = "list" data: List[ModelInfo] class EmbeddingRequest(BaseModel): """OpenAI embedding request format.""" model: str = "moxie-embed" input: str | List[str] encoding_format: Optional[str] = "float" class EmbeddingData(BaseModel): """Single embedding data.""" object: str = "embedding" embedding: List[float] index: int class EmbeddingResponse(BaseModel): """OpenAI embedding response.""" object: str = "list" data: List[EmbeddingData] model: str usage: dict # ============================================================================ # Endpoints # ============================================================================ @router.get("/models", response_model=ModelsResponse) async def list_models(): """List available models (OpenAI compatible).""" models = [ ModelInfo(id="moxie", created=int(time.time()), owned_by="moxie"), ModelInfo(id="moxie-embed", created=int(time.time()), owned_by="moxie"), ] return ModelsResponse(data=models) @router.get("/models/{model_id}") async def get_model(model_id: str): """Get info about a specific model.""" return ModelInfo( id=model_id, created=int(time.time()), owned_by="moxie" ) @router.post("/chat/completions") async def chat_completions( request: ChatCompletionRequest, req: Request ): """Handle chat completions (OpenAI compatible).""" orchestrator: Orchestrator = req.app.state.orchestrator # Convert messages to dict format messages = [msg.model_dump(exclude_none=True) for msg in request.messages] if request.stream: return StreamingResponse( stream_chat_completion(orchestrator, messages, request), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", } ) else: return await non_stream_chat_completion(orchestrator, messages, request) async def non_stream_chat_completion( orchestrator: Orchestrator, messages: List[dict], request: ChatCompletionRequest ) -> ChatCompletionResponse: """Generate a non-streaming chat completion.""" result = await orchestrator.process( messages=messages, model=request.model, temperature=request.temperature, max_tokens=request.max_tokens, ) return ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex[:8]}", created=int(time.time()), model=request.model, choices=[ ChatCompletionChoice( index=0, message=ChatMessage( role="assistant", content=result["content"] ), finish_reason="stop" ) ], usage=ChatCompletionUsage( prompt_tokens=result.get("prompt_tokens", 0), completion_tokens=result.get("completion_tokens", 0), total_tokens=result.get("total_tokens", 0) ) ) async def stream_chat_completion( orchestrator: Orchestrator, messages: List[dict], request: ChatCompletionRequest ) -> AsyncGenerator[str, None]: """Generate a streaming chat completion.""" completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" async for chunk in orchestrator.process_stream( messages=messages, model=request.model, temperature=request.temperature, max_tokens=request.max_tokens, ): # Format as SSE data = { "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": request.model, "choices": [ { "index": 0, "delta": chunk, "finish_reason": None } ] } yield f"data: {json.dumps(data)}\n\n" # Send final chunk final_data = { "id": completion_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": request.model, "choices": [ { "index": 0, "delta": {}, "finish_reason": "stop" } ] } yield f"data: {json.dumps(final_data)}\n\n" yield "data: [DONE]\n\n" @router.post("/embeddings", response_model=EmbeddingResponse) async def create_embeddings(request: EmbeddingRequest, req: Request): """Generate embeddings using Ollama (OpenAI compatible).""" rag_store: RAGStore = req.app.state.rag_store # Handle single string or list texts = request.input if isinstance(request.input, list) else [request.input] embeddings = [] for i, text in enumerate(texts): embedding = await rag_store.generate_embedding(text) embeddings.append( EmbeddingData( object="embedding", embedding=embedding, index=i ) ) return EmbeddingResponse( object="list", data=embeddings, model=request.model, usage={ "prompt_tokens": sum(len(t.split()) for t in texts), "total_tokens": sum(len(t.split()) for t in texts) } )