Abbasid commited on
Commit
d5ef142
·
verified ·
1 Parent(s): 83ebef7

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +120 -146
agent.py CHANGED
@@ -8,10 +8,9 @@ import re
8
  import subprocess
9
  import textwrap
10
  import base64
11
- from datetime import datetime, timedelta
12
  from io import BytesIO
13
  from pathlib import Path
14
- from typing import List
15
 
16
  # Third-party libraries
17
  import requests
@@ -23,9 +22,9 @@ from langchain.schema import Document
23
  from langchain.tools.retriever import create_retriever_tool
24
  from langchain_community.vectorstores import FAISS
25
  from langchain_community.tools.tavily_search import TavilySearchResults
26
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader # Added loaders
27
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
28
- from langchain_core.tools import tool
29
  from langchain_google_genai import ChatGoogleGenerativeAI
30
  from langchain_groq import ChatGroq
31
  from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
@@ -42,185 +41,143 @@ load_dotenv()
42
  JSONL_PATH = Path("metadata.jsonl")
43
  FAISS_CACHE = Path("faiss_index.pkl")
44
  EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
45
- RETRIEVER_K = 5 # Number of similar documents to retrieve
46
- CACHE_TTL = 600 # Cache API calls for 10 minutes
47
- # Global cache object for API calls
48
  API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
49
 
 
 
 
 
 
 
 
50
  # ----------------------------------------------------------
51
- # Section 2: The Agent Class
52
  # ----------------------------------------------------------
53
- class MyAgent:
54
- """
55
- Encapsulates the agent's state, including LLMs, retriever, and tools.
56
- This class-based approach ensures clean management of dependencies.
57
- """
58
 
59
- def __init__(self, provider: str = "groq"):
60
- """
61
- Initializes the agent, setting up LLMs and the FAISS retriever.
62
- Args:
63
- provider (str): The LLM provider to use ('google', 'groq', 'huggingface').
64
- """
65
- print(f"Initializing agent with provider: {provider}")
66
-
67
- self.llm = self._build_llm(provider)
68
- self.vision_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
69
- self.retriever = self._get_retriever()
70
-
71
- def _get_retriever(self):
72
- """Builds or loads the FAISS retriever from a local cache."""
73
- embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
74
-
75
- if FAISS_CACHE.exists():
76
- print(f"Loading FAISS index from cache: {FAISS_CACHE}")
77
- with open(FAISS_CACHE, "rb") as f:
78
- vector_store = pickle.load(f)
79
- else:
80
- print("FAISS cache not found. Building new index from metadata.jsonl...")
81
- if not JSONL_PATH.exists():
82
- raise FileNotFoundError(f"{JSONL_PATH} not found. Cannot build vector store.")
83
- docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
84
- if not docs: raise ValueError("No documents found in metadata.jsonl.")
85
- vector_store = FAISS.from_documents(docs, embeddings)
86
- with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
87
- print(f"FAISS index built and saved to cache: {FAISS_CACHE}")
88
-
89
- return vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
90
-
91
- def _build_llm(self, provider: str):
92
- """Helper to build the main text-based LLM based on the chosen provider."""
93
- if provider == "google": return ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
94
- elif provider == "groq": return ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
95
- elif provider == "huggingface": return ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct", temperature=0))
96
- else: raise ValueError("Provider must be 'google', 'groq', or 'huggingface'")
97
-
98
- def _cached_get(self, key: str, fetch_fn):
99
- """Helper for caching API calls."""
100
- if key in API_CACHE: return API_CACHE[key]
101
- val = fetch_fn()
102
- API_CACHE[key] = val
103
- return val
104
-
105
- # --- Tool Definitions as Class Methods ---
106
-
107
- @tool
108
- def python_repl(self, code: str) -> str:
109
- """Executes a string of Python code and returns the stdout/stderr."""
110
- code = textwrap.dedent(code).strip()
111
- try:
112
- result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False)
113
- if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```"
114
- else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
115
- except subprocess.TimeoutExpired: return "Execution timed out (>10s)."
116
-
117
- @tool
118
- def describe_image(self, image_source: str) -> str:
119
- """Describes an image from a local file path or a URL using Gemini vision."""
120
- try:
121
- if image_source.startswith("http"):
122
- img = Image.open(BytesIO(requests.get(image_source, timeout=10).content))
123
- else:
124
- img = Image.open(image_source)
125
- buffered = BytesIO()
126
- img.convert("RGB").save(buffered, format="JPEG")
127
- b64_string = base64.b64encode(buffered.getvalue()).decode()
128
- msg = HumanMessage(content=[{"type": "text", "text": "Describe this image in detail."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}])
129
- return self.vision_llm.invoke([msg]).content
130
- except Exception as e: return f"Error processing image: {e}"
131
-
132
- @tool
133
- def web_search(self, query: str) -> str:
134
- """Performs a web search using Tavily and returns a compilation of results."""
135
- key = f"web:{query}"
136
- results = self._cached_get(key, lambda: TavilySearchResults(max_results=5).invoke(query))
137
- return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results])
138
-
139
- @tool
140
- def wiki_search(self, query: str) -> str:
141
- """Searches Wikipedia and returns the top 2 results."""
142
- key = f"wiki:{query}"
143
- docs = self._cached_get(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load())
144
- return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs])
145
-
146
- @tool
147
- def arxiv_search(self, query: str) -> str:
148
- """Searches Arxiv for scientific papers and returns the top 2 results."""
149
- key = f"arxiv:{query}"
150
- docs = self._cached_get(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
151
- return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
152
-
153
- def get_tools(self) -> list:
154
- """Returns a list of all tools available to the agent."""
155
- tools_list = [
156
- self.python_repl,
157
- self.describe_image,
158
- self.web_search,
159
- self.wiki_search,
160
- self.arxiv_search,
161
- ]
162
- retriever_tool = create_retriever_tool(
163
- retriever=self.retriever,
164
- name="retrieve_examples",
165
- description="Retrieve solved questions and answers similar to the user's query.",
166
- )
167
- tools_list.append(retriever_tool)
168
- return tools_list
169
 
170
  # ----------------------------------------------------------
171
  # Section 3: System Prompt
172
  # ----------------------------------------------------------
173
-
174
  SYSTEM_PROMPT = (
175
  """You are an expert-level research assistant designed to answer questions accurately.
176
 
177
  **Your Reasoning Process:**
178
- 1. **Think Step-by-Step:** Before answering, break down the user's question into a series of logical steps. Plan which tools you need to use for each step.
179
- 2. **Use Your Tools:** Execute your plan by calling one tool at a time. Analyze the results from the tool to see if you have enough information.
180
- 3. **Iterate if Necessary:** If the first tool call doesn't give you the full answer, continue to use other tools until you are confident you have all the necessary information.
181
- 4. **Synthesize the Final Answer:** Once you have gathered all the information, and only then, formulate a concise final answer.
182
 
183
  **Output Format:**
184
- - Your final response to the user MUST strictly follow this format and nothing else:
185
  `FINAL ANSWER: [Your concise and accurate answer here]`
186
 
187
  **Crucial Instructions:**
188
- - If the tools you have **cannot possibly answer the question** (e.g., the question asks you to listen to an audio file, watch a video, or read a local file you cannot access), you MUST respond by stating the limitation.
189
  - In case of a limitation, your response should be:
190
  `FINAL ANSWER: I am unable to answer this question because it requires a capability I do not possess, such as [describe the missing capability].`
191
-
192
- **Example of handling a limitation:**
193
- - User Question: "Please summarize the attached PDF."
194
- - Your Response: `FINAL ANSWER: I am unable to answer this question because it requires a capability I do not possess, such as reading local PDF files.`
195
  """
196
  )
197
 
198
  # ----------------------------------------------------------
199
  # Section 4: Factory Function for Agent Executor
200
  # ----------------------------------------------------------
201
- def create_agent_executor(provider: str = "google"):
202
- """Factory function to create and compile the LangGraph agent executor."""
203
- my_agent_instance = MyAgent(provider=provider)
204
- tools_list = my_agent_instance.get_tools()
205
- llm_with_tools = my_agent_instance.llm.bind_tools(tools_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
 
207
  def retriever_node(state: MessagesState):
208
- """First node: retrieves examples and prepends them to the message history."""
209
  user_query = state["messages"][-1].content
210
- docs = my_agent_instance.retriever.invoke(user_query)
211
  messages = [SystemMessage(content=SYSTEM_PROMPT)]
212
  if docs:
213
  example_text = "\n\n---\n\n".join(d.page_content for d in docs)
214
- example_msg = AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever")
215
- messages.append(example_msg)
216
  messages.extend(state["messages"])
217
  return {"messages": messages}
218
 
219
  def assistant_node(state: MessagesState):
220
- """Main assistant node: calls the LLM with the current state to decide the next action."""
221
  result = llm_with_tools.invoke(state["messages"])
222
  return {"messages": [result]}
223
 
 
224
  builder = StateGraph(MessagesState)
225
  builder.add_node("retriever", retriever_node)
226
  builder.add_node("assistant", assistant_node)
@@ -236,10 +193,27 @@ def create_agent_executor(provider: str = "google"):
236
  return agent_executor
237
 
238
  # ----------------------------------------------------------
239
- # Section 5: Direct Execution Block for Testing
240
  # ----------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  if __name__ == "__main__":
242
- """direct testing of the agent's logic."""
243
  print("--- Running Agent in Test Mode ---")
244
  agent = create_agent_executor(provider="google")
245
  question = "According to wikipedia, what is the main difference between a lama and an alpaca?"
 
8
  import subprocess
9
  import textwrap
10
  import base64
11
+ import functools # Used to pre-fill arguments for our tool functions
12
  from io import BytesIO
13
  from pathlib import Path
 
14
 
15
  # Third-party libraries
16
  import requests
 
22
  from langchain.tools.retriever import create_retriever_tool
23
  from langchain_community.vectorstores import FAISS
24
  from langchain_community.tools.tavily_search import TavilySearchResults
25
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
26
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
27
+ from langchain_core.tools import Tool, tool # Import Tool for manual tool creation
28
  from langchain_google_genai import ChatGoogleGenerativeAI
29
  from langchain_groq import ChatGroq
30
  from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
 
41
  JSONL_PATH = Path("metadata.jsonl")
42
  FAISS_CACHE = Path("faiss_index.pkl")
43
  EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
44
+ RETRIEVER_K = 5
45
+ CACHE_TTL = 600
 
46
  API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
47
 
48
+ # Global helper for caching API calls
49
+ def cached_get(key: str, fetch_fn):
50
+ if key in API_CACHE: return API_CACHE[key]
51
+ val = fetch_fn()
52
+ API_CACHE[key] = val
53
+ return val
54
+
55
  # ----------------------------------------------------------
56
+ # Section 2: Standalone Tool Functions (No 'self' parameter)
57
  # ----------------------------------------------------------
 
 
 
 
 
58
 
59
+ @tool
60
+ def python_repl(code: str) -> str:
61
+ """Executes a string of Python code and returns the stdout/stderr."""
62
+ code = textwrap.dedent(code).strip()
63
+ try:
64
+ result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False)
65
+ if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```"
66
+ else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
67
+ except subprocess.TimeoutExpired: return "Execution timed out (>10s)."
68
+
69
+ # These functions now accept their dependencies (like an llm instance or a cache function) as arguments.
70
+ @tool
71
+ def describe_image_func(image_source: str, vision_llm_instance) -> str:
72
+ """Describes an image from a local file path or a URL using a provided vision LLM."""
73
+ try:
74
+ if image_source.startswith("http"): img = Image.open(BytesIO(requests.get(image_source, timeout=10).content))
75
+ else: img = Image.open(image_source)
76
+ buffered = BytesIO()
77
+ img.convert("RGB").save(buffered, format="JPEG")
78
+ b64_string = base64.b64encode(buffered.getvalue()).decode()
79
+ msg = HumanMessage(content=[{"type": "text", "text": "Describe this image in detail."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}])
80
+ return vision_llm_instance.invoke([msg]).content
81
+ except Exception as e: return f"Error processing image: {e}"
82
+ @tool
83
+ def web_search_func(query: str, cache_func) -> str:
84
+ """Performs a web search using Tavily and returns a compilation of results."""
85
+ key = f"web:{query}"
86
+ results = cache_func(key, lambda: TavilySearchResults(max_results=5).invoke(query))
87
+ return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results])
88
+ @tool
89
+ def wiki_search_func(query: str, cache_func) -> str:
90
+ """Searches Wikipedia and returns the top 2 results."""
91
+ key = f"wiki:{query}"
92
+ docs = cache_func(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load())
93
+ return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs])
94
+ @tool
95
+ def arxiv_search_func(query: str, cache_func) -> str:
96
+ """Searches Arxiv for scientific papers and returns the top 2 results."""
97
+ key = f"arxiv:{query}"
98
+ docs = cache_func(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
99
+ return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # ----------------------------------------------------------
102
  # Section 3: System Prompt
103
  # ----------------------------------------------------------
 
104
  SYSTEM_PROMPT = (
105
  """You are an expert-level research assistant designed to answer questions accurately.
106
 
107
  **Your Reasoning Process:**
108
+ 1. **Think Step-by-Step:** Break down the user's question into logical steps. Plan which tools you need.
109
+ 2. **Use Your Tools:** Execute your plan by calling one tool at a time. Analyze the results.
110
+ 3. **Iterate:** If needed, use more tools until you have enough information.
111
+ 4. **Synthesize:** Formulate a concise final answer based on the information.
112
 
113
  **Output Format:**
114
+ - Your final response MUST strictly follow this format:
115
  `FINAL ANSWER: [Your concise and accurate answer here]`
116
 
117
  **Crucial Instructions:**
118
+ - If your tools **cannot possibly answer the question** (e.g., it requires watching a video or listening to audio), you MUST respond by stating the limitation.
119
  - In case of a limitation, your response should be:
120
  `FINAL ANSWER: I am unable to answer this question because it requires a capability I do not possess, such as [describe the missing capability].`
 
 
 
 
121
  """
122
  )
123
 
124
  # ----------------------------------------------------------
125
  # Section 4: Factory Function for Agent Executor
126
  # ----------------------------------------------------------
127
+ def create_agent_executor(provider: str = "groq"):
128
+ """
129
+ Factory function to create and compile the LangGraph agent executor.
130
+ This version creates tools from standalone functions to ensure model compatibility.
131
+ """
132
+ print(f"Initializing agent with provider: {provider}")
133
+
134
+ # Step 1: Build LLMs
135
+ if provider == "google": main_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
136
+ elif provider == "groq": main_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
137
+ elif provider == "huggingface": main_llm = ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.1))
138
+ else: raise ValueError("Invalid provider selected")
139
+ vision_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
140
+
141
+ # Step 2: Build Retriever
142
+ embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
143
+ if FAISS_CACHE.exists():
144
+ with open(FAISS_CACHE, "rb") as f: vector_store = pickle.load(f)
145
+ else:
146
+ docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
147
+ vector_store = FAISS.from_documents(docs, embeddings)
148
+ with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
149
+ retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
150
+
151
+ # Step 3: Create the final list of tools
152
+ # We use functools.partial to "bake in" the dependencies (like the LLM or cache) into our standalone functions.
153
+ # This creates new functions with a simpler signature that the agent can easily call.
154
+ tools_list = [
155
+ python_repl,
156
+ Tool(name="describe_image", func=functools.partial(describe_image_func, vision_llm_instance=vision_llm), description="Describes an image from a local file path or a URL."),
157
+ Tool(name="web_search", func=functools.partial(web_search_func, cache_func=cached_get), description="Performs a web search using Tavily."),
158
+ Tool(name="wiki_search", func=functools.partial(wiki_search_func, cache_func=cached_get), description="Searches Wikipedia."),
159
+ Tool(name="arxiv_search", func=functools.partial(arxiv_search_func, cache_func=cached_get), description="Searches Arxiv for scientific papers."),
160
+ create_retriever_tool(retriever=retriever, name="retrieve_examples", description="Retrieve solved questions similar to the user's query."),
161
+ ]
162
+
163
+ llm_with_tools = main_llm.bind_tools(tools_list)
164
 
165
+ # Step 4: Define Graph Nodes
166
  def retriever_node(state: MessagesState):
 
167
  user_query = state["messages"][-1].content
168
+ docs = retriever.invoke(user_query)
169
  messages = [SystemMessage(content=SYSTEM_PROMPT)]
170
  if docs:
171
  example_text = "\n\n---\n\n".join(d.page_content for d in docs)
172
+ messages.append(AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever"))
 
173
  messages.extend(state["messages"])
174
  return {"messages": messages}
175
 
176
  def assistant_node(state: MessagesState):
 
177
  result = llm_with_tools.invoke(state["messages"])
178
  return {"messages": [result]}
179
 
180
+ # Step 5: Build Graph
181
  builder = StateGraph(MessagesState)
182
  builder.add_node("retriever", retriever_node)
183
  builder.add_node("assistant", assistant_node)
 
193
  return agent_executor
194
 
195
  # ----------------------------------------------------------
196
+ # Section 5: Pre-flight check and Direct Execution Block
197
  # ----------------------------------------------------------
198
+ def test_llm_connection(provider: str = "google"):
199
+ """Performs a quick test to see if the LLM provider is accessible."""
200
+ print(f"--- Performing pre-flight check for LLM provider: {provider} ---")
201
+ try:
202
+ if provider == "google": llm, name = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest"), "Google Gemini"
203
+ elif provider == "groq": llm, name = ChatGroq(model_name="llama3-70b-8192"), "Groq"
204
+ elif provider == "huggingface": llm, name = ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1")), "Hugging Face"
205
+ else: return "❌ **LLM Status:** Invalid provider configured."
206
+ llm.invoke("hello")
207
+ success_message = f"✅ **LLM Status:** Connection to {name} is successful."
208
+ print(success_message)
209
+ return success_message
210
+ except Exception as e:
211
+ error_message = f"❌ **LLM Status:** FAILED to connect. Check API keys/credits. Details: {e}"
212
+ print(error_message)
213
+ return error_message
214
+
215
  if __name__ == "__main__":
216
+ """Allows for direct testing of the agent's logic."""
217
  print("--- Running Agent in Test Mode ---")
218
  agent = create_agent_executor(provider="google")
219
  question = "According to wikipedia, what is the main difference between a lama and an alpaca?"