- AI chat interface with file uploads - Admin panel for managing OpenAI/Ollama endpoints - User authentication with JWT - SQLite database backend - SvelteKit frontend with dark theme
232 lines
6.7 KiB
Python
232 lines
6.7 KiB
Python
import os
|
|
import uuid
|
|
from typing import List, Optional
|
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Form
|
|
from sqlalchemy.orm import Session
|
|
import httpx
|
|
from openai import AsyncOpenAI
|
|
from app.core.database import get_db
|
|
from app.core.auth import get_current_user
|
|
from app.core.config import settings
|
|
from app.models.models import User, AIEndpoint, ChatMessage, UploadedFile
|
|
from app.schemas.schemas import (
|
|
ChatRequest,
|
|
ChatResponse,
|
|
ChatMessageResponse,
|
|
UploadedFileResponse
|
|
)
|
|
|
|
router = APIRouter(prefix="/chat", tags=["Chat"])
|
|
|
|
|
|
# Ensure upload directory exists
|
|
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
|
|
|
|
|
@router.post("/upload", response_model=UploadedFileResponse)
|
|
async def upload_file(
|
|
file: UploadFile = File(...),
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
# Check file size
|
|
contents = await file.read()
|
|
if len(contents) > settings.MAX_FILE_SIZE:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"File size exceeds maximum allowed size of {settings.MAX_FILE_SIZE} bytes"
|
|
)
|
|
|
|
# Generate unique filename
|
|
file_extension = os.path.splitext(file.filename)[1]
|
|
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
|
file_path = os.path.join(settings.UPLOAD_DIR, unique_filename)
|
|
|
|
# Save file
|
|
with open(file_path, "wb") as f:
|
|
f.write(contents)
|
|
|
|
# Create database record
|
|
uploaded_file = UploadedFile(
|
|
user_id=current_user.id,
|
|
filename=unique_filename,
|
|
original_filename=file.filename,
|
|
file_path=file_path,
|
|
file_size=len(contents),
|
|
file_type=file.content_type
|
|
)
|
|
db.add(uploaded_file)
|
|
db.commit()
|
|
db.refresh(uploaded_file)
|
|
|
|
return uploaded_file
|
|
|
|
|
|
@router.get("/files", response_model=List[UploadedFileResponse])
|
|
def list_files(
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
files = db.query(UploadedFile).filter(UploadedFile.user_id == current_user.id).all()
|
|
return files
|
|
|
|
|
|
@router.delete("/files/{file_id}")
|
|
def delete_file(
|
|
file_id: int,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
uploaded_file = db.query(UploadedFile).filter(
|
|
UploadedFile.id == file_id,
|
|
UploadedFile.user_id == current_user.id
|
|
).first()
|
|
|
|
if not uploaded_file:
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
# Delete file from filesystem
|
|
if os.path.exists(uploaded_file.file_path):
|
|
os.remove(uploaded_file.file_path)
|
|
|
|
db.delete(uploaded_file)
|
|
db.commit()
|
|
|
|
return {"message": "File deleted successfully"}
|
|
|
|
|
|
@router.post("/message", response_model=ChatResponse)
|
|
async def send_message(
|
|
request: ChatRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
# Get the AI endpoint
|
|
endpoint = None
|
|
if request.endpoint_id:
|
|
endpoint = db.query(AIEndpoint).filter(
|
|
AIEndpoint.id == request.endpoint_id,
|
|
AIEndpoint.is_active == True
|
|
).first()
|
|
else:
|
|
# Get default endpoint
|
|
endpoint = db.query(AIEndpoint).filter(
|
|
AIEndpoint.is_default == True,
|
|
AIEndpoint.is_active == True
|
|
).first()
|
|
|
|
if not endpoint:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="No active AI endpoint available"
|
|
)
|
|
|
|
# Save user message
|
|
user_message = ChatMessage(
|
|
user_id=current_user.id,
|
|
role="user",
|
|
content=request.message,
|
|
endpoint_id=endpoint.id
|
|
)
|
|
db.add(user_message)
|
|
db.commit()
|
|
|
|
# Build messages for API call
|
|
messages = []
|
|
if request.conversation_history:
|
|
for msg in request.conversation_history:
|
|
messages.append({"role": msg.role, "content": msg.content})
|
|
messages.append({"role": "user", "content": request.message})
|
|
|
|
# Call AI endpoint
|
|
try:
|
|
response_content = await call_ai_endpoint(endpoint, messages)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Error calling AI endpoint: {str(e)}"
|
|
)
|
|
|
|
# Save assistant message
|
|
assistant_message = ChatMessage(
|
|
user_id=current_user.id,
|
|
role="assistant",
|
|
content=response_content,
|
|
endpoint_id=endpoint.id
|
|
)
|
|
db.add(assistant_message)
|
|
db.commit()
|
|
|
|
return ChatResponse(
|
|
response=response_content,
|
|
endpoint_id=endpoint.id,
|
|
model=endpoint.model_name
|
|
)
|
|
|
|
|
|
async def call_ai_endpoint(endpoint: AIEndpoint, messages: List[dict]) -> str:
|
|
if endpoint.endpoint_type == "openai":
|
|
return await call_openai_compatible(endpoint, messages)
|
|
elif endpoint.endpoint_type == "ollama":
|
|
return await call_ollama(endpoint, messages)
|
|
else:
|
|
raise ValueError(f"Unknown endpoint type: {endpoint.endpoint_type}")
|
|
|
|
|
|
async def call_openai_compatible(endpoint: AIEndpoint, messages: List[dict]) -> str:
|
|
"""Call OpenAI-compatible API (works with OpenAI, local AI servers, etc.)"""
|
|
client = AsyncOpenAI(
|
|
api_key=endpoint.api_key or "not-needed",
|
|
base_url=endpoint.base_url
|
|
)
|
|
|
|
response = await client.chat.completions.create(
|
|
model=endpoint.model_name,
|
|
messages=messages
|
|
)
|
|
|
|
return response.choices[0].message.content
|
|
|
|
|
|
async def call_ollama(endpoint: AIEndpoint, messages: List[dict]) -> str:
|
|
"""Call Ollama API"""
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{endpoint.base_url.rstrip('/')}/api/chat",
|
|
json={
|
|
"model": endpoint.model_name,
|
|
"messages": messages,
|
|
"stream": False
|
|
},
|
|
timeout=60.0
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data["message"]["content"]
|
|
|
|
|
|
@router.get("/history", response_model=List[ChatMessageResponse])
|
|
def get_history(
|
|
limit: int = 50,
|
|
endpoint_id: Optional[int] = None,
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
query = db.query(ChatMessage).filter(ChatMessage.user_id == current_user.id)
|
|
|
|
if endpoint_id:
|
|
query = query.filter(ChatMessage.endpoint_id == endpoint_id)
|
|
|
|
messages = query.order_by(ChatMessage.created_at.desc()).limit(limit).all()
|
|
return list(reversed(messages))
|
|
|
|
|
|
@router.delete("/history")
|
|
def clear_history(
|
|
current_user: User = Depends(get_current_user),
|
|
db: Session = Depends(get_db)
|
|
):
|
|
db.query(ChatMessage).filter(ChatMessage.user_id == current_user.id).delete()
|
|
db.commit()
|
|
return {"message": "Chat history cleared successfully"}
|