Abbasid's picture
Update agent.py
e5bb694 verified
raw
history blame
11.5 kB
# ----------------------------------------------------------
# Section 0: Imports
# ----------------------------------------------------------
import json
import os
import pickle
import re
import subprocess
import textwrap
import base64
from datetime import datetime, timedelta
from io import BytesIO
from pathlib import Path
from typing import List
# Third-party libraries
import requests
from cachetools import TTLCache
from PIL import Image
# LangChain and associated libraries
from langchain.schema import Document
from langchain.tools.retriever import create_retriever_tool
from langchain_community.vectorstores import FAISS
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader # Added loaders
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
# Environment variable loading
from dotenv import load_dotenv
load_dotenv()
# ----------------------------------------------------------
# Section 1: Constants and Configuration
# ----------------------------------------------------------
JSONL_PATH = Path("metadata.jsonl")
FAISS_CACHE = Path("faiss_index.pkl")
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
RETRIEVER_K = 5 # Number of similar documents to retrieve
CACHE_TTL = 600 # Cache API calls for 10 minutes
# Global cache object for API calls
API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
# ----------------------------------------------------------
# Section 2: The Agent Class
# ----------------------------------------------------------
class MyAgent:
"""
Encapsulates the agent's state, including LLMs, retriever, and tools.
This class-based approach ensures clean management of dependencies.
"""
def __init__(self, provider: str = "google"):
"""
Initializes the agent, setting up LLMs and the FAISS retriever.
Args:
provider (str): The LLM provider to use ('google', 'groq', 'huggingface').
"""
print(f"Initializing agent with provider: {provider}")
self.llm = self._build_llm(provider)
self.vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)
self.retriever = self._get_retriever()
def _get_retriever(self):
"""Builds or loads the FAISS retriever from a local cache."""
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
if FAISS_CACHE.exists():
print(f"Loading FAISS index from cache: {FAISS_CACHE}")
with open(FAISS_CACHE, "rb") as f:
vector_store = pickle.load(f)
else:
print("FAISS cache not found. Building new index from metadata.jsonl...")
if not JSONL_PATH.exists():
raise FileNotFoundError(f"{JSONL_PATH} not found. Cannot build vector store.")
docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
if not docs: raise ValueError("No documents found in metadata.jsonl.")
vector_store = FAISS.from_documents(docs, embeddings)
with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
print(f"FAISS index built and saved to cache: {FAISS_CACHE}")
return vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
def _build_llm(self, provider: str):
"""Helper to build the main text-based LLM based on the chosen provider."""
if provider == "google": return ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
elif provider == "groq": return ChatGroq(model_name="llama3-70b-8192", temperature=0)
elif provider == "huggingface": return ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct", temperature=0))
else: raise ValueError("Provider must be 'google', 'groq', or 'huggingface'")
def _cached_get(self, key: str, fetch_fn):
"""Helper for caching API calls."""
if key in API_CACHE: return API_CACHE[key]
val = fetch_fn()
API_CACHE[key] = val
return val
# --- Tool Definitions as Class Methods ---
@tool
def python_repl(self, code: str) -> str:
"""Executes a string of Python code and returns the stdout/stderr."""
code = textwrap.dedent(code).strip()
try:
result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False)
if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```"
else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
except subprocess.TimeoutExpired: return "Execution timed out (>10s)."
@tool
def describe_image(self, image_source: str) -> str:
"""Describes an image from a local file path or a URL using Gemini vision."""
try:
if image_source.startswith("http"):
img = Image.open(BytesIO(requests.get(image_source, timeout=10).content))
else:
img = Image.open(image_source)
buffered = BytesIO()
img.convert("RGB").save(buffered, format="JPEG")
b64_string = base64.b64encode(buffered.getvalue()).decode()
msg = HumanMessage(content=[{"type": "text", "text": "Describe this image in detail."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}])
return self.vision_llm.invoke([msg]).content
except Exception as e: return f"Error processing image: {e}"
@tool
def web_search(self, query: str) -> str:
"""Performs a web search using Tavily and returns a compilation of results."""
key = f"web:{query}"
results = self._cached_get(key, lambda: TavilySearchResults(max_results=5).invoke(query))
return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results])
@tool
def wiki_search(self, query: str) -> str:
"""Searches Wikipedia and returns the top 2 results."""
key = f"wiki:{query}"
docs = self._cached_get(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load())
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs])
@tool
def arxiv_search(self, query: str) -> str:
"""Searches Arxiv for scientific papers and returns the top 2 results."""
key = f"arxiv:{query}"
docs = self._cached_get(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
def get_tools(self) -> list:
"""Returns a list of all tools available to the agent."""
tools_list = [
self.python_repl,
self.describe_image,
self.web_search,
self.wiki_search,
self.arxiv_search,
]
retriever_tool = create_retriever_tool(
retriever=self.retriever,
name="retrieve_examples",
description="Retrieve solved questions and answers similar to the user's query.",
)
tools_list.append(retriever_tool)
return tools_list
# ----------------------------------------------------------
# Section 3: System Prompt
# ----------------------------------------------------------
SYSTEM_PROMPT = (
"""You are a helpful and expert assistant designed to answer questions accurately and concisely.
**Instructions:**
1. **Analyze the Question:** Carefully understand what is being asked.
2. **Use Tools:** You have a set of tools to find information. Use them logically.
3. **Synthesize the Answer:** Based on the information from the tools, formulate your final answer.
4. **Format the Output:** Your final response MUST be in the following format and nothing else:
FINAL ANSWER: [Your concise and accurate answer here]
If the `retrieve_examples` tool provides an answer to an identical question, use that answer. Otherwise, use your tools to find the correct answer for the current question.
"""
)
# ----------------------------------------------------------
# Section 4: Factory Function for Agent Executor
# ----------------------------------------------------------
def create_agent_executor(provider: str = "google"):
"""Factory function to create and compile the LangGraph agent executor."""
my_agent_instance = MyAgent(provider=provider)
tools_list = my_agent_instance.get_tools()
llm_with_tools = my_agent_instance.llm.bind_tools(tools_list)
def retriever_node(state: MessagesState):
"""First node: retrieves examples and prepends them to the message history."""
user_query = state["messages"][-1].content
docs = my_agent_instance.retriever.invoke(user_query)
messages = [SystemMessage(content=SYSTEM_PROMPT)]
if docs:
example_text = "\n\n---\n\n".join(d.page_content for d in docs)
example_msg = AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever")
messages.append(example_msg)
messages.extend(state["messages"])
return {"messages": messages}
def assistant_node(state: MessagesState):
"""Main assistant node: calls the LLM with the current state to decide the next action."""
result = llm_with_tools.invoke(state["messages"])
return {"messages": [result]}
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever_node)
builder.add_node("assistant", assistant_node)
builder.add_node("tools", ToolNode(tools_list))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"})
builder.add_edge("tools", "assistant")
agent_executor = builder.compile()
print("Agent Executor created successfully.")
return agent_executor
# ----------------------------------------------------------
# Section 5: Direct Execution Block for Testing
# ----------------------------------------------------------
if __name__ == "__main__":
"""direct testing of the agent's logic."""
print("--- Running Agent in Test Mode ---")
agent = create_agent_executor(provider="google")
question = "According to wikipedia, what is the main difference between a lama and an alpaca?"
print(f"\nTest Question: {question}\n\n--- Agent Thinking... ---\n")
for chunk in agent.stream({"messages": [("user", question)]}):
for key, value in chunk.items():
if value['messages']:
message = value['messages'][-1]
if message.content: print(f"--- Node: {key} ---\n{message.content}\n")