176 lines
7.2 KiB
Python
176 lines
7.2 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import logging
|
|
import shutil
|
|
from pathlib import Path
|
|
from pyrag3.repo_manager import RepoManager
|
|
from pyrag3.retrieval_service import RetrievalService
|
|
from pyrag3.search_manager import SearchManager
|
|
from pyrag3.crawler import Crawler
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DocumentRetrieveApp:
|
|
def __init__(self, data_root="rag_data"):
|
|
self.data_root = Path(data_root)
|
|
self.db_path = self.data_root / "index.db"
|
|
|
|
# Initialize modules
|
|
self.repo = RepoManager()
|
|
self.retrieval = RetrievalService(db_path=self.db_path)
|
|
self.searcher = SearchManager()
|
|
# Crawler still uses a local temp dir for downloading before upload
|
|
self.crawler = Crawler(root_dir=self.data_root / "temp_web")
|
|
|
|
def sync(self):
|
|
"""Sync with remote shared index and then verify file metadata."""
|
|
logger.info("Synchronizing shared index and remote metadata...")
|
|
try:
|
|
# 1. Pull the shared 'brain' (index.db) if it exists
|
|
self.retrieval.pull_index_db()
|
|
|
|
# 2. Re-verify the file lists to catch any missing raw documents
|
|
self.retrieval.sync_and_reindex()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Sync failed: {e}")
|
|
return False
|
|
|
|
def query(self, text, limit=5, fallback=True, similarity_threshold=0.35, perfect_threshold=0.70):
|
|
"""Perform Hybrid chunk-level retrieval (Local Index + Web Discovery)."""
|
|
logger.info(f"Hybrid Query (Chunk Level): {text}")
|
|
|
|
# 1. Initial Local Search for top chunks
|
|
results = self.retrieval.search(text, limit=limit)
|
|
best_score = results[0].get('score', 0) if results else 0
|
|
|
|
# 2. Proactive Discovery (If no good chunks found)
|
|
if fallback and best_score < perfect_threshold:
|
|
logger.info("Triggering search discovery for higher-relevance context...")
|
|
search_urls = self.searcher.search(text, num_results=5)
|
|
|
|
new_files = []
|
|
for url in search_urls:
|
|
# Deduplicate: check if URL is already represented in local chunks
|
|
if any(url == r['url'] for r in results):
|
|
continue
|
|
|
|
logger.info(f"Archiving fresh context: {url}...")
|
|
local_path = self.crawler.download_page(url)
|
|
if local_path:
|
|
new_files.append(local_path)
|
|
|
|
if new_files:
|
|
logger.info(f"Ingesting {len(new_files)} new discoveries...")
|
|
uploaded_urls = self.repo.commit_and_push(new_files, message=f"Discovery: {text}")
|
|
|
|
if uploaded_urls:
|
|
for local_p in new_files:
|
|
remote_path = f"web_results/{local_p.name}"
|
|
d_url = next((u for u in uploaded_urls if local_p.name in u), None)
|
|
if d_url:
|
|
# Index as granular chunks immediately
|
|
self.retrieval.ingest_document(local_p, remote_path, d_url)
|
|
|
|
# Cleanup and push updated shared index
|
|
try: shutil.rmtree(self.crawler.root_dir, ignore_errors=True)
|
|
except: pass
|
|
self.retrieval.push_index_db()
|
|
|
|
# Final Search to get the new high-relevance chunks
|
|
results = self.retrieval.search(text, limit=limit)
|
|
|
|
# Final filtering based on chunk similarity score
|
|
final_results = [r for r in results if r.get('score', 0) >= similarity_threshold]
|
|
|
|
if final_results:
|
|
logger.info(f"Returning {len(final_results)} unified chunk results.")
|
|
return final_results
|
|
else:
|
|
logger.info("Returning 0 unified chunk results.")
|
|
return []
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Storage-Efficient Document Retrieval Tool")
|
|
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
|
|
|
# Query command
|
|
query_parser = subparsers.add_parser("query", help="Query the document index")
|
|
query_parser.add_argument("text", help="The query text")
|
|
query_parser.add_argument("--limit", type=int, default=5, help="Number of results to return")
|
|
query_parser.add_argument("--no-fallback", action="store_true", help="Disable web search fallback")
|
|
query_parser.add_argument("--format", choices=["json", "text"], default="json", help="Output format")
|
|
|
|
# Add command
|
|
add_parser = subparsers.add_parser("add", help="Add a local file or directory to the repo and index")
|
|
add_parser.add_argument("path", help="Path to local file or directory")
|
|
|
|
# Update command
|
|
subparsers.add_parser("update", help="Sync remote inventory and re-index")
|
|
|
|
# Reset command
|
|
subparsers.add_parser("reset", help="DANGEROUS: Purge all data from repo and index")
|
|
|
|
args = parser.parse_args()
|
|
|
|
app = DocumentRetrieveApp()
|
|
|
|
if args.command == "update":
|
|
if app.sync():
|
|
print("Successfully synced remote metadata and updated index.")
|
|
else:
|
|
print("Update failed.")
|
|
|
|
elif args.command == "reset":
|
|
confirm = input("⚠️ CAUTION: This will permanently DELETE all files in Gitea and the local index. Are you sure? (y/n): ")
|
|
if confirm.lower() == 'y':
|
|
if app.retrieval.reset_all():
|
|
print("Successfully purged all knowledge base assets.")
|
|
else:
|
|
print("Reset failed.")
|
|
else:
|
|
print("Reset cancelled.")
|
|
|
|
elif args.command == "add":
|
|
path = Path(args.path)
|
|
if path.is_file():
|
|
if app.retrieval.add_local_file(str(path)):
|
|
print(f"Successfully added {path.name} to the distributed knowledge base.")
|
|
else:
|
|
print(f"Failed to add {path.name}.")
|
|
elif path.is_dir():
|
|
files = list(path.glob("*"))
|
|
print(f"Adding {len(files)} files from {path.name}...")
|
|
count = 0
|
|
for f in files:
|
|
if f.is_file():
|
|
if app.retrieval.add_local_file(str(f)):
|
|
count += 1
|
|
print(f"Successfully added {count} files from {path.name}.")
|
|
else:
|
|
print(f"Path not found: {args.path}")
|
|
|
|
elif args.command == "query":
|
|
results = app.query(args.text, limit=args.limit, fallback=(not args.no_fallback))
|
|
|
|
if args.format == "json":
|
|
# This now includes the full 'content' field for the LLM
|
|
print(json.dumps(results, indent=2))
|
|
else:
|
|
if not results:
|
|
print("No results found.")
|
|
for r in results:
|
|
print(f"[{r['score']:.4f}] {r['title']}")
|
|
print(f"URL: {r['url']}")
|
|
content_preview = r.get('content', '')[:200].replace('\n', ' ')
|
|
print(f"Preview: {content_preview}...")
|
|
print("-" * 40)
|
|
else:
|
|
parser.print_help()
|
|
|
|
if __name__ == "__main__":
|
|
main()
|