Add local tool selector: keyword parser picks relevant tools, no LLM
_select_tools() parses the user message with keyword matching: - News keywords → news_aggregate, news_get_top_stories, news_get_reddit - Finance/stock keywords → finance_get_stock_info/history (extracts ticker) - Crypto keywords → finance_get_crypto_price (extracts coin name), finance_get_top_cryptos - Weather keywords → weather_get_current/forecast/air_quality (extracts location) - Medical keywords → pubmed, fda, disease data, health topics - Science keywords → science_aggregate_search - Wikipedia keywords → wikipedia_search - Always: web_search + web_instant_answer as general fallback - URL in message → web_get_page_content Entity extractors: - _extract_ticker: maps known company names, handles $TICKER format - _extract_crypto: maps known crypto names to CoinGecko IDs - _extract_location: preposition-based + known locations (prefers longest match) - _extract_subject: strips question patterns, leading articles, trailing punctuation Flow remains: request → select tools → run in parallel → results into system prompt → 1 LLM call
This commit is contained in:
parent
70109d6889
commit
7a6b6f1086
330
main.py
330
main.py
@ -440,15 +440,19 @@ async def download_website_if_needed(user_message: str) -> dict[str, Any]:
|
||||
# =============================================================================
|
||||
|
||||
async def _run_all_tools(user_message: str) -> list[dict]:
|
||||
"""Run ALL tools in parallel. No LLM involved.
|
||||
|
||||
- Tools with no required args: run with defaults.
|
||||
- Tools with required args: use the user message as the query argument.
|
||||
- Each tool gets a timeout so slow ones don't block.
|
||||
"""
|
||||
"""Select relevant tools via local keyword matching, then run them all in parallel."""
|
||||
if not state.tool_manager:
|
||||
return []
|
||||
|
||||
# Step 1: Determine which tools to call
|
||||
tool_calls = _select_tools(user_message)
|
||||
if not tool_calls:
|
||||
log.info("No tools selected by keyword parser")
|
||||
return []
|
||||
|
||||
log.info(f"Selected {len(tool_calls)} tools: {[tc['name'] for tc in tool_calls]}")
|
||||
|
||||
# Step 2: Execute them all in parallel
|
||||
async def _run_one(name: str, kwargs: dict):
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
@ -461,44 +465,292 @@ async def _run_all_tools(user_message: str) -> list[dict]:
|
||||
except Exception as e:
|
||||
return {"name": name, "success": False, "error": str(e)}
|
||||
|
||||
tasks = []
|
||||
for name, schema in state.tool_manager._schemas.items():
|
||||
func_schema = schema.get("function", {})
|
||||
params = func_schema.get("parameters", {})
|
||||
required = set(params.get("required", []))
|
||||
props = params.get("properties", {})
|
||||
|
||||
# Build kwargs: defaults from schema, then fill required from user message
|
||||
kwargs = {}
|
||||
for pname, pinfo in props.items():
|
||||
if "default" in pinfo:
|
||||
kwargs[pname] = pinfo["default"]
|
||||
|
||||
for pname in required:
|
||||
if pname not in kwargs:
|
||||
# Heuristic: use user_message for common query-like params
|
||||
if pname in ("query", "q", "search", "search_query", "topic", "title", "question"):
|
||||
kwargs[pname] = user_message
|
||||
# Use user_message for specific ID fields that look queryable
|
||||
elif pname in ("disease",):
|
||||
kwargs[pname] = user_message
|
||||
# Skip tools that need specific args we can't guess (symbol, pmid, paper_id, url, etc.)
|
||||
else:
|
||||
kwargs = None
|
||||
break
|
||||
|
||||
if kwargs is not None:
|
||||
tasks.append(_run_one(name, kwargs))
|
||||
else:
|
||||
log.debug(f"Skipping tool {name}: can't fill required param from user message")
|
||||
|
||||
log.info(f"Running {len(tasks)} tools in parallel...")
|
||||
results = await asyncio.gather(*tasks)
|
||||
results = await asyncio.gather(*[_run_one(tc["name"], tc["kwargs"]) for tc in tool_calls])
|
||||
successes = [r for r in results if r["success"]]
|
||||
log.info(f"Tool execution complete: {len(successes)}/{len(results)} succeeded")
|
||||
return results
|
||||
|
||||
|
||||
def _select_tools(user_message: str) -> list[dict]:
|
||||
"""Parse the user message and determine which tools to call.
|
||||
|
||||
Uses keyword/category matching. Returns list of {"name": str, "kwargs": dict}.
|
||||
"""
|
||||
msg = user_message.lower()
|
||||
tools = []
|
||||
|
||||
# --- Extract useful entities from the message ---
|
||||
location = _extract_location(user_message)
|
||||
ticker = _extract_ticker(user_message)
|
||||
crypto = _extract_crypto(user_message)
|
||||
url = _extract_url(user_message)
|
||||
subject = _extract_subject(user_message) # the main topic/query
|
||||
|
||||
# --- News tools ---
|
||||
if any(kw in msg for kw in ["news", "headline", "headlines", "current event", "current events",
|
||||
"breaking", "trending", "reddit", "hacker", "story", "stories"]):
|
||||
tools.append({"name": "news_aggregate", "kwargs": {"query": subject}})
|
||||
tools.append({"name": "news_get_top_stories", "kwargs": {}})
|
||||
tools.append({"name": "news_get_reddit", "kwargs": {"subreddit": "news"}})
|
||||
if any(kw in msg for kw in ["reddit", "subreddit"]):
|
||||
tools.append({"name": "news_search_reddit", "kwargs": {"query": subject}})
|
||||
|
||||
# --- Finance / stock tools ---
|
||||
if any(kw in msg for kw in ["stock", "share", "price", "market", "nasdaq", "nyse",
|
||||
"dow", "s&p", "dividend", "portfolio", "ipo"]):
|
||||
if ticker:
|
||||
tools.append({"name": "finance_get_stock_info", "kwargs": {"symbol": ticker}})
|
||||
tools.append({"name": "finance_get_stock_history", "kwargs": {"symbol": ticker}})
|
||||
else:
|
||||
# Try to extract any potential ticker from message
|
||||
for word in user_message.upper().split():
|
||||
word = word.strip(",$.!?;:")
|
||||
if 1 <= len(word) <= 5 and word.isalpha():
|
||||
tools.append({"name": "finance_get_stock_info", "kwargs": {"symbol": word}})
|
||||
break
|
||||
|
||||
# --- Crypto tools ---
|
||||
if any(kw in msg for kw in ["crypto", "bitcoin", "btc", "ethereum", "eth", "solana",
|
||||
"dogecoin", "memecoin", "altcoin", "blockchain", "coin", "token"]):
|
||||
if crypto:
|
||||
tools.append({"name": "finance_get_crypto_price", "kwargs": {"coin_id": crypto}})
|
||||
tools.append({"name": "finance_get_top_cryptos", "kwargs": {}})
|
||||
|
||||
# --- Weather tools ---
|
||||
if any(kw in msg for kw in ["weather", "temperature", "forecast", "rain", "snow",
|
||||
"wind", "humid", "sunny", "cloudy", "storm", "air quality",
|
||||
"aqi", "pollution", "uv index"]):
|
||||
if location:
|
||||
tools.append({"name": "weather_get_current", "kwargs": {"location": location}})
|
||||
tools.append({"name": "weather_get_forecast", "kwargs": {"location": location}})
|
||||
tools.append({"name": "weather_get_air_quality", "kwargs": {"location": location}})
|
||||
else:
|
||||
tools.append({"name": "weather_get_current", "kwargs": {"location": subject}})
|
||||
tools.append({"name": "weather_get_forecast", "kwargs": {"location": subject}})
|
||||
|
||||
# --- Medical tools ---
|
||||
if any(kw in msg for kw in ["medical", "health", "disease", "symptom", "drug", "medication",
|
||||
"covid", "vaccine", "fda", "hospital", "pubmed", "clinical",
|
||||
"treatment", "diagnosis", "doctor", "patient"]):
|
||||
tools.append({"name": "medical_search_pubmed", "kwargs": {"query": subject}})
|
||||
tools.append({"name": "medical_search_fda", "kwargs": {"query": subject}})
|
||||
tools.append({"name": "medical_get_disease_data", "kwargs": {"disease": subject}})
|
||||
tools.append({"name": "medical_get_health_topics", "kwargs": {"topic": subject}})
|
||||
|
||||
# --- Science / research tools ---
|
||||
if any(kw in msg for kw in ["research", "paper", "study", "arxiv", "academic", "journal",
|
||||
"scholar", "citation", "peer-review", "scientific", "thesis",
|
||||
"experiment", "theory", "physics", "math"]):
|
||||
tools.append({"name": "science_aggregate_search", "kwargs": {"query": subject}})
|
||||
|
||||
# --- Wikipedia ---
|
||||
if any(kw in msg for kw in ["wikipedia", "wiki", "who is", "what is", "history of",
|
||||
"explain", "definition", "meaning of", "tell me about"]):
|
||||
tools.append({"name": "wikipedia_search", "kwargs": {"query": subject}})
|
||||
|
||||
# --- Web search (always include as general fallback) ---
|
||||
tools.append({"name": "web_search", "kwargs": {"query": subject}})
|
||||
tools.append({"name": "web_instant_answer", "kwargs": {"query": subject}})
|
||||
|
||||
# --- URL extraction ---
|
||||
if url:
|
||||
tools.append({"name": "web_get_page_content", "kwargs": {"url": url}})
|
||||
|
||||
# Deduplicate by name
|
||||
seen = set()
|
||||
unique = []
|
||||
for tc in tools:
|
||||
if tc["name"] not in seen:
|
||||
seen.add(tc["name"])
|
||||
unique.append(tc)
|
||||
|
||||
return unique
|
||||
|
||||
|
||||
# --- Entity extractors ---
|
||||
|
||||
# Common stock tickers
|
||||
_KNOWN_TICKERS = {
|
||||
"aapl": "AAPL", "apple": "AAPL",
|
||||
"googl": "GOOGL", "google": "GOOGL",
|
||||
"msft": "MSFT", "microsoft": "MSFT",
|
||||
"amzn": "AMZN", "amazon": "AMZN",
|
||||
"tsla": "TSLA", "tesla": "TSLA",
|
||||
"meta": "META", "facebook": "META",
|
||||
"nvda": "NVDA", "nvidia": "NVDA",
|
||||
"netflix": "NFLX", "nflx": "NFLX",
|
||||
"amd": "AMD", "intel": "INTC",
|
||||
"disney": "DIS", "jpmorgan": "JPM",
|
||||
"ba": "BA", "boeing": "BA",
|
||||
"walmart": "WMT", "wmt": "WMT",
|
||||
"pfizer": "PFE", "pfe": "PFE",
|
||||
"nio": "NIO", "pltr": "PLTR", "palantir": "PLTR",
|
||||
"coin": "COIN", "coinbase": "COIN",
|
||||
"roku": "ROKU", "spotify": "SPOT", "shopify": "SHOP",
|
||||
}
|
||||
|
||||
# Common crypto names
|
||||
_KNOWN_CRYPTO = {
|
||||
"bitcoin": "bitcoin", "btc": "bitcoin",
|
||||
"ethereum": "ethereum", "eth": "ethereum",
|
||||
"solana": "solana", "sol": "solana",
|
||||
"dogecoin": "dogecoin", "doge": "dogecoin",
|
||||
"ripple": "ripple", "xrp": "ripple",
|
||||
"cardano": "cardano", "ada": "cardano",
|
||||
"polkadot": "polkadot", "dot": "polkadot",
|
||||
"litecoin": "litecoin", "ltc": "litecoin",
|
||||
"chainlink": "chainlink", "link": "chainlink",
|
||||
"avalanche": "avalanche", "avax": "avalanche",
|
||||
"polygon": "polygon", "matic": "polygon",
|
||||
"shiba": "shiba-inu", "shib": "shiba-inu",
|
||||
"tron": "tron", "trx": "tron",
|
||||
"usdt": "tether", "tether": "tether",
|
||||
}
|
||||
|
||||
# US states and common cities for location extraction
|
||||
_KNOWN_LOCATIONS = [
|
||||
"alabama", "alaska", "arizona", "arkansas", "california", "colorado",
|
||||
"connecticut", "delaware", "florida", "georgia", "hawaii", "idaho",
|
||||
"illinois", "indiana", "iowa", "kansas", "kentucky", "louisiana",
|
||||
"maine", "maryland", "massachusetts", "michigan", "minnesota",
|
||||
"mississippi", "missouri", "montana", "nebraska", "nevada",
|
||||
"new hampshire", "new jersey", "new mexico", "new york", "north carolina",
|
||||
"north dakota", "ohio", "oklahoma", "oregon", "pennsylvania",
|
||||
"rhode island", "south carolina", "south dakota", "tennessee", "texas",
|
||||
"utah", "vermont", "virginia", "washington", "west virginia",
|
||||
"wisconsin", "wyoming", "oroville", "chico", "redding", "sacramento",
|
||||
"los angeles", "san francisco", "san diego", "new york city", "chicago",
|
||||
"houston", "phoenix", "dallas", "austin", "seattle", "portland",
|
||||
"denver", "miami", "boston", "atlanta", "london", "paris", "tokyo",
|
||||
"berlin", "sydney", "toronto", "vancouver", "melbourne",
|
||||
]
|
||||
|
||||
_LOCATION_PREPOSITIONS = ["in", "at", "for", "near", "around", "outside", "from"]
|
||||
|
||||
|
||||
def _extract_ticker(message: str) -> str:
|
||||
"""Extract a stock ticker from the message."""
|
||||
words = message.upper().split()
|
||||
for i, word in enumerate(words):
|
||||
clean = word.strip(",$.!?;:\"'")
|
||||
# Check known names
|
||||
if message.lower().split()[i] in _KNOWN_TICKERS:
|
||||
return _KNOWN_TICKERS[message.lower().split()[i]]
|
||||
# Check $TICKER format
|
||||
if clean.startswith("$") and 1 <= len(clean[1:]) <= 5 and clean[1:].isalpha():
|
||||
return clean[1:]
|
||||
# Check raw uppercase ticker (1-5 alpha chars)
|
||||
if 2 <= len(clean) <= 5 and clean.isalpha() and clean.isupper():
|
||||
return clean
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_crypto(message: str) -> str:
|
||||
"""Extract a cryptocurrency name from the message."""
|
||||
msg_lower = message.lower()
|
||||
for name, coin_id in _KNOWN_CRYPTO.items():
|
||||
if name in msg_lower:
|
||||
return coin_id
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_location(message: str) -> str:
|
||||
"""Extract a location from the message."""
|
||||
msg_lower = message.lower()
|
||||
words = msg_lower.split()
|
||||
|
||||
# Try preposition-based extraction first (more specific)
|
||||
best_match = ""
|
||||
for i, word in enumerate(words):
|
||||
if word in _LOCATION_PREPOSITIONS and i + 1 < len(words):
|
||||
candidate_words = []
|
||||
for j in range(i + 1, min(i + 6, len(words))):
|
||||
if words[j] in _LOCATION_PREPOSITIONS or words[j] in [",", ".", "?", "!"]:
|
||||
break
|
||||
candidate_words.append(words[j].strip(",$.!?;:\"'"))
|
||||
candidate = " ".join(candidate_words)
|
||||
if not candidate:
|
||||
continue
|
||||
# Find the longest known location that matches within the candidate
|
||||
matches = sorted(
|
||||
[loc for loc in _KNOWN_LOCATIONS if loc in candidate],
|
||||
key=len, reverse=True
|
||||
)
|
||||
if matches:
|
||||
best_match = matches[0].title()
|
||||
elif len(candidate) <= 4:
|
||||
best_match = candidate.title()
|
||||
if best_match:
|
||||
return best_match
|
||||
|
||||
# Fallback: check for known locations appearing anywhere (prefer longest)
|
||||
matches = sorted(
|
||||
[loc for loc in _KNOWN_LOCATIONS if loc in msg_lower],
|
||||
key=len, reverse=True
|
||||
)
|
||||
if matches:
|
||||
return matches[0].title()
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_url(message: str) -> str:
|
||||
"""Extract a URL from the message."""
|
||||
match = re.search(r'https?://[^\s<>"{}|\\^`\[\]]+', message)
|
||||
return match.group(0) if match else ""
|
||||
|
||||
|
||||
def _extract_subject(message: str) -> str:
|
||||
"""Extract the main subject/query from the user message.
|
||||
|
||||
Strips common question patterns to get the core topic.
|
||||
"""
|
||||
subject = message.strip()
|
||||
|
||||
# Remove question starters (longest first to avoid partial matches)
|
||||
starters = [
|
||||
"give me all the", "give me all", "give me",
|
||||
"tell me about", "tell me",
|
||||
"what is the", "what is",
|
||||
"what are the", "what are",
|
||||
"what's the", "what's",
|
||||
"how is the", "how is",
|
||||
"how are the", "how are",
|
||||
"how do", "how does",
|
||||
"how much", "how many",
|
||||
"can you", "could you", "would you",
|
||||
"please ", "i need", "i want",
|
||||
"show me", "get me",
|
||||
"find me", "find",
|
||||
"search for", "look up", "lookup",
|
||||
"check the", "check",
|
||||
"what's happening",
|
||||
"explain", "describe", "summarize",
|
||||
]
|
||||
msg_lower = subject.lower()
|
||||
for starter in starters:
|
||||
if msg_lower.startswith(starter):
|
||||
subject = subject[len(starter):].strip()
|
||||
msg_lower = subject.lower()
|
||||
break
|
||||
|
||||
# Strip leading "the", "a", "an"
|
||||
for article in ["the ", "a ", "an "]:
|
||||
if msg_lower.startswith(article):
|
||||
subject = subject[len(article):]
|
||||
msg_lower = subject.lower()
|
||||
break
|
||||
|
||||
# Strip trailing punctuation
|
||||
subject = subject.rstrip("?.!;, ")
|
||||
|
||||
# If still long, take first meaningful chunk
|
||||
if len(subject) > 200:
|
||||
subject = subject[:200]
|
||||
|
||||
return subject or message.strip()
|
||||
|
||||
|
||||
def _build_tool_results_text(tool_results: list[dict]) -> str:
|
||||
"""Build a text block of all tool results for the system prompt."""
|
||||
if not tool_results:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user