Spaces:
Sleeping
Sleeping
# ---------------------------------------------------------- | |
# Section 0: Imports | |
# ---------------------------------------------------------- | |
import json | |
import os | |
import pickle | |
import re | |
import subprocess | |
import textwrap | |
import base64 | |
from datetime import datetime, timedelta | |
from io import BytesIO | |
from pathlib import Path | |
from typing import List | |
# Third-party libraries | |
import requests | |
from cachetools import TTLCache | |
from PIL import Image | |
# LangChain and associated libraries | |
from langchain.schema import Document | |
from langchain.tools.retriever import create_retriever_tool | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader # Added loaders | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain_core.tools import tool | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import ToolNode, tools_condition | |
# Environment variable loading | |
from dotenv import load_dotenv | |
load_dotenv() | |
# ---------------------------------------------------------- | |
# Section 1: Constants and Configuration | |
# ---------------------------------------------------------- | |
JSONL_PATH = Path("metadata.jsonl") | |
FAISS_CACHE = Path("faiss_index.pkl") | |
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2" | |
RETRIEVER_K = 5 # Number of similar documents to retrieve | |
CACHE_TTL = 600 # Cache API calls for 10 minutes | |
# Global cache object for API calls | |
API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL) | |
# ---------------------------------------------------------- | |
# Section 2: The Agent Class | |
# ---------------------------------------------------------- | |
class MyAgent: | |
""" | |
Encapsulates the agent's state, including LLMs, retriever, and tools. | |
This class-based approach ensures clean management of dependencies. | |
""" | |
def __init__(self, provider: str = "google"): | |
""" | |
Initializes the agent, setting up LLMs and the FAISS retriever. | |
Args: | |
provider (str): The LLM provider to use ('google', 'groq', 'huggingface'). | |
""" | |
print(f"Initializing agent with provider: {provider}") | |
self.llm = self._build_llm(provider) | |
self.vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0) | |
self.retriever = self._get_retriever() | |
def _get_retriever(self): | |
"""Builds or loads the FAISS retriever from a local cache.""" | |
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL) | |
if FAISS_CACHE.exists(): | |
print(f"Loading FAISS index from cache: {FAISS_CACHE}") | |
with open(FAISS_CACHE, "rb") as f: | |
vector_store = pickle.load(f) | |
else: | |
print("FAISS cache not found. Building new index from metadata.jsonl...") | |
if not JSONL_PATH.exists(): | |
raise FileNotFoundError(f"{JSONL_PATH} not found. Cannot build vector store.") | |
docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))] | |
if not docs: raise ValueError("No documents found in metadata.jsonl.") | |
vector_store = FAISS.from_documents(docs, embeddings) | |
with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f) | |
print(f"FAISS index built and saved to cache: {FAISS_CACHE}") | |
return vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K}) | |
def _build_llm(self, provider: str): | |
"""Helper to build the main text-based LLM based on the chosen provider.""" | |
if provider == "google": return ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0) | |
elif provider == "groq": return ChatGroq(model_name="llama3-70b-8192", temperature=0) | |
elif provider == "huggingface": return ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct", temperature=0)) | |
else: raise ValueError("Provider must be 'google', 'groq', or 'huggingface'") | |
def _cached_get(self, key: str, fetch_fn): | |
"""Helper for caching API calls.""" | |
if key in API_CACHE: return API_CACHE[key] | |
val = fetch_fn() | |
API_CACHE[key] = val | |
return val | |
# --- Tool Definitions as Class Methods --- | |
def python_repl(self, code: str) -> str: | |
"""Executes a string of Python code and returns the stdout/stderr.""" | |
code = textwrap.dedent(code).strip() | |
try: | |
result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False) | |
if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```" | |
else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```" | |
except subprocess.TimeoutExpired: return "Execution timed out (>10s)." | |
def describe_image(self, image_source: str) -> str: | |
"""Describes an image from a local file path or a URL using Gemini vision.""" | |
try: | |
if image_source.startswith("http"): | |
img = Image.open(BytesIO(requests.get(image_source, timeout=10).content)) | |
else: | |
img = Image.open(image_source) | |
buffered = BytesIO() | |
img.convert("RGB").save(buffered, format="JPEG") | |
b64_string = base64.b64encode(buffered.getvalue()).decode() | |
msg = HumanMessage(content=[{"type": "text", "text": "Describe this image in detail."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}]) | |
return self.vision_llm.invoke([msg]).content | |
except Exception as e: return f"Error processing image: {e}" | |
def web_search(self, query: str) -> str: | |
"""Performs a web search using Tavily and returns a compilation of results.""" | |
key = f"web:{query}" | |
results = self._cached_get(key, lambda: TavilySearchResults(max_results=5).invoke(query)) | |
return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results]) | |
def wiki_search(self, query: str) -> str: | |
"""Searches Wikipedia and returns the top 2 results.""" | |
key = f"wiki:{query}" | |
docs = self._cached_get(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load()) | |
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs]) | |
def arxiv_search(self, query: str) -> str: | |
"""Searches Arxiv for scientific papers and returns the top 2 results.""" | |
key = f"arxiv:{query}" | |
docs = self._cached_get(key, lambda: ArxivLoader(query=query, load_max_docs=2).load()) | |
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs]) | |
def get_tools(self) -> list: | |
"""Returns a list of all tools available to the agent.""" | |
tools_list = [ | |
self.python_repl, | |
self.describe_image, | |
self.web_search, | |
self.wiki_search, | |
self.arxiv_search, | |
] | |
retriever_tool = create_retriever_tool( | |
retriever=self.retriever, | |
name="retrieve_examples", | |
description="Retrieve solved questions and answers similar to the user's query.", | |
) | |
tools_list.append(retriever_tool) | |
return tools_list | |
# ---------------------------------------------------------- | |
# Section 3: System Prompt | |
# ---------------------------------------------------------- | |
SYSTEM_PROMPT = ( | |
"""You are a helpful and expert assistant designed to answer questions accurately and concisely. | |
**Instructions:** | |
1. **Analyze the Question:** Carefully understand what is being asked. | |
2. **Use Tools:** You have a set of tools to find information. Use them logically. | |
3. **Synthesize the Answer:** Based on the information from the tools, formulate your final answer. | |
4. **Format the Output:** Your final response MUST be in the following format and nothing else: | |
FINAL ANSWER: [Your concise and accurate answer here] | |
If the `retrieve_examples` tool provides an answer to an identical question, use that answer. Otherwise, use your tools to find the correct answer for the current question. | |
""" | |
) | |
# ---------------------------------------------------------- | |
# Section 4: Factory Function for Agent Executor | |
# ---------------------------------------------------------- | |
def create_agent_executor(provider: str = "google"): | |
"""Factory function to create and compile the LangGraph agent executor.""" | |
my_agent_instance = MyAgent(provider=provider) | |
tools_list = my_agent_instance.get_tools() | |
llm_with_tools = my_agent_instance.llm.bind_tools(tools_list) | |
def retriever_node(state: MessagesState): | |
"""First node: retrieves examples and prepends them to the message history.""" | |
user_query = state["messages"][-1].content | |
docs = my_agent_instance.retriever.invoke(user_query) | |
messages = [SystemMessage(content=SYSTEM_PROMPT)] | |
if docs: | |
example_text = "\n\n---\n\n".join(d.page_content for d in docs) | |
example_msg = AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever") | |
messages.append(example_msg) | |
messages.extend(state["messages"]) | |
return {"messages": messages} | |
def assistant_node(state: MessagesState): | |
"""Main assistant node: calls the LLM with the current state to decide the next action.""" | |
result = llm_with_tools.invoke(state["messages"]) | |
return {"messages": [result]} | |
builder = StateGraph(MessagesState) | |
builder.add_node("retriever", retriever_node) | |
builder.add_node("assistant", assistant_node) | |
builder.add_node("tools", ToolNode(tools_list)) | |
builder.add_edge(START, "retriever") | |
builder.add_edge("retriever", "assistant") | |
builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"}) | |
builder.add_edge("tools", "assistant") | |
agent_executor = builder.compile() | |
print("Agent Executor created successfully.") | |
return agent_executor | |
# ---------------------------------------------------------- | |
# Section 5: Direct Execution Block for Testing | |
# ---------------------------------------------------------- | |
if __name__ == "__main__": | |
"""direct testing of the agent's logic.""" | |
print("--- Running Agent in Test Mode ---") | |
agent = create_agent_executor(provider="google") | |
question = "According to wikipedia, what is the main difference between a lama and an alpaca?" | |
print(f"\nTest Question: {question}\n\n--- Agent Thinking... ---\n") | |
for chunk in agent.stream({"messages": [("user", question)]}): | |
for key, value in chunk.items(): | |
if value['messages']: | |
message = value['messages'][-1] | |
if message.content: print(f"--- Node: {key} ---\n{message.content}\n") |