121 lines
4.2 KiB
Python
Executable File
121 lines
4.2 KiB
Python
Executable File
"""
|
|
Gemini Tool
|
|
Calls Google Gemini API for "deep reasoning" tasks.
|
|
This tool is hidden from the user - they just see "deep_reasoning".
|
|
"""
|
|
from typing import Dict, Any, Optional
|
|
import httpx
|
|
from loguru import logger
|
|
|
|
from config import load_config_from_db, settings
|
|
from tools.base import BaseTool, ToolResult
|
|
|
|
|
|
class GeminiTool(BaseTool):
|
|
"""Call Gemini API for complex reasoning tasks."""
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "deep_reasoning"
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
return "Perform deep reasoning and analysis for complex problems. Use this for difficult questions that require careful thought, math, coding, or multi-step reasoning."
|
|
|
|
@property
|
|
def parameters(self) -> Dict[str, Any]:
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"prompt": {
|
|
"type": "string",
|
|
"description": "The problem or question to reason about"
|
|
}
|
|
},
|
|
"required": ["prompt"]
|
|
}
|
|
|
|
def _validate_config(self) -> None:
|
|
"""Validate that API key is configured."""
|
|
config = load_config_from_db()
|
|
self.api_key = config.get("gemini_api_key")
|
|
self.model = config.get("gemini_model", "gemini-1.5-flash")
|
|
|
|
async def execute(self, prompt: str, **kwargs) -> ToolResult:
|
|
"""Execute Gemini API call."""
|
|
self._log_execution({"prompt": prompt[:100]})
|
|
|
|
# Reload config in case it was updated
|
|
self._validate_config()
|
|
|
|
if not self.api_key:
|
|
return ToolResult(
|
|
success=False,
|
|
error="Gemini API key not configured. Please configure it in the admin panel."
|
|
)
|
|
|
|
try:
|
|
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model}:generateContent"
|
|
|
|
payload = {
|
|
"contents": [
|
|
{
|
|
"parts": [
|
|
{"text": prompt}
|
|
]
|
|
}
|
|
],
|
|
"generationConfig": {
|
|
"temperature": 0.7,
|
|
"maxOutputTokens": 2048,
|
|
}
|
|
}
|
|
|
|
params = {"key": self.api_key}
|
|
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
response = await client.post(
|
|
url,
|
|
json=payload,
|
|
params=params
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
error_msg = f"API error: {response.status_code}"
|
|
try:
|
|
error_data = response.json()
|
|
if "error" in error_data:
|
|
error_msg = error_data["error"].get("message", error_msg)
|
|
except Exception:
|
|
pass
|
|
|
|
self._log_error(error_msg)
|
|
return ToolResult(success=False, error=error_msg)
|
|
|
|
data = response.json()
|
|
|
|
# Extract response text
|
|
if "candidates" in data and len(data["candidates"]) > 0:
|
|
candidate = data["candidates"][0]
|
|
if "content" in candidate and "parts" in candidate["content"]:
|
|
text = "".join(
|
|
part.get("text", "")
|
|
for part in candidate["content"]["parts"]
|
|
)
|
|
|
|
self._log_success(text[:100])
|
|
return ToolResult(success=True, data=text)
|
|
|
|
return ToolResult(
|
|
success=False,
|
|
error="Unexpected response format from Gemini"
|
|
)
|
|
|
|
except httpx.TimeoutException:
|
|
self._log_error("Request timed out")
|
|
return ToolResult(success=False, error="Request timed out")
|
|
|
|
except Exception as e:
|
|
self._log_error(str(e))
|
|
return ToolResult(success=False, error=str(e))
|