Fix tool call parsing, improve embeddings, and fix async issues

- main.py: Rewrote _parse_tool_call with brace-counting for robust JSON extraction
- main.py: Improved _clean_tool_syntax with brace-aware removal of tool_call JSON
- main.py: Fixed dict key mismatches (chunks_ingested, pages_downloaded)
- main.py: Run tool execution in asyncio.to_thread to avoid blocking event loop
- main.py: Always clean tool syntax from responses (handles edge cases)
- rag/__init__.py: Wrap blocking website_downloader in run_in_executor
- rag/__init__.py: Replace deprecated datetime.utcnow() with datetime.now(timezone.utc)
- rag/__init__.py: Add add_document_from_url method
- rag/vector_store.py: Replace hash-based embeddings with TF-IDF inspired embeddings
- rag/vector_store.py: Add embedding dimension mismatch handling in search
- README.md: Update API key config documentation
This commit is contained in:
Z User 2026-03-29 17:49:32 +00:00
parent 6eb18ce7f3
commit c03bde8023
4 changed files with 308 additions and 90 deletions

View File

@ -140,7 +140,7 @@ Configure via environment variables or `.env` file:
| `DEBUG` | `false` | Enable debug mode |
| `MODEL_NAME` | `DocRAG-GLM-4.7` | Display model name |
| `UPSTREAM_MODEL` | `glm-4.7` | Upstream model to use |
| `ZAI_API_KEY` | (required) | API key for ZAI SDK |
| `ZAI_API_KEY` / `OPENROUTER_API_KEY` | (required) | API key for upstream LLM (OpenRouter) |
| `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model |
| `VECTOR_STORE_PATH` | `./data/vectors` | Vector store location |
| `DOCUMENTS_PATH` | `./data/documents` | Document storage |

202
main.py
View File

@ -404,12 +404,12 @@ async def download_website_if_needed(user_message: str) -> dict[str, Any]:
# Check if site is already downloaded
site_info = state.rag_system.get_site_info(url)
if site_info:
log.info(f"Site already downloaded: {url} ({site_info.get('chunk_count', 0)} chunks)")
log.info(f"Site already downloaded: {url} ({site_info.get('chunks_ingested', 0)} chunks)")
return {
"downloaded": True,
"url": url,
"chunks": site_info.get("chunk_count", 0),
"pages": site_info.get("page_count", 0),
"chunks": site_info.get("chunks_ingested", 0),
"pages": site_info.get("pages_downloaded", 0),
"local_path": site_info.get("local_path"),
"cached": True,
}
@ -749,24 +749,58 @@ def _parse_tool_call(content: str) -> Optional[dict]:
"""Parse a tool call from LLM response content."""
import re
# Look for JSON tool_call in the response
# Pattern 1: ```json {"tool_call": ...} ```
json_match = re.search(r'```json\s*(\{.*?"tool_call".*?\})\s*```', content, re.DOTALL)
if json_match:
def _extract_json_object(text: str, start_key: str) -> Optional[dict]:
"""Extract a JSON object containing start_key using brace counting."""
# Find the start of the outermost object containing start_key
idx = text.find(start_key)
if idx == -1:
return None
# Walk backwards to find the opening { of this object
depth = 0
obj_start = -1
for i in range(idx, -1, -1):
if text[i] == '}':
depth += 1
elif text[i] == '{':
if depth == 0:
obj_start = i
break
depth -= 1
if obj_start == -1:
return None
# Walk forwards to find the matching closing }
depth = 0
obj_end = -1
for i in range(obj_start, len(text)):
if text[i] == '{':
depth += 1
elif text[i] == '}':
depth -= 1
if depth == 0:
obj_end = i + 1
break
if obj_end == -1:
return None
try:
data = json.loads(json_match.group(1))
return data.get("tool_call")
return json.loads(text[obj_start:obj_end])
except json.JSONDecodeError:
pass
return None
# Pattern 2: {"tool_call": {...}} anywhere in response
json_match = re.search(r'\{"tool_call":\s*\{[^}]+\}\s*\}', content)
if json_match:
try:
data = json.loads(json_match.group(0))
# Pattern 1: code fence blocks (```json, ```, ```JSON, etc.)
# Match any code fence that might contain a tool_call
fence_match = re.search(r'```\w*\s*(.*?)\s*```', content, re.DOTALL)
if fence_match:
block_text = fence_match.group(1)
if '"tool_call"' in block_text:
data = _extract_json_object(block_text, '"tool_call"')
if data and "tool_call" in data:
return data.get("tool_call")
# Pattern 2: {"tool_call": {...}} anywhere in response (bare JSON)
if '"tool_call"' in content:
data = _extract_json_object(content, '"tool_call"')
if data and "tool_call" in data:
return data.get("tool_call")
except json.JSONDecodeError:
pass
# Pattern 3: Look for tool name pattern like [USE: tool_name args]
bracket_match = re.search(r'\[USE:\s*(\w+)\s*(?:args:\s*(\{.*?\}))?\s*\]', content, re.DOTALL)
@ -836,54 +870,59 @@ async def generate_response(
# Check if response contains a tool call
tool_call = _parse_tool_call(content)
if tool_call and state.tool_manager:
if tool_call:
tool_name = tool_call.get("name")
tool_args = tool_call.get("arguments", {})
log.info(f"Parsed tool call: {tool_name}")
if state.tool_manager:
log.info(f"Parsed tool call: {tool_name}")
# Execute the tool
if isinstance(tool_args, dict):
result = state.tool_manager.execute_tool(tool_name, tool_args)
# Execute the tool (run in thread pool to avoid blocking the event loop)
if isinstance(tool_args, dict):
result = await asyncio.to_thread(
state.tool_manager.execute_tool, tool_name, tool_args
)
else:
result = await asyncio.to_thread(
state.tool_manager.execute_tool_from_json, tool_name, json.dumps(tool_args)
)
log.info(f"Tool {tool_name} result: success={result.get('success', False)}")
# Store tool result
tool_results.append({
"name": tool_name,
"result": result,
})
# Rebuild system message with tool results
# Find and update the system message
for i, msg in enumerate(messages_dict):
if msg["role"] == "system":
tool_result_text = f"\n\n--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data."
messages_dict[i]["content"] += tool_result_text
break
# Add assistant's tool call as a message
messages_dict.append({
"role": "assistant",
"content": f"[Executing tool: {tool_name}]"
})
# Add user prompt to continue
messages_dict.append({
"role": "user",
"content": f"The tool {tool_name} returned the above result. Please provide your response to the original question using this data."
})
continue
else:
result = state.tool_manager.execute_tool_from_json(tool_name, json.dumps(tool_args))
log.warning(f"Tool call detected ({tool_name}) but tool_manager is None! Stripping tool call from response.")
log.info(f"Tool {tool_name} result: success={result.get('success', False)}")
# Store tool result
tool_results.append({
"name": tool_name,
"result": result,
})
# Rebuild system message with tool results
# Find and update the system message
for i, msg in enumerate(messages_dict):
if msg["role"] == "system":
# Rebuild with tool results
# This is a simplified approach - in production you'd want better state management
tool_result_text = f"\n\n--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data."
messages_dict[i]["content"] += tool_result_text
break
# Add assistant's tool call as a message
messages_dict.append({
"role": "assistant",
"content": f"[Executing tool: {tool_name}]"
})
# Add user prompt to continue
messages_dict.append({
"role": "user",
"content": f"The tool {tool_name} returned the above result. Please provide your response to the original question using this data."
})
continue
# No tool call found - return the response
# Clean up any partial tool call syntax from response
# No tool call found (or tool_manager unavailable) - return the response
# ALWAYS run cleanup to strip any residual tool_call JSON from response
cleaned_content = _clean_tool_syntax(content)
log.info(f"Returning final response")
log.info(f"Returning final response (cleaned={len(cleaned_content) != len(content)})")
return cleaned_content or "I apologize, but I couldn't generate a response."
# Max iterations reached
@ -900,10 +939,51 @@ async def generate_response(
def _clean_tool_syntax(content: str) -> str:
"""Remove tool call syntax from response if partially included."""
import re
def _remove_json_containing_key(text: str, key: str) -> str:
"""Remove JSON objects containing a specific key from text."""
result = text
while key in result:
idx = result.find(key)
# Walk backwards to find opening {
depth = 0
obj_start = -1
for i in range(idx, -1, -1):
if result[i] == '}':
depth += 1
elif result[i] == '{':
if depth == 0:
obj_start = i
break
depth -= 1
if obj_start == -1:
break
# Walk forwards to find matching }
depth = 0
obj_end = -1
for i in range(obj_start, len(result)):
if result[i] == '{':
depth += 1
elif result[i] == '}':
depth -= 1
if depth == 0:
obj_end = i + 1
break
if obj_end == -1:
break
result = result[:obj_start] + result[obj_end:]
return result
# Remove ```json ... ``` blocks containing tool_call
cleaned = re.sub(r'```json\s*\{.*?"tool_call".*?\}\s*```', '', content, flags=re.DOTALL)
# Remove standalone tool_call JSON
cleaned = re.sub(r'\{"tool_call":\s*\{[^}]+\}\s*\}', '', cleaned)
def remove_code_block(m):
block = m.group(0)
inner = m.group(1)
if '"tool_call"' in inner:
return ''
return block
cleaned = re.sub(r'```json\s*(.*?)\s*```', remove_code_block, content, flags=re.DOTALL)
cleaned = _remove_json_containing_key(cleaned, '"tool_call"')
return cleaned.strip()

View File

@ -164,14 +164,17 @@ class RAGSystem:
log.info(f"Downloading website: {url}")
# Use website_downloader_tool to download the site
download_result = website_downloader(
url=url,
destination=str(self.downloaded_sites_path / self._get_site_folder(url)),
max_pages=max_pages,
threads=threads,
download_external_assets=download_external_assets,
external_domains=external_domains,
# Use website_downloader_tool to download the site (in thread pool to avoid blocking)
download_result = await asyncio.get_event_loop().run_in_executor(
None,
lambda: website_downloader(
url=url,
destination=str(self.downloaded_sites_path / self._get_site_folder(url)),
max_pages=max_pages,
threads=threads,
download_external_assets=download_external_assets,
external_domains=external_domains,
),
)
if not download_result.get("success"):
@ -331,8 +334,63 @@ class RAGSystem:
def _get_timestamp(self) -> str:
"""Get current timestamp in ISO format."""
from datetime import datetime
return datetime.utcnow().isoformat()
from datetime import datetime, timezone
return datetime.now(timezone.utc).isoformat()
async def add_document_from_url(
self,
url: str,
metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""
Add a document from a URL to the knowledge base.
Downloads the content from the URL and processes it into chunks.
Args:
url: URL of the document to add
metadata: Optional additional metadata
Returns:
Dictionary with processing results
"""
self._ensure_initialized()
try:
import aiohttp
except ImportError:
# Fallback to requests (sync)
import requests
try:
resp = requests.get(url, timeout=30)
resp.raise_for_status()
content = resp.content
except Exception as e:
raise RuntimeError(f"Failed to download {url}: {e}")
else:
# aiohttp available - use async
import asyncio
try:
async def _fetch():
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
resp.raise_for_status()
return await resp.read()
content = asyncio.get_event_loop().run_until_complete(_fetch())
except Exception as e:
raise RuntimeError(f"Failed to download {url}: {e}")
filename = url.split("/")[-1] or "document.html"
doc_metadata = {
"source_url": url,
**(metadata or {}),
}
return await self.add_document(
content=content,
filename=filename,
metadata=doc_metadata,
)
async def add_document(
self,

View File

@ -10,12 +10,21 @@ from __future__ import annotations
import hashlib
import json
import logging
import math
import os
import re
from collections import Counter
from pathlib import Path
from typing import Any, Optional
log = logging.getLogger(__name__)
# Default embedding dimension
_EMBEDDING_DIM = 256
# Simple tokenization pattern
_WORD_RE = re.compile(r'[a-zA-Z0-9]+' )
class VectorStore:
"""
@ -152,26 +161,89 @@ class VectorStore:
log.info(f"Added {len(chunks)} chunks to vector store")
def _tokenize(self, text: str) -> list[str]:
"""Simple word tokenization."""
return [w.lower() for w in _WORD_RE.findall(text) if len(w) > 1]
def _build_vocab(self, all_tokenized: list[list[str]], max_vocab: int = 10000) -> dict[str, int]:
"""Build vocabulary from tokenized texts with IDF weighting."""
doc_freq = Counter()
for tokens in all_tokenized:
unique_tokens = set(tokens)
for t in unique_tokens:
doc_freq[t] += 1
# Take top tokens by document frequency (most useful for search)
vocab = {}
for idx, (token, _) in enumerate(doc_freq.most_common(max_vocab)):
vocab[token] = idx
return vocab
async def _generate_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
Generate embeddings for texts.
Generate TF-IDF inspired embeddings for texts.
Uses a simple hash-based embedding for demonstration.
In production, use a real embedding model via API.
Uses a bag-of-words approach with TF-IDF weighting projected into a
fixed-dimension space. This produces meaningful cosine similarities
between semantically related texts, unlike hash-based embeddings.
In production, replace with a real embedding model API call.
"""
embeddings = []
if not texts:
return []
for text in texts:
# Simple hash-based embedding (for demo purposes)
# In production, use OpenAI embeddings or similar
hash_bytes = hashlib.sha256(text.encode()).digest()
# Create a 384-dimensional embedding (common size)
embedding = []
for i in range(384):
byte_idx = i % len(hash_bytes)
value = (hash_bytes[byte_idx] - 128) / 128.0
embedding.append(value)
embeddings.append(embedding)
# Tokenize all texts
all_tokenized = [self._tokenize(t) for t in texts]
# Build vocabulary from these texts + existing corpus
# Include existing chunks for consistent vocab
existing_texts = [c["content"] for c in self._chunks]
existing_tokenized = [self._tokenize(t) for t in existing_texts]
combined_tokenized = existing_tokenized + all_tokenized
vocab = self._build_vocab(combined_tokenized)
vocab_size = len(vocab)
if vocab_size == 0:
# Fallback: return zero vectors
return [[0.0] * _EMBEDDING_DIM for _ in texts]
# Compute IDF from all texts
n_docs = len(combined_tokenized)
idf = {}
for token, idx in vocab.items():
df = sum(1 for tokens in combined_tokenized if token in set(tokens))
idf[token] = math.log((n_docs + 1) / (df + 1)) + 1
# Dimension: project vocab into fixed dimension using hash-based assignment
dim = min(_EMBEDDING_DIM, vocab_size)
embeddings = []
for tokens in all_tokenized:
if not tokens:
embeddings.append([0.0] * dim)
continue
# Compute TF
tf = Counter(tokens)
max_tf = max(tf.values())
# Build sparse TF-IDF vector projected to fixed dimension
vec = [0.0] * dim
for token, count in tf.items():
if token not in vocab:
continue
normalized_tf = 0.5 + 0.5 * (count / max_tf) if max_tf > 0 else 0
tfidf = normalized_tf * idf.get(token, 1.0)
# Hash token to a dimension index
bucket = vocab[token] % dim
vec[bucket] += tfidf
# L2 normalize
norm = math.sqrt(sum(v * v for v in vec))
if norm > 0:
vec = [v / norm for v in vec]
embeddings.append(vec)
return embeddings
@ -197,9 +269,17 @@ class VectorStore:
if not self._chunks:
return []
# Generate query embedding
# Generate query embedding (use full corpus for consistent vocab)
query_embedding = (await self._generate_embeddings([query]))[0]
# Ensure dimensions match
if self._embeddings and len(query_embedding) != len(self._embeddings[0]):
log.warning(f"Embedding dimension mismatch: query={len(query_embedding)}, stored={len(self._embeddings[0])}. Using zero-padded query.")
if len(query_embedding) < len(self._embeddings[0]):
query_embedding = query_embedding + [0.0] * (len(self._embeddings[0]) - len(query_embedding))
else:
query_embedding = query_embedding[:len(self._embeddings[0])]
# Calculate similarities
results = []
for i, (chunk, embedding, metadata) in enumerate(