Spaces:
Starting
Starting
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import pandas as pd | |
import PyPDF2 | |
import os | |
from typing import List, Dict | |
class DocumentRetrieverTool: | |
def __init__(self): | |
self.name = "document_retriever" | |
self.description = "Retrieves relevant text from GAIA text-heavy files (CSV, TXT, PDF) using semantic search." | |
self.inputs = { | |
"task_id": {"type": "string", "description": "GAIA task ID for the file"}, | |
"query": {"type": "string", "description": "Question or query to search for"}, | |
"file_type": {"type": "string", "description": "File type (csv, txt, pdf, default: txt)"} | |
} | |
self.output_type = str | |
self.embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
self.text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
length_function=len | |
) | |
self.chunks: List[str] = [] | |
self.embeddings: np.ndarray = None | |
async def aparse(self, task_id: str, query: str, file_type: str = "txt") -> str: | |
""" | |
Loads a GAIA file, splits it into chunks, embeds them, and retrieves relevant text for the query. | |
Supports CSV, TXT, and PDF files. | |
""" | |
try: | |
file_path = f"temp_{task_id}.{file_type}" | |
if not os.path.exists(file_path): | |
return f"File not found for task ID {task_id}" | |
# Load and preprocess file | |
text = "" | |
if file_type == "csv": | |
df = pd.read_csv(file_path) | |
text = df.to_string() | |
elif file_type == "txt": | |
with open(file_path, "r", encoding="utf-8") as f: | |
text = f.read() | |
elif file_type == "pdf": | |
with open(file_path, "rb") as f: | |
pdf = PyPDF2.PdfReader(f) | |
text = "".join(page.extract_text() or "" for page in pdf.pages) | |
else: | |
return f"Unsupported file type: {file_type}" | |
# Check if text was extracted | |
if not text.strip(): | |
return "No extractable text found in file." | |
# Split text into chunks | |
self.chunks = self.text_splitter.split_text(text) | |
if not self.chunks: | |
return "No content found in file." | |
# Embed chunks and query | |
self.embeddings = self.embedder.encode(self.chunks, convert_to_tensor=True) | |
query_embedding = self.embedder.encode(query, convert_to_tensor=True) | |
# Compute cosine similarities | |
from sentence_transformers import util | |
similarities = util.cos_sim(query_embedding, self.embeddings)[0] | |
# Get top 3 most relevant chunks | |
top_k = min(3, len(self.chunks)) | |
top_indices = similarities.argsort(descending=True)[:top_k] | |
relevant_chunks = [self.chunks[idx] for idx in top_indices] | |
# Combine results | |
return "\n\n".join(relevant_chunks) | |
except Exception as e: | |
return f"Error retrieving documents: {str(e)}" | |
document_retriever_tool = DocumentRetrieverTool() | |