diff --git a/main.py b/main.py index b3559dc..f276d88 100755 --- a/main.py +++ b/main.py @@ -439,6 +439,87 @@ async def download_website_if_needed(user_message: str) -> dict[str, Any]: # Chat Completion Logic # ============================================================================= +async def _run_all_tools(user_message: str) -> list[dict]: + """Run ALL tools in parallel. No LLM involved. + + - Tools with no required args: run with defaults. + - Tools with required args: use the user message as the query argument. + - Each tool gets a timeout so slow ones don't block. + """ + if not state.tool_manager: + return [] + + async def _run_one(name: str, kwargs: dict): + try: + result = await asyncio.wait_for( + asyncio.to_thread(state.tool_manager.execute_tool, name, kwargs), + timeout=30, + ) + return {"name": name, "success": True, "result": result} + except asyncio.TimeoutError: + return {"name": name, "success": False, "error": "timeout"} + except Exception as e: + return {"name": name, "success": False, "error": str(e)} + + tasks = [] + for name, schema in state.tool_manager._schemas.items(): + func_schema = schema.get("function", {}) + params = func_schema.get("parameters", {}) + required = set(params.get("required", [])) + props = params.get("properties", {}) + + # Build kwargs: defaults from schema, then fill required from user message + kwargs = {} + for pname, pinfo in props.items(): + if "default" in pinfo: + kwargs[pname] = pinfo["default"] + + for pname in required: + if pname not in kwargs: + # Heuristic: use user_message for common query-like params + if pname in ("query", "q", "search", "search_query", "topic", "title", "question"): + kwargs[pname] = user_message + # Use user_message for specific ID fields that look queryable + elif pname in ("disease",): + kwargs[pname] = user_message + # Skip tools that need specific args we can't guess (symbol, pmid, paper_id, url, etc.) + else: + kwargs = None + break + + if kwargs is not None: + tasks.append(_run_one(name, kwargs)) + else: + log.debug(f"Skipping tool {name}: can't fill required param from user message") + + log.info(f"Running {len(tasks)} tools in parallel...") + results = await asyncio.gather(*tasks) + successes = [r for r in results if r["success"]] + log.info(f"Tool execution complete: {len(successes)}/{len(results)} succeeded") + return results + + +def _build_tool_results_text(tool_results: list[dict]) -> str: + """Build a text block of all tool results for the system prompt.""" + if not tool_results: + return "" + + parts = [] + for tr in tool_results: + name = tr["name"] + if tr["success"]: + result_data = tr.get("result", {}) + # Truncate large results to keep prompt manageable + result_str = json.dumps(result_data, ensure_ascii=False) + if len(result_str) > 3000: + result_str = result_str[:3000] + '..." [TRUNCATED]' + parts.append(f"### {name}\n{result_str}") + else: + parts.append(f"### {name}\n[ERROR: {tr.get('error', 'unknown')}]") + + return "\n\n".join(parts) + + async def complete_chat(request: ChatCompletionRequest, request_id: str) -> ChatCompletionResponse: """Process a non-streaming chat completion request.""" log.info(f"=== Starting complete_chat for request {request_id} ===") @@ -456,12 +537,12 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat log.info(f"User message: {user_message[:100]}...") - # Step 1: Download website if user is asking about one (BEFORE RAG retrieval) + # Step 1: Download website if user is asking about one download_info = await download_website_if_needed(user_message) if download_info.get("downloaded"): log.info(f"Website auto-downloaded: {download_info.get('url')}") - # Step 2: RAG Retrieval (now includes newly downloaded content) + # Step 2: RAG Retrieval context = "" sources = [] if state.rag_system: @@ -476,19 +557,24 @@ async def complete_chat(request: ChatCompletionRequest, request_id: str) -> Chat except Exception as e: log.warning(f"RAG retrieval failed: {e}") - # Step 3: Build enhanced prompt with context - enhanced_messages = build_enhanced_messages(messages, context, sources, download_info) + # Step 3: Run ALL tools in parallel (no LLM needed) + tool_results = [] + if state.tool_manager and config.ENABLE_TOOLS: + tool_results = await _run_all_tools(user_message) - # Step 4: Generate response with upstream LLM - log.info(f"Calling generate_response for request {request_id}") - response_content = await generate_response( + # Step 4: Build system prompt with tool results as context + tool_results_text = _build_tool_results_text(tool_results) + enhanced_messages = build_enhanced_messages(messages, context, sources, download_info, tool_results_text) + + # Step 5: ONE LLM call + log.info(f"Calling LLM (single call) for request {request_id}") + response_content = await call_llm( enhanced_messages, temperature=request.temperature, max_tokens=request.max_tokens, ) log.info(f"=== Completed complete_chat for request {request_id} ===") - # Step 5: Build and return response return ChatCompletionResponse( id=request_id, model=config.MODEL_NAME, @@ -526,12 +612,12 @@ async def stream_chat_completion( yield f"data: {json.dumps({'error': 'No user message found'})}\n\n" return - # Step 1: Download website if user is asking about one (BEFORE RAG retrieval) + # Step 1: Download website if user is asking about one download_info = await download_website_if_needed(user_message) if download_info.get("downloaded"): log.info(f"Website auto-downloaded: {download_info.get('url')}") - # Step 2: RAG Retrieval (now includes newly downloaded content) + # Step 2: RAG Retrieval context = "" sources = [] if state.rag_system: @@ -546,61 +632,44 @@ async def stream_chat_completion( except Exception as e: log.warning(f"RAG retrieval failed: {e}") - # Step 3: Build enhanced prompt with context - enhanced_messages = build_enhanced_messages(messages, context, sources, download_info) + # Step 3: Run ALL tools in parallel (no LLM needed) + tool_results = [] + if state.tool_manager and config.ENABLE_TOOLS: + tool_results = await _run_all_tools(user_message) - # Step 4: Stream response from upstream LLM + # Step 4: Build system prompt with tool results as context + tool_results_text = _build_tool_results_text(tool_results) + enhanced_messages = build_enhanced_messages(messages, context, sources, download_info, tool_results_text) + + # Step 5: ONE LLM call (stream the result) created = int(time.time()) try: if state.llm_client: - # For streaming with tools, we need to handle tool calls first - # Then stream the final response - if state.tool_manager and config.ENABLE_TOOLS: - # Use non-streaming for tool calls, then stream the result - response_content = await generate_response( - enhanced_messages, - temperature=request.temperature or 0.7, - max_tokens=request.max_tokens or 4096, - ) - # Stream the final response as a single chunk - yield f"data: {json.dumps({ - 'id': request_id, - 'object': 'chat.completion.chunk', - 'created': created, - 'model': config.MODEL_NAME, - 'choices': [{ - 'index': 0, - 'delta': {'content': response_content}, - 'finish_reason': None - }] - })}\n\n" - else: - # No tools - use regular streaming - stream = await state.llm_client.chat.completions.create( - model=config.UPSTREAM_MODEL, - messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content], - temperature=request.temperature or 0.7, - max_tokens=request.max_tokens or 4096, - stream=True, - ) + stream = await state.llm_client.chat.completions.create( + model=config.UPSTREAM_MODEL, + messages=[{"role": m.role, "content": m.content} for m in enhanced_messages if m.content], + temperature=request.temperature or 0.7, + max_tokens=request.max_tokens or 4096, + stream=True, + ) - async for chunk in stream: - if chunk.choices and chunk.choices[0].delta.content: - content = chunk.choices[0].delta.content - yield f"data: {json.dumps({ - 'id': request_id, - 'object': 'chat.completion.chunk', - 'created': created, - 'model': config.MODEL_NAME, - 'choices': [{ - 'index': 0, - 'delta': {'content': content}, - 'finish_reason': None - }] - })}\n\n" + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + yield f"data: {json.dumps({ + 'id': request_id, + 'object': 'chat.completion.chunk', + 'created': created, + 'model': config.MODEL_NAME, + 'choices': [{ + 'index': 0, + 'delta': {'content': content}, + 'finish_reason': None + }] + })}\n\n" - # Send final chunk + # Final chunk yield f"data: {json.dumps({ 'id': request_id, 'object': 'chat.completion.chunk', @@ -615,13 +684,12 @@ async def stream_chat_completion( yield "data: [DONE]\n\n" else: - # Mock streaming response for testing + # Mock streaming response mock_response = f"I understand you're asking about: {user_message}\n\n" - if download_info.get("downloaded"): - mock_response += f"I have downloaded and analyzed {download_info.get('url')}.\n" - mock_response += f"Processed {download_info.get('pages')} pages into {download_info.get('chunks')} chunks.\n\n" + if tool_results_text: + mock_response += f"I gathered data from {len(tool_results)} tools.\n\n" if context: - mock_response += f"Based on my knowledge base, here's what I found:\n\n{context[:1000]}...\n\n" + mock_response += f"Knowledge base context:\n{context[:1000]}...\n\n" mock_response += "\n\n[Demo mode - configure OPENROUTER_API_KEY for full LLM responses]" for char in mock_response: @@ -661,55 +729,27 @@ def build_enhanced_messages( context: str, sources: list[str], download_info: dict = None, - tool_results: list[dict] = None, + tool_results_text: str = "", ) -> list[ChatMessage]: - """Build enhanced messages with RAG context.""" + """Build enhanced messages with RAG context and tool results in system prompt.""" enhanced = [] - # Build tool descriptions for context - tool_descriptions = _build_tool_descriptions() - - # Add system message with RAG context and tool instructions - system_content = f"""You are a helpful AI assistant with access to real-time data through various tools. + system_content = "You are a helpful AI assistant with access to real-time data.\n" -## AVAILABLE TOOLS -{tool_descriptions} - -## HOW TO USE TOOLS -When you need to use one or more tools, output a SINGLE JSON block containing ALL tool calls as an array. -You MUST bundle every tool call into one response - do NOT respond with just one tool at a time. - -Output EXACTLY this format (nothing else before or after): -```json -{{"tool_calls": [ - {{"name": "tool_name", "arguments": {{"arg1": "value1"}}}}, - {{"name": "another_tool", "arguments": {{"arg2": "value2"}}}} -]}} -``` - -## IMPORTANT RULES -1. ALWAYS use tools to get CURRENT data - do NOT say you cannot access real-time data -2. When asked about stocks, crypto, weather, or news, you MUST use the appropriate tool(s) -3. Bundle ALL needed tool calls into a single `tool_calls` array - include every tool you need in one response -4. After receiving tool results, provide a helpful, natural-language response based on the data -5. Be concise and factual - report exact data from tools -""" + if tool_results_text: + system_content += f"\n## REAL-TIME DATA (from tools)\n{tool_results_text}\n" if download_info and download_info.get("downloaded"): - system_content += f"\n\n--- Website Access ---\n" + system_content += f"\n--- Website Access ---\n" system_content += f"Downloaded website: {download_info.get('url')}\n" system_content += f"Pages: {download_info.get('pages')}, Chunks: {download_info.get('chunks')}\n" if context: - system_content += f"\n\n--- Relevant Context from Knowledge Base ---\n{context}\n" + system_content += f"\n## Relevant Context from Knowledge Base\n{context}\n" if sources: - system_content += f"\n--- Sources ---\n" + "\n".join(f"- {s}" for s in sources[:10]) + system_content += f"\nSources:\n" + "\n".join(f"- {s}" for s in sources[:10]) - # Add previous tool results if any - if tool_results: - system_content += "\n\n--- PREVIOUS TOOL RESULTS ---\n" - for tr in tool_results: - system_content += f"\nTool: {tr['name']}\nResult: {json.dumps(tr['result'], indent=2)}\n" + system_content += "\n\n## INSTRUCTIONS\nUse the data above to answer the user's question. Be concise and factual." enhanced.append(ChatMessage(role="system", content=system_content)) @@ -721,283 +761,38 @@ Output EXACTLY this format (nothing else before or after): return enhanced -def _build_tool_descriptions() -> str: - """Build a concise description of all available tools for the system prompt.""" - if not state.tool_manager: - return "No tools available." - - descriptions = [] - for name, schema in state.tool_manager._schemas.items(): - func = schema.get("function", {}) - desc = func.get("description", "")[:100] # Truncate long descriptions - params = func.get("parameters", {}).get("properties", {}) - required = func.get("parameters", {}).get("required", []) - - # Build param list - param_strs = [] - for pname, pinfo in params.items(): - ptype = pinfo.get("type", "any") - preq = " (required)" if pname in required else "" - param_strs.append(f"{pname}: {ptype}{preq}") - - params_str = ", ".join(param_strs) if param_strs else "none" - descriptions.append(f"- {name}({params_str}): {desc}") - - return "\n".join(descriptions) - - -def _parse_tool_calls(content: str) -> list[dict]: - """Parse tool calls from LLM response content. - - Expects the LLM to output a JSON block like: - {"tool_calls": [{"name": "tool_name", "arguments": {...}}, ...]} - - Returns a list of tool call dicts, each with 'name' and 'arguments' keys. - """ - tool_calls = [] - - def _extract_json_object(text: str, start_key: str) -> Optional[dict]: - """Extract a JSON object containing start_key using brace counting.""" - idx = text.find(start_key) - if idx == -1: - return None - # Walk backwards to find the opening { - depth = 0 - obj_start = -1 - for i in range(idx, -1, -1): - if text[i] == '}': - depth += 1 - elif text[i] == '{': - if depth == 0: - obj_start = i - break - depth -= 1 - if obj_start == -1: - return None - # Walk forwards to find the matching closing } - depth = 0 - obj_end = -1 - for i in range(obj_start, len(text)): - if text[i] == '{': - depth += 1 - elif text[i] == '}': - depth -= 1 - if depth == 0: - obj_end = i + 1 - break - if obj_end == -1: - return None - try: - return json.loads(text[obj_start:obj_end]) - except json.JSONDecodeError: - return None - - # --- Pattern 1: {"tool_calls": [...]} in a code fence block --- - fence_matches = re.findall(r'```\w*\s*(.*?)\s*```', content, re.DOTALL) - for block_text in fence_matches: - obj = _extract_json_object(block_text, '"tool_calls"') - if obj and "tool_calls" in obj and isinstance(obj["tool_calls"], list): - for tc in obj["tool_calls"]: - if isinstance(tc, dict) and "name" in tc: - tool_calls.append(tc) - if tool_calls: - log.info(f"Parsed tool_calls from code fence: {len(tool_calls)} calls") - return tool_calls - - # --- Pattern 2: {"tool_calls": [...]} bare JSON (outside code fences) --- - stripped = re.sub(r'```\w*\s*.*?\s*```', '', content, flags=re.DOTALL) - obj = _extract_json_object(stripped, '"tool_calls"') - if obj and "tool_calls" in obj and isinstance(obj["tool_calls"], list): - for tc in obj["tool_calls"]: - if isinstance(tc, dict) and "name" in tc: - tool_calls.append(tc) - if tool_calls: - log.info(f"Parsed tool_calls from bare JSON: {len(tool_calls)} calls") - return tool_calls - - # --- Pattern 3 (legacy fallback): {"tool_call": {...}} single tool --- - for block_text in fence_matches: - obj = _extract_json_object(block_text, '"tool_call"') - if obj and "tool_call" in obj and isinstance(obj["tool_call"], dict) and "name" in obj["tool_call"]: - tool_calls.append(obj["tool_call"]) - if not tool_calls: - obj = _extract_json_object(stripped, '"tool_call"') - if obj and "tool_call" in obj and isinstance(obj["tool_call"], dict) and "name" in obj["tool_call"]: - tool_calls.append(obj["tool_call"]) - if tool_calls: - log.info(f"Parsed tool_call (legacy format): {len(tool_calls)} calls") - return tool_calls - - # --- Pattern 4 (desperate fallback): try to find any JSON with tool names --- - # Look for patterns like {"name": "some_tool_name", "arguments": {...}} - # This catches LLMs that output the array directly without the wrapper - known_tools = set(state.tool_manager._tools.keys()) if state.tool_manager else set() - if known_tools: - # Try extracting the entire content as JSON - try: - maybe_json = json.loads(content.strip()) - candidates = [] - if isinstance(maybe_json, list): - candidates = maybe_json - elif isinstance(maybe_json, dict) and "tool_calls" in maybe_json: - candidates = maybe_json["tool_calls"] - elif isinstance(maybe_json, dict) and "name" in maybe_json: - candidates = [maybe_json] - for tc in candidates: - if isinstance(tc, dict) and "name" in tc and tc["name"] in known_tools: - tool_calls.append(tc) - except (json.JSONDecodeError, TypeError): - pass - if tool_calls: - log.info(f"Parsed tool_calls (desperate fallback): {len(tool_calls)} calls") - return tool_calls - - return tool_calls - - -async def generate_response( +async def call_llm( messages: list[ChatMessage], temperature: float = 0.7, max_tokens: int = 4096, ) -> str: - """Generate response using upstream LLM via OpenRouter. - - Uses content-based tool calling: the LLM outputs a single JSON block with - all tool calls bundled as a `tool_calls` array. This works around model - limitations on the number of native tool calls per response. - """ + """Single LLM call. No tool logic, just prompt in → response out.""" if not state.llm_client: - # Mock response for testing user_msg = "" for msg in reversed(messages): if msg.role == "user" and msg.content: user_msg = msg.content break - return f"Demo mode response. Your question: {user_msg[:100]}... Configure OPENROUTER_API_KEY for full functionality." + return f"Demo mode. Your question: {user_msg[:100]}... Configure OPENROUTER_API_KEY for full functionality." try: - # Convert messages to dict format - messages_dict = [] - for m in messages: - if m.content: - messages_dict.append({"role": m.role, "content": m.content}) - - # Tool calling loop (content-based approach — no `tools` param to API) - max_iterations = config.MAX_TOOL_ITERATIONS - iteration = 0 - - while iteration < max_iterations: - iteration += 1 - log.info(f"LLM call iteration {iteration}") - - # Call LLM WITHOUT tools parameter — tool instructions are in the system prompt - response = await state.llm_client.chat.completions.create( - model=config.UPSTREAM_MODEL, - messages=messages_dict, - temperature=temperature, - max_tokens=max_tokens, - ) - - if not response.choices: - log.warning("No choices in response") - return "I apologize, but I couldn't generate a response." - - content = response.choices[0].message.content or "" - log.info(f"LLM response: content_len={len(content)}") - # Log first 500 chars of response for debugging - log.debug(f"LLM response content preview: {content[:500]}") - if content: - log.info(f"LLM response (first 300 chars): {content[:300]!r}") - - # --- Parse tool calls from content --- - tool_calls = _parse_tool_calls(content) - - if tool_calls: - log.info(f"Parsed {len(tool_calls)} tool calls: {[tc.get('name') for tc in tool_calls]}") - - # Execute ALL tools concurrently - if state.tool_manager: - import asyncio as _asyncio - - async def _run_tool(tc): - name = tc.get("name") - args = tc.get("arguments", {}) - if not isinstance(args, dict): - try: - args = json.loads(args) - except (json.JSONDecodeError, TypeError): - args = {} - result = await _asyncio.to_thread( - state.tool_manager.execute_tool, name, args - ) - return name, result - - results = await _asyncio.gather(*[_run_tool(tc) for tc in tool_calls]) - - # Build a single consolidated results block - results_text = "" - for name, result in results: - log.info(f"Tool {name} result: success={result.get('success', False)}") - results_text += f"\n### Tool: {name}\n{json.dumps(result, indent=2)}\n" - - # Append assistant's tool call message to conversation - messages_dict.append({"role": "assistant", "content": content}) - - # Feed ALL results back in one user message - messages_dict.append({ - "role": "user", - "content": ( - f"--- ALL TOOL RESULTS ---\n" - f"Executed {len(tool_calls)} tool(s). Results:\n{results_text}\n" - f"---\n\n" - f"Now provide a helpful response to the original question using ALL the data above." - ), - }) - - continue - else: - log.warning("Tool call detected but tool_manager is None") - - # --- No tool calls — return the final response --- - cleaned_content = _clean_tool_syntax(content) - - # Safety net: if cleaning stripped everything, return original content - # (better to show raw JSON than the useless "I apologize" message) - if not cleaned_content.strip() and content.strip(): - log.warning(f"_clean_tool_syntax stripped all content! Original had {len(content)} chars. Returning original.") - cleaned_content = content.strip() - - log.info(f"Returning final response (len={len(cleaned_content)})") - return cleaned_content or "I apologize, but I couldn't generate a response." - - # Max iterations reached - log.warning(f"Max iterations ({max_iterations}) reached") - return "I reached the maximum number of tool call rounds. Please try a more specific question." - + messages_dict = [{"role": m.role, "content": m.content} for m in messages if m.content] + response = await state.llm_client.chat.completions.create( + model=config.UPSTREAM_MODEL, + messages=messages_dict, + temperature=temperature, + max_tokens=max_tokens, + ) + if not response.choices: + log.warning("No choices in LLM response") + return "I apologize, but I couldn't generate a response." + content = response.choices[0].message.content or "" + return content or "I apologize, but I couldn't generate a response." except Exception as e: - log.error(f"OpenRouter LLM call failed: {e}") - import traceback - log.error(traceback.format_exc()) + log.error(f"LLM call failed: {e}") return f"I encountered an error: {str(e)}" -def _clean_tool_syntax(content: str) -> str: - """Remove tool call JSON blocks from response text. - - Strips code-fence-wrapped blocks containing "tool_calls" or "tool_call". - Does NOT strip bare JSON to avoid removing valid content. - """ - def remove_code_block(m): - block = m.group(0) - inner = m.group(1) - if '"tool_calls"' in inner or '"tool_call"' in inner: - return '' - return block - - cleaned = re.sub(r'```\w*\s*(.*?)\s*```', remove_code_block, content, flags=re.DOTALL) - return cleaned.strip() - - # ============================================================================= # Document Management Endpoints # =============================================================================