Fix tool call parsing, improve embeddings, and fix async issues
- main.py: Rewrote _parse_tool_call with brace-counting for robust JSON extraction - main.py: Improved _clean_tool_syntax with brace-aware removal of tool_call JSON - main.py: Fixed dict key mismatches (chunks_ingested, pages_downloaded) - main.py: Run tool execution in asyncio.to_thread to avoid blocking event loop - main.py: Always clean tool syntax from responses (handles edge cases) - rag/__init__.py: Wrap blocking website_downloader in run_in_executor - rag/__init__.py: Replace deprecated datetime.utcnow() with datetime.now(timezone.utc) - rag/__init__.py: Add add_document_from_url method - rag/vector_store.py: Replace hash-based embeddings with TF-IDF inspired embeddings - rag/vector_store.py: Add embedding dimension mismatch handling in search - README.md: Update API key config documentation
This commit is contained in:
parent
6eb18ce7f3
commit
c03bde8023
@ -140,7 +140,7 @@ Configure via environment variables or `.env` file:
|
||||
| `DEBUG` | `false` | Enable debug mode |
|
||||
| `MODEL_NAME` | `DocRAG-GLM-4.7` | Display model name |
|
||||
| `UPSTREAM_MODEL` | `glm-4.7` | Upstream model to use |
|
||||
| `ZAI_API_KEY` | (required) | API key for ZAI SDK |
|
||||
| `ZAI_API_KEY` / `OPENROUTER_API_KEY` | (required) | API key for upstream LLM (OpenRouter) |
|
||||
| `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model |
|
||||
| `VECTOR_STORE_PATH` | `./data/vectors` | Vector store location |
|
||||
| `DOCUMENTS_PATH` | `./data/documents` | Document storage |
|
||||
|
||||
206
main.py
206
main.py
@ -404,12 +404,12 @@ async def download_website_if_needed(user_message: str) -> dict[str, Any]:
|
||||
# Check if site is already downloaded
|
||||
site_info = state.rag_system.get_site_info(url)
|
||||
if site_info:
|
||||
log.info(f"Site already downloaded: {url} ({site_info.get('chunk_count', 0)} chunks)")
|
||||
log.info(f"Site already downloaded: {url} ({site_info.get('chunks_ingested', 0)} chunks)")
|
||||
return {
|
||||
"downloaded": True,
|
||||
"url": url,
|
||||
"chunks": site_info.get("chunk_count", 0),
|
||||
"pages": site_info.get("page_count", 0),
|
||||
"chunks": site_info.get("chunks_ingested", 0),
|
||||
"pages": site_info.get("pages_downloaded", 0),
|
||||
"local_path": site_info.get("local_path"),
|
||||
"cached": True,
|
||||
}
|
||||
@ -749,24 +749,58 @@ def _parse_tool_call(content: str) -> Optional[dict]:
|
||||
"""Parse a tool call from LLM response content."""
|
||||
import re
|
||||
|
||||
# Look for JSON tool_call in the response
|
||||
# Pattern 1: ```json {"tool_call": ...} ```
|
||||
json_match = re.search(r'```json\s*(\{.*?"tool_call".*?\})\s*```', content, re.DOTALL)
|
||||
if json_match:
|
||||
def _extract_json_object(text: str, start_key: str) -> Optional[dict]:
|
||||
"""Extract a JSON object containing start_key using brace counting."""
|
||||
# Find the start of the outermost object containing start_key
|
||||
idx = text.find(start_key)
|
||||
if idx == -1:
|
||||
return None
|
||||
# Walk backwards to find the opening { of this object
|
||||
depth = 0
|
||||
obj_start = -1
|
||||
for i in range(idx, -1, -1):
|
||||
if text[i] == '}':
|
||||
depth += 1
|
||||
elif text[i] == '{':
|
||||
if depth == 0:
|
||||
obj_start = i
|
||||
break
|
||||
depth -= 1
|
||||
if obj_start == -1:
|
||||
return None
|
||||
# Walk forwards to find the matching closing }
|
||||
depth = 0
|
||||
obj_end = -1
|
||||
for i in range(obj_start, len(text)):
|
||||
if text[i] == '{':
|
||||
depth += 1
|
||||
elif text[i] == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
obj_end = i + 1
|
||||
break
|
||||
if obj_end == -1:
|
||||
return None
|
||||
try:
|
||||
data = json.loads(json_match.group(1))
|
||||
return data.get("tool_call")
|
||||
return json.loads(text[obj_start:obj_end])
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Pattern 2: {"tool_call": {...}} anywhere in response
|
||||
json_match = re.search(r'\{"tool_call":\s*\{[^}]+\}\s*\}', content)
|
||||
if json_match:
|
||||
try:
|
||||
data = json.loads(json_match.group(0))
|
||||
return None
|
||||
|
||||
# Pattern 1: code fence blocks (```json, ```, ```JSON, etc.)
|
||||
# Match any code fence that might contain a tool_call
|
||||
fence_match = re.search(r'```\w*\s*(.*?)\s*```', content, re.DOTALL)
|
||||
if fence_match:
|
||||
block_text = fence_match.group(1)
|
||||
if '"tool_call"' in block_text:
|
||||
data = _extract_json_object(block_text, '"tool_call"')
|
||||
if data and "tool_call" in data:
|
||||
return data.get("tool_call")
|
||||
|
||||
# Pattern 2: {"tool_call": {...}} anywhere in response (bare JSON)
|
||||
if '"tool_call"' in content:
|
||||
data = _extract_json_object(content, '"tool_call"')
|
||||
if data and "tool_call" in data:
|
||||
return data.get("tool_call")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Pattern 3: Look for tool name pattern like [USE: tool_name args]
|
||||
bracket_match = re.search(r'\[USE:\s*(\w+)\s*(?:args:\s*(\{.*?\}))?\s*\]', content, re.DOTALL)
|
||||
@ -836,54 +870,59 @@ async def generate_response(
|
||||
# Check if response contains a tool call
|
||||
tool_call = _parse_tool_call(content)
|
||||
|
||||
if tool_call and state.tool_manager:
|
||||
if tool_call:
|
||||
tool_name = tool_call.get("name")
|
||||
tool_args = tool_call.get("arguments", {})
|
||||
|
||||
log.info(f"Parsed tool call: {tool_name}")
|
||||
|
||||
# Execute the tool
|
||||
if isinstance(tool_args, dict):
|
||||
result = state.tool_manager.execute_tool(tool_name, tool_args)
|
||||
if state.tool_manager:
|
||||
log.info(f"Parsed tool call: {tool_name}")
|
||||
|
||||
# Execute the tool (run in thread pool to avoid blocking the event loop)
|
||||
if isinstance(tool_args, dict):
|
||||
result = await asyncio.to_thread(
|
||||
state.tool_manager.execute_tool, tool_name, tool_args
|
||||
)
|
||||
else:
|
||||
result = await asyncio.to_thread(
|
||||
state.tool_manager.execute_tool_from_json, tool_name, json.dumps(tool_args)
|
||||
)
|
||||
|
||||
log.info(f"Tool {tool_name} result: success={result.get('success', False)}")
|
||||
|
||||
# Store tool result
|
||||
tool_results.append({
|
||||
"name": tool_name,
|
||||
"result": result,
|
||||
})
|
||||
|
||||
# Rebuild system message with tool results
|
||||
# Find and update the system message
|
||||
for i, msg in enumerate(messages_dict):
|
||||
if msg["role"] == "system":
|
||||
tool_result_text = f"\n\n--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data."
|
||||
messages_dict[i]["content"] += tool_result_text
|
||||
break
|
||||
|
||||
# Add assistant's tool call as a message
|
||||
messages_dict.append({
|
||||
"role": "assistant",
|
||||
"content": f"[Executing tool: {tool_name}]"
|
||||
})
|
||||
|
||||
# Add user prompt to continue
|
||||
messages_dict.append({
|
||||
"role": "user",
|
||||
"content": f"The tool {tool_name} returned the above result. Please provide your response to the original question using this data."
|
||||
})
|
||||
|
||||
continue
|
||||
else:
|
||||
result = state.tool_manager.execute_tool_from_json(tool_name, json.dumps(tool_args))
|
||||
|
||||
log.info(f"Tool {tool_name} result: success={result.get('success', False)}")
|
||||
|
||||
# Store tool result
|
||||
tool_results.append({
|
||||
"name": tool_name,
|
||||
"result": result,
|
||||
})
|
||||
|
||||
# Rebuild system message with tool results
|
||||
# Find and update the system message
|
||||
for i, msg in enumerate(messages_dict):
|
||||
if msg["role"] == "system":
|
||||
# Rebuild with tool results
|
||||
# This is a simplified approach - in production you'd want better state management
|
||||
tool_result_text = f"\n\n--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data."
|
||||
messages_dict[i]["content"] += tool_result_text
|
||||
break
|
||||
|
||||
# Add assistant's tool call as a message
|
||||
messages_dict.append({
|
||||
"role": "assistant",
|
||||
"content": f"[Executing tool: {tool_name}]"
|
||||
})
|
||||
|
||||
# Add user prompt to continue
|
||||
messages_dict.append({
|
||||
"role": "user",
|
||||
"content": f"The tool {tool_name} returned the above result. Please provide your response to the original question using this data."
|
||||
})
|
||||
|
||||
continue
|
||||
log.warning(f"Tool call detected ({tool_name}) but tool_manager is None! Stripping tool call from response.")
|
||||
|
||||
# No tool call found - return the response
|
||||
# Clean up any partial tool call syntax from response
|
||||
# No tool call found (or tool_manager unavailable) - return the response
|
||||
# ALWAYS run cleanup to strip any residual tool_call JSON from response
|
||||
cleaned_content = _clean_tool_syntax(content)
|
||||
log.info(f"Returning final response")
|
||||
log.info(f"Returning final response (cleaned={len(cleaned_content) != len(content)})")
|
||||
return cleaned_content or "I apologize, but I couldn't generate a response."
|
||||
|
||||
# Max iterations reached
|
||||
@ -900,10 +939,51 @@ async def generate_response(
|
||||
def _clean_tool_syntax(content: str) -> str:
|
||||
"""Remove tool call syntax from response if partially included."""
|
||||
import re
|
||||
|
||||
def _remove_json_containing_key(text: str, key: str) -> str:
|
||||
"""Remove JSON objects containing a specific key from text."""
|
||||
result = text
|
||||
while key in result:
|
||||
idx = result.find(key)
|
||||
# Walk backwards to find opening {
|
||||
depth = 0
|
||||
obj_start = -1
|
||||
for i in range(idx, -1, -1):
|
||||
if result[i] == '}':
|
||||
depth += 1
|
||||
elif result[i] == '{':
|
||||
if depth == 0:
|
||||
obj_start = i
|
||||
break
|
||||
depth -= 1
|
||||
if obj_start == -1:
|
||||
break
|
||||
# Walk forwards to find matching }
|
||||
depth = 0
|
||||
obj_end = -1
|
||||
for i in range(obj_start, len(result)):
|
||||
if result[i] == '{':
|
||||
depth += 1
|
||||
elif result[i] == '}':
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
obj_end = i + 1
|
||||
break
|
||||
if obj_end == -1:
|
||||
break
|
||||
result = result[:obj_start] + result[obj_end:]
|
||||
return result
|
||||
|
||||
# Remove ```json ... ``` blocks containing tool_call
|
||||
cleaned = re.sub(r'```json\s*\{.*?"tool_call".*?\}\s*```', '', content, flags=re.DOTALL)
|
||||
# Remove standalone tool_call JSON
|
||||
cleaned = re.sub(r'\{"tool_call":\s*\{[^}]+\}\s*\}', '', cleaned)
|
||||
def remove_code_block(m):
|
||||
block = m.group(0)
|
||||
inner = m.group(1)
|
||||
if '"tool_call"' in inner:
|
||||
return ''
|
||||
return block
|
||||
|
||||
cleaned = re.sub(r'```json\s*(.*?)\s*```', remove_code_block, content, flags=re.DOTALL)
|
||||
cleaned = _remove_json_containing_key(cleaned, '"tool_call"')
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
|
||||
@ -164,14 +164,17 @@ class RAGSystem:
|
||||
|
||||
log.info(f"Downloading website: {url}")
|
||||
|
||||
# Use website_downloader_tool to download the site
|
||||
download_result = website_downloader(
|
||||
url=url,
|
||||
destination=str(self.downloaded_sites_path / self._get_site_folder(url)),
|
||||
max_pages=max_pages,
|
||||
threads=threads,
|
||||
download_external_assets=download_external_assets,
|
||||
external_domains=external_domains,
|
||||
# Use website_downloader_tool to download the site (in thread pool to avoid blocking)
|
||||
download_result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: website_downloader(
|
||||
url=url,
|
||||
destination=str(self.downloaded_sites_path / self._get_site_folder(url)),
|
||||
max_pages=max_pages,
|
||||
threads=threads,
|
||||
download_external_assets=download_external_assets,
|
||||
external_domains=external_domains,
|
||||
),
|
||||
)
|
||||
|
||||
if not download_result.get("success"):
|
||||
@ -331,8 +334,63 @@ class RAGSystem:
|
||||
|
||||
def _get_timestamp(self) -> str:
|
||||
"""Get current timestamp in ISO format."""
|
||||
from datetime import datetime
|
||||
return datetime.utcnow().isoformat()
|
||||
from datetime import datetime, timezone
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
async def add_document_from_url(
|
||||
self,
|
||||
url: str,
|
||||
metadata: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Add a document from a URL to the knowledge base.
|
||||
|
||||
Downloads the content from the URL and processes it into chunks.
|
||||
|
||||
Args:
|
||||
url: URL of the document to add
|
||||
metadata: Optional additional metadata
|
||||
|
||||
Returns:
|
||||
Dictionary with processing results
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
# Fallback to requests (sync)
|
||||
import requests
|
||||
try:
|
||||
resp = requests.get(url, timeout=30)
|
||||
resp.raise_for_status()
|
||||
content = resp.content
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to download {url}: {e}")
|
||||
else:
|
||||
# aiohttp available - use async
|
||||
import asyncio
|
||||
try:
|
||||
async def _fetch():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
resp.raise_for_status()
|
||||
return await resp.read()
|
||||
content = asyncio.get_event_loop().run_until_complete(_fetch())
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to download {url}: {e}")
|
||||
|
||||
filename = url.split("/")[-1] or "document.html"
|
||||
doc_metadata = {
|
||||
"source_url": url,
|
||||
**(metadata or {}),
|
||||
}
|
||||
|
||||
return await self.add_document(
|
||||
content=content,
|
||||
filename=filename,
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
|
||||
async def add_document(
|
||||
self,
|
||||
|
||||
@ -10,12 +10,21 @@ from __future__ import annotations
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Default embedding dimension
|
||||
_EMBEDDING_DIM = 256
|
||||
|
||||
# Simple tokenization pattern
|
||||
_WORD_RE = re.compile(r'[a-zA-Z0-9]+' )
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""
|
||||
@ -152,26 +161,89 @@ class VectorStore:
|
||||
|
||||
log.info(f"Added {len(chunks)} chunks to vector store")
|
||||
|
||||
def _tokenize(self, text: str) -> list[str]:
|
||||
"""Simple word tokenization."""
|
||||
return [w.lower() for w in _WORD_RE.findall(text) if len(w) > 1]
|
||||
|
||||
def _build_vocab(self, all_tokenized: list[list[str]], max_vocab: int = 10000) -> dict[str, int]:
|
||||
"""Build vocabulary from tokenized texts with IDF weighting."""
|
||||
doc_freq = Counter()
|
||||
for tokens in all_tokenized:
|
||||
unique_tokens = set(tokens)
|
||||
for t in unique_tokens:
|
||||
doc_freq[t] += 1
|
||||
# Take top tokens by document frequency (most useful for search)
|
||||
vocab = {}
|
||||
for idx, (token, _) in enumerate(doc_freq.most_common(max_vocab)):
|
||||
vocab[token] = idx
|
||||
return vocab
|
||||
|
||||
async def _generate_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
Generate embeddings for texts.
|
||||
Generate TF-IDF inspired embeddings for texts.
|
||||
|
||||
Uses a simple hash-based embedding for demonstration.
|
||||
In production, use a real embedding model via API.
|
||||
Uses a bag-of-words approach with TF-IDF weighting projected into a
|
||||
fixed-dimension space. This produces meaningful cosine similarities
|
||||
between semantically related texts, unlike hash-based embeddings.
|
||||
|
||||
In production, replace with a real embedding model API call.
|
||||
"""
|
||||
embeddings = []
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
for text in texts:
|
||||
# Simple hash-based embedding (for demo purposes)
|
||||
# In production, use OpenAI embeddings or similar
|
||||
hash_bytes = hashlib.sha256(text.encode()).digest()
|
||||
# Create a 384-dimensional embedding (common size)
|
||||
embedding = []
|
||||
for i in range(384):
|
||||
byte_idx = i % len(hash_bytes)
|
||||
value = (hash_bytes[byte_idx] - 128) / 128.0
|
||||
embedding.append(value)
|
||||
embeddings.append(embedding)
|
||||
# Tokenize all texts
|
||||
all_tokenized = [self._tokenize(t) for t in texts]
|
||||
|
||||
# Build vocabulary from these texts + existing corpus
|
||||
# Include existing chunks for consistent vocab
|
||||
existing_texts = [c["content"] for c in self._chunks]
|
||||
existing_tokenized = [self._tokenize(t) for t in existing_texts]
|
||||
combined_tokenized = existing_tokenized + all_tokenized
|
||||
|
||||
vocab = self._build_vocab(combined_tokenized)
|
||||
vocab_size = len(vocab)
|
||||
|
||||
if vocab_size == 0:
|
||||
# Fallback: return zero vectors
|
||||
return [[0.0] * _EMBEDDING_DIM for _ in texts]
|
||||
|
||||
# Compute IDF from all texts
|
||||
n_docs = len(combined_tokenized)
|
||||
idf = {}
|
||||
for token, idx in vocab.items():
|
||||
df = sum(1 for tokens in combined_tokenized if token in set(tokens))
|
||||
idf[token] = math.log((n_docs + 1) / (df + 1)) + 1
|
||||
|
||||
# Dimension: project vocab into fixed dimension using hash-based assignment
|
||||
dim = min(_EMBEDDING_DIM, vocab_size)
|
||||
|
||||
embeddings = []
|
||||
for tokens in all_tokenized:
|
||||
if not tokens:
|
||||
embeddings.append([0.0] * dim)
|
||||
continue
|
||||
|
||||
# Compute TF
|
||||
tf = Counter(tokens)
|
||||
max_tf = max(tf.values())
|
||||
|
||||
# Build sparse TF-IDF vector projected to fixed dimension
|
||||
vec = [0.0] * dim
|
||||
for token, count in tf.items():
|
||||
if token not in vocab:
|
||||
continue
|
||||
normalized_tf = 0.5 + 0.5 * (count / max_tf) if max_tf > 0 else 0
|
||||
tfidf = normalized_tf * idf.get(token, 1.0)
|
||||
# Hash token to a dimension index
|
||||
bucket = vocab[token] % dim
|
||||
vec[bucket] += tfidf
|
||||
|
||||
# L2 normalize
|
||||
norm = math.sqrt(sum(v * v for v in vec))
|
||||
if norm > 0:
|
||||
vec = [v / norm for v in vec]
|
||||
|
||||
embeddings.append(vec)
|
||||
|
||||
return embeddings
|
||||
|
||||
@ -197,9 +269,17 @@ class VectorStore:
|
||||
if not self._chunks:
|
||||
return []
|
||||
|
||||
# Generate query embedding
|
||||
# Generate query embedding (use full corpus for consistent vocab)
|
||||
query_embedding = (await self._generate_embeddings([query]))[0]
|
||||
|
||||
# Ensure dimensions match
|
||||
if self._embeddings and len(query_embedding) != len(self._embeddings[0]):
|
||||
log.warning(f"Embedding dimension mismatch: query={len(query_embedding)}, stored={len(self._embeddings[0])}. Using zero-padded query.")
|
||||
if len(query_embedding) < len(self._embeddings[0]):
|
||||
query_embedding = query_embedding + [0.0] * (len(self._embeddings[0]) - len(query_embedding))
|
||||
else:
|
||||
query_embedding = query_embedding[:len(self._embeddings[0])]
|
||||
|
||||
# Calculate similarities
|
||||
results = []
|
||||
for i, (chunk, embedding, metadata) in enumerate(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user