Abbasid's picture
Update agent.py
d5ef142 verified
raw
history blame
11.8 kB
# ----------------------------------------------------------
# Section 0: Imports
# ----------------------------------------------------------
import json
import os
import pickle
import re
import subprocess
import textwrap
import base64
import functools # Used to pre-fill arguments for our tool functions
from io import BytesIO
from pathlib import Path
# Third-party libraries
import requests
from cachetools import TTLCache
from PIL import Image
# LangChain and associated libraries
from langchain.schema import Document
from langchain.tools.retriever import create_retriever_tool
from langchain_community.vectorstores import FAISS
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import Tool, tool # Import Tool for manual tool creation
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode, tools_condition
# Environment variable loading
from dotenv import load_dotenv
load_dotenv()
# ----------------------------------------------------------
# Section 1: Constants and Configuration
# ----------------------------------------------------------
JSONL_PATH = Path("metadata.jsonl")
FAISS_CACHE = Path("faiss_index.pkl")
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
RETRIEVER_K = 5
CACHE_TTL = 600
API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
# Global helper for caching API calls
def cached_get(key: str, fetch_fn):
if key in API_CACHE: return API_CACHE[key]
val = fetch_fn()
API_CACHE[key] = val
return val
# ----------------------------------------------------------
# Section 2: Standalone Tool Functions (No 'self' parameter)
# ----------------------------------------------------------
@tool
def python_repl(code: str) -> str:
"""Executes a string of Python code and returns the stdout/stderr."""
code = textwrap.dedent(code).strip()
try:
result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False)
if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```"
else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
except subprocess.TimeoutExpired: return "Execution timed out (>10s)."
# These functions now accept their dependencies (like an llm instance or a cache function) as arguments.
@tool
def describe_image_func(image_source: str, vision_llm_instance) -> str:
"""Describes an image from a local file path or a URL using a provided vision LLM."""
try:
if image_source.startswith("http"): img = Image.open(BytesIO(requests.get(image_source, timeout=10).content))
else: img = Image.open(image_source)
buffered = BytesIO()
img.convert("RGB").save(buffered, format="JPEG")
b64_string = base64.b64encode(buffered.getvalue()).decode()
msg = HumanMessage(content=[{"type": "text", "text": "Describe this image in detail."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}])
return vision_llm_instance.invoke([msg]).content
except Exception as e: return f"Error processing image: {e}"
@tool
def web_search_func(query: str, cache_func) -> str:
"""Performs a web search using Tavily and returns a compilation of results."""
key = f"web:{query}"
results = cache_func(key, lambda: TavilySearchResults(max_results=5).invoke(query))
return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results])
@tool
def wiki_search_func(query: str, cache_func) -> str:
"""Searches Wikipedia and returns the top 2 results."""
key = f"wiki:{query}"
docs = cache_func(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load())
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs])
@tool
def arxiv_search_func(query: str, cache_func) -> str:
"""Searches Arxiv for scientific papers and returns the top 2 results."""
key = f"arxiv:{query}"
docs = cache_func(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
# ----------------------------------------------------------
# Section 3: System Prompt
# ----------------------------------------------------------
SYSTEM_PROMPT = (
"""You are an expert-level research assistant designed to answer questions accurately.
**Your Reasoning Process:**
1. **Think Step-by-Step:** Break down the user's question into logical steps. Plan which tools you need.
2. **Use Your Tools:** Execute your plan by calling one tool at a time. Analyze the results.
3. **Iterate:** If needed, use more tools until you have enough information.
4. **Synthesize:** Formulate a concise final answer based on the information.
**Output Format:**
- Your final response MUST strictly follow this format:
`FINAL ANSWER: [Your concise and accurate answer here]`
**Crucial Instructions:**
- If your tools **cannot possibly answer the question** (e.g., it requires watching a video or listening to audio), you MUST respond by stating the limitation.
- In case of a limitation, your response should be:
`FINAL ANSWER: I am unable to answer this question because it requires a capability I do not possess, such as [describe the missing capability].`
"""
)
# ----------------------------------------------------------
# Section 4: Factory Function for Agent Executor
# ----------------------------------------------------------
def create_agent_executor(provider: str = "groq"):
"""
Factory function to create and compile the LangGraph agent executor.
This version creates tools from standalone functions to ensure model compatibility.
"""
print(f"Initializing agent with provider: {provider}")
# Step 1: Build LLMs
if provider == "google": main_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
elif provider == "groq": main_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
elif provider == "huggingface": main_llm = ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.1))
else: raise ValueError("Invalid provider selected")
vision_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
# Step 2: Build Retriever
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
if FAISS_CACHE.exists():
with open(FAISS_CACHE, "rb") as f: vector_store = pickle.load(f)
else:
docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
vector_store = FAISS.from_documents(docs, embeddings)
with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
# Step 3: Create the final list of tools
# We use functools.partial to "bake in" the dependencies (like the LLM or cache) into our standalone functions.
# This creates new functions with a simpler signature that the agent can easily call.
tools_list = [
python_repl,
Tool(name="describe_image", func=functools.partial(describe_image_func, vision_llm_instance=vision_llm), description="Describes an image from a local file path or a URL."),
Tool(name="web_search", func=functools.partial(web_search_func, cache_func=cached_get), description="Performs a web search using Tavily."),
Tool(name="wiki_search", func=functools.partial(wiki_search_func, cache_func=cached_get), description="Searches Wikipedia."),
Tool(name="arxiv_search", func=functools.partial(arxiv_search_func, cache_func=cached_get), description="Searches Arxiv for scientific papers."),
create_retriever_tool(retriever=retriever, name="retrieve_examples", description="Retrieve solved questions similar to the user's query."),
]
llm_with_tools = main_llm.bind_tools(tools_list)
# Step 4: Define Graph Nodes
def retriever_node(state: MessagesState):
user_query = state["messages"][-1].content
docs = retriever.invoke(user_query)
messages = [SystemMessage(content=SYSTEM_PROMPT)]
if docs:
example_text = "\n\n---\n\n".join(d.page_content for d in docs)
messages.append(AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever"))
messages.extend(state["messages"])
return {"messages": messages}
def assistant_node(state: MessagesState):
result = llm_with_tools.invoke(state["messages"])
return {"messages": [result]}
# Step 5: Build Graph
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever_node)
builder.add_node("assistant", assistant_node)
builder.add_node("tools", ToolNode(tools_list))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"})
builder.add_edge("tools", "assistant")
agent_executor = builder.compile()
print("Agent Executor created successfully.")
return agent_executor
# ----------------------------------------------------------
# Section 5: Pre-flight check and Direct Execution Block
# ----------------------------------------------------------
def test_llm_connection(provider: str = "google"):
"""Performs a quick test to see if the LLM provider is accessible."""
print(f"--- Performing pre-flight check for LLM provider: {provider} ---")
try:
if provider == "google": llm, name = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest"), "Google Gemini"
elif provider == "groq": llm, name = ChatGroq(model_name="llama3-70b-8192"), "Groq"
elif provider == "huggingface": llm, name = ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")), "Hugging Face"
else: return "❌ **LLM Status:** Invalid provider configured."
llm.invoke("hello")
success_message = f"✅ **LLM Status:** Connection to {name} is successful."
print(success_message)
return success_message
except Exception as e:
error_message = f"❌ **LLM Status:** FAILED to connect. Check API keys/credits. Details: {e}"
print(error_message)
return error_message
if __name__ == "__main__":
"""Allows for direct testing of the agent's logic."""
print("--- Running Agent in Test Mode ---")
agent = create_agent_executor(provider="google")
question = "According to wikipedia, what is the main difference between a lama and an alpaca?"
print(f"\nTest Question: {question}\n\n--- Agent Thinking... ---\n")
for chunk in agent.stream({"messages": [("user", question)]}):
for key, value in chunk.items():
if value['messages']:
message = value['messages'][-1]
if message.content: print(f"--- Node: {key} ---\n{message.content}\n")