119 lines
3.6 KiB
Python
Executable File
119 lines
3.6 KiB
Python
Executable File
"""
|
|
Tool Registry
|
|
Manages all available tools and executes them.
|
|
"""
|
|
from typing import Dict, List, Any, Optional, Type
|
|
from loguru import logger
|
|
|
|
from tools.base import BaseTool, ToolResult
|
|
from tools.web_search import WebSearchTool
|
|
from tools.wikipedia import WikipediaTool
|
|
from tools.rag import RAGTool
|
|
from tools.gemini import GeminiTool
|
|
from tools.openrouter import OpenRouterTool
|
|
from tools.comfyui.image import ImageGenerationTool
|
|
from tools.comfyui.video import VideoGenerationTool
|
|
from tools.comfyui.audio import AudioGenerationTool
|
|
|
|
|
|
class ToolRegistry:
|
|
"""
|
|
Registry for all tools.
|
|
|
|
Handles:
|
|
- Tool registration
|
|
- Tool discovery (returns definitions for Ollama)
|
|
- Tool execution
|
|
"""
|
|
|
|
def __init__(self, rag_store=None):
|
|
self.tools: Dict[str, BaseTool] = {}
|
|
self.rag_store = rag_store
|
|
|
|
# Register all tools
|
|
self._register_default_tools()
|
|
|
|
def _register_default_tools(self) -> None:
|
|
"""Register all default tools."""
|
|
# Web search (DuckDuckGo - no API key needed)
|
|
self.register(WebSearchTool())
|
|
|
|
# Wikipedia
|
|
self.register(WikipediaTool())
|
|
|
|
# RAG (if store is available)
|
|
if self.rag_store:
|
|
self.register(RAGTool(self.rag_store))
|
|
|
|
# External LLM tools (these are hidden from user)
|
|
self.register(GeminiTool())
|
|
self.register(OpenRouterTool())
|
|
|
|
# ComfyUI generation tools
|
|
self.register(ImageGenerationTool())
|
|
self.register(VideoGenerationTool())
|
|
self.register(AudioGenerationTool())
|
|
|
|
logger.info(f"Registered {len(self.tools)} tools")
|
|
|
|
def register(self, tool: BaseTool) -> None:
|
|
"""Register a tool."""
|
|
self.tools[tool.name] = tool
|
|
logger.debug(f"Registered tool: {tool.name}")
|
|
|
|
def unregister(self, tool_name: str) -> None:
|
|
"""Unregister a tool."""
|
|
if tool_name in self.tools:
|
|
del self.tools[tool_name]
|
|
logger.debug(f"Unregistered tool: {tool_name}")
|
|
|
|
def get_tool(self, tool_name: str) -> Optional[BaseTool]:
|
|
"""Get a tool by name."""
|
|
return self.tools.get(tool_name)
|
|
|
|
def get_tool_definitions(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get tool definitions for Ollama.
|
|
|
|
Returns definitions in the format expected by Ollama's tool calling.
|
|
"""
|
|
definitions = []
|
|
|
|
for tool in self.tools.values():
|
|
# Only include tools that have valid configurations
|
|
definitions.append(tool.get_definition())
|
|
|
|
return definitions
|
|
|
|
async def execute(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
|
"""
|
|
Execute a tool by name with given arguments.
|
|
|
|
Returns the result as a string for LLM consumption.
|
|
"""
|
|
tool = self.get_tool(tool_name)
|
|
|
|
if not tool:
|
|
logger.error(f"Tool not found: {tool_name}")
|
|
return f"Error: Tool '{tool_name}' not found"
|
|
|
|
try:
|
|
result = await tool.execute(**arguments)
|
|
|
|
if result.success:
|
|
return result.to_string()
|
|
else:
|
|
return f"Error: {result.error}"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Tool execution failed: {tool_name} - {e}")
|
|
return f"Error: {str(e)}"
|
|
|
|
def list_tools(self) -> List[str]:
|
|
"""List all registered tool names."""
|
|
return list(self.tools.keys())
|
|
|
|
def has_tool(self, tool_name: str) -> bool:
|
|
"""Check if a tool is registered."""
|
|
return tool_name in self.tools
|