- Pass all registered tools to LLM during chat completion - Handle tool_calls from LLM response - Execute tools and feed results back to LLM - Loop until LLM returns final response - Updated system prompt to encourage tool use - Updated streaming to handle tool calls - Increased MAX_TOOL_ITERATIONS to 5
177 lines
4.7 KiB
Python
Executable File
177 lines
4.7 KiB
Python
Executable File
"""
|
|
Retriever - Handles context retrieval from the vector store
|
|
|
|
Provides intelligent retrieval with:
|
|
- Query optimization
|
|
- Result ranking
|
|
- Context windowing
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any, Optional
|
|
|
|
from .vector_store import VectorStore
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class Retriever:
|
|
"""
|
|
Retriever for fetching relevant context from the vector store.
|
|
|
|
Handles:
|
|
- Query preprocessing
|
|
- Similarity search
|
|
- Result ranking and filtering
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vector_store: VectorStore,
|
|
default_top_k: int = 5,
|
|
min_score: float = 0.0,
|
|
):
|
|
self.vector_store = vector_store
|
|
self.default_top_k = default_top_k
|
|
self.min_score = min_score
|
|
|
|
async def retrieve(
|
|
self,
|
|
query: str,
|
|
top_k: Optional[int] = None,
|
|
filter_metadata: Optional[dict] = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Retrieve relevant chunks for a query.
|
|
|
|
Args:
|
|
query: Query string
|
|
top_k: Number of results (uses default if not provided)
|
|
filter_metadata: Optional metadata filters
|
|
|
|
Returns:
|
|
List of relevant chunks with scores
|
|
"""
|
|
top_k = top_k or self.default_top_k
|
|
|
|
# Preprocess query
|
|
processed_query = self._preprocess_query(query)
|
|
|
|
# Search vector store
|
|
results = await self.vector_store.search(
|
|
query=processed_query,
|
|
top_k=top_k * 2, # Get more results for filtering
|
|
filter_metadata=filter_metadata,
|
|
)
|
|
|
|
# Filter by minimum score
|
|
results = [r for r in results if r["score"] >= self.min_score]
|
|
|
|
# Rank and deduplicate
|
|
results = self._rank_results(results, query)
|
|
|
|
# Return top_k
|
|
return results[:top_k]
|
|
|
|
def _preprocess_query(self, query: str) -> str:
|
|
"""
|
|
Preprocess query for better retrieval.
|
|
|
|
- Remove extra whitespace
|
|
- Handle special characters
|
|
- Normalize case
|
|
"""
|
|
# Remove extra whitespace
|
|
query = " ".join(query.split())
|
|
|
|
# Remove question marks and other punctuation that might hurt matching
|
|
query = query.replace("?", " ").replace("!", " ")
|
|
|
|
# Normalize whitespace again
|
|
query = " ".join(query.split())
|
|
|
|
return query.strip()
|
|
|
|
def _rank_results(
|
|
self,
|
|
results: list[dict[str, Any]],
|
|
query: str,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Rank results by relevance.
|
|
|
|
Uses a combination of:
|
|
- Vector similarity score
|
|
- Keyword matching
|
|
- Document diversity
|
|
"""
|
|
if not results:
|
|
return results
|
|
|
|
# Calculate additional scores
|
|
query_words = set(query.lower().split())
|
|
|
|
for result in results:
|
|
content = result["content"].lower()
|
|
content_words = set(content.split())
|
|
|
|
# Keyword overlap score
|
|
overlap = len(query_words & content_words)
|
|
keyword_score = overlap / max(len(query_words), 1)
|
|
|
|
# Combine scores
|
|
result["combined_score"] = (
|
|
result["score"] * 0.7 + # Vector similarity
|
|
keyword_score * 0.3 # Keyword matching
|
|
)
|
|
|
|
# Sort by combined score
|
|
results.sort(key=lambda x: x["combined_score"], reverse=True)
|
|
|
|
# Remove duplicate content (keep highest scoring)
|
|
seen_content = set()
|
|
unique_results = []
|
|
|
|
for result in results:
|
|
# Use first 100 chars as content fingerprint
|
|
content_fingerprint = result["content"][:100]
|
|
|
|
if content_fingerprint not in seen_content:
|
|
seen_content.add(content_fingerprint)
|
|
unique_results.append(result)
|
|
|
|
return unique_results
|
|
|
|
async def retrieve_with_context(
|
|
self,
|
|
query: str,
|
|
top_k: int = 5,
|
|
context_window: int = 1,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Retrieve chunks with surrounding context.
|
|
|
|
Args:
|
|
query: Query string
|
|
top_k: Number of main results
|
|
context_window: Number of adjacent chunks to include
|
|
|
|
Returns:
|
|
Dictionary with expanded context
|
|
"""
|
|
results = await self.retrieve(query=query, top_k=top_k)
|
|
|
|
# For now, return basic results
|
|
# In a full implementation, we'd expand to include adjacent chunks
|
|
return {
|
|
"results": results,
|
|
"context": "\n\n".join(r["content"] for r in results),
|
|
"sources": list(set(
|
|
r.get("metadata", {}).get("source", "")
|
|
for r in results
|
|
if r.get("metadata", {}).get("source")
|
|
)),
|
|
}
|