Fix tool call parsing, improve embeddings, and fix async issues
- main.py: Rewrote _parse_tool_call with brace-counting for robust JSON extraction - main.py: Improved _clean_tool_syntax with brace-aware removal of tool_call JSON - main.py: Fixed dict key mismatches (chunks_ingested, pages_downloaded) - main.py: Run tool execution in asyncio.to_thread to avoid blocking event loop - main.py: Always clean tool syntax from responses (handles edge cases) - rag/__init__.py: Wrap blocking website_downloader in run_in_executor - rag/__init__.py: Replace deprecated datetime.utcnow() with datetime.now(timezone.utc) - rag/__init__.py: Add add_document_from_url method - rag/vector_store.py: Replace hash-based embeddings with TF-IDF inspired embeddings - rag/vector_store.py: Add embedding dimension mismatch handling in search - README.md: Update API key config documentation
This commit is contained in:
parent
6eb18ce7f3
commit
c03bde8023
@ -140,7 +140,7 @@ Configure via environment variables or `.env` file:
|
|||||||
| `DEBUG` | `false` | Enable debug mode |
|
| `DEBUG` | `false` | Enable debug mode |
|
||||||
| `MODEL_NAME` | `DocRAG-GLM-4.7` | Display model name |
|
| `MODEL_NAME` | `DocRAG-GLM-4.7` | Display model name |
|
||||||
| `UPSTREAM_MODEL` | `glm-4.7` | Upstream model to use |
|
| `UPSTREAM_MODEL` | `glm-4.7` | Upstream model to use |
|
||||||
| `ZAI_API_KEY` | (required) | API key for ZAI SDK |
|
| `ZAI_API_KEY` / `OPENROUTER_API_KEY` | (required) | API key for upstream LLM (OpenRouter) |
|
||||||
| `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model |
|
| `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model |
|
||||||
| `VECTOR_STORE_PATH` | `./data/vectors` | Vector store location |
|
| `VECTOR_STORE_PATH` | `./data/vectors` | Vector store location |
|
||||||
| `DOCUMENTS_PATH` | `./data/documents` | Document storage |
|
| `DOCUMENTS_PATH` | `./data/documents` | Document storage |
|
||||||
|
|||||||
206
main.py
206
main.py
@ -404,12 +404,12 @@ async def download_website_if_needed(user_message: str) -> dict[str, Any]:
|
|||||||
# Check if site is already downloaded
|
# Check if site is already downloaded
|
||||||
site_info = state.rag_system.get_site_info(url)
|
site_info = state.rag_system.get_site_info(url)
|
||||||
if site_info:
|
if site_info:
|
||||||
log.info(f"Site already downloaded: {url} ({site_info.get('chunk_count', 0)} chunks)")
|
log.info(f"Site already downloaded: {url} ({site_info.get('chunks_ingested', 0)} chunks)")
|
||||||
return {
|
return {
|
||||||
"downloaded": True,
|
"downloaded": True,
|
||||||
"url": url,
|
"url": url,
|
||||||
"chunks": site_info.get("chunk_count", 0),
|
"chunks": site_info.get("chunks_ingested", 0),
|
||||||
"pages": site_info.get("page_count", 0),
|
"pages": site_info.get("pages_downloaded", 0),
|
||||||
"local_path": site_info.get("local_path"),
|
"local_path": site_info.get("local_path"),
|
||||||
"cached": True,
|
"cached": True,
|
||||||
}
|
}
|
||||||
@ -749,24 +749,58 @@ def _parse_tool_call(content: str) -> Optional[dict]:
|
|||||||
"""Parse a tool call from LLM response content."""
|
"""Parse a tool call from LLM response content."""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# Look for JSON tool_call in the response
|
def _extract_json_object(text: str, start_key: str) -> Optional[dict]:
|
||||||
# Pattern 1: ```json {"tool_call": ...} ```
|
"""Extract a JSON object containing start_key using brace counting."""
|
||||||
json_match = re.search(r'```json\s*(\{.*?"tool_call".*?\})\s*```', content, re.DOTALL)
|
# Find the start of the outermost object containing start_key
|
||||||
if json_match:
|
idx = text.find(start_key)
|
||||||
|
if idx == -1:
|
||||||
|
return None
|
||||||
|
# Walk backwards to find the opening { of this object
|
||||||
|
depth = 0
|
||||||
|
obj_start = -1
|
||||||
|
for i in range(idx, -1, -1):
|
||||||
|
if text[i] == '}':
|
||||||
|
depth += 1
|
||||||
|
elif text[i] == '{':
|
||||||
|
if depth == 0:
|
||||||
|
obj_start = i
|
||||||
|
break
|
||||||
|
depth -= 1
|
||||||
|
if obj_start == -1:
|
||||||
|
return None
|
||||||
|
# Walk forwards to find the matching closing }
|
||||||
|
depth = 0
|
||||||
|
obj_end = -1
|
||||||
|
for i in range(obj_start, len(text)):
|
||||||
|
if text[i] == '{':
|
||||||
|
depth += 1
|
||||||
|
elif text[i] == '}':
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
obj_end = i + 1
|
||||||
|
break
|
||||||
|
if obj_end == -1:
|
||||||
|
return None
|
||||||
try:
|
try:
|
||||||
data = json.loads(json_match.group(1))
|
return json.loads(text[obj_start:obj_end])
|
||||||
return data.get("tool_call")
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
return None
|
||||||
|
|
||||||
# Pattern 2: {"tool_call": {...}} anywhere in response
|
# Pattern 1: code fence blocks (```json, ```, ```JSON, etc.)
|
||||||
json_match = re.search(r'\{"tool_call":\s*\{[^}]+\}\s*\}', content)
|
# Match any code fence that might contain a tool_call
|
||||||
if json_match:
|
fence_match = re.search(r'```\w*\s*(.*?)\s*```', content, re.DOTALL)
|
||||||
try:
|
if fence_match:
|
||||||
data = json.loads(json_match.group(0))
|
block_text = fence_match.group(1)
|
||||||
|
if '"tool_call"' in block_text:
|
||||||
|
data = _extract_json_object(block_text, '"tool_call"')
|
||||||
|
if data and "tool_call" in data:
|
||||||
|
return data.get("tool_call")
|
||||||
|
|
||||||
|
# Pattern 2: {"tool_call": {...}} anywhere in response (bare JSON)
|
||||||
|
if '"tool_call"' in content:
|
||||||
|
data = _extract_json_object(content, '"tool_call"')
|
||||||
|
if data and "tool_call" in data:
|
||||||
return data.get("tool_call")
|
return data.get("tool_call")
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Pattern 3: Look for tool name pattern like [USE: tool_name args]
|
# Pattern 3: Look for tool name pattern like [USE: tool_name args]
|
||||||
bracket_match = re.search(r'\[USE:\s*(\w+)\s*(?:args:\s*(\{.*?\}))?\s*\]', content, re.DOTALL)
|
bracket_match = re.search(r'\[USE:\s*(\w+)\s*(?:args:\s*(\{.*?\}))?\s*\]', content, re.DOTALL)
|
||||||
@ -836,54 +870,59 @@ async def generate_response(
|
|||||||
# Check if response contains a tool call
|
# Check if response contains a tool call
|
||||||
tool_call = _parse_tool_call(content)
|
tool_call = _parse_tool_call(content)
|
||||||
|
|
||||||
if tool_call and state.tool_manager:
|
if tool_call:
|
||||||
tool_name = tool_call.get("name")
|
tool_name = tool_call.get("name")
|
||||||
tool_args = tool_call.get("arguments", {})
|
tool_args = tool_call.get("arguments", {})
|
||||||
|
|
||||||
log.info(f"Parsed tool call: {tool_name}")
|
if state.tool_manager:
|
||||||
|
log.info(f"Parsed tool call: {tool_name}")
|
||||||
# Execute the tool
|
|
||||||
if isinstance(tool_args, dict):
|
# Execute the tool (run in thread pool to avoid blocking the event loop)
|
||||||
result = state.tool_manager.execute_tool(tool_name, tool_args)
|
if isinstance(tool_args, dict):
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
state.tool_manager.execute_tool, tool_name, tool_args
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
state.tool_manager.execute_tool_from_json, tool_name, json.dumps(tool_args)
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info(f"Tool {tool_name} result: success={result.get('success', False)}")
|
||||||
|
|
||||||
|
# Store tool result
|
||||||
|
tool_results.append({
|
||||||
|
"name": tool_name,
|
||||||
|
"result": result,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Rebuild system message with tool results
|
||||||
|
# Find and update the system message
|
||||||
|
for i, msg in enumerate(messages_dict):
|
||||||
|
if msg["role"] == "system":
|
||||||
|
tool_result_text = f"\n\n--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data."
|
||||||
|
messages_dict[i]["content"] += tool_result_text
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add assistant's tool call as a message
|
||||||
|
messages_dict.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"[Executing tool: {tool_name}]"
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add user prompt to continue
|
||||||
|
messages_dict.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": f"The tool {tool_name} returned the above result. Please provide your response to the original question using this data."
|
||||||
|
})
|
||||||
|
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
result = state.tool_manager.execute_tool_from_json(tool_name, json.dumps(tool_args))
|
log.warning(f"Tool call detected ({tool_name}) but tool_manager is None! Stripping tool call from response.")
|
||||||
|
|
||||||
log.info(f"Tool {tool_name} result: success={result.get('success', False)}")
|
|
||||||
|
|
||||||
# Store tool result
|
|
||||||
tool_results.append({
|
|
||||||
"name": tool_name,
|
|
||||||
"result": result,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Rebuild system message with tool results
|
|
||||||
# Find and update the system message
|
|
||||||
for i, msg in enumerate(messages_dict):
|
|
||||||
if msg["role"] == "system":
|
|
||||||
# Rebuild with tool results
|
|
||||||
# This is a simplified approach - in production you'd want better state management
|
|
||||||
tool_result_text = f"\n\n--- TOOL RESULT ---\nTool: {tool_name}\nResult: {json.dumps(result, indent=2)}\n\nNow provide a helpful response based on this data."
|
|
||||||
messages_dict[i]["content"] += tool_result_text
|
|
||||||
break
|
|
||||||
|
|
||||||
# Add assistant's tool call as a message
|
|
||||||
messages_dict.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": f"[Executing tool: {tool_name}]"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add user prompt to continue
|
|
||||||
messages_dict.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": f"The tool {tool_name} returned the above result. Please provide your response to the original question using this data."
|
|
||||||
})
|
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
# No tool call found - return the response
|
# No tool call found (or tool_manager unavailable) - return the response
|
||||||
# Clean up any partial tool call syntax from response
|
# ALWAYS run cleanup to strip any residual tool_call JSON from response
|
||||||
cleaned_content = _clean_tool_syntax(content)
|
cleaned_content = _clean_tool_syntax(content)
|
||||||
log.info(f"Returning final response")
|
log.info(f"Returning final response (cleaned={len(cleaned_content) != len(content)})")
|
||||||
return cleaned_content or "I apologize, but I couldn't generate a response."
|
return cleaned_content or "I apologize, but I couldn't generate a response."
|
||||||
|
|
||||||
# Max iterations reached
|
# Max iterations reached
|
||||||
@ -900,10 +939,51 @@ async def generate_response(
|
|||||||
def _clean_tool_syntax(content: str) -> str:
|
def _clean_tool_syntax(content: str) -> str:
|
||||||
"""Remove tool call syntax from response if partially included."""
|
"""Remove tool call syntax from response if partially included."""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
def _remove_json_containing_key(text: str, key: str) -> str:
|
||||||
|
"""Remove JSON objects containing a specific key from text."""
|
||||||
|
result = text
|
||||||
|
while key in result:
|
||||||
|
idx = result.find(key)
|
||||||
|
# Walk backwards to find opening {
|
||||||
|
depth = 0
|
||||||
|
obj_start = -1
|
||||||
|
for i in range(idx, -1, -1):
|
||||||
|
if result[i] == '}':
|
||||||
|
depth += 1
|
||||||
|
elif result[i] == '{':
|
||||||
|
if depth == 0:
|
||||||
|
obj_start = i
|
||||||
|
break
|
||||||
|
depth -= 1
|
||||||
|
if obj_start == -1:
|
||||||
|
break
|
||||||
|
# Walk forwards to find matching }
|
||||||
|
depth = 0
|
||||||
|
obj_end = -1
|
||||||
|
for i in range(obj_start, len(result)):
|
||||||
|
if result[i] == '{':
|
||||||
|
depth += 1
|
||||||
|
elif result[i] == '}':
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
obj_end = i + 1
|
||||||
|
break
|
||||||
|
if obj_end == -1:
|
||||||
|
break
|
||||||
|
result = result[:obj_start] + result[obj_end:]
|
||||||
|
return result
|
||||||
|
|
||||||
# Remove ```json ... ``` blocks containing tool_call
|
# Remove ```json ... ``` blocks containing tool_call
|
||||||
cleaned = re.sub(r'```json\s*\{.*?"tool_call".*?\}\s*```', '', content, flags=re.DOTALL)
|
def remove_code_block(m):
|
||||||
# Remove standalone tool_call JSON
|
block = m.group(0)
|
||||||
cleaned = re.sub(r'\{"tool_call":\s*\{[^}]+\}\s*\}', '', cleaned)
|
inner = m.group(1)
|
||||||
|
if '"tool_call"' in inner:
|
||||||
|
return ''
|
||||||
|
return block
|
||||||
|
|
||||||
|
cleaned = re.sub(r'```json\s*(.*?)\s*```', remove_code_block, content, flags=re.DOTALL)
|
||||||
|
cleaned = _remove_json_containing_key(cleaned, '"tool_call"')
|
||||||
return cleaned.strip()
|
return cleaned.strip()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -164,14 +164,17 @@ class RAGSystem:
|
|||||||
|
|
||||||
log.info(f"Downloading website: {url}")
|
log.info(f"Downloading website: {url}")
|
||||||
|
|
||||||
# Use website_downloader_tool to download the site
|
# Use website_downloader_tool to download the site (in thread pool to avoid blocking)
|
||||||
download_result = website_downloader(
|
download_result = await asyncio.get_event_loop().run_in_executor(
|
||||||
url=url,
|
None,
|
||||||
destination=str(self.downloaded_sites_path / self._get_site_folder(url)),
|
lambda: website_downloader(
|
||||||
max_pages=max_pages,
|
url=url,
|
||||||
threads=threads,
|
destination=str(self.downloaded_sites_path / self._get_site_folder(url)),
|
||||||
download_external_assets=download_external_assets,
|
max_pages=max_pages,
|
||||||
external_domains=external_domains,
|
threads=threads,
|
||||||
|
download_external_assets=download_external_assets,
|
||||||
|
external_domains=external_domains,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not download_result.get("success"):
|
if not download_result.get("success"):
|
||||||
@ -331,8 +334,63 @@ class RAGSystem:
|
|||||||
|
|
||||||
def _get_timestamp(self) -> str:
|
def _get_timestamp(self) -> str:
|
||||||
"""Get current timestamp in ISO format."""
|
"""Get current timestamp in ISO format."""
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
return datetime.utcnow().isoformat()
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
async def add_document_from_url(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
metadata: Optional[dict[str, Any]] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Add a document from a URL to the knowledge base.
|
||||||
|
|
||||||
|
Downloads the content from the URL and processes it into chunks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL of the document to add
|
||||||
|
metadata: Optional additional metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with processing results
|
||||||
|
"""
|
||||||
|
self._ensure_initialized()
|
||||||
|
|
||||||
|
try:
|
||||||
|
import aiohttp
|
||||||
|
except ImportError:
|
||||||
|
# Fallback to requests (sync)
|
||||||
|
import requests
|
||||||
|
try:
|
||||||
|
resp = requests.get(url, timeout=30)
|
||||||
|
resp.raise_for_status()
|
||||||
|
content = resp.content
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to download {url}: {e}")
|
||||||
|
else:
|
||||||
|
# aiohttp available - use async
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
async def _fetch():
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
return await resp.read()
|
||||||
|
content = asyncio.get_event_loop().run_until_complete(_fetch())
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to download {url}: {e}")
|
||||||
|
|
||||||
|
filename = url.split("/")[-1] or "document.html"
|
||||||
|
doc_metadata = {
|
||||||
|
"source_url": url,
|
||||||
|
**(metadata or {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
return await self.add_document(
|
||||||
|
content=content,
|
||||||
|
filename=filename,
|
||||||
|
metadata=doc_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
async def add_document(
|
async def add_document(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -10,12 +10,21 @@ from __future__ import annotations
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default embedding dimension
|
||||||
|
_EMBEDDING_DIM = 256
|
||||||
|
|
||||||
|
# Simple tokenization pattern
|
||||||
|
_WORD_RE = re.compile(r'[a-zA-Z0-9]+' )
|
||||||
|
|
||||||
|
|
||||||
class VectorStore:
|
class VectorStore:
|
||||||
"""
|
"""
|
||||||
@ -152,26 +161,89 @@ class VectorStore:
|
|||||||
|
|
||||||
log.info(f"Added {len(chunks)} chunks to vector store")
|
log.info(f"Added {len(chunks)} chunks to vector store")
|
||||||
|
|
||||||
|
def _tokenize(self, text: str) -> list[str]:
|
||||||
|
"""Simple word tokenization."""
|
||||||
|
return [w.lower() for w in _WORD_RE.findall(text) if len(w) > 1]
|
||||||
|
|
||||||
|
def _build_vocab(self, all_tokenized: list[list[str]], max_vocab: int = 10000) -> dict[str, int]:
|
||||||
|
"""Build vocabulary from tokenized texts with IDF weighting."""
|
||||||
|
doc_freq = Counter()
|
||||||
|
for tokens in all_tokenized:
|
||||||
|
unique_tokens = set(tokens)
|
||||||
|
for t in unique_tokens:
|
||||||
|
doc_freq[t] += 1
|
||||||
|
# Take top tokens by document frequency (most useful for search)
|
||||||
|
vocab = {}
|
||||||
|
for idx, (token, _) in enumerate(doc_freq.most_common(max_vocab)):
|
||||||
|
vocab[token] = idx
|
||||||
|
return vocab
|
||||||
|
|
||||||
async def _generate_embeddings(self, texts: list[str]) -> list[list[float]]:
|
async def _generate_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""
|
"""
|
||||||
Generate embeddings for texts.
|
Generate TF-IDF inspired embeddings for texts.
|
||||||
|
|
||||||
Uses a simple hash-based embedding for demonstration.
|
Uses a bag-of-words approach with TF-IDF weighting projected into a
|
||||||
In production, use a real embedding model via API.
|
fixed-dimension space. This produces meaningful cosine similarities
|
||||||
|
between semantically related texts, unlike hash-based embeddings.
|
||||||
|
|
||||||
|
In production, replace with a real embedding model API call.
|
||||||
"""
|
"""
|
||||||
embeddings = []
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
for text in texts:
|
# Tokenize all texts
|
||||||
# Simple hash-based embedding (for demo purposes)
|
all_tokenized = [self._tokenize(t) for t in texts]
|
||||||
# In production, use OpenAI embeddings or similar
|
|
||||||
hash_bytes = hashlib.sha256(text.encode()).digest()
|
# Build vocabulary from these texts + existing corpus
|
||||||
# Create a 384-dimensional embedding (common size)
|
# Include existing chunks for consistent vocab
|
||||||
embedding = []
|
existing_texts = [c["content"] for c in self._chunks]
|
||||||
for i in range(384):
|
existing_tokenized = [self._tokenize(t) for t in existing_texts]
|
||||||
byte_idx = i % len(hash_bytes)
|
combined_tokenized = existing_tokenized + all_tokenized
|
||||||
value = (hash_bytes[byte_idx] - 128) / 128.0
|
|
||||||
embedding.append(value)
|
vocab = self._build_vocab(combined_tokenized)
|
||||||
embeddings.append(embedding)
|
vocab_size = len(vocab)
|
||||||
|
|
||||||
|
if vocab_size == 0:
|
||||||
|
# Fallback: return zero vectors
|
||||||
|
return [[0.0] * _EMBEDDING_DIM for _ in texts]
|
||||||
|
|
||||||
|
# Compute IDF from all texts
|
||||||
|
n_docs = len(combined_tokenized)
|
||||||
|
idf = {}
|
||||||
|
for token, idx in vocab.items():
|
||||||
|
df = sum(1 for tokens in combined_tokenized if token in set(tokens))
|
||||||
|
idf[token] = math.log((n_docs + 1) / (df + 1)) + 1
|
||||||
|
|
||||||
|
# Dimension: project vocab into fixed dimension using hash-based assignment
|
||||||
|
dim = min(_EMBEDDING_DIM, vocab_size)
|
||||||
|
|
||||||
|
embeddings = []
|
||||||
|
for tokens in all_tokenized:
|
||||||
|
if not tokens:
|
||||||
|
embeddings.append([0.0] * dim)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Compute TF
|
||||||
|
tf = Counter(tokens)
|
||||||
|
max_tf = max(tf.values())
|
||||||
|
|
||||||
|
# Build sparse TF-IDF vector projected to fixed dimension
|
||||||
|
vec = [0.0] * dim
|
||||||
|
for token, count in tf.items():
|
||||||
|
if token not in vocab:
|
||||||
|
continue
|
||||||
|
normalized_tf = 0.5 + 0.5 * (count / max_tf) if max_tf > 0 else 0
|
||||||
|
tfidf = normalized_tf * idf.get(token, 1.0)
|
||||||
|
# Hash token to a dimension index
|
||||||
|
bucket = vocab[token] % dim
|
||||||
|
vec[bucket] += tfidf
|
||||||
|
|
||||||
|
# L2 normalize
|
||||||
|
norm = math.sqrt(sum(v * v for v in vec))
|
||||||
|
if norm > 0:
|
||||||
|
vec = [v / norm for v in vec]
|
||||||
|
|
||||||
|
embeddings.append(vec)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@ -197,9 +269,17 @@ class VectorStore:
|
|||||||
if not self._chunks:
|
if not self._chunks:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Generate query embedding
|
# Generate query embedding (use full corpus for consistent vocab)
|
||||||
query_embedding = (await self._generate_embeddings([query]))[0]
|
query_embedding = (await self._generate_embeddings([query]))[0]
|
||||||
|
|
||||||
|
# Ensure dimensions match
|
||||||
|
if self._embeddings and len(query_embedding) != len(self._embeddings[0]):
|
||||||
|
log.warning(f"Embedding dimension mismatch: query={len(query_embedding)}, stored={len(self._embeddings[0])}. Using zero-padded query.")
|
||||||
|
if len(query_embedding) < len(self._embeddings[0]):
|
||||||
|
query_embedding = query_embedding + [0.0] * (len(self._embeddings[0]) - len(query_embedding))
|
||||||
|
else:
|
||||||
|
query_embedding = query_embedding[:len(self._embeddings[0])]
|
||||||
|
|
||||||
# Calculate similarities
|
# Calculate similarities
|
||||||
results = []
|
results = []
|
||||||
for i, (chunk, embedding, metadata) in enumerate(
|
for i, (chunk, embedding, metadata) in enumerate(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user