From c03bde8023e71cd031cb4bc2ed5811f209339108 Mon Sep 17 00:00:00 2001 From: Z User Date: Sun, 29 Mar 2026 17:49:32 +0000 Subject: [PATCH] Fix tool call parsing, improve embeddings, and fix async issues - main.py: Rewrote _parse_tool_call with brace-counting for robust JSON extraction - main.py: Improved _clean_tool_syntax with brace-aware removal of tool_call JSON - main.py: Fixed dict key mismatches (chunks_ingested, pages_downloaded) - main.py: Run tool execution in asyncio.to_thread to avoid blocking event loop - main.py: Always clean tool syntax from responses (handles edge cases) - rag/__init__.py: Wrap blocking website_downloader in run_in_executor - rag/__init__.py: Replace deprecated datetime.utcnow() with datetime.now(timezone.utc) - rag/__init__.py: Add add_document_from_url method - rag/vector_store.py: Replace hash-based embeddings with TF-IDF inspired embeddings - rag/vector_store.py: Add embedding dimension mismatch handling in search - README.md: Update API key config documentation --- README.md | 2 +- main.py | 206 ++++++++++++++++++++++++++++++-------------- rag/__init__.py | 78 ++++++++++++++--- rag/vector_store.py | 112 ++++++++++++++++++++---- 4 files changed, 308 insertions(+), 90 deletions(-) diff --git a/README.md b/README.md index c97f03b..ab53f85 100755 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ Configure via environment variables or `.env` file: | `DEBUG` | `false` | Enable debug mode | | `MODEL_NAME` | `DocRAG-GLM-4.7` | Display model name | | `UPSTREAM_MODEL` | `glm-4.7` | Upstream model to use | -| `ZAI_API_KEY` | (required) | API key for ZAI SDK | +| `ZAI_API_KEY` / `OPENROUTER_API_KEY` | (required) | API key for upstream LLM (OpenRouter) | | `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model | | `VECTOR_STORE_PATH` | `./data/vectors` | Vector store location | | `DOCUMENTS_PATH` | `./data/documents` | Document storage | diff --git a/main.py b/main.py index 7729ba8..249e3c3 100755 --- a/main.py +++ b/main.py @@ -404,12 +404,12 @@ async def download_website_if_needed(user_message: str) -> dict[str, Any]: # Check if site is already downloaded site_info = state.rag_system.get_site_info(url) if site_info: - log.info(f"Site already downloaded: {url} ({site_info.get('chunk_count', 0)} chunks)") + log.info(f"Site already downloaded: {url} ({site_info.get('chunks_ingested', 0)} chunks)") return { "downloaded": True, "url": url, - "chunks": site_info.get("chunk_count", 0), - "pages": site_info.get("page_count", 0), + "chunks": site_info.get("chunks_ingested", 0), + "pages": site_info.get("pages_downloaded", 0), "local_path": site_info.get("local_path"), "cached": True, } @@ -749,24 +749,58 @@ def _parse_tool_call(content: str) -> Optional[dict]: """Parse a tool call from LLM response content.""" import re - # Look for JSON tool_call in the response - # Pattern 1: ```json {"tool_call": ...} ``` - json_match = re.search(r'```json\s*(\{.*?"tool_call".*?\})\s*```', content, re.DOTALL) - if json_match: + def _extract_json_object(text: str, start_key: str) -> Optional[dict]: + """Extract a JSON object containing start_key using brace counting.""" + # Find the start of the outermost object containing start_key + idx = text.find(start_key) + if idx == -1: + return None + # Walk backwards to find the opening { of this object + depth = 0 + obj_start = -1 + for i in range(idx, -1, -1): + if text[i] == '}': + depth += 1 + elif text[i] == '{': + if depth == 0: + obj_start = i + break + depth -= 1 + if obj_start == -1: + return None + # Walk forwards to find the matching closing } + depth = 0 + obj_end = -1 + for i in range(obj_start, len(text)): + if text[i] == '{': + depth += 1 + elif text[i] == '}': + depth -= 1 + if depth == 0: + obj_end = i + 1 + break + if obj_end == -1: + return None try: - data = json.loads(json_match.group(1)) - return data.get("tool_call") + return json.loads(text[obj_start:obj_end]) except json.JSONDecodeError: - pass - - # Pattern 2: {"tool_call": {...}} anywhere in response - json_match = re.search(r'\{"tool_call":\s*\{[^}]+\}\s*\}', content) - if json_match: - try: - data = json.loads(json_match.group(0)) + return None + + # Pattern 1: code fence blocks (```json, ```, ```JSON, etc.) + # Match any code fence that might contain a tool_call + fence_match = re.search(r'```\w*\s*(.*?)\s*```', content, re.DOTALL) + if fence_match: + block_text = fence_match.group(1) + if '"tool_call"' in block_text: + data = _extract_json_object(block_text, '"tool_call"') + if data and "tool_call" in data: + return data.get("tool_call") + + # Pattern 2: {"tool_call": {...}} anywhere in response (bare JSON) + if '"tool_call"' in content: + data = _extract_json_object(content, '"tool_call"') + if data and "tool_call" in data: return data.get("tool_call") - except json.JSONDecodeError: - pass # Pattern 3: Look for tool name pattern like [USE: tool_name args] bracket_match = re.search(r'\[USE:\s*(\w+)\s*(?:args:\s*(\{.*?\}))?\s*\]', content, re.DOTALL) @@ -836,54 +870,59 @@ async def generate_response( # Check if response contains a tool call tool_call = _parse_tool_call(content) - if tool_call and state.tool_manager: + if tool_call: tool_name = tool_call.get("name") tool_args = tool_call.get("arguments", {}) - log.info(f"Parsed tool call: {tool_name}") - - # Execute the tool - if isinstance(tool_args, dict): - result = state.tool_manager.execute_tool(tool_name, tool_args) + if state.tool_manager: + log.info(f"Parsed tool call: {tool_name}") + + # Execute the tool (run in thread pool to avoid blocking the event loop) + if isinstance(tool_args, dict): + result = await asyncio.to_thread( + state.tool_manager.execute_tool, tool_name, tool_args + ) + else: + result = await asyncio.to_thread( + state.tool_manager.execute_tool_from_json, tool_name, json.dumps(tool_args) + ) + + log.info(f"Tool {tool_name} result: success={result.get('success', False)}") + + # Store tool result + tool_results.append({ + "name": tool_name, + "result": result, + }) + + # Rebuild system message with tool results + # Find and update the system message + for i, msg in enumerate(messages_dict): + if msg["role"] == "system": + tool_result_text = f"\n\n--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data." + messages_dict[i]["content"] += tool_result_text + break + + # Add assistant's tool call as a message + messages_dict.append({ + "role": "assistant", + "content": f"[Executing tool: {tool_name}]" + }) + + # Add user prompt to continue + messages_dict.append({ + "role": "user", + "content": f"The tool {tool_name} returned the above result. Please provide your response to the original question using this data." + }) + + continue else: - result = state.tool_manager.execute_tool_from_json(tool_name, json.dumps(tool_args)) - - log.info(f"Tool {tool_name} result: success={result.get('success', False)}") - - # Store tool result - tool_results.append({ - "name": tool_name, - "result": result, - }) - - # Rebuild system message with tool results - # Find and update the system message - for i, msg in enumerate(messages_dict): - if msg["role"] == "system": - # Rebuild with tool results - # This is a simplified approach - in production you'd want better state management - tool_result_text = f"\n\n--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data." - messages_dict[i]["content"] += tool_result_text - break - - # Add assistant's tool call as a message - messages_dict.append({ - "role": "assistant", - "content": f"[Executing tool: {tool_name}]" - }) - - # Add user prompt to continue - messages_dict.append({ - "role": "user", - "content": f"The tool {tool_name} returned the above result. Please provide your response to the original question using this data." - }) - - continue + log.warning(f"Tool call detected ({tool_name}) but tool_manager is None! Stripping tool call from response.") - # No tool call found - return the response - # Clean up any partial tool call syntax from response + # No tool call found (or tool_manager unavailable) - return the response + # ALWAYS run cleanup to strip any residual tool_call JSON from response cleaned_content = _clean_tool_syntax(content) - log.info(f"Returning final response") + log.info(f"Returning final response (cleaned={len(cleaned_content) != len(content)})") return cleaned_content or "I apologize, but I couldn't generate a response." # Max iterations reached @@ -900,10 +939,51 @@ async def generate_response( def _clean_tool_syntax(content: str) -> str: """Remove tool call syntax from response if partially included.""" import re + + def _remove_json_containing_key(text: str, key: str) -> str: + """Remove JSON objects containing a specific key from text.""" + result = text + while key in result: + idx = result.find(key) + # Walk backwards to find opening { + depth = 0 + obj_start = -1 + for i in range(idx, -1, -1): + if result[i] == '}': + depth += 1 + elif result[i] == '{': + if depth == 0: + obj_start = i + break + depth -= 1 + if obj_start == -1: + break + # Walk forwards to find matching } + depth = 0 + obj_end = -1 + for i in range(obj_start, len(result)): + if result[i] == '{': + depth += 1 + elif result[i] == '}': + depth -= 1 + if depth == 0: + obj_end = i + 1 + break + if obj_end == -1: + break + result = result[:obj_start] + result[obj_end:] + return result + # Remove ```json ... ``` blocks containing tool_call - cleaned = re.sub(r'```json\s*\{.*?"tool_call".*?\}\s*```', '', content, flags=re.DOTALL) - # Remove standalone tool_call JSON - cleaned = re.sub(r'\{"tool_call":\s*\{[^}]+\}\s*\}', '', cleaned) + def remove_code_block(m): + block = m.group(0) + inner = m.group(1) + if '"tool_call"' in inner: + return '' + return block + + cleaned = re.sub(r'```json\s*(.*?)\s*```', remove_code_block, content, flags=re.DOTALL) + cleaned = _remove_json_containing_key(cleaned, '"tool_call"') return cleaned.strip() diff --git a/rag/__init__.py b/rag/__init__.py index acf1d70..3417ce8 100755 --- a/rag/__init__.py +++ b/rag/__init__.py @@ -164,14 +164,17 @@ class RAGSystem: log.info(f"Downloading website: {url}") - # Use website_downloader_tool to download the site - download_result = website_downloader( - url=url, - destination=str(self.downloaded_sites_path / self._get_site_folder(url)), - max_pages=max_pages, - threads=threads, - download_external_assets=download_external_assets, - external_domains=external_domains, + # Use website_downloader_tool to download the site (in thread pool to avoid blocking) + download_result = await asyncio.get_event_loop().run_in_executor( + None, + lambda: website_downloader( + url=url, + destination=str(self.downloaded_sites_path / self._get_site_folder(url)), + max_pages=max_pages, + threads=threads, + download_external_assets=download_external_assets, + external_domains=external_domains, + ), ) if not download_result.get("success"): @@ -331,8 +334,63 @@ class RAGSystem: def _get_timestamp(self) -> str: """Get current timestamp in ISO format.""" - from datetime import datetime - return datetime.utcnow().isoformat() + from datetime import datetime, timezone + return datetime.now(timezone.utc).isoformat() + + async def add_document_from_url( + self, + url: str, + metadata: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """ + Add a document from a URL to the knowledge base. + + Downloads the content from the URL and processes it into chunks. + + Args: + url: URL of the document to add + metadata: Optional additional metadata + + Returns: + Dictionary with processing results + """ + self._ensure_initialized() + + try: + import aiohttp + except ImportError: + # Fallback to requests (sync) + import requests + try: + resp = requests.get(url, timeout=30) + resp.raise_for_status() + content = resp.content + except Exception as e: + raise RuntimeError(f"Failed to download {url}: {e}") + else: + # aiohttp available - use async + import asyncio + try: + async def _fetch(): + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp: + resp.raise_for_status() + return await resp.read() + content = asyncio.get_event_loop().run_until_complete(_fetch()) + except Exception as e: + raise RuntimeError(f"Failed to download {url}: {e}") + + filename = url.split("/")[-1] or "document.html" + doc_metadata = { + "source_url": url, + **(metadata or {}), + } + + return await self.add_document( + content=content, + filename=filename, + metadata=doc_metadata, + ) async def add_document( self, diff --git a/rag/vector_store.py b/rag/vector_store.py index 52eafa3..4b0af4c 100755 --- a/rag/vector_store.py +++ b/rag/vector_store.py @@ -10,12 +10,21 @@ from __future__ import annotations import hashlib import json import logging +import math import os +import re +from collections import Counter from pathlib import Path from typing import Any, Optional log = logging.getLogger(__name__) +# Default embedding dimension +_EMBEDDING_DIM = 256 + +# Simple tokenization pattern +_WORD_RE = re.compile(r'[a-zA-Z0-9]+' ) + class VectorStore: """ @@ -152,26 +161,89 @@ class VectorStore: log.info(f"Added {len(chunks)} chunks to vector store") + def _tokenize(self, text: str) -> list[str]: + """Simple word tokenization.""" + return [w.lower() for w in _WORD_RE.findall(text) if len(w) > 1] + + def _build_vocab(self, all_tokenized: list[list[str]], max_vocab: int = 10000) -> dict[str, int]: + """Build vocabulary from tokenized texts with IDF weighting.""" + doc_freq = Counter() + for tokens in all_tokenized: + unique_tokens = set(tokens) + for t in unique_tokens: + doc_freq[t] += 1 + # Take top tokens by document frequency (most useful for search) + vocab = {} + for idx, (token, _) in enumerate(doc_freq.most_common(max_vocab)): + vocab[token] = idx + return vocab + async def _generate_embeddings(self, texts: list[str]) -> list[list[float]]: """ - Generate embeddings for texts. + Generate TF-IDF inspired embeddings for texts. - Uses a simple hash-based embedding for demonstration. - In production, use a real embedding model via API. + Uses a bag-of-words approach with TF-IDF weighting projected into a + fixed-dimension space. This produces meaningful cosine similarities + between semantically related texts, unlike hash-based embeddings. + + In production, replace with a real embedding model API call. """ - embeddings = [] + if not texts: + return [] - for text in texts: - # Simple hash-based embedding (for demo purposes) - # In production, use OpenAI embeddings or similar - hash_bytes = hashlib.sha256(text.encode()).digest() - # Create a 384-dimensional embedding (common size) - embedding = [] - for i in range(384): - byte_idx = i % len(hash_bytes) - value = (hash_bytes[byte_idx] - 128) / 128.0 - embedding.append(value) - embeddings.append(embedding) + # Tokenize all texts + all_tokenized = [self._tokenize(t) for t in texts] + + # Build vocabulary from these texts + existing corpus + # Include existing chunks for consistent vocab + existing_texts = [c["content"] for c in self._chunks] + existing_tokenized = [self._tokenize(t) for t in existing_texts] + combined_tokenized = existing_tokenized + all_tokenized + + vocab = self._build_vocab(combined_tokenized) + vocab_size = len(vocab) + + if vocab_size == 0: + # Fallback: return zero vectors + return [[0.0] * _EMBEDDING_DIM for _ in texts] + + # Compute IDF from all texts + n_docs = len(combined_tokenized) + idf = {} + for token, idx in vocab.items(): + df = sum(1 for tokens in combined_tokenized if token in set(tokens)) + idf[token] = math.log((n_docs + 1) / (df + 1)) + 1 + + # Dimension: project vocab into fixed dimension using hash-based assignment + dim = min(_EMBEDDING_DIM, vocab_size) + + embeddings = [] + for tokens in all_tokenized: + if not tokens: + embeddings.append([0.0] * dim) + continue + + # Compute TF + tf = Counter(tokens) + max_tf = max(tf.values()) + + # Build sparse TF-IDF vector projected to fixed dimension + vec = [0.0] * dim + for token, count in tf.items(): + if token not in vocab: + continue + normalized_tf = 0.5 + 0.5 * (count / max_tf) if max_tf > 0 else 0 + tfidf = normalized_tf * idf.get(token, 1.0) + # Hash token to a dimension index + bucket = vocab[token] % dim + vec[bucket] += tfidf + + # L2 normalize + norm = math.sqrt(sum(v * v for v in vec)) + if norm > 0: + vec = [v / norm for v in vec] + + embeddings.append(vec) return embeddings @@ -197,9 +269,17 @@ class VectorStore: if not self._chunks: return [] - # Generate query embedding + # Generate query embedding (use full corpus for consistent vocab) query_embedding = (await self._generate_embeddings([query]))[0] + # Ensure dimensions match + if self._embeddings and len(query_embedding) != len(self._embeddings[0]): + log.warning(f"Embedding dimension mismatch: query={len(query_embedding)}, stored={len(self._embeddings[0])}. Using zero-padded query.") + if len(query_embedding) < len(self._embeddings[0]): + query_embedding = query_embedding + [0.0] * (len(self._embeddings[0]) - len(query_embedding)) + else: + query_embedding = query_embedding[:len(self._embeddings[0])] + # Calculate similarities results = [] for i, (chunk, embedding, metadata) in enumerate(