Fix tool call parsing, improve embeddings, and fix async issues

- main.py: Rewrote _parse_tool_call with brace-counting for robust JSON extraction
- main.py: Improved _clean_tool_syntax with brace-aware removal of tool_call JSON
- main.py: Fixed dict key mismatches (chunks_ingested, pages_downloaded)
- main.py: Run tool execution in asyncio.to_thread to avoid blocking event loop
- main.py: Always clean tool syntax from responses (handles edge cases)
- rag/__init__.py: Wrap blocking website_downloader in run_in_executor
- rag/__init__.py: Replace deprecated datetime.utcnow() with datetime.now(timezone.utc)
- rag/__init__.py: Add add_document_from_url method
- rag/vector_store.py: Replace hash-based embeddings with TF-IDF inspired embeddings
- rag/vector_store.py: Add embedding dimension mismatch handling in search
- README.md: Update API key config documentation
This commit is contained in:
Z User 2026-03-29 17:49:32 +00:00
parent 6eb18ce7f3
commit c03bde8023
4 changed files with 308 additions and 90 deletions

View File

@ -140,7 +140,7 @@ Configure via environment variables or `.env` file:
| `DEBUG` | `false` | Enable debug mode | | `DEBUG` | `false` | Enable debug mode |
| `MODEL_NAME` | `DocRAG-GLM-4.7` | Display model name | | `MODEL_NAME` | `DocRAG-GLM-4.7` | Display model name |
| `UPSTREAM_MODEL` | `glm-4.7` | Upstream model to use | | `UPSTREAM_MODEL` | `glm-4.7` | Upstream model to use |
| `ZAI_API_KEY` | (required) | API key for ZAI SDK | | `ZAI_API_KEY` / `OPENROUTER_API_KEY` | (required) | API key for upstream LLM (OpenRouter) |
| `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model | | `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model |
| `VECTOR_STORE_PATH` | `./data/vectors` | Vector store location | | `VECTOR_STORE_PATH` | `./data/vectors` | Vector store location |
| `DOCUMENTS_PATH` | `./data/documents` | Document storage | | `DOCUMENTS_PATH` | `./data/documents` | Document storage |

202
main.py
View File

@ -404,12 +404,12 @@ async def download_website_if_needed(user_message: str) -> dict[str, Any]:
# Check if site is already downloaded # Check if site is already downloaded
site_info = state.rag_system.get_site_info(url) site_info = state.rag_system.get_site_info(url)
if site_info: if site_info:
log.info(f"Site already downloaded: {url} ({site_info.get('chunk_count', 0)} chunks)") log.info(f"Site already downloaded: {url} ({site_info.get('chunks_ingested', 0)} chunks)")
return { return {
"downloaded": True, "downloaded": True,
"url": url, "url": url,
"chunks": site_info.get("chunk_count", 0), "chunks": site_info.get("chunks_ingested", 0),
"pages": site_info.get("page_count", 0), "pages": site_info.get("pages_downloaded", 0),
"local_path": site_info.get("local_path"), "local_path": site_info.get("local_path"),
"cached": True, "cached": True,
} }
@ -749,24 +749,58 @@ def _parse_tool_call(content: str) -> Optional[dict]:
"""Parse a tool call from LLM response content.""" """Parse a tool call from LLM response content."""
import re import re
# Look for JSON tool_call in the response def _extract_json_object(text: str, start_key: str) -> Optional[dict]:
# Pattern 1: ```json {"tool_call": ...} ``` """Extract a JSON object containing start_key using brace counting."""
json_match = re.search(r'```json\s*(\{.*?"tool_call".*?\})\s*```', content, re.DOTALL) # Find the start of the outermost object containing start_key
if json_match: idx = text.find(start_key)
if idx == -1:
return None
# Walk backwards to find the opening { of this object
depth = 0
obj_start = -1
for i in range(idx, -1, -1):
if text[i] == '}':
depth += 1
elif text[i] == '{':
if depth == 0:
obj_start = i
break
depth -= 1
if obj_start == -1:
return None
# Walk forwards to find the matching closing }
depth = 0
obj_end = -1
for i in range(obj_start, len(text)):
if text[i] == '{':
depth += 1
elif text[i] == '}':
depth -= 1
if depth == 0:
obj_end = i + 1
break
if obj_end == -1:
return None
try: try:
data = json.loads(json_match.group(1)) return json.loads(text[obj_start:obj_end])
return data.get("tool_call")
except json.JSONDecodeError: except json.JSONDecodeError:
pass return None
# Pattern 2: {"tool_call": {...}} anywhere in response # Pattern 1: code fence blocks (```json, ```, ```JSON, etc.)
json_match = re.search(r'\{"tool_call":\s*\{[^}]+\}\s*\}', content) # Match any code fence that might contain a tool_call
if json_match: fence_match = re.search(r'```\w*\s*(.*?)\s*```', content, re.DOTALL)
try: if fence_match:
data = json.loads(json_match.group(0)) block_text = fence_match.group(1)
if '"tool_call"' in block_text:
data = _extract_json_object(block_text, '"tool_call"')
if data and "tool_call" in data:
return data.get("tool_call")
# Pattern 2: {"tool_call": {...}} anywhere in response (bare JSON)
if '"tool_call"' in content:
data = _extract_json_object(content, '"tool_call"')
if data and "tool_call" in data:
return data.get("tool_call") return data.get("tool_call")
except json.JSONDecodeError:
pass
# Pattern 3: Look for tool name pattern like [USE: tool_name args] # Pattern 3: Look for tool name pattern like [USE: tool_name args]
bracket_match = re.search(r'\[USE:\s*(\w+)\s*(?:args:\s*(\{.*?\}))?\s*\]', content, re.DOTALL) bracket_match = re.search(r'\[USE:\s*(\w+)\s*(?:args:\s*(\{.*?\}))?\s*\]', content, re.DOTALL)
@ -836,54 +870,59 @@ async def generate_response(
# Check if response contains a tool call # Check if response contains a tool call
tool_call = _parse_tool_call(content) tool_call = _parse_tool_call(content)
if tool_call and state.tool_manager: if tool_call:
tool_name = tool_call.get("name") tool_name = tool_call.get("name")
tool_args = tool_call.get("arguments", {}) tool_args = tool_call.get("arguments", {})
log.info(f"Parsed tool call: {tool_name}") if state.tool_manager:
log.info(f"Parsed tool call: {tool_name}")
# Execute the tool # Execute the tool (run in thread pool to avoid blocking the event loop)
if isinstance(tool_args, dict): if isinstance(tool_args, dict):
result = state.tool_manager.execute_tool(tool_name, tool_args) result = await asyncio.to_thread(
state.tool_manager.execute_tool, tool_name, tool_args
)
else:
result = await asyncio.to_thread(
state.tool_manager.execute_tool_from_json, tool_name, json.dumps(tool_args)
)
log.info(f"Tool {tool_name} result: success={result.get('success', False)}")
# Store tool result
tool_results.append({
"name": tool_name,
"result": result,
})
# Rebuild system message with tool results
# Find and update the system message
for i, msg in enumerate(messages_dict):
if msg["role"] == "system":
tool_result_text = f"\n\n--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data."
messages_dict[i]["content"] += tool_result_text
break
# Add assistant's tool call as a message
messages_dict.append({
"role": "assistant",
"content": f"[Executing tool: {tool_name}]"
})
# Add user prompt to continue
messages_dict.append({
"role": "user",
"content": f"The tool {tool_name} returned the above result. Please provide your response to the original question using this data."
})
continue
else: else:
result = state.tool_manager.execute_tool_from_json(tool_name, json.dumps(tool_args)) log.warning(f"Tool call detected ({tool_name}) but tool_manager is None! Stripping tool call from response.")
log.info(f"Tool {tool_name} result: success={result.get('success', False)}") # No tool call found (or tool_manager unavailable) - return the response
# ALWAYS run cleanup to strip any residual tool_call JSON from response
# Store tool result
tool_results.append({
"name": tool_name,
"result": result,
})
# Rebuild system message with tool results
# Find and update the system message
for i, msg in enumerate(messages_dict):
if msg["role"] == "system":
# Rebuild with tool results
# This is a simplified approach - in production you'd want better state management
tool_result_text = f"\n\n--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data."
messages_dict[i]["content"] += tool_result_text
break
# Add assistant's tool call as a message
messages_dict.append({
"role": "assistant",
"content": f"[Executing tool: {tool_name}]"
})
# Add user prompt to continue
messages_dict.append({
"role": "user",
"content": f"The tool {tool_name} returned the above result. Please provide your response to the original question using this data."
})
continue
# No tool call found - return the response
# Clean up any partial tool call syntax from response
cleaned_content = _clean_tool_syntax(content) cleaned_content = _clean_tool_syntax(content)
log.info(f"Returning final response") log.info(f"Returning final response (cleaned={len(cleaned_content) != len(content)})")
return cleaned_content or "I apologize, but I couldn't generate a response." return cleaned_content or "I apologize, but I couldn't generate a response."
# Max iterations reached # Max iterations reached
@ -900,10 +939,51 @@ async def generate_response(
def _clean_tool_syntax(content: str) -> str: def _clean_tool_syntax(content: str) -> str:
"""Remove tool call syntax from response if partially included.""" """Remove tool call syntax from response if partially included."""
import re import re
def _remove_json_containing_key(text: str, key: str) -> str:
"""Remove JSON objects containing a specific key from text."""
result = text
while key in result:
idx = result.find(key)
# Walk backwards to find opening {
depth = 0
obj_start = -1
for i in range(idx, -1, -1):
if result[i] == '}':
depth += 1
elif result[i] == '{':
if depth == 0:
obj_start = i
break
depth -= 1
if obj_start == -1:
break
# Walk forwards to find matching }
depth = 0
obj_end = -1
for i in range(obj_start, len(result)):
if result[i] == '{':
depth += 1
elif result[i] == '}':
depth -= 1
if depth == 0:
obj_end = i + 1
break
if obj_end == -1:
break
result = result[:obj_start] + result[obj_end:]
return result
# Remove ```json ... ``` blocks containing tool_call # Remove ```json ... ``` blocks containing tool_call
cleaned = re.sub(r'```json\s*\{.*?"tool_call".*?\}\s*```', '', content, flags=re.DOTALL) def remove_code_block(m):
# Remove standalone tool_call JSON block = m.group(0)
cleaned = re.sub(r'\{"tool_call":\s*\{[^}]+\}\s*\}', '', cleaned) inner = m.group(1)
if '"tool_call"' in inner:
return ''
return block
cleaned = re.sub(r'```json\s*(.*?)\s*```', remove_code_block, content, flags=re.DOTALL)
cleaned = _remove_json_containing_key(cleaned, '"tool_call"')
return cleaned.strip() return cleaned.strip()

View File

@ -164,14 +164,17 @@ class RAGSystem:
log.info(f"Downloading website: {url}") log.info(f"Downloading website: {url}")
# Use website_downloader_tool to download the site # Use website_downloader_tool to download the site (in thread pool to avoid blocking)
download_result = website_downloader( download_result = await asyncio.get_event_loop().run_in_executor(
url=url, None,
destination=str(self.downloaded_sites_path / self._get_site_folder(url)), lambda: website_downloader(
max_pages=max_pages, url=url,
threads=threads, destination=str(self.downloaded_sites_path / self._get_site_folder(url)),
download_external_assets=download_external_assets, max_pages=max_pages,
external_domains=external_domains, threads=threads,
download_external_assets=download_external_assets,
external_domains=external_domains,
),
) )
if not download_result.get("success"): if not download_result.get("success"):
@ -331,8 +334,63 @@ class RAGSystem:
def _get_timestamp(self) -> str: def _get_timestamp(self) -> str:
"""Get current timestamp in ISO format.""" """Get current timestamp in ISO format."""
from datetime import datetime from datetime import datetime, timezone
return datetime.utcnow().isoformat() return datetime.now(timezone.utc).isoformat()
async def add_document_from_url(
self,
url: str,
metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""
Add a document from a URL to the knowledge base.
Downloads the content from the URL and processes it into chunks.
Args:
url: URL of the document to add
metadata: Optional additional metadata
Returns:
Dictionary with processing results
"""
self._ensure_initialized()
try:
import aiohttp
except ImportError:
# Fallback to requests (sync)
import requests
try:
resp = requests.get(url, timeout=30)
resp.raise_for_status()
content = resp.content
except Exception as e:
raise RuntimeError(f"Failed to download {url}: {e}")
else:
# aiohttp available - use async
import asyncio
try:
async def _fetch():
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
resp.raise_for_status()
return await resp.read()
content = asyncio.get_event_loop().run_until_complete(_fetch())
except Exception as e:
raise RuntimeError(f"Failed to download {url}: {e}")
filename = url.split("/")[-1] or "document.html"
doc_metadata = {
"source_url": url,
**(metadata or {}),
}
return await self.add_document(
content=content,
filename=filename,
metadata=doc_metadata,
)
async def add_document( async def add_document(
self, self,

View File

@ -10,12 +10,21 @@ from __future__ import annotations
import hashlib import hashlib
import json import json
import logging import logging
import math
import os import os
import re
from collections import Counter
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Default embedding dimension
_EMBEDDING_DIM = 256
# Simple tokenization pattern
_WORD_RE = re.compile(r'[a-zA-Z0-9]+' )
class VectorStore: class VectorStore:
""" """
@ -152,26 +161,89 @@ class VectorStore:
log.info(f"Added {len(chunks)} chunks to vector store") log.info(f"Added {len(chunks)} chunks to vector store")
def _tokenize(self, text: str) -> list[str]:
"""Simple word tokenization."""
return [w.lower() for w in _WORD_RE.findall(text) if len(w) > 1]
def _build_vocab(self, all_tokenized: list[list[str]], max_vocab: int = 10000) -> dict[str, int]:
"""Build vocabulary from tokenized texts with IDF weighting."""
doc_freq = Counter()
for tokens in all_tokenized:
unique_tokens = set(tokens)
for t in unique_tokens:
doc_freq[t] += 1
# Take top tokens by document frequency (most useful for search)
vocab = {}
for idx, (token, _) in enumerate(doc_freq.most_common(max_vocab)):
vocab[token] = idx
return vocab
async def _generate_embeddings(self, texts: list[str]) -> list[list[float]]: async def _generate_embeddings(self, texts: list[str]) -> list[list[float]]:
""" """
Generate embeddings for texts. Generate TF-IDF inspired embeddings for texts.
Uses a simple hash-based embedding for demonstration. Uses a bag-of-words approach with TF-IDF weighting projected into a
In production, use a real embedding model via API. fixed-dimension space. This produces meaningful cosine similarities
between semantically related texts, unlike hash-based embeddings.
In production, replace with a real embedding model API call.
""" """
embeddings = [] if not texts:
return []
for text in texts: # Tokenize all texts
# Simple hash-based embedding (for demo purposes) all_tokenized = [self._tokenize(t) for t in texts]
# In production, use OpenAI embeddings or similar
hash_bytes = hashlib.sha256(text.encode()).digest() # Build vocabulary from these texts + existing corpus
# Create a 384-dimensional embedding (common size) # Include existing chunks for consistent vocab
embedding = [] existing_texts = [c["content"] for c in self._chunks]
for i in range(384): existing_tokenized = [self._tokenize(t) for t in existing_texts]
byte_idx = i % len(hash_bytes) combined_tokenized = existing_tokenized + all_tokenized
value = (hash_bytes[byte_idx] - 128) / 128.0
embedding.append(value) vocab = self._build_vocab(combined_tokenized)
embeddings.append(embedding) vocab_size = len(vocab)
if vocab_size == 0:
# Fallback: return zero vectors
return [[0.0] * _EMBEDDING_DIM for _ in texts]
# Compute IDF from all texts
n_docs = len(combined_tokenized)
idf = {}
for token, idx in vocab.items():
df = sum(1 for tokens in combined_tokenized if token in set(tokens))
idf[token] = math.log((n_docs + 1) / (df + 1)) + 1
# Dimension: project vocab into fixed dimension using hash-based assignment
dim = min(_EMBEDDING_DIM, vocab_size)
embeddings = []
for tokens in all_tokenized:
if not tokens:
embeddings.append([0.0] * dim)
continue
# Compute TF
tf = Counter(tokens)
max_tf = max(tf.values())
# Build sparse TF-IDF vector projected to fixed dimension
vec = [0.0] * dim
for token, count in tf.items():
if token not in vocab:
continue
normalized_tf = 0.5 + 0.5 * (count / max_tf) if max_tf > 0 else 0
tfidf = normalized_tf * idf.get(token, 1.0)
# Hash token to a dimension index
bucket = vocab[token] % dim
vec[bucket] += tfidf
# L2 normalize
norm = math.sqrt(sum(v * v for v in vec))
if norm > 0:
vec = [v / norm for v in vec]
embeddings.append(vec)
return embeddings return embeddings
@ -197,9 +269,17 @@ class VectorStore:
if not self._chunks: if not self._chunks:
return [] return []
# Generate query embedding # Generate query embedding (use full corpus for consistent vocab)
query_embedding = (await self._generate_embeddings([query]))[0] query_embedding = (await self._generate_embeddings([query]))[0]
# Ensure dimensions match
if self._embeddings and len(query_embedding) != len(self._embeddings[0]):
log.warning(f"Embedding dimension mismatch: query={len(query_embedding)}, stored={len(self._embeddings[0])}. Using zero-padded query.")
if len(query_embedding) < len(self._embeddings[0]):
query_embedding = query_embedding + [0.0] * (len(self._embeddings[0]) - len(query_embedding))
else:
query_embedding = query_embedding[:len(self._embeddings[0])]
# Calculate similarities # Calculate similarities
results = [] results = []
for i, (chunk, embedding, metadata) in enumerate( for i, (chunk, embedding, metadata) in enumerate(