jarvis_gaia_agent / tools /retriever.py
onisj's picture
Add .gitignore and clean tracked files
1bbca12
raw
history blame
3.3 kB
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import numpy as np
import pandas as pd
import PyPDF2
import os
from typing import List, Dict
class DocumentRetrieverTool:
def __init__(self):
self.name = "document_retriever"
self.description = "Retrieves relevant text from GAIA text-heavy files (CSV, TXT, PDF) using semantic search."
self.inputs = {
"task_id": {"type": "string", "description": "GAIA task ID for the file"},
"query": {"type": "string", "description": "Question or query to search for"},
"file_type": {"type": "string", "description": "File type (csv, txt, pdf, default: txt)"}
}
self.output_type = str
self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
length_function=len
)
self.chunks: List[str] = []
self.embeddings: np.ndarray = None
async def aparse(self, task_id: str, query: str, file_type: str = "txt") -> str:
"""
Loads a GAIA file, splits it into chunks, embeds them, and retrieves relevant text for the query.
Supports CSV, TXT, and PDF files.
"""
try:
file_path = f"temp_{task_id}.{file_type}"
if not os.path.exists(file_path):
return f"File not found for task ID {task_id}"
# Load and preprocess file
text = ""
if file_type == "csv":
df = pd.read_csv(file_path)
text = df.to_string()
elif file_type == "txt":
with open(file_path, "r", encoding="utf-8") as f:
text = f.read()
elif file_type == "pdf":
with open(file_path, "rb") as f:
pdf = PyPDF2.PdfReader(f)
text = "".join(page.extract_text() or "" for page in pdf.pages)
else:
return f"Unsupported file type: {file_type}"
# Check if text was extracted
if not text.strip():
return "No extractable text found in file."
# Split text into chunks
self.chunks = self.text_splitter.split_text(text)
if not self.chunks:
return "No content found in file."
# Embed chunks and query
self.embeddings = self.embedder.encode(self.chunks, convert_to_tensor=True)
query_embedding = self.embedder.encode(query, convert_to_tensor=True)
# Compute cosine similarities
from sentence_transformers import util
similarities = util.cos_sim(query_embedding, self.embeddings)[0]
# Get top 3 most relevant chunks
top_k = min(3, len(self.chunks))
top_indices = similarities.argsort(descending=True)[:top_k]
relevant_chunks = [self.chunks[idx] for idx in top_indices]
# Combine results
return "\n\n".join(relevant_chunks)
except Exception as e:
return f"Error retrieving documents: {str(e)}"
document_retriever_tool = DocumentRetrieverTool()