Abbasid commited on
Commit
e5bb694
·
verified ·
1 Parent(s): d490d45

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +204 -215
agent.py CHANGED
@@ -1,253 +1,242 @@
 
 
 
1
  import json
2
  import os
3
  import pickle
4
  import re
 
 
 
5
  from datetime import datetime, timedelta
6
  from io import BytesIO
7
  from pathlib import Path
8
  from typing import List
9
 
 
10
  import requests
11
  from cachetools import TTLCache
 
 
 
12
  from langchain.schema import Document
 
13
  from langchain_community.vectorstores import FAISS
14
- from langchain_huggingface import HuggingFaceEmbeddings
 
 
 
15
  from langchain_google_genai import ChatGoogleGenerativeAI
16
- from langchain_core.messages import SystemMessage, HumanMessage, AIMessageChunk
 
17
  from langgraph.graph import START, StateGraph, MessagesState
18
  from langgraph.prebuilt import ToolNode, tools_condition
19
- from langchain_core.tools import tool
20
- from dotenv import load_dotenv
21
 
 
 
22
  load_dotenv()
23
 
24
  # ----------------------------------------------------------
25
- # 0. Constants
26
  # ----------------------------------------------------------
27
  JSONL_PATH = Path("metadata.jsonl")
28
  FAISS_CACHE = Path("faiss_index.pkl")
29
  EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
30
- RETRIEVER_K = 5
31
- CACHE_TTL = 600
32
- CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
33
-
34
- # ----------------------------------------------------------
35
- # 1. Build / load FAISS retriever
36
- # ----------------------------------------------------------
37
- embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
38
-
39
- if FAISS_CACHE.exists():
40
- with open(FAISS_CACHE, "rb") as f:
41
- vector_store = pickle.load(f)
42
- else:
43
- if not JSONL_PATH.exists():
44
- raise FileNotFoundError("metadata.jsonl not found")
45
- docs = []
46
- with open(JSONL_PATH, "rt", encoding="utf-8") as f:
47
- for line in f:
48
- rec = json.loads(line)
49
- content = f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}"
50
- docs.append(Document(page_content=content, metadata={"source": rec["task_id"]}))
51
- vector_store = FAISS.from_documents(docs, embeddings)
52
- with open(FAISS_CACHE, "wb") as f:
53
- pickle.dump(vector_store, f)
54
-
55
- retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
56
-
57
- # ----------------------------------------------------------
58
- # 2. Caching helper
59
- # ----------------------------------------------------------
60
- def cached_get(key: str, fetch_fn):
61
- if key in CACHE:
62
- return CACHE[key]
63
- val = fetch_fn()
64
- CACHE[key] = val
65
- return val
66
-
67
- # ----------------------------------------------------------
68
- # 3. Tools
69
- # ----------------------------------------------------------
70
- @tool
71
- def python_repl(code: str) -> str:
72
- """Execute Python code and return stdout/stderr."""
73
- import subprocess, textwrap
74
- code = textwrap.dedent(code).strip()
75
- try:
76
- result = subprocess.run(
77
- ["python", "-c", code],
78
- capture_output=True,
79
- text=True,
80
- timeout=5,
81
- )
82
- return result.stdout if not result.stderr else f"STDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
83
- except subprocess.TimeoutExpired:
84
- return "Execution timed out (>5s)."
85
-
86
- @tool
87
- def describe_image(image_source: str) -> str:
88
- """Describe an image from local path or URL with Gemini vision."""
89
- import base64
90
- from PIL import Image
91
-
92
- if image_source.startswith("http"):
93
- img = Image.open(BytesIO(requests.get(image_source, timeout=10).content))
94
- else:
95
- img = Image.open(image_source)
96
-
97
- buffered = BytesIO()
98
- img.convert("RGB").save(buffered, format="JPEG")
99
- b64 = base64.b64encode(buffered.getvalue()).decode()
100
-
101
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
102
- msg = HumanMessage(
103
- content=[
104
- {"type": "text", "text": "Describe this image in detail."},
105
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  ]
107
- )
108
- return llm.invoke([msg]).content
109
-
110
- @tool
111
- def web_search(query: str) -> str:
112
- """Smart web search with 3 keyword variants, cached."""
113
- from langchain_community.tools.tavily_search import TavilySearchResults
114
-
115
- keywords = [query, query.replace(" ", " OR "), f'"{query}"']
116
- seen = set()
117
- results = []
118
- for kw in keywords:
119
- key = f"web:{kw}"
120
- snippets = cached_get(
121
- key,
122
- lambda: TavilySearchResults(max_results=3, include_raw_content=True).invoke(kw),
123
  )
124
- for s in snippets:
125
- if s["url"] not in seen:
126
- seen.add(s["url"])
127
- results.append(s["content"][:2000])
128
- if len(results) >= 5:
129
- break
130
- return "\n\n---\n\n".join(results)
131
-
132
- @tool
133
- def wiki_search(query: str) -> str:
134
- from langchain_community.document_loaders import WikipediaLoader
135
- key = f"wiki:{query}"
136
- docs = cached_get(
137
- key,
138
- lambda: WikipediaLoader(query=query, load_max_docs=2).load(),
139
- )
140
- return "\n\n---\n\n".join(
141
- f'<Document source="{d.metadata.get("source", "")}">\n{d.page_content}\n</Document>'
142
- for d in docs
143
- )
144
-
145
- @tool
146
- def arxiv_search(query: str) -> str:
147
- from langchain_community.document_loaders import ArxivLoader
148
- key = f"arxiv:{query}"
149
- docs = cached_get(
150
- key,
151
- lambda: ArxivLoader(query=query, load_max_docs=2).load(),
152
- )
153
- return "\n\n---\n\n".join(
154
- f'<Document source="{d.metadata.get("source", "")}">\n{d.page_content[:2000]}...\n</Document>'
155
- for d in docs
156
- )
157
 
158
  # ----------------------------------------------------------
159
- # 4. System prompt
160
  # ----------------------------------------------------------
161
  SYSTEM_PROMPT = (
162
- """You are a helpful assistant tasked with answering questions using a set of tools.
163
-
164
- Your final answer must strictly follow this format:
165
- FINAL ANSWER: [ANSWER]
166
-
167
- Only write the answer in that exact format. Do not explain anything. Do not include any other text.
168
 
169
- If you are provided with a similar question and its final answer, and the current question is **exactly the same**, then simply return the same final answer without using any tools.
 
 
 
 
 
 
170
 
171
- Only use tools if the current question is different from the similar one.
172
-
173
- Examples:
174
- - FINAL ANSWER: FunkMonk
175
- - FINAL ANSWER: Paris
176
- - FINAL ANSWER: 128
177
- If you do not follow this format exactly, your response will be considered incorrect.
178
  """
179
  )
180
 
181
  # ----------------------------------------------------------
182
- # 5. Manual LangGraph construction
183
- # ----------------------------------------------------------
184
- tools_list = [python_repl, describe_image, web_search, wiki_search, arxiv_search]
185
-
186
- # retriever tool
187
- from langchain.tools.retriever import create_retriever_tool
188
- tools_list.append(
189
- create_retriever_tool(
190
- retriever=retriever,
191
- name="retrieve_examples",
192
- description="Retrieve up to 5 solved questions similar to the user query.",
193
- )
194
- )
195
-
196
- # ----------------------------------------------------------
197
- # provider switcher
198
- # ----------------------------------------------------------
199
- def build_llm(provider: str = "groq"):
200
- if provider == "google":
201
- return ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
202
- elif provider == "groq":
203
- return ChatGroq(model="llama-3.3-70b-versatile", temperature=0)
204
- elif provider == "huggingface":
205
- return ChatHuggingFace(
206
- llm=HuggingFaceEndpoint(
207
- repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
208
- temperature=0,
209
- )
210
- )
211
- else:
212
- raise ValueError("provider must be 'google', 'groq', or 'huggingface'")
213
-
214
- llm = build_llm("google") # or "groq", "huggingface"
215
- llm_with_tools = llm.bind_tools(tools_list)
216
-
217
- def assistant(state: MessagesState):
218
- """LLM node that can call tools."""
219
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
220
-
221
- def retriever_node(state: MessagesState):
222
- """First node: fetch examples and prepend them."""
223
- user_query = state["messages"][-1].content
224
- docs = retriever.invoke(user_query)
225
- if docs:
226
- example_text = "\n\n---\n\n".join(d.page_content for d in docs)
227
- example_msg = HumanMessage(
228
- content=f"Here are {len(docs)} similar solved examples:\n\n{example_text}"
229
- )
230
- return {"messages": [SYSTEM_PROMPT] + state["messages"] + [example_msg]}
231
- return {"messages": [SYSTEM_PROMPT] + state["messages"]}
232
-
233
- builder = StateGraph(MessagesState)
234
- builder.add_node("retriever", retriever_node)
235
- builder.add_node("assistant", assistant)
236
- builder.add_node("tools", ToolNode(tools_list))
237
- builder.add_edge(START, "retriever")
238
- builder.add_edge("retriever", "assistant")
239
- builder.add_conditional_edges("assistant", tools_condition)
240
- builder.add_edge("tools", "assistant")
241
-
242
- agent = builder.compile()
243
-
244
- # ----------------------------------------------------------
245
- # 6. Quick streaming test
246
  # ----------------------------------------------------------
247
  if __name__ == "__main__":
248
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
249
- print("Agent thinking ")
250
- for chunk in agent.stream({"messages": [("user", question)]}, stream_mode="values"):
251
- last = chunk["messages"][-1]
252
- if hasattr(last, "content"):
253
- print(last.content, end="", flush=True)
 
 
 
 
 
 
1
+ # ----------------------------------------------------------
2
+ # Section 0: Imports
3
+ # ----------------------------------------------------------
4
  import json
5
  import os
6
  import pickle
7
  import re
8
+ import subprocess
9
+ import textwrap
10
+ import base64
11
  from datetime import datetime, timedelta
12
  from io import BytesIO
13
  from pathlib import Path
14
  from typing import List
15
 
16
+ # Third-party libraries
17
  import requests
18
  from cachetools import TTLCache
19
+ from PIL import Image
20
+
21
+ # LangChain and associated libraries
22
  from langchain.schema import Document
23
+ from langchain.tools.retriever import create_retriever_tool
24
  from langchain_community.vectorstores import FAISS
25
+ from langchain_community.tools.tavily_search import TavilySearchResults
26
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader # Added loaders
27
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
28
+ from langchain_core.tools import tool
29
  from langchain_google_genai import ChatGoogleGenerativeAI
30
+ from langchain_groq import ChatGroq
31
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
32
  from langgraph.graph import START, StateGraph, MessagesState
33
  from langgraph.prebuilt import ToolNode, tools_condition
 
 
34
 
35
+ # Environment variable loading
36
+ from dotenv import load_dotenv
37
  load_dotenv()
38
 
39
  # ----------------------------------------------------------
40
+ # Section 1: Constants and Configuration
41
  # ----------------------------------------------------------
42
  JSONL_PATH = Path("metadata.jsonl")
43
  FAISS_CACHE = Path("faiss_index.pkl")
44
  EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
45
+ RETRIEVER_K = 5 # Number of similar documents to retrieve
46
+ CACHE_TTL = 600 # Cache API calls for 10 minutes
47
+ # Global cache object for API calls
48
+ API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
49
+
50
+ # ----------------------------------------------------------
51
+ # Section 2: The Agent Class
52
+ # ----------------------------------------------------------
53
+ class MyAgent:
54
+ """
55
+ Encapsulates the agent's state, including LLMs, retriever, and tools.
56
+ This class-based approach ensures clean management of dependencies.
57
+ """
58
+
59
+ def __init__(self, provider: str = "google"):
60
+ """
61
+ Initializes the agent, setting up LLMs and the FAISS retriever.
62
+ Args:
63
+ provider (str): The LLM provider to use ('google', 'groq', 'huggingface').
64
+ """
65
+ print(f"Initializing agent with provider: {provider}")
66
+
67
+ self.llm = self._build_llm(provider)
68
+ self.vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)
69
+ self.retriever = self._get_retriever()
70
+
71
+ def _get_retriever(self):
72
+ """Builds or loads the FAISS retriever from a local cache."""
73
+ embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
74
+
75
+ if FAISS_CACHE.exists():
76
+ print(f"Loading FAISS index from cache: {FAISS_CACHE}")
77
+ with open(FAISS_CACHE, "rb") as f:
78
+ vector_store = pickle.load(f)
79
+ else:
80
+ print("FAISS cache not found. Building new index from metadata.jsonl...")
81
+ if not JSONL_PATH.exists():
82
+ raise FileNotFoundError(f"{JSONL_PATH} not found. Cannot build vector store.")
83
+ docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
84
+ if not docs: raise ValueError("No documents found in metadata.jsonl.")
85
+ vector_store = FAISS.from_documents(docs, embeddings)
86
+ with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
87
+ print(f"FAISS index built and saved to cache: {FAISS_CACHE}")
88
+
89
+ return vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
90
+
91
+ def _build_llm(self, provider: str):
92
+ """Helper to build the main text-based LLM based on the chosen provider."""
93
+ if provider == "google": return ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
94
+ elif provider == "groq": return ChatGroq(model_name="llama3-70b-8192", temperature=0)
95
+ elif provider == "huggingface": return ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="Qwen/Qwen2.5-Coder-32B-Instruct", temperature=0))
96
+ else: raise ValueError("Provider must be 'google', 'groq', or 'huggingface'")
97
+
98
+ def _cached_get(self, key: str, fetch_fn):
99
+ """Helper for caching API calls."""
100
+ if key in API_CACHE: return API_CACHE[key]
101
+ val = fetch_fn()
102
+ API_CACHE[key] = val
103
+ return val
104
+
105
+ # --- Tool Definitions as Class Methods ---
106
+
107
+ @tool
108
+ def python_repl(self, code: str) -> str:
109
+ """Executes a string of Python code and returns the stdout/stderr."""
110
+ code = textwrap.dedent(code).strip()
111
+ try:
112
+ result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False)
113
+ if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```"
114
+ else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
115
+ except subprocess.TimeoutExpired: return "Execution timed out (>10s)."
116
+
117
+ @tool
118
+ def describe_image(self, image_source: str) -> str:
119
+ """Describes an image from a local file path or a URL using Gemini vision."""
120
+ try:
121
+ if image_source.startswith("http"):
122
+ img = Image.open(BytesIO(requests.get(image_source, timeout=10).content))
123
+ else:
124
+ img = Image.open(image_source)
125
+ buffered = BytesIO()
126
+ img.convert("RGB").save(buffered, format="JPEG")
127
+ b64_string = base64.b64encode(buffered.getvalue()).decode()
128
+ msg = HumanMessage(content=[{"type": "text", "text": "Describe this image in detail."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}])
129
+ return self.vision_llm.invoke([msg]).content
130
+ except Exception as e: return f"Error processing image: {e}"
131
+
132
+ @tool
133
+ def web_search(self, query: str) -> str:
134
+ """Performs a web search using Tavily and returns a compilation of results."""
135
+ key = f"web:{query}"
136
+ results = self._cached_get(key, lambda: TavilySearchResults(max_results=5).invoke(query))
137
+ return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results])
138
+
139
+ @tool
140
+ def wiki_search(self, query: str) -> str:
141
+ """Searches Wikipedia and returns the top 2 results."""
142
+ key = f"wiki:{query}"
143
+ docs = self._cached_get(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load())
144
+ return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs])
145
+
146
+ @tool
147
+ def arxiv_search(self, query: str) -> str:
148
+ """Searches Arxiv for scientific papers and returns the top 2 results."""
149
+ key = f"arxiv:{query}"
150
+ docs = self._cached_get(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
151
+ return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
152
+
153
+ def get_tools(self) -> list:
154
+ """Returns a list of all tools available to the agent."""
155
+ tools_list = [
156
+ self.python_repl,
157
+ self.describe_image,
158
+ self.web_search,
159
+ self.wiki_search,
160
+ self.arxiv_search,
161
  ]
162
+ retriever_tool = create_retriever_tool(
163
+ retriever=self.retriever,
164
+ name="retrieve_examples",
165
+ description="Retrieve solved questions and answers similar to the user's query.",
 
 
 
 
 
 
 
 
 
 
 
 
166
  )
167
+ tools_list.append(retriever_tool)
168
+ return tools_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  # ----------------------------------------------------------
171
+ # Section 3: System Prompt
172
  # ----------------------------------------------------------
173
  SYSTEM_PROMPT = (
174
+ """You are a helpful and expert assistant designed to answer questions accurately and concisely.
 
 
 
 
 
175
 
176
+ **Instructions:**
177
+ 1. **Analyze the Question:** Carefully understand what is being asked.
178
+ 2. **Use Tools:** You have a set of tools to find information. Use them logically.
179
+ 3. **Synthesize the Answer:** Based on the information from the tools, formulate your final answer.
180
+ 4. **Format the Output:** Your final response MUST be in the following format and nothing else:
181
+
182
+ FINAL ANSWER: [Your concise and accurate answer here]
183
 
184
+ If the `retrieve_examples` tool provides an answer to an identical question, use that answer. Otherwise, use your tools to find the correct answer for the current question.
 
 
 
 
 
 
185
  """
186
  )
187
 
188
  # ----------------------------------------------------------
189
+ # Section 4: Factory Function for Agent Executor
190
+ # ----------------------------------------------------------
191
+ def create_agent_executor(provider: str = "google"):
192
+ """Factory function to create and compile the LangGraph agent executor."""
193
+ my_agent_instance = MyAgent(provider=provider)
194
+ tools_list = my_agent_instance.get_tools()
195
+ llm_with_tools = my_agent_instance.llm.bind_tools(tools_list)
196
+
197
+ def retriever_node(state: MessagesState):
198
+ """First node: retrieves examples and prepends them to the message history."""
199
+ user_query = state["messages"][-1].content
200
+ docs = my_agent_instance.retriever.invoke(user_query)
201
+ messages = [SystemMessage(content=SYSTEM_PROMPT)]
202
+ if docs:
203
+ example_text = "\n\n---\n\n".join(d.page_content for d in docs)
204
+ example_msg = AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever")
205
+ messages.append(example_msg)
206
+ messages.extend(state["messages"])
207
+ return {"messages": messages}
208
+
209
+ def assistant_node(state: MessagesState):
210
+ """Main assistant node: calls the LLM with the current state to decide the next action."""
211
+ result = llm_with_tools.invoke(state["messages"])
212
+ return {"messages": [result]}
213
+
214
+ builder = StateGraph(MessagesState)
215
+ builder.add_node("retriever", retriever_node)
216
+ builder.add_node("assistant", assistant_node)
217
+ builder.add_node("tools", ToolNode(tools_list))
218
+
219
+ builder.add_edge(START, "retriever")
220
+ builder.add_edge("retriever", "assistant")
221
+ builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"})
222
+ builder.add_edge("tools", "assistant")
223
+
224
+ agent_executor = builder.compile()
225
+ print("Agent Executor created successfully.")
226
+ return agent_executor
227
+
228
+ # ----------------------------------------------------------
229
+ # Section 5: Direct Execution Block for Testing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  # ----------------------------------------------------------
231
  if __name__ == "__main__":
232
+ """direct testing of the agent's logic."""
233
+ print("--- Running Agent in Test Mode ---")
234
+ agent = create_agent_executor(provider="google")
235
+ question = "According to wikipedia, what is the main difference between a lama and an alpaca?"
236
+ print(f"\nTest Question: {question}\n\n--- Agent Thinking... ---\n")
237
+
238
+ for chunk in agent.stream({"messages": [("user", question)]}):
239
+ for key, value in chunk.items():
240
+ if value['messages']:
241
+ message = value['messages'][-1]
242
+ if message.content: print(f"--- Node: {key} ---\n{message.content}\n")