Rewrite: firehose all tools in parallel, then single LLM call

No LLM needed for tool selection. Flow is now:
  Request → run ALL tools in parallel → results into system prompt → 1 LLM call

- _run_all_tools: fires every tool concurrently (30s timeout each)
  - No required args: run with schema defaults
  - Query-like required args (query, topic, title, etc): use user message
  - Specific args (symbol, url, pmid): skip (can't guess)
- _build_tool_results_text: formats all results into system prompt
- build_enhanced_messages: system prompt now has real-time data section
- call_llm: dead simple, just prompt → response (replaces generate_response)
- Removed: generate_response, _parse_tool_calls, _clean_tool_syntax,
  _build_tool_descriptions (all dead code now)
- Streaming path: same flow, runs tools then streams the LLM response
- Both streaming and non-streaming use identical tool pipeline
This commit is contained in:
Z User 2026-03-29 18:36:37 +00:00
parent 8a46a78a4e
commit 70109d6889

517
main.py
View File

@ -439,6 +439,87 @@ async def download_website_if_needed(user_message: str) -> dict[str, Any]:
# Chat Completion Logic
# =============================================================================
async def _run_all_tools(user_message: str) -> list[dict]:
"""Run ALL tools in parallel. No LLM involved.
- Tools with no required args: run with defaults.
- Tools with required args: use the user message as the query argument.
- Each tool gets a timeout so slow ones don't block.
"""
if not state.tool_manager:
return []
async def _run_one(name: str, kwargs: dict):
try:
result = await asyncio.wait_for(
asyncio.to_thread(state.tool_manager.execute_tool, name, kwargs),
timeout=30,
)
return {"name": name, "success": True, "result": result}
except asyncio.TimeoutError:
return {"name": name, "success": False, "error": "timeout"}
except Exception as e:
return {"name": name, "success": False, "error": str(e)}
tasks = []
for name, schema in state.tool_manager._schemas.items():
func_schema = schema.get("function", {})
params = func_schema.get("parameters", {})
required = set(params.get("required", []))
props = params.get("properties", {})
# Build kwargs: defaults from schema, then fill required from user message
kwargs = {}
for pname, pinfo in props.items():
if "default" in pinfo:
kwargs[pname] = pinfo["default"]
for pname in required:
if pname not in kwargs:
# Heuristic: use user_message for common query-like params
if pname in ("query", "q", "search", "search_query", "topic", "title", "question"):
kwargs[pname] = user_message
# Use user_message for specific ID fields that look queryable
elif pname in ("disease",):
kwargs[pname] = user_message
# Skip tools that need specific args we can't guess (symbol, pmid, paper_id, url, etc.)
else:
kwargs = None
break
if kwargs is not None:
tasks.append(_run_one(name, kwargs))
else:
log.debug(f"Skipping tool {name}: can't fill required param from user message")
log.info(f"Running {len(tasks)} tools in parallel...")
results = await asyncio.gather(*tasks)
successes = [r for r in results if r["success"]]
log.info(f"Tool execution complete: {len(successes)}/{len(results)} succeeded")
return results
def _build_tool_results_text(tool_results: list[dict]) -> str:
"""Build a text block of all tool results for the system prompt."""
if not tool_results:
return ""
parts = []
for tr in tool_results:
name = tr["name"]
if tr["success"]:
result_data = tr.get("result", {})
# Truncate large results to keep prompt manageable
result_str = json.dumps(result_data, ensure_ascii=False)
if len(result_str) > 3000:
result_str = result_str[:3000] + '..." [TRUNCATED]'
parts.append(f"### {name}\n{result_str}")
else:
parts.append(f"### {name}\n[ERROR: {tr.get('error', 'unknown')}]")
return "\n\n".join(parts)
async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse:
"""Process a non-streaming chat completion request."""
log.info(f"=== Starting complete_chat for request {request_id} ===")
@ -456,12 +537,12 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat
log.info(f"User message: {user_message[:100]}...")
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
# Step 1: Download website if user is asking about one
download_info = await download_website_if_needed(user_message)
if download_info.get("downloaded"):
log.info(f"Website auto-downloaded: {download_info.get('url')}")
# Step 2: RAG Retrieval (now includes newly downloaded content)
# Step 2: RAG Retrieval
context = ""
sources = []
if state.rag_system:
@ -476,19 +557,24 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat
except Exception as e:
log.warning(f"RAG retrieval failed: {e}")
# Step 3: Build enhanced prompt with context
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
# Step 3: Run ALL tools in parallel (no LLM needed)
tool_results = []
if state.tool_manager and config.ENABLE_TOOLS:
tool_results = await _run_all_tools(user_message)
# Step 4: Generate response with upstream LLM
log.info(f"Calling generate_response for request {request_id}")
response_content = await generate_response(
# Step 4: Build system prompt with tool results as context
tool_results_text = _build_tool_results_text(tool_results)
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info, tool_results_text)
# Step 5: ONE LLM call
log.info(f"Calling LLM (single call) for request {request_id}")
response_content = await call_llm(
enhanced_messages,
temperature=request.temperature,
max_tokens=request.max_tokens,
)
log.info(f"=== Completed complete_chat for request {request_id} ===")
# Step 5: Build and return response
return ChatCompletionResponse(
id=request_id,
model=config.MODEL_NAME,
@ -526,12 +612,12 @@ async def stream_chat_completion(
yield f"data: {json.dumps({'error': 'No user message found'})}\n\n"
return
# Step 1: Download website if user is asking about one (BEFORE RAG retrieval)
# Step 1: Download website if user is asking about one
download_info = await download_website_if_needed(user_message)
if download_info.get("downloaded"):
log.info(f"Website auto-downloaded: {download_info.get('url')}")
# Step 2: RAG Retrieval (now includes newly downloaded content)
# Step 2: RAG Retrieval
context = ""
sources = []
if state.rag_system:
@ -546,61 +632,44 @@ async def stream_chat_completion(
except Exception as e:
log.warning(f"RAG retrieval failed: {e}")
# Step 3: Build enhanced prompt with context
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info)
# Step 3: Run ALL tools in parallel (no LLM needed)
tool_results = []
if state.tool_manager and config.ENABLE_TOOLS:
tool_results = await _run_all_tools(user_message)
# Step 4: Stream response from upstream LLM
# Step 4: Build system prompt with tool results as context
tool_results_text = _build_tool_results_text(tool_results)
enhanced_messages = build_enhanced_messages(messages, context, sources, download_info, tool_results_text)
# Step 5: ONE LLM call (stream the result)
created = int(time.time())
try:
if state.llm_client:
# For streaming with tools, we need to handle tool calls first
# Then stream the final response
if state.tool_manager and config.ENABLE_TOOLS:
# Use non-streaming for tool calls, then stream the result
response_content = await generate_response(
enhanced_messages,
temperature=request.temperature or 0.7,
max_tokens=request.max_tokens or 4096,
)
# Stream the final response as a single chunk
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': response_content},
'finish_reason': None
}]
})}\n\n"
else:
# No tools - use regular streaming
stream = await state.llm_client.chat.completions.create(
model=config.UPSTREAM_MODEL,
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
temperature=request.temperature or 0.7,
max_tokens=request.max_tokens or 4096,
stream=True,
)
stream = await state.llm_client.chat.completions.create(
model=config.UPSTREAM_MODEL,
messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content],
temperature=request.temperature or 0.7,
max_tokens=request.max_tokens or 4096,
stream=True,
)
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': content},
'finish_reason': None
}]
})}\n\n"
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
'created': created,
'model': config.MODEL_NAME,
'choices': [{
'index': 0,
'delta': {'content': content},
'finish_reason': None
}]
})}\n\n"
# Send final chunk
# Final chunk
yield f"data: {json.dumps({
'id': request_id,
'object': 'chat.completion.chunk',
@ -615,13 +684,12 @@ async def stream_chat_completion(
yield "data: [DONE]\n\n"
else:
# Mock streaming response for testing
# Mock streaming response
mock_response = f"I understand you're asking about: {user_message}\n\n"
if download_info.get("downloaded"):
mock_response += f"I have downloaded and analyzed {download_info.get('url')}.\n"
mock_response += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} chunks.\n\n"
if tool_results_text:
mock_response += f"I gathered data from {len(tool_results)} tools.\n\n"
if context:
mock_response += f"Based on my knowledge base, here's what I found:\n\n{context[:1000]}...\n\n"
mock_response += f"Knowledge base context:\n{context[:1000]}...\n\n"
mock_response += "\n\n[Demo mode - configure OPENROUTER_API_KEY for full LLM responses]"
for char in mock_response:
@ -661,55 +729,27 @@ def build_enhanced_messages(
context: str,
sources: list[str],
download_info: dict = None,
tool_results: list[dict] = None,
tool_results_text: str = "",
) -> list[ChatMessage]:
"""Build enhanced messages with RAG context."""
"""Build enhanced messages with RAG context and tool results in system prompt."""
enhanced = []
# Build tool descriptions for context
tool_descriptions = _build_tool_descriptions()
# Add system message with RAG context and tool instructions
system_content = f"""You are a helpful AI assistant with access to real-time data through various tools.
system_content = "You are a helpful AI assistant with access to real-time data.\n"
## AVAILABLE TOOLS
{tool_descriptions}
## HOW TO USE TOOLS
When you need to use one or more tools, output a SINGLE JSON block containing ALL tool calls as an array.
You MUST bundle every tool call into one response - do NOT respond with just one tool at a time.
Output EXACTLY this format (nothing else before or after):
```json
{{"tool_calls": [
{{"name": "tool_name", "arguments": {{"arg1": "value1"}}}},
{{"name": "another_tool", "arguments": {{"arg2": "value2"}}}}
]}}
```
## IMPORTANT RULES
1. ALWAYS use tools to get CURRENT data - do NOT say you cannot access real-time data
2. When asked about stocks, crypto, weather, or news, you MUST use the appropriate tool(s)
3. Bundle ALL needed tool calls into a single `tool_calls` array - include every tool you need in one response
4. After receiving tool results, provide a helpful, natural-language response based on the data
5. Be concise and factual - report exact data from tools
"""
if tool_results_text:
system_content += f"\n## REAL-TIME DATA (from tools)\n{tool_results_text}\n"
if download_info and download_info.get("downloaded"):
system_content += f"\n\n--- Website Access ---\n"
system_content += f"\n--- Website Access ---\n"
system_content += f"Downloaded website: {download_info.get('url')}\n"
system_content += f"Pages: {download_info.get('pages')}, Chunks: {download_info.get('chunks')}\n"
if context:
system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n"
system_content += f"\n## Relevant Context from Knowledge Base\n{context}\n"
if sources:
system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources[:10])
system_content += f"\nSources:\n" + "\n".join(f"- {s}" for s in sources[:10])
# Add previous tool results if any
if tool_results:
system_content += "\n\n--- PREVIOUS TOOL RESULTS ---\n"
for tr in tool_results:
system_content += f"\nTool: {tr['name']}\nResult: {json.dumps(tr['result'], indent=2)}\n"
system_content += "\n\n## INSTRUCTIONS\nUse the data above to answer the user's question. Be concise and factual."
enhanced.append(ChatMessage(role="system", content=system_content))
@ -721,283 +761,38 @@ Output EXACTLY this format (nothing else before or after):
return enhanced
def _build_tool_descriptions() -> str:
"""Build a concise description of all available tools for the system prompt."""
if not state.tool_manager:
return "No tools available."
descriptions = []
for name, schema in state.tool_manager._schemas.items():
func = schema.get("function", {})
desc = func.get("description", "")[:100] # Truncate long descriptions
params = func.get("parameters", {}).get("properties", {})
required = func.get("parameters", {}).get("required", [])
# Build param list
param_strs = []
for pname, pinfo in params.items():
ptype = pinfo.get("type", "any")
preq = " (required)" if pname in required else ""
param_strs.append(f"{pname}: {ptype}{preq}")
params_str = ", ".join(param_strs) if param_strs else "none"
descriptions.append(f"- {name}({params_str}): {desc}")
return "\n".join(descriptions)
def _parse_tool_calls(content: str) -> list[dict]:
"""Parse tool calls from LLM response content.
Expects the LLM to output a JSON block like:
{"tool_calls": [{"name": "tool_name", "arguments": {...}}, ...]}
Returns a list of tool call dicts, each with 'name' and 'arguments' keys.
"""
tool_calls = []
def _extract_json_object(text: str, start_key: str) -> Optional[dict]:
"""Extract a JSON object containing start_key using brace counting."""
idx = text.find(start_key)
if idx == -1:
return None
# Walk backwards to find the opening {
depth = 0
obj_start = -1
for i in range(idx, -1, -1):
if text[i] == '}':
depth += 1
elif text[i] == '{':
if depth == 0:
obj_start = i
break
depth -= 1
if obj_start == -1:
return None
# Walk forwards to find the matching closing }
depth = 0
obj_end = -1
for i in range(obj_start, len(text)):
if text[i] == '{':
depth += 1
elif text[i] == '}':
depth -= 1
if depth == 0:
obj_end = i + 1
break
if obj_end == -1:
return None
try:
return json.loads(text[obj_start:obj_end])
except json.JSONDecodeError:
return None
# --- Pattern 1: {"tool_calls": [...]} in a code fence block ---
fence_matches = re.findall(r'```\w*\s*(.*?)\s*```', content, re.DOTALL)
for block_text in fence_matches:
obj = _extract_json_object(block_text, '"tool_calls"')
if obj and "tool_calls" in obj and isinstance(obj["tool_calls"], list):
for tc in obj["tool_calls"]:
if isinstance(tc, dict) and "name" in tc:
tool_calls.append(tc)
if tool_calls:
log.info(f"Parsed tool_calls from code fence: {len(tool_calls)} calls")
return tool_calls
# --- Pattern 2: {"tool_calls": [...]} bare JSON (outside code fences) ---
stripped = re.sub(r'```\w*\s*.*?\s*```', '', content, flags=re.DOTALL)
obj = _extract_json_object(stripped, '"tool_calls"')
if obj and "tool_calls" in obj and isinstance(obj["tool_calls"], list):
for tc in obj["tool_calls"]:
if isinstance(tc, dict) and "name" in tc:
tool_calls.append(tc)
if tool_calls:
log.info(f"Parsed tool_calls from bare JSON: {len(tool_calls)} calls")
return tool_calls
# --- Pattern 3 (legacy fallback): {"tool_call": {...}} single tool ---
for block_text in fence_matches:
obj = _extract_json_object(block_text, '"tool_call"')
if obj and "tool_call" in obj and isinstance(obj["tool_call"], dict) and "name" in obj["tool_call"]:
tool_calls.append(obj["tool_call"])
if not tool_calls:
obj = _extract_json_object(stripped, '"tool_call"')
if obj and "tool_call" in obj and isinstance(obj["tool_call"], dict) and "name" in obj["tool_call"]:
tool_calls.append(obj["tool_call"])
if tool_calls:
log.info(f"Parsed tool_call (legacy format): {len(tool_calls)} calls")
return tool_calls
# --- Pattern 4 (desperate fallback): try to find any JSON with tool names ---
# Look for patterns like {"name": "some_tool_name", "arguments": {...}}
# This catches LLMs that output the array directly without the wrapper
known_tools = set(state.tool_manager._tools.keys()) if state.tool_manager else set()
if known_tools:
# Try extracting the entire content as JSON
try:
maybe_json = json.loads(content.strip())
candidates = []
if isinstance(maybe_json, list):
candidates = maybe_json
elif isinstance(maybe_json, dict) and "tool_calls" in maybe_json:
candidates = maybe_json["tool_calls"]
elif isinstance(maybe_json, dict) and "name" in maybe_json:
candidates = [maybe_json]
for tc in candidates:
if isinstance(tc, dict) and "name" in tc and tc["name"] in known_tools:
tool_calls.append(tc)
except (json.JSONDecodeError, TypeError):
pass
if tool_calls:
log.info(f"Parsed tool_calls (desperate fallback): {len(tool_calls)} calls")
return tool_calls
return tool_calls
async def generate_response(
async def call_llm(
messages: list[ChatMessage],
temperature: float = 0.7,
max_tokens: int = 4096,
) -> str:
"""Generate response using upstream LLM via OpenRouter.
Uses content-based tool calling: the LLM outputs a single JSON block with
all tool calls bundled as a `tool_calls` array. This works around model
limitations on the number of native tool calls per response.
"""
"""Single LLM call. No tool logic, just prompt in → response out."""
if not state.llm_client:
# Mock response for testing
user_msg = ""
for msg in reversed(messages):
if msg.role == "user" and msg.content:
user_msg = msg.content
break
return f"Demo mode response. Your question: {user_msg[:100]}... Configure OPENROUTER_API_KEY for full functionality."
return f"Demo mode. Your question: {user_msg[:100]}... Configure OPENROUTER_API_KEY for full functionality."
try:
# Convert messages to dict format
messages_dict = []
for m in messages:
if m.content:
messages_dict.append({"role": m.role, "content": m.content})
# Tool calling loop (content-based approach — no `tools` param to API)
max_iterations = config.MAX_TOOL_ITERATIONS
iteration = 0
while iteration < max_iterations:
iteration += 1
log.info(f"LLM call iteration {iteration}")
# Call LLM WITHOUT tools parameter — tool instructions are in the system prompt
response = await state.llm_client.chat.completions.create(
model=config.UPSTREAM_MODEL,
messages=messages_dict,
temperature=temperature,
max_tokens=max_tokens,
)
if not response.choices:
log.warning("No choices in response")
return "I apologize, but I couldn't generate a response."
content = response.choices[0].message.content or ""
log.info(f"LLM response: content_len={len(content)}")
# Log first 500 chars of response for debugging
log.debug(f"LLM response content preview: {content[:500]}")
if content:
log.info(f"LLM response (first 300 chars): {content[:300]!r}")
# --- Parse tool calls from content ---
tool_calls = _parse_tool_calls(content)
if tool_calls:
log.info(f"Parsed {len(tool_calls)} tool calls: {[tc.get('name') for tc in tool_calls]}")
# Execute ALL tools concurrently
if state.tool_manager:
import asyncio as _asyncio
async def _run_tool(tc):
name = tc.get("name")
args = tc.get("arguments", {})
if not isinstance(args, dict):
try:
args = json.loads(args)
except (json.JSONDecodeError, TypeError):
args = {}
result = await _asyncio.to_thread(
state.tool_manager.execute_tool, name, args
)
return name, result
results = await _asyncio.gather(*[_run_tool(tc) for tc in tool_calls])
# Build a single consolidated results block
results_text = ""
for name, result in results:
log.info(f"Tool {name} result: success={result.get('success', False)}")
results_text += f"\n### Tool: {name}\n{json.dumps(result, indent=2)}\n"
# Append assistant's tool call message to conversation
messages_dict.append({"role": "assistant", "content": content})
# Feed ALL results back in one user message
messages_dict.append({
"role": "user",
"content": (
f"--- ALL TOOL RESULTS ---\n"
f"Executed {len(tool_calls)} tool(s). Results:\n{results_text}\n"
f"---\n\n"
f"Now provide a helpful response to the original question using ALL the data above."
),
})
continue
else:
log.warning("Tool call detected but tool_manager is None")
# --- No tool calls — return the final response ---
cleaned_content = _clean_tool_syntax(content)
# Safety net: if cleaning stripped everything, return original content
# (better to show raw JSON than the useless "I apologize" message)
if not cleaned_content.strip() and content.strip():
log.warning(f"_clean_tool_syntax stripped all content! Original had {len(content)} chars. Returning original.")
cleaned_content = content.strip()
log.info(f"Returning final response (len={len(cleaned_content)})")
return cleaned_content or "I apologize, but I couldn't generate a response."
# Max iterations reached
log.warning(f"Max iterations ({max_iterations}) reached")
return "I reached the maximum number of tool call rounds. Please try a more specific question."
messages_dict = [{"role": m.role, "content": m.content} for m in messages if m.content]
response = await state.llm_client.chat.completions.create(
model=config.UPSTREAM_MODEL,
messages=messages_dict,
temperature=temperature,
max_tokens=max_tokens,
)
if not response.choices:
log.warning("No choices in LLM response")
return "I apologize, but I couldn't generate a response."
content = response.choices[0].message.content or ""
return content or "I apologize, but I couldn't generate a response."
except Exception as e:
log.error(f"OpenRouter LLM call failed: {e}")
import traceback
log.error(traceback.format_exc())
log.error(f"LLM call failed: {e}")
return f"I encountered an error: {str(e)}"
def _clean_tool_syntax(content: str) -> str:
"""Remove tool call JSON blocks from response text.
Strips code-fence-wrapped blocks containing "tool_calls" or "tool_call".
Does NOT strip bare JSON to avoid removing valid content.
"""
def remove_code_block(m):
block = m.group(0)
inner = m.group(1)
if '"tool_calls"' in inner or '"tool_call"' in inner:
return ''
return block
cleaned = re.sub(r'```\w*\s*(.*?)\s*```', remove_code_block, content, flags=re.DOTALL)
return cleaned.strip()
# =============================================================================
# Document Management Endpoints
# =============================================================================