- Pass all registered tools to LLM during chat completion - Handle tool_calls from LLM response - Execute tools and feed results back to LLM - Loop until LLM returns final response - Updated system prompt to encourage tool use - Updated streaming to handle tool calls - Increased MAX_TOOL_ITERATIONS to 5
433 lines
14 KiB
Python
Executable File
433 lines
14 KiB
Python
Executable File
"""
|
|
Tools Module - Tool management for the RAG system
|
|
|
|
Provides a unified interface for tool registration and execution.
|
|
All tools use completely free APIs with no authentication required.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any, Callable, Optional
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class ToolManager:
|
|
"""
|
|
Manages tool registration and execution.
|
|
|
|
Provides:
|
|
- Tool registration
|
|
- Tool schema generation
|
|
- Tool execution with error handling
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._tools: dict[str, Callable] = {}
|
|
self._schemas: dict[str, dict] = {}
|
|
|
|
# Register built-in tools
|
|
self._register_builtin_tools()
|
|
|
|
def _register_builtin_tools(self) -> None:
|
|
"""Register all built-in tools."""
|
|
|
|
# === Website Downloader Tool ===
|
|
try:
|
|
from website_downloader_tool import (
|
|
website_downloader,
|
|
get_tool_schema as get_website_downloader_schema,
|
|
)
|
|
self.register_tool(
|
|
name="website_downloader",
|
|
function=website_downloader,
|
|
schema=get_website_downloader_schema(),
|
|
)
|
|
except ImportError as e:
|
|
log.warning(f"Could not import website_downloader_tool: {e}")
|
|
|
|
# === Wikipedia Tools ===
|
|
try:
|
|
from tools.wikipedia_tool import (
|
|
wikipedia_search,
|
|
wikipedia_get_article,
|
|
wikipedia_get_full_article,
|
|
WIKIPEDIA_SEARCH_SCHEMA,
|
|
WIKIPEDIA_GET_ARTICLE_SCHEMA,
|
|
WIKIPEDIA_GET_FULL_ARTICLE_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="wikipedia_search",
|
|
function=wikipedia_search,
|
|
schema=WIKIPEDIA_SEARCH_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="wikipedia_get_article",
|
|
function=wikipedia_get_article,
|
|
schema=WIKIPEDIA_GET_ARTICLE_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="wikipedia_get_full_article",
|
|
function=wikipedia_get_full_article,
|
|
schema=WIKIPEDIA_GET_FULL_ARTICLE_SCHEMA,
|
|
)
|
|
except ImportError as e:
|
|
log.warning(f"Could not import wikipedia_tool: {e}")
|
|
|
|
# === News Tools ===
|
|
try:
|
|
from tools.news_tool import (
|
|
news_search_hackernews,
|
|
news_get_top_stories,
|
|
news_get_reddit,
|
|
news_search_reddit,
|
|
news_aggregate,
|
|
NEWS_SEARCH_HACKERNEWS_SCHEMA,
|
|
NEWS_GET_TOP_STORIES_SCHEMA,
|
|
NEWS_GET_REDDIT_SCHEMA,
|
|
NEWS_SEARCH_REDDIT_SCHEMA,
|
|
NEWS_AGGREGATE_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="news_search_hackernews",
|
|
function=news_search_hackernews,
|
|
schema=NEWS_SEARCH_HACKERNEWS_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="news_get_top_stories",
|
|
function=news_get_top_stories,
|
|
schema=NEWS_GET_TOP_STORIES_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="news_get_reddit",
|
|
function=news_get_reddit,
|
|
schema=NEWS_GET_REDDIT_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="news_search_reddit",
|
|
function=news_search_reddit,
|
|
schema=NEWS_SEARCH_REDDIT_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="news_aggregate",
|
|
function=news_aggregate,
|
|
schema=NEWS_AGGREGATE_SCHEMA,
|
|
)
|
|
except ImportError as e:
|
|
log.warning(f"Could not import news_tool: {e}")
|
|
|
|
# === Finance Tools ===
|
|
try:
|
|
from tools.finance_tool import (
|
|
finance_get_stock_info,
|
|
finance_get_stock_history,
|
|
finance_get_crypto_price,
|
|
finance_get_top_cryptos,
|
|
finance_get_exchange_rate,
|
|
finance_search_crypto,
|
|
FINANCE_GET_STOCK_INFO_SCHEMA,
|
|
FINANCE_GET_STOCK_HISTORY_SCHEMA,
|
|
FINANCE_GET_CRYPTO_PRICE_SCHEMA,
|
|
FINANCE_GET_TOP_CRYPTOS_SCHEMA,
|
|
FINANCE_GET_EXCHANGE_RATE_SCHEMA,
|
|
FINANCE_SEARCH_CRYPTO_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="finance_get_stock_info",
|
|
function=finance_get_stock_info,
|
|
schema=FINANCE_GET_STOCK_INFO_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="finance_get_stock_history",
|
|
function=finance_get_stock_history,
|
|
schema=FINANCE_GET_STOCK_HISTORY_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="finance_get_crypto_price",
|
|
function=finance_get_crypto_price,
|
|
schema=FINANCE_GET_CRYPTO_PRICE_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="finance_get_top_cryptos",
|
|
function=finance_get_top_cryptos,
|
|
schema=FINANCE_GET_TOP_CRYPTOS_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="finance_get_exchange_rate",
|
|
function=finance_get_exchange_rate,
|
|
schema=FINANCE_GET_EXCHANGE_RATE_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="finance_search_crypto",
|
|
function=finance_search_crypto,
|
|
schema=FINANCE_SEARCH_CRYPTO_SCHEMA,
|
|
)
|
|
except ImportError as e:
|
|
log.warning(f"Could not import finance_tool: {e}")
|
|
|
|
# === Medical Tools ===
|
|
try:
|
|
from tools.medical_tool import (
|
|
medical_search_pubmed,
|
|
medical_get_pubmed_abstract,
|
|
medical_get_disease_data,
|
|
medical_get_covid_country,
|
|
medical_search_fda,
|
|
medical_get_health_topics,
|
|
MEDICAL_SEARCH_PUBMED_SCHEMA,
|
|
MEDICAL_GET_PUBMED_ABSTRACT_SCHEMA,
|
|
MEDICAL_GET_DISEASE_DATA_SCHEMA,
|
|
MEDICAL_GET_COVID_COUNTRY_SCHEMA,
|
|
MEDICAL_SEARCH_FDA_SCHEMA,
|
|
MEDICAL_GET_HEALTH_TOPICS_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="medical_search_pubmed",
|
|
function=medical_search_pubmed,
|
|
schema=MEDICAL_SEARCH_PUBMED_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="medical_get_pubmed_abstract",
|
|
function=medical_get_pubmed_abstract,
|
|
schema=MEDICAL_GET_PUBMED_ABSTRACT_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="medical_get_disease_data",
|
|
function=medical_get_disease_data,
|
|
schema=MEDICAL_GET_DISEASE_DATA_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="medical_get_covid_country",
|
|
function=medical_get_covid_country,
|
|
schema=MEDICAL_GET_COVID_COUNTRY_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="medical_search_fda",
|
|
function=medical_search_fda,
|
|
schema=MEDICAL_SEARCH_FDA_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="medical_get_health_topics",
|
|
function=medical_get_health_topics,
|
|
schema=MEDICAL_GET_HEALTH_TOPICS_SCHEMA,
|
|
)
|
|
except ImportError as e:
|
|
log.warning(f"Could not import medical_tool: {e}")
|
|
|
|
# === Weather Tools ===
|
|
try:
|
|
from tools.weather_tool import (
|
|
weather_get_current,
|
|
weather_get_forecast,
|
|
weather_get_air_quality,
|
|
WEATHER_GET_CURRENT_SCHEMA,
|
|
WEATHER_GET_FORECAST_SCHEMA,
|
|
WEATHER_GET_AIR_QUALITY_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="weather_get_current",
|
|
function=weather_get_current,
|
|
schema=WEATHER_GET_CURRENT_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="weather_get_forecast",
|
|
function=weather_get_forecast,
|
|
schema=WEATHER_GET_FORECAST_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="weather_get_air_quality",
|
|
function=weather_get_air_quality,
|
|
schema=WEATHER_GET_AIR_QUALITY_SCHEMA,
|
|
)
|
|
except ImportError as e:
|
|
log.warning(f"Could not import weather_tool: {e}")
|
|
|
|
# === Science Tools ===
|
|
try:
|
|
from tools.science_tool import (
|
|
science_search_arxiv,
|
|
science_search_semantic_scholar,
|
|
science_get_paper_details,
|
|
science_search_doaj,
|
|
science_aggregate_search,
|
|
SCIENCE_SEARCH_ARXIV_SCHEMA,
|
|
SCIENCE_SEARCH_SEMANTIC_SCHOLAR_SCHEMA,
|
|
SCIENCE_GET_PAPER_DETAILS_SCHEMA,
|
|
SCIENCE_SEARCH_DOAJ_SCHEMA,
|
|
SCIENCE_AGGREGATE_SEARCH_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="science_search_arxiv",
|
|
function=science_search_arxiv,
|
|
schema=SCIENCE_SEARCH_ARXIV_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="science_search_semantic_scholar",
|
|
function=science_search_semantic_scholar,
|
|
schema=SCIENCE_SEARCH_SEMANTIC_SCHOLAR_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="science_get_paper_details",
|
|
function=science_get_paper_details,
|
|
schema=SCIENCE_GET_PAPER_DETAILS_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="science_search_doaj",
|
|
function=science_search_doaj,
|
|
schema=SCIENCE_SEARCH_DOAJ_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="science_aggregate_search",
|
|
function=science_aggregate_search,
|
|
schema=SCIENCE_AGGREGATE_SEARCH_SCHEMA,
|
|
)
|
|
except ImportError as e:
|
|
log.warning(f"Could not import science_tool: {e}")
|
|
|
|
# === Web Search Tools ===
|
|
try:
|
|
from tools.web_tool import (
|
|
web_search,
|
|
web_instant_answer,
|
|
web_get_page_content,
|
|
web_search_and_fetch,
|
|
WEB_SEARCH_SCHEMA,
|
|
WEB_INSTANT_ANSWER_SCHEMA,
|
|
WEB_GET_PAGE_CONTENT_SCHEMA,
|
|
WEB_SEARCH_AND_FETCH_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="web_search",
|
|
function=web_search,
|
|
schema=WEB_SEARCH_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="web_instant_answer",
|
|
function=web_instant_answer,
|
|
schema=WEB_INSTANT_ANSWER_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="web_get_page_content",
|
|
function=web_get_page_content,
|
|
schema=WEB_GET_PAGE_CONTENT_SCHEMA,
|
|
)
|
|
self.register_tool(
|
|
name="web_search_and_fetch",
|
|
function=web_search_and_fetch,
|
|
schema=WEB_SEARCH_AND_FETCH_SCHEMA,
|
|
)
|
|
except ImportError as e:
|
|
log.warning(f"Could not import web_tool: {e}")
|
|
|
|
log.info(f"Registered {len(self._tools)} built-in tools")
|
|
|
|
def register_tool(
|
|
self,
|
|
name: str,
|
|
function: Callable,
|
|
schema: dict,
|
|
) -> None:
|
|
"""
|
|
Register a new tool.
|
|
|
|
Args:
|
|
name: Tool name
|
|
function: Tool function
|
|
schema: OpenAI function schema
|
|
"""
|
|
self._tools[name] = function
|
|
self._schemas[name] = schema
|
|
log.info(f"Registered tool: {name}")
|
|
|
|
def get_tool_schema(self, name: str) -> Optional[dict]:
|
|
"""Get the schema for a tool."""
|
|
return self._schemas.get(name)
|
|
|
|
def get_all_schemas(self) -> list[dict]:
|
|
"""Get schemas for all registered tools."""
|
|
return list(self._schemas.values())
|
|
|
|
def list_tools(self) -> list[str]:
|
|
"""List all registered tool names."""
|
|
return list(self._tools.keys())
|
|
|
|
def execute_tool(
|
|
self,
|
|
name: str,
|
|
arguments: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Execute a tool with the given arguments.
|
|
|
|
Args:
|
|
name: Tool name
|
|
arguments: Tool arguments
|
|
|
|
Returns:
|
|
Tool result
|
|
"""
|
|
if name not in self._tools:
|
|
return {
|
|
"success": False,
|
|
"error": f"Unknown tool: {name}",
|
|
}
|
|
|
|
try:
|
|
log.info(f"Executing tool: {name} with args: {arguments}")
|
|
result = self._tools[name](**arguments)
|
|
return result
|
|
|
|
except TypeError as e:
|
|
log.error(f"Invalid arguments for tool {name}: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid arguments: {str(e)}",
|
|
}
|
|
|
|
except Exception as e:
|
|
log.exception(f"Tool execution failed: {name}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
}
|
|
|
|
def execute_tool_from_json(
|
|
self,
|
|
name: str,
|
|
arguments_json: str,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Execute a tool with JSON arguments.
|
|
|
|
Args:
|
|
name: Tool name
|
|
arguments_json: JSON string of arguments
|
|
|
|
Returns:
|
|
Tool result
|
|
"""
|
|
try:
|
|
arguments = json.loads(arguments_json)
|
|
return self.execute_tool(name, arguments)
|
|
except json.JSONDecodeError as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"Invalid JSON arguments: {str(e)}",
|
|
}
|
|
|
|
|
|
# Global tool manager instance
|
|
_tool_manager: Optional[ToolManager] = None
|
|
|
|
|
|
def get_tool_manager() -> ToolManager:
|
|
"""Get or create the global tool manager instance."""
|
|
global _tool_manager
|
|
|
|
if _tool_manager is None:
|
|
_tool_manager = ToolManager()
|
|
|
|
return _tool_manager
|