# ---------------------------------------------------------- # Section 0: Imports # ---------------------------------------------------------- import json import os import pickle import re import subprocess import textwrap import base64 from datetime import datetime, timedelta from io import BytesIO from pathlib import Path from typing import List # Third-party libraries import requests from cachetools import TTLCache from PIL import Image # LangChain and associated libraries from langchain.schema import Document from langchain.tools.retriever import create_retriever_tool from langchain_community.vectorstores import FAISS from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader, ArxivLoader # Added loaders from langchain_core.messages import SystemMessage, HumanMessage, AIMessage from langchain_core.tools import tool from langchain_google_genai import ChatGoogleGenerativeAI from langchain_groq import ChatGroq from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace from langgraph.graph import START, StateGraph, MessagesState from langgraph.prebuilt import ToolNode, tools_condition # Environment variable loading from dotenv import load_dotenv load_dotenv() # ---------------------------------------------------------- # Section 1: Constants and Configuration # ---------------------------------------------------------- JSONL_PATH = Path("metadata.jsonl") FAISS_CACHE = Path("faiss_index.pkl") EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2" RETRIEVER_K = 5 # Number of similar documents to retrieve CACHE_TTL = 600 # Cache API calls for 10 minutes # Global cache object for API calls API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL) # ---------------------------------------------------------- # Section 2: The Agent Class # ---------------------------------------------------------- class MyAgent: """ Encapsulates the agent's state, including LLMs, retriever, and tools. This class-based approach ensures clean management of dependencies. """ def __init__(self, provider: str = "google"): """ Initializes the agent, setting up LLMs and the FAISS retriever. Args: provider (str): The LLM provider to use ('google', 'groq', 'huggingface'). """ print(f"Initializing agent with provider: {provider}") self.llm = self._build_llm(provider) self.vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0) self.retriever = self._get_retriever() def _get_retriever(self): """Builds or loads the FAISS retriever from a local cache.""" embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL) if FAISS_CACHE.exists(): print(f"Loading FAISS index from cache: {FAISS_CACHE}") with open(FAISS_CACHE, "rb") as f: vector_store = pickle.load(f) else: print("FAISS cache not found. Building new index from metadata.jsonl...") if not JSONL_PATH.exists(): raise FileNotFoundError(f"{JSONL_PATH} not found. Cannot build vector store.") docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))] if not docs: raise ValueError("No documents found in metadata.jsonl.") vector_store = FAISS.from_documents(docs, embeddings) with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f) print(f"FAISS index built and saved to cache: {FAISS_CACHE}") return vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K}) def _build_llm(self, provider: str): """Helper to build the main text-based LLM based on the chosen provider.""" if provider == "google": return ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0) elif provider == "groq": return ChatGroq(model_name="llama3-70b-8192", temperature=0) elif provider == "huggingface": return ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct", temperature=0)) else: raise ValueError("Provider must be 'google', 'groq', or 'huggingface'") def _cached_get(self, key: str, fetch_fn): """Helper for caching API calls.""" if key in API_CACHE: return API_CACHE[key] val = fetch_fn() API_CACHE[key] = val return val # --- Tool Definitions as Class Methods --- @tool def python_repl(self, code: str) -> str: """Executes a string of Python code and returns the stdout/stderr.""" code = textwrap.dedent(code).strip() try: result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False) if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```" else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```" except subprocess.TimeoutExpired: return "Execution timed out (>10s)." @tool def describe_image(self, image_source: str) -> str: """Describes an image from a local file path or a URL using Gemini vision.""" try: if image_source.startswith("http"): img = Image.open(BytesIO(requests.get(image_source, timeout=10).content)) else: img = Image.open(image_source) buffered = BytesIO() img.convert("RGB").save(buffered, format="JPEG") b64_string = base64.b64encode(buffered.getvalue()).decode() msg = HumanMessage(content=[{"type": "text", "text": "Describe this image in detail."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}]) return self.vision_llm.invoke([msg]).content except Exception as e: return f"Error processing image: {e}" @tool def web_search(self, query: str) -> str: """Performs a web search using Tavily and returns a compilation of results.""" key = f"web:{query}" results = self._cached_get(key, lambda: TavilySearchResults(max_results=5).invoke(query)) return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results]) @tool def wiki_search(self, query: str) -> str: """Searches Wikipedia and returns the top 2 results.""" key = f"wiki:{query}" docs = self._cached_get(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load()) return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs]) @tool def arxiv_search(self, query: str) -> str: """Searches Arxiv for scientific papers and returns the top 2 results.""" key = f"arxiv:{query}" docs = self._cached_get(key, lambda: ArxivLoader(query=query, load_max_docs=2).load()) return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs]) def get_tools(self) -> list: """Returns a list of all tools available to the agent.""" tools_list = [ self.python_repl, self.describe_image, self.web_search, self.wiki_search, self.arxiv_search, ] retriever_tool = create_retriever_tool( retriever=self.retriever, name="retrieve_examples", description="Retrieve solved questions and answers similar to the user's query.", ) tools_list.append(retriever_tool) return tools_list # ---------------------------------------------------------- # Section 3: System Prompt # ---------------------------------------------------------- SYSTEM_PROMPT = ( """You are a helpful and expert assistant designed to answer questions accurately and concisely. **Instructions:** 1. **Analyze the Question:** Carefully understand what is being asked. 2. **Use Tools:** You have a set of tools to find information. Use them logically. 3. **Synthesize the Answer:** Based on the information from the tools, formulate your final answer. 4. **Format the Output:** Your final response MUST be in the following format and nothing else: FINAL ANSWER: [Your concise and accurate answer here] If the `retrieve_examples` tool provides an answer to an identical question, use that answer. Otherwise, use your tools to find the correct answer for the current question. """ ) # ---------------------------------------------------------- # Section 4: Factory Function for Agent Executor # ---------------------------------------------------------- def create_agent_executor(provider: str = "google"): """Factory function to create and compile the LangGraph agent executor.""" my_agent_instance = MyAgent(provider=provider) tools_list = my_agent_instance.get_tools() llm_with_tools = my_agent_instance.llm.bind_tools(tools_list) def retriever_node(state: MessagesState): """First node: retrieves examples and prepends them to the message history.""" user_query = state["messages"][-1].content docs = my_agent_instance.retriever.invoke(user_query) messages = [SystemMessage(content=SYSTEM_PROMPT)] if docs: example_text = "\n\n---\n\n".join(d.page_content for d in docs) example_msg = AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever") messages.append(example_msg) messages.extend(state["messages"]) return {"messages": messages} def assistant_node(state: MessagesState): """Main assistant node: calls the LLM with the current state to decide the next action.""" result = llm_with_tools.invoke(state["messages"]) return {"messages": [result]} builder = StateGraph(MessagesState) builder.add_node("retriever", retriever_node) builder.add_node("assistant", assistant_node) builder.add_node("tools", ToolNode(tools_list)) builder.add_edge(START, "retriever") builder.add_edge("retriever", "assistant") builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"}) builder.add_edge("tools", "assistant") agent_executor = builder.compile() print("Agent Executor created successfully.") return agent_executor # ---------------------------------------------------------- # Section 5: Direct Execution Block for Testing # ---------------------------------------------------------- if __name__ == "__main__": """direct testing of the agent's logic.""" print("--- Running Agent in Test Mode ---") agent = create_agent_executor(provider="google") question = "According to wikipedia, what is the main difference between a lama and an alpaca?" print(f"\nTest Question: {question}\n\n--- Agent Thinking... ---\n") for chunk in agent.stream({"messages": [("user", question)]}): for key, value in chunk.items(): if value['messages']: message = value['messages'][-1] if message.content: print(f"--- Node: {key} ---\n{message.content}\n")