270 lines
7.4 KiB
Python
Executable File
270 lines
7.4 KiB
Python
Executable File
"""
|
|
OpenAI-Compatible API Routes
|
|
Implements /v1/chat/completions, /v1/models, and /v1/embeddings
|
|
"""
|
|
import json
|
|
import time
|
|
import uuid
|
|
from typing import Optional, List, AsyncGenerator
|
|
from fastapi import APIRouter, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
from loguru import logger
|
|
|
|
from config import settings
|
|
from core.orchestrator import Orchestrator
|
|
from rag.store import RAGStore
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# ============================================================================
|
|
# Request/Response Models (OpenAI Compatible)
|
|
# ============================================================================
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""OpenAI chat message format."""
|
|
role: str
|
|
content: Optional[str] = None
|
|
name: Optional[str] = None
|
|
tool_calls: Optional[List[dict]] = None
|
|
tool_call_id: Optional[str] = None
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
"""OpenAI chat completion request format."""
|
|
model: str = "moxie"
|
|
messages: List[ChatMessage]
|
|
temperature: Optional[float] = 0.7
|
|
top_p: Optional[float] = 1.0
|
|
max_tokens: Optional[int] = None
|
|
stream: Optional[bool] = False
|
|
tools: Optional[List[dict]] = None
|
|
tool_choice: Optional[str] = "auto"
|
|
frequency_penalty: Optional[float] = 0.0
|
|
presence_penalty: Optional[float] = 0.0
|
|
stop: Optional[List[str]] = None
|
|
|
|
|
|
class ChatCompletionChoice(BaseModel):
|
|
"""OpenAI chat completion choice."""
|
|
index: int
|
|
message: ChatMessage
|
|
finish_reason: str
|
|
|
|
|
|
class ChatCompletionUsage(BaseModel):
|
|
"""Token usage information."""
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
total_tokens: int
|
|
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""OpenAI chat completion response."""
|
|
id: str
|
|
object: str = "chat.completion"
|
|
created: int
|
|
model: str
|
|
choices: List[ChatCompletionChoice]
|
|
usage: ChatCompletionUsage
|
|
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""OpenAI model info format."""
|
|
id: str
|
|
object: str = "model"
|
|
created: int
|
|
owned_by: str = "moxie"
|
|
|
|
|
|
class ModelsResponse(BaseModel):
|
|
"""OpenAI models list response."""
|
|
object: str = "list"
|
|
data: List[ModelInfo]
|
|
|
|
|
|
class EmbeddingRequest(BaseModel):
|
|
"""OpenAI embedding request format."""
|
|
model: str = "moxie-embed"
|
|
input: str | List[str]
|
|
encoding_format: Optional[str] = "float"
|
|
|
|
|
|
class EmbeddingData(BaseModel):
|
|
"""Single embedding data."""
|
|
object: str = "embedding"
|
|
embedding: List[float]
|
|
index: int
|
|
|
|
|
|
class EmbeddingResponse(BaseModel):
|
|
"""OpenAI embedding response."""
|
|
object: str = "list"
|
|
data: List[EmbeddingData]
|
|
model: str
|
|
usage: dict
|
|
|
|
|
|
# ============================================================================
|
|
# Endpoints
|
|
# ============================================================================
|
|
|
|
@router.get("/models", response_model=ModelsResponse)
|
|
async def list_models():
|
|
"""List available models (OpenAI compatible)."""
|
|
models = [
|
|
ModelInfo(id="moxie", created=int(time.time()), owned_by="moxie"),
|
|
ModelInfo(id="moxie-embed", created=int(time.time()), owned_by="moxie"),
|
|
]
|
|
return ModelsResponse(data=models)
|
|
|
|
|
|
@router.get("/models/{model_id}")
|
|
async def get_model(model_id: str):
|
|
"""Get info about a specific model."""
|
|
return ModelInfo(
|
|
id=model_id,
|
|
created=int(time.time()),
|
|
owned_by="moxie"
|
|
)
|
|
|
|
|
|
@router.post("/chat/completions")
|
|
async def chat_completions(
|
|
request: ChatCompletionRequest,
|
|
req: Request
|
|
):
|
|
"""Handle chat completions (OpenAI compatible)."""
|
|
orchestrator: Orchestrator = req.app.state.orchestrator
|
|
|
|
# Convert messages to dict format
|
|
messages = [msg.model_dump(exclude_none=True) for msg in request.messages]
|
|
|
|
if request.stream:
|
|
return StreamingResponse(
|
|
stream_chat_completion(orchestrator, messages, request),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
}
|
|
)
|
|
else:
|
|
return await non_stream_chat_completion(orchestrator, messages, request)
|
|
|
|
|
|
async def non_stream_chat_completion(
|
|
orchestrator: Orchestrator,
|
|
messages: List[dict],
|
|
request: ChatCompletionRequest
|
|
) -> ChatCompletionResponse:
|
|
"""Generate a non-streaming chat completion."""
|
|
result = await orchestrator.process(
|
|
messages=messages,
|
|
model=request.model,
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
)
|
|
|
|
return ChatCompletionResponse(
|
|
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
|
created=int(time.time()),
|
|
model=request.model,
|
|
choices=[
|
|
ChatCompletionChoice(
|
|
index=0,
|
|
message=ChatMessage(
|
|
role="assistant",
|
|
content=result["content"]
|
|
),
|
|
finish_reason="stop"
|
|
)
|
|
],
|
|
usage=ChatCompletionUsage(
|
|
prompt_tokens=result.get("prompt_tokens", 0),
|
|
completion_tokens=result.get("completion_tokens", 0),
|
|
total_tokens=result.get("total_tokens", 0)
|
|
)
|
|
)
|
|
|
|
|
|
async def stream_chat_completion(
|
|
orchestrator: Orchestrator,
|
|
messages: List[dict],
|
|
request: ChatCompletionRequest
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Generate a streaming chat completion."""
|
|
completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
|
|
|
|
async for chunk in orchestrator.process_stream(
|
|
messages=messages,
|
|
model=request.model,
|
|
temperature=request.temperature,
|
|
max_tokens=request.max_tokens,
|
|
):
|
|
# Format as SSE
|
|
data = {
|
|
"id": completion_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": request.model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": chunk,
|
|
"finish_reason": None
|
|
}
|
|
]
|
|
}
|
|
yield f"data: {json.dumps(data)}\n\n"
|
|
|
|
# Send final chunk
|
|
final_data = {
|
|
"id": completion_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": request.model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {},
|
|
"finish_reason": "stop"
|
|
}
|
|
]
|
|
}
|
|
yield f"data: {json.dumps(final_data)}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
@router.post("/embeddings", response_model=EmbeddingResponse)
|
|
async def create_embeddings(request: EmbeddingRequest, req: Request):
|
|
"""Generate embeddings using Ollama (OpenAI compatible)."""
|
|
rag_store: RAGStore = req.app.state.rag_store
|
|
|
|
# Handle single string or list
|
|
texts = request.input if isinstance(request.input, list) else [request.input]
|
|
|
|
embeddings = []
|
|
for i, text in enumerate(texts):
|
|
embedding = await rag_store.generate_embedding(text)
|
|
embeddings.append(
|
|
EmbeddingData(
|
|
object="embedding",
|
|
embedding=embedding,
|
|
index=i
|
|
)
|
|
)
|
|
|
|
return EmbeddingResponse(
|
|
object="list",
|
|
data=embeddings,
|
|
model=request.model,
|
|
usage={
|
|
"prompt_tokens": sum(len(t.split()) for t in texts),
|
|
"total_tokens": sum(len(t.split()) for t in texts)
|
|
}
|
|
)
|