From 7a6b6f10868e5dff65874714f724cd20605cdd19 Mon Sep 17 00:00:00 2001 From: Z User Date: Sun, 29 Mar 2026 18:44:14 +0000 Subject: [PATCH] Add local tool selector: keyword parser picks relevant tools, no LLM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _select_tools() parses the user message with keyword matching: - News keywords → news_aggregate, news_get_top_stories, news_get_reddit - Finance/stock keywords → finance_get_stock_info/history (extracts ticker) - Crypto keywords → finance_get_crypto_price (extracts coin name), finance_get_top_cryptos - Weather keywords → weather_get_current/forecast/air_quality (extracts location) - Medical keywords → pubmed, fda, disease data, health topics - Science keywords → science_aggregate_search - Wikipedia keywords → wikipedia_search - Always: web_search + web_instant_answer as general fallback - URL in message → web_get_page_content Entity extractors: - _extract_ticker: maps known company names, handles $TICKER format - _extract_crypto: maps known crypto names to CoinGecko IDs - _extract_location: preposition-based + known locations (prefers longest match) - _extract_subject: strips question patterns, leading articles, trailing punctuation Flow remains: request → select tools → run in parallel → results into system prompt → 1 LLM call --- main.py | 330 +++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 291 insertions(+), 39 deletions(-) diff --git a/main.py b/main.py index f276d88..2d9897c 100755 --- a/main.py +++ b/main.py @@ -440,15 +440,19 @@ async def download_website_if_needed(user_message: str) -> dict[str, Any]: # ============================================================================= async def _run_all_tools(user_message: str) -> list[dict]: - """Run ALL tools in parallel. No LLM involved. - - - Tools with no required args: run with defaults. - - Tools with required args: use the user message as the query argument. - - Each tool gets a timeout so slow ones don't block. - """ + """Select relevant tools via local keyword matching, then run them all in parallel.""" if not state.tool_manager: return [] + # Step 1: Determine which tools to call + tool_calls = _select_tools(user_message) + if not tool_calls: + log.info("No tools selected by keyword parser") + return [] + + log.info(f"Selected {len(tool_calls)} tools: {[tc['name'] for tc in tool_calls]}") + + # Step 2: Execute them all in parallel async def _run_one(name: str, kwargs: dict): try: result = await asyncio.wait_for( @@ -461,44 +465,292 @@ async def _run_all_tools(user_message: str) -> list[dict]: except Exception as e: return {"name": name, "success": False, "error": str(e)} - tasks = [] - for name, schema in state.tool_manager._schemas.items(): - func_schema = schema.get("function", {}) - params = func_schema.get("parameters", {}) - required = set(params.get("required", [])) - props = params.get("properties", {}) - - # Build kwargs: defaults from schema, then fill required from user message - kwargs = {} - for pname, pinfo in props.items(): - if "default" in pinfo: - kwargs[pname] = pinfo["default"] - - for pname in required: - if pname not in kwargs: - # Heuristic: use user_message for common query-like params - if pname in ("query", "q", "search", "search_query", "topic", "title", "question"): - kwargs[pname] = user_message - # Use user_message for specific ID fields that look queryable - elif pname in ("disease",): - kwargs[pname] = user_message - # Skip tools that need specific args we can't guess (symbol, pmid, paper_id, url, etc.) - else: - kwargs = None - break - - if kwargs is not None: - tasks.append(_run_one(name, kwargs)) - else: - log.debug(f"Skipping tool {name}: can't fill required param from user message") - - log.info(f"Running {len(tasks)} tools in parallel...") - results = await asyncio.gather(*tasks) + results = await asyncio.gather(*[_run_one(tc["name"], tc["kwargs"]) for tc in tool_calls]) successes = [r for r in results if r["success"]] log.info(f"Tool execution complete: {len(successes)}/{len(results)} succeeded") return results +def _select_tools(user_message: str) -> list[dict]: + """Parse the user message and determine which tools to call. + + Uses keyword/category matching. Returns list of {"name": str, "kwargs": dict}. + """ + msg = user_message.lower() + tools = [] + + # --- Extract useful entities from the message --- + location = _extract_location(user_message) + ticker = _extract_ticker(user_message) + crypto = _extract_crypto(user_message) + url = _extract_url(user_message) + subject = _extract_subject(user_message) # the main topic/query + + # --- News tools --- + if any(kw in msg for kw in ["news", "headline", "headlines", "current event", "current events", + "breaking", "trending", "reddit", "hacker", "story", "stories"]): + tools.append({"name": "news_aggregate", "kwargs": {"query": subject}}) + tools.append({"name": "news_get_top_stories", "kwargs": {}}) + tools.append({"name": "news_get_reddit", "kwargs": {"subreddit": "news"}}) + if any(kw in msg for kw in ["reddit", "subreddit"]): + tools.append({"name": "news_search_reddit", "kwargs": {"query": subject}}) + + # --- Finance / stock tools --- + if any(kw in msg for kw in ["stock", "share", "price", "market", "nasdaq", "nyse", + "dow", "s&p", "dividend", "portfolio", "ipo"]): + if ticker: + tools.append({"name": "finance_get_stock_info", "kwargs": {"symbol": ticker}}) + tools.append({"name": "finance_get_stock_history", "kwargs": {"symbol": ticker}}) + else: + # Try to extract any potential ticker from message + for word in user_message.upper().split(): + word = word.strip(",$.!?;:") + if 1 <= len(word) <= 5 and word.isalpha(): + tools.append({"name": "finance_get_stock_info", "kwargs": {"symbol": word}}) + break + + # --- Crypto tools --- + if any(kw in msg for kw in ["crypto", "bitcoin", "btc", "ethereum", "eth", "solana", + "dogecoin", "memecoin", "altcoin", "blockchain", "coin", "token"]): + if crypto: + tools.append({"name": "finance_get_crypto_price", "kwargs": {"coin_id": crypto}}) + tools.append({"name": "finance_get_top_cryptos", "kwargs": {}}) + + # --- Weather tools --- + if any(kw in msg for kw in ["weather", "temperature", "forecast", "rain", "snow", + "wind", "humid", "sunny", "cloudy", "storm", "air quality", + "aqi", "pollution", "uv index"]): + if location: + tools.append({"name": "weather_get_current", "kwargs": {"location": location}}) + tools.append({"name": "weather_get_forecast", "kwargs": {"location": location}}) + tools.append({"name": "weather_get_air_quality", "kwargs": {"location": location}}) + else: + tools.append({"name": "weather_get_current", "kwargs": {"location": subject}}) + tools.append({"name": "weather_get_forecast", "kwargs": {"location": subject}}) + + # --- Medical tools --- + if any(kw in msg for kw in ["medical", "health", "disease", "symptom", "drug", "medication", + "covid", "vaccine", "fda", "hospital", "pubmed", "clinical", + "treatment", "diagnosis", "doctor", "patient"]): + tools.append({"name": "medical_search_pubmed", "kwargs": {"query": subject}}) + tools.append({"name": "medical_search_fda", "kwargs": {"query": subject}}) + tools.append({"name": "medical_get_disease_data", "kwargs": {"disease": subject}}) + tools.append({"name": "medical_get_health_topics", "kwargs": {"topic": subject}}) + + # --- Science / research tools --- + if any(kw in msg for kw in ["research", "paper", "study", "arxiv", "academic", "journal", + "scholar", "citation", "peer-review", "scientific", "thesis", + "experiment", "theory", "physics", "math"]): + tools.append({"name": "science_aggregate_search", "kwargs": {"query": subject}}) + + # --- Wikipedia --- + if any(kw in msg for kw in ["wikipedia", "wiki", "who is", "what is", "history of", + "explain", "definition", "meaning of", "tell me about"]): + tools.append({"name": "wikipedia_search", "kwargs": {"query": subject}}) + + # --- Web search (always include as general fallback) --- + tools.append({"name": "web_search", "kwargs": {"query": subject}}) + tools.append({"name": "web_instant_answer", "kwargs": {"query": subject}}) + + # --- URL extraction --- + if url: + tools.append({"name": "web_get_page_content", "kwargs": {"url": url}}) + + # Deduplicate by name + seen = set() + unique = [] + for tc in tools: + if tc["name"] not in seen: + seen.add(tc["name"]) + unique.append(tc) + + return unique + + +# --- Entity extractors --- + +# Common stock tickers +_KNOWN_TICKERS = { + "aapl": "AAPL", "apple": "AAPL", + "googl": "GOOGL", "google": "GOOGL", + "msft": "MSFT", "microsoft": "MSFT", + "amzn": "AMZN", "amazon": "AMZN", + "tsla": "TSLA", "tesla": "TSLA", + "meta": "META", "facebook": "META", + "nvda": "NVDA", "nvidia": "NVDA", + "netflix": "NFLX", "nflx": "NFLX", + "amd": "AMD", "intel": "INTC", + "disney": "DIS", "jpmorgan": "JPM", + "ba": "BA", "boeing": "BA", + "walmart": "WMT", "wmt": "WMT", + "pfizer": "PFE", "pfe": "PFE", + "nio": "NIO", "pltr": "PLTR", "palantir": "PLTR", + "coin": "COIN", "coinbase": "COIN", + "roku": "ROKU", "spotify": "SPOT", "shopify": "SHOP", +} + +# Common crypto names +_KNOWN_CRYPTO = { + "bitcoin": "bitcoin", "btc": "bitcoin", + "ethereum": "ethereum", "eth": "ethereum", + "solana": "solana", "sol": "solana", + "dogecoin": "dogecoin", "doge": "dogecoin", + "ripple": "ripple", "xrp": "ripple", + "cardano": "cardano", "ada": "cardano", + "polkadot": "polkadot", "dot": "polkadot", + "litecoin": "litecoin", "ltc": "litecoin", + "chainlink": "chainlink", "link": "chainlink", + "avalanche": "avalanche", "avax": "avalanche", + "polygon": "polygon", "matic": "polygon", + "shiba": "shiba-inu", "shib": "shiba-inu", + "tron": "tron", "trx": "tron", + "usdt": "tether", "tether": "tether", +} + +# US states and common cities for location extraction +_KNOWN_LOCATIONS = [ + "alabama", "alaska", "arizona", "arkansas", "california", "colorado", + "connecticut", "delaware", "florida", "georgia", "hawaii", "idaho", + "illinois", "indiana", "iowa", "kansas", "kentucky", "louisiana", + "maine", "maryland", "massachusetts", "michigan", "minnesota", + "mississippi", "missouri", "montana", "nebraska", "nevada", + "new hampshire", "new jersey", "new mexico", "new york", "north carolina", + "north dakota", "ohio", "oklahoma", "oregon", "pennsylvania", + "rhode island", "south carolina", "south dakota", "tennessee", "texas", + "utah", "vermont", "virginia", "washington", "west virginia", + "wisconsin", "wyoming", "oroville", "chico", "redding", "sacramento", + "los angeles", "san francisco", "san diego", "new york city", "chicago", + "houston", "phoenix", "dallas", "austin", "seattle", "portland", + "denver", "miami", "boston", "atlanta", "london", "paris", "tokyo", + "berlin", "sydney", "toronto", "vancouver", "melbourne", +] + +_LOCATION_PREPOSITIONS = ["in", "at", "for", "near", "around", "outside", "from"] + + +def _extract_ticker(message: str) -> str: + """Extract a stock ticker from the message.""" + words = message.upper().split() + for i, word in enumerate(words): + clean = word.strip(",$.!?;:\"'") + # Check known names + if message.lower().split()[i] in _KNOWN_TICKERS: + return _KNOWN_TICKERS[message.lower().split()[i]] + # Check $TICKER format + if clean.startswith("$") and 1 <= len(clean[1:]) <= 5 and clean[1:].isalpha(): + return clean[1:] + # Check raw uppercase ticker (1-5 alpha chars) + if 2 <= len(clean) <= 5 and clean.isalpha() and clean.isupper(): + return clean + return "" + + +def _extract_crypto(message: str) -> str: + """Extract a cryptocurrency name from the message.""" + msg_lower = message.lower() + for name, coin_id in _KNOWN_CRYPTO.items(): + if name in msg_lower: + return coin_id + return "" + + +def _extract_location(message: str) -> str: + """Extract a location from the message.""" + msg_lower = message.lower() + words = msg_lower.split() + + # Try preposition-based extraction first (more specific) + best_match = "" + for i, word in enumerate(words): + if word in _LOCATION_PREPOSITIONS and i + 1 < len(words): + candidate_words = [] + for j in range(i + 1, min(i + 6, len(words))): + if words[j] in _LOCATION_PREPOSITIONS or words[j] in [",", ".", "?", "!"]: + break + candidate_words.append(words[j].strip(",$.!?;:\"'")) + candidate = " ".join(candidate_words) + if not candidate: + continue + # Find the longest known location that matches within the candidate + matches = sorted( + [loc for loc in _KNOWN_LOCATIONS if loc in candidate], + key=len, reverse=True + ) + if matches: + best_match = matches[0].title() + elif len(candidate) <= 4: + best_match = candidate.title() + if best_match: + return best_match + + # Fallback: check for known locations appearing anywhere (prefer longest) + matches = sorted( + [loc for loc in _KNOWN_LOCATIONS if loc in msg_lower], + key=len, reverse=True + ) + if matches: + return matches[0].title() + + return "" + + +def _extract_url(message: str) -> str: + """Extract a URL from the message.""" + match = re.search(r'https?://[^\s<>"{}|\\^`\[\]]+', message) + return match.group(0) if match else "" + + +def _extract_subject(message: str) -> str: + """Extract the main subject/query from the user message. + + Strips common question patterns to get the core topic. + """ + subject = message.strip() + + # Remove question starters (longest first to avoid partial matches) + starters = [ + "give me all the", "give me all", "give me", + "tell me about", "tell me", + "what is the", "what is", + "what are the", "what are", + "what's the", "what's", + "how is the", "how is", + "how are the", "how are", + "how do", "how does", + "how much", "how many", + "can you", "could you", "would you", + "please ", "i need", "i want", + "show me", "get me", + "find me", "find", + "search for", "look up", "lookup", + "check the", "check", + "what's happening", + "explain", "describe", "summarize", + ] + msg_lower = subject.lower() + for starter in starters: + if msg_lower.startswith(starter): + subject = subject[len(starter):].strip() + msg_lower = subject.lower() + break + + # Strip leading "the", "a", "an" + for article in ["the ", "a ", "an "]: + if msg_lower.startswith(article): + subject = subject[len(article):] + msg_lower = subject.lower() + break + + # Strip trailing punctuation + subject = subject.rstrip("?.!;, ") + + # If still long, take first meaningful chunk + if len(subject) > 200: + subject = subject[:200] + + return subject or message.strip() + + def _build_tool_results_text(tool_results: list[dict]) -> str: """Build a text block of all tool results for the system prompt.""" if not tool_results: