Add local tool selector: keyword parser picks relevant tools, no LLM

_select_tools() parses the user message with keyword matching:
- News keywords → news_aggregate, news_get_top_stories, news_get_reddit
- Finance/stock keywords → finance_get_stock_info/history (extracts ticker)
- Crypto keywords → finance_get_crypto_price (extracts coin name), finance_get_top_cryptos
- Weather keywords → weather_get_current/forecast/air_quality (extracts location)
- Medical keywords → pubmed, fda, disease data, health topics
- Science keywords → science_aggregate_search
- Wikipedia keywords → wikipedia_search
- Always: web_search + web_instant_answer as general fallback
- URL in message → web_get_page_content

Entity extractors:
- _extract_ticker: maps known company names, handles $TICKER format
- _extract_crypto: maps known crypto names to CoinGecko IDs
- _extract_location: preposition-based + known locations (prefers longest match)
- _extract_subject: strips question patterns, leading articles, trailing punctuation

Flow remains: request → select tools → run in parallel → results into system prompt → 1 LLM call
This commit is contained in:
Z User 2026-03-29 18:44:14 +00:00
parent 70109d6889
commit 7a6b6f1086

330
main.py
View File

@ -440,15 +440,19 @@ async def download_website_if_needed(user_message: str) -> dict[str, Any]:
# ============================================================================= # =============================================================================
async def _run_all_tools(user_message: str) -> list[dict]: async def _run_all_tools(user_message: str) -> list[dict]:
"""Run ALL tools in parallel. No LLM involved. """Select relevant tools via local keyword matching, then run them all in parallel."""
- Tools with no required args: run with defaults.
- Tools with required args: use the user message as the query argument.
- Each tool gets a timeout so slow ones don't block.
"""
if not state.tool_manager: if not state.tool_manager:
return [] return []
# Step 1: Determine which tools to call
tool_calls = _select_tools(user_message)
if not tool_calls:
log.info("No tools selected by keyword parser")
return []
log.info(f"Selected {len(tool_calls)} tools: {[tc['name'] for tc in tool_calls]}")
# Step 2: Execute them all in parallel
async def _run_one(name: str, kwargs: dict): async def _run_one(name: str, kwargs: dict):
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
@ -461,44 +465,292 @@ async def _run_all_tools(user_message: str) -> list[dict]:
except Exception as e: except Exception as e:
return {"name": name, "success": False, "error": str(e)} return {"name": name, "success": False, "error": str(e)}
tasks = [] results = await asyncio.gather(*[_run_one(tc["name"], tc["kwargs"]) for tc in tool_calls])
for name, schema in state.tool_manager._schemas.items():
func_schema = schema.get("function", {})
params = func_schema.get("parameters", {})
required = set(params.get("required", []))
props = params.get("properties", {})
# Build kwargs: defaults from schema, then fill required from user message
kwargs = {}
for pname, pinfo in props.items():
if "default" in pinfo:
kwargs[pname] = pinfo["default"]
for pname in required:
if pname not in kwargs:
# Heuristic: use user_message for common query-like params
if pname in ("query", "q", "search", "search_query", "topic", "title", "question"):
kwargs[pname] = user_message
# Use user_message for specific ID fields that look queryable
elif pname in ("disease",):
kwargs[pname] = user_message
# Skip tools that need specific args we can't guess (symbol, pmid, paper_id, url, etc.)
else:
kwargs = None
break
if kwargs is not None:
tasks.append(_run_one(name, kwargs))
else:
log.debug(f"Skipping tool {name}: can't fill required param from user message")
log.info(f"Running {len(tasks)} tools in parallel...")
results = await asyncio.gather(*tasks)
successes = [r for r in results if r["success"]] successes = [r for r in results if r["success"]]
log.info(f"Tool execution complete: {len(successes)}/{len(results)} succeeded") log.info(f"Tool execution complete: {len(successes)}/{len(results)} succeeded")
return results return results
def _select_tools(user_message: str) -> list[dict]:
"""Parse the user message and determine which tools to call.
Uses keyword/category matching. Returns list of {"name": str, "kwargs": dict}.
"""
msg = user_message.lower()
tools = []
# --- Extract useful entities from the message ---
location = _extract_location(user_message)
ticker = _extract_ticker(user_message)
crypto = _extract_crypto(user_message)
url = _extract_url(user_message)
subject = _extract_subject(user_message) # the main topic/query
# --- News tools ---
if any(kw in msg for kw in ["news", "headline", "headlines", "current event", "current events",
"breaking", "trending", "reddit", "hacker", "story", "stories"]):
tools.append({"name": "news_aggregate", "kwargs": {"query": subject}})
tools.append({"name": "news_get_top_stories", "kwargs": {}})
tools.append({"name": "news_get_reddit", "kwargs": {"subreddit": "news"}})
if any(kw in msg for kw in ["reddit", "subreddit"]):
tools.append({"name": "news_search_reddit", "kwargs": {"query": subject}})
# --- Finance / stock tools ---
if any(kw in msg for kw in ["stock", "share", "price", "market", "nasdaq", "nyse",
"dow", "s&p", "dividend", "portfolio", "ipo"]):
if ticker:
tools.append({"name": "finance_get_stock_info", "kwargs": {"symbol": ticker}})
tools.append({"name": "finance_get_stock_history", "kwargs": {"symbol": ticker}})
else:
# Try to extract any potential ticker from message
for word in user_message.upper().split():
word = word.strip(",$.!?;:")
if 1 <= len(word) <= 5 and word.isalpha():
tools.append({"name": "finance_get_stock_info", "kwargs": {"symbol": word}})
break
# --- Crypto tools ---
if any(kw in msg for kw in ["crypto", "bitcoin", "btc", "ethereum", "eth", "solana",
"dogecoin", "memecoin", "altcoin", "blockchain", "coin", "token"]):
if crypto:
tools.append({"name": "finance_get_crypto_price", "kwargs": {"coin_id": crypto}})
tools.append({"name": "finance_get_top_cryptos", "kwargs": {}})
# --- Weather tools ---
if any(kw in msg for kw in ["weather", "temperature", "forecast", "rain", "snow",
"wind", "humid", "sunny", "cloudy", "storm", "air quality",
"aqi", "pollution", "uv index"]):
if location:
tools.append({"name": "weather_get_current", "kwargs": {"location": location}})
tools.append({"name": "weather_get_forecast", "kwargs": {"location": location}})
tools.append({"name": "weather_get_air_quality", "kwargs": {"location": location}})
else:
tools.append({"name": "weather_get_current", "kwargs": {"location": subject}})
tools.append({"name": "weather_get_forecast", "kwargs": {"location": subject}})
# --- Medical tools ---
if any(kw in msg for kw in ["medical", "health", "disease", "symptom", "drug", "medication",
"covid", "vaccine", "fda", "hospital", "pubmed", "clinical",
"treatment", "diagnosis", "doctor", "patient"]):
tools.append({"name": "medical_search_pubmed", "kwargs": {"query": subject}})
tools.append({"name": "medical_search_fda", "kwargs": {"query": subject}})
tools.append({"name": "medical_get_disease_data", "kwargs": {"disease": subject}})
tools.append({"name": "medical_get_health_topics", "kwargs": {"topic": subject}})
# --- Science / research tools ---
if any(kw in msg for kw in ["research", "paper", "study", "arxiv", "academic", "journal",
"scholar", "citation", "peer-review", "scientific", "thesis",
"experiment", "theory", "physics", "math"]):
tools.append({"name": "science_aggregate_search", "kwargs": {"query": subject}})
# --- Wikipedia ---
if any(kw in msg for kw in ["wikipedia", "wiki", "who is", "what is", "history of",
"explain", "definition", "meaning of", "tell me about"]):
tools.append({"name": "wikipedia_search", "kwargs": {"query": subject}})
# --- Web search (always include as general fallback) ---
tools.append({"name": "web_search", "kwargs": {"query": subject}})
tools.append({"name": "web_instant_answer", "kwargs": {"query": subject}})
# --- URL extraction ---
if url:
tools.append({"name": "web_get_page_content", "kwargs": {"url": url}})
# Deduplicate by name
seen = set()
unique = []
for tc in tools:
if tc["name"] not in seen:
seen.add(tc["name"])
unique.append(tc)
return unique
# --- Entity extractors ---
# Common stock tickers
_KNOWN_TICKERS = {
"aapl": "AAPL", "apple": "AAPL",
"googl": "GOOGL", "google": "GOOGL",
"msft": "MSFT", "microsoft": "MSFT",
"amzn": "AMZN", "amazon": "AMZN",
"tsla": "TSLA", "tesla": "TSLA",
"meta": "META", "facebook": "META",
"nvda": "NVDA", "nvidia": "NVDA",
"netflix": "NFLX", "nflx": "NFLX",
"amd": "AMD", "intel": "INTC",
"disney": "DIS", "jpmorgan": "JPM",
"ba": "BA", "boeing": "BA",
"walmart": "WMT", "wmt": "WMT",
"pfizer": "PFE", "pfe": "PFE",
"nio": "NIO", "pltr": "PLTR", "palantir": "PLTR",
"coin": "COIN", "coinbase": "COIN",
"roku": "ROKU", "spotify": "SPOT", "shopify": "SHOP",
}
# Common crypto names
_KNOWN_CRYPTO = {
"bitcoin": "bitcoin", "btc": "bitcoin",
"ethereum": "ethereum", "eth": "ethereum",
"solana": "solana", "sol": "solana",
"dogecoin": "dogecoin", "doge": "dogecoin",
"ripple": "ripple", "xrp": "ripple",
"cardano": "cardano", "ada": "cardano",
"polkadot": "polkadot", "dot": "polkadot",
"litecoin": "litecoin", "ltc": "litecoin",
"chainlink": "chainlink", "link": "chainlink",
"avalanche": "avalanche", "avax": "avalanche",
"polygon": "polygon", "matic": "polygon",
"shiba": "shiba-inu", "shib": "shiba-inu",
"tron": "tron", "trx": "tron",
"usdt": "tether", "tether": "tether",
}
# US states and common cities for location extraction
_KNOWN_LOCATIONS = [
"alabama", "alaska", "arizona", "arkansas", "california", "colorado",
"connecticut", "delaware", "florida", "georgia", "hawaii", "idaho",
"illinois", "indiana", "iowa", "kansas", "kentucky", "louisiana",
"maine", "maryland", "massachusetts", "michigan", "minnesota",
"mississippi", "missouri", "montana", "nebraska", "nevada",
"new hampshire", "new jersey", "new mexico", "new york", "north carolina",
"north dakota", "ohio", "oklahoma", "oregon", "pennsylvania",
"rhode island", "south carolina", "south dakota", "tennessee", "texas",
"utah", "vermont", "virginia", "washington", "west virginia",
"wisconsin", "wyoming", "oroville", "chico", "redding", "sacramento",
"los angeles", "san francisco", "san diego", "new york city", "chicago",
"houston", "phoenix", "dallas", "austin", "seattle", "portland",
"denver", "miami", "boston", "atlanta", "london", "paris", "tokyo",
"berlin", "sydney", "toronto", "vancouver", "melbourne",
]
_LOCATION_PREPOSITIONS = ["in", "at", "for", "near", "around", "outside", "from"]
def _extract_ticker(message: str) -> str:
"""Extract a stock ticker from the message."""
words = message.upper().split()
for i, word in enumerate(words):
clean = word.strip(",$.!?;:\"'")
# Check known names
if message.lower().split()[i] in _KNOWN_TICKERS:
return _KNOWN_TICKERS[message.lower().split()[i]]
# Check $TICKER format
if clean.startswith("$") and 1 <= len(clean[1:]) <= 5 and clean[1:].isalpha():
return clean[1:]
# Check raw uppercase ticker (1-5 alpha chars)
if 2 <= len(clean) <= 5 and clean.isalpha() and clean.isupper():
return clean
return ""
def _extract_crypto(message: str) -> str:
"""Extract a cryptocurrency name from the message."""
msg_lower = message.lower()
for name, coin_id in _KNOWN_CRYPTO.items():
if name in msg_lower:
return coin_id
return ""
def _extract_location(message: str) -> str:
"""Extract a location from the message."""
msg_lower = message.lower()
words = msg_lower.split()
# Try preposition-based extraction first (more specific)
best_match = ""
for i, word in enumerate(words):
if word in _LOCATION_PREPOSITIONS and i + 1 < len(words):
candidate_words = []
for j in range(i + 1, min(i + 6, len(words))):
if words[j] in _LOCATION_PREPOSITIONS or words[j] in [",", ".", "?", "!"]:
break
candidate_words.append(words[j].strip(",$.!?;:\"'"))
candidate = " ".join(candidate_words)
if not candidate:
continue
# Find the longest known location that matches within the candidate
matches = sorted(
[loc for loc in _KNOWN_LOCATIONS if loc in candidate],
key=len, reverse=True
)
if matches:
best_match = matches[0].title()
elif len(candidate) <= 4:
best_match = candidate.title()
if best_match:
return best_match
# Fallback: check for known locations appearing anywhere (prefer longest)
matches = sorted(
[loc for loc in _KNOWN_LOCATIONS if loc in msg_lower],
key=len, reverse=True
)
if matches:
return matches[0].title()
return ""
def _extract_url(message: str) -> str:
"""Extract a URL from the message."""
match = re.search(r'https?://[^\s<>"{}|\\^`\[\]]+', message)
return match.group(0) if match else ""
def _extract_subject(message: str) -> str:
"""Extract the main subject/query from the user message.
Strips common question patterns to get the core topic.
"""
subject = message.strip()
# Remove question starters (longest first to avoid partial matches)
starters = [
"give me all the", "give me all", "give me",
"tell me about", "tell me",
"what is the", "what is",
"what are the", "what are",
"what's the", "what's",
"how is the", "how is",
"how are the", "how are",
"how do", "how does",
"how much", "how many",
"can you", "could you", "would you",
"please ", "i need", "i want",
"show me", "get me",
"find me", "find",
"search for", "look up", "lookup",
"check the", "check",
"what's happening",
"explain", "describe", "summarize",
]
msg_lower = subject.lower()
for starter in starters:
if msg_lower.startswith(starter):
subject = subject[len(starter):].strip()
msg_lower = subject.lower()
break
# Strip leading "the", "a", "an"
for article in ["the ", "a ", "an "]:
if msg_lower.startswith(article):
subject = subject[len(article):]
msg_lower = subject.lower()
break
# Strip trailing punctuation
subject = subject.rstrip("?.!;, ")
# If still long, take first meaningful chunk
if len(subject) > 200:
subject = subject[:200]
return subject or message.strip()
def _build_tool_results_text(tool_results: list[dict]) -> str: def _build_tool_results_text(tool_results: list[dict]) -> str:
"""Build a text block of all tool results for the system prompt.""" """Build a text block of all tool results for the system prompt."""
if not tool_results: if not tool_results: