test/moxie/api/routes.py
2026-03-24 04:07:54 +00:00

270 lines
7.4 KiB
Python
Executable File

"""
OpenAI-Compatible API Routes
Implements /v1/chat/completions, /v1/models, and /v1/embeddings
"""
import json
import time
import uuid
from typing import Optional, List, AsyncGenerator
from fastapi import APIRouter, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from loguru import logger
from config import settings
from core.orchestrator import Orchestrator
from rag.store import RAGStore
router = APIRouter()
# ============================================================================
# Request/Response Models (OpenAI Compatible)
# ============================================================================
class ChatMessage(BaseModel):
"""OpenAI chat message format."""
role: str
content: Optional[str] = None
name: Optional[str] = None
tool_calls: Optional[List[dict]] = None
tool_call_id: Optional[str] = None
class ChatCompletionRequest(BaseModel):
"""OpenAI chat completion request format."""
model: str = "moxie"
messages: List[ChatMessage]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
max_tokens: Optional[int] = None
stream: Optional[bool] = False
tools: Optional[List[dict]] = None
tool_choice: Optional[str] = "auto"
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
stop: Optional[List[str]] = None
class ChatCompletionChoice(BaseModel):
"""OpenAI chat completion choice."""
index: int
message: ChatMessage
finish_reason: str
class ChatCompletionUsage(BaseModel):
"""Token usage information."""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
"""OpenAI chat completion response."""
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatCompletionChoice]
usage: ChatCompletionUsage
class ModelInfo(BaseModel):
"""OpenAI model info format."""
id: str
object: str = "model"
created: int
owned_by: str = "moxie"
class ModelsResponse(BaseModel):
"""OpenAI models list response."""
object: str = "list"
data: List[ModelInfo]
class EmbeddingRequest(BaseModel):
"""OpenAI embedding request format."""
model: str = "moxie-embed"
input: str | List[str]
encoding_format: Optional[str] = "float"
class EmbeddingData(BaseModel):
"""Single embedding data."""
object: str = "embedding"
embedding: List[float]
index: int
class EmbeddingResponse(BaseModel):
"""OpenAI embedding response."""
object: str = "list"
data: List[EmbeddingData]
model: str
usage: dict
# ============================================================================
# Endpoints
# ============================================================================
@router.get("/models", response_model=ModelsResponse)
async def list_models():
"""List available models (OpenAI compatible)."""
models = [
ModelInfo(id="moxie", created=int(time.time()), owned_by="moxie"),
ModelInfo(id="moxie-embed", created=int(time.time()), owned_by="moxie"),
]
return ModelsResponse(data=models)
@router.get("/models/{model_id}")
async def get_model(model_id: str):
"""Get info about a specific model."""
return ModelInfo(
id=model_id,
created=int(time.time()),
owned_by="moxie"
)
@router.post("/chat/completions")
async def chat_completions(
request: ChatCompletionRequest,
req: Request
):
"""Handle chat completions (OpenAI compatible)."""
orchestrator: Orchestrator = req.app.state.orchestrator
# Convert messages to dict format
messages = [msg.model_dump(exclude_none=True) for msg in request.messages]
if request.stream:
return StreamingResponse(
stream_chat_completion(orchestrator, messages, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
}
)
else:
return await non_stream_chat_completion(orchestrator, messages, request)
async def non_stream_chat_completion(
orchestrator: Orchestrator,
messages: List[dict],
request: ChatCompletionRequest
) -> ChatCompletionResponse:
"""Generate a non-streaming chat completion."""
result = await orchestrator.process(
messages=messages,
model=request.model,
temperature=request.temperature,
max_tokens=request.max_tokens,
)
return ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request.model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatMessage(
role="assistant",
content=result["content"]
),
finish_reason="stop"
)
],
usage=ChatCompletionUsage(
prompt_tokens=result.get("prompt_tokens", 0),
completion_tokens=result.get("completion_tokens", 0),
total_tokens=result.get("total_tokens", 0)
)
)
async def stream_chat_completion(
orchestrator: Orchestrator,
messages: List[dict],
request: ChatCompletionRequest
) -> AsyncGenerator[str, None]:
"""Generate a streaming chat completion."""
completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
async for chunk in orchestrator.process_stream(
messages=messages,
model=request.model,
temperature=request.temperature,
max_tokens=request.max_tokens,
):
# Format as SSE
data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"delta": chunk,
"finish_reason": None
}
]
}
yield f"data: {json.dumps(data)}\n\n"
# Send final chunk
final_data = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop"
}
]
}
yield f"data: {json.dumps(final_data)}\n\n"
yield "data: [DONE]\n\n"
@router.post("/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest, req: Request):
"""Generate embeddings using Ollama (OpenAI compatible)."""
rag_store: RAGStore = req.app.state.rag_store
# Handle single string or list
texts = request.input if isinstance(request.input, list) else [request.input]
embeddings = []
for i, text in enumerate(texts):
embedding = await rag_store.generate_embedding(text)
embeddings.append(
EmbeddingData(
object="embedding",
embedding=embedding,
index=i
)
)
return EmbeddingResponse(
object="list",
data=embeddings,
model=request.model,
usage={
"prompt_tokens": sum(len(t.split()) for t in texts),
"total_tokens": sum(len(t.split()) for t in texts)
}
)