Abbasid commited on
Commit
6a1d4f7
·
verified ·
1 Parent(s): 73d4bc0

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +239 -0
agent.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bare-bones improved GAIA agent – manual LangGraph, no DB.
3
+ Includes: vision, code-REPL, smarter search, caching, streaming.
4
+ """
5
+ import json
6
+ import os
7
+ import pickle
8
+ import re
9
+ from datetime import datetime, timedelta
10
+ from io import BytesIO
11
+ from pathlib import Path
12
+ from typing import List
13
+
14
+ import requests
15
+ from cachetools import TTLCache
16
+ from langchain.schema import Document
17
+ from langchain_community.vectorstores import FAISS
18
+ from langchain_huggingface import HuggingFaceEmbeddings
19
+ from langchain_google_genai import ChatGoogleGenerativeAI
20
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessageChunk
21
+ from langgraph.graph import START, StateGraph, MessagesState
22
+ from langgraph.prebuilt import ToolNode, tools_condition
23
+ from langchain_core.tools import tool
24
+ from dotenv import load_dotenv
25
+
26
+ load_dotenv()
27
+
28
+ # ----------------------------------------------------------
29
+ # 0. Constants
30
+ # ----------------------------------------------------------
31
+ JSONL_PATH = Path("metadata.jsonl")
32
+ FAISS_CACHE = Path("faiss_index.pkl")
33
+ EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
34
+ RETRIEVER_K = 5
35
+ CACHE_TTL = 600
36
+ CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
37
+
38
+ # ----------------------------------------------------------
39
+ # 1. Build / load FAISS retriever
40
+ # ----------------------------------------------------------
41
+ embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
42
+
43
+ if FAISS_CACHE.exists():
44
+ with open(FAISS_CACHE, "rb") as f:
45
+ vector_store = pickle.load(f)
46
+ else:
47
+ if not JSONL_PATH.exists():
48
+ raise FileNotFoundError("metadata.jsonl not found")
49
+ docs = []
50
+ with open(JSONL_PATH, "rt", encoding="utf-8") as f:
51
+ for line in f:
52
+ rec = json.loads(line)
53
+ content = f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}"
54
+ docs.append(Document(page_content=content, metadata={"source": rec["task_id"]}))
55
+ vector_store = FAISS.from_documents(docs, embeddings)
56
+ with open(FAISS_CACHE, "wb") as f:
57
+ pickle.dump(vector_store, f)
58
+
59
+ retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
60
+
61
+ # ----------------------------------------------------------
62
+ # 2. Caching helper
63
+ # ----------------------------------------------------------
64
+ def cached_get(key: str, fetch_fn):
65
+ if key in CACHE:
66
+ return CACHE[key]
67
+ val = fetch_fn()
68
+ CACHE[key] = val
69
+ return val
70
+
71
+ # ----------------------------------------------------------
72
+ # 3. Tools
73
+ # ----------------------------------------------------------
74
+ @tool
75
+ def python_repl(code: str) -> str:
76
+ """Execute Python code and return stdout/stderr."""
77
+ import subprocess, textwrap
78
+ code = textwrap.dedent(code).strip()
79
+ try:
80
+ result = subprocess.run(
81
+ ["python", "-c", code],
82
+ capture_output=True,
83
+ text=True,
84
+ timeout=5,
85
+ )
86
+ return result.stdout if not result.stderr else f"STDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
87
+ except subprocess.TimeoutExpired:
88
+ return "Execution timed out (>5s)."
89
+
90
+ @tool
91
+ def describe_image(image_source: str) -> str:
92
+ """Describe an image from local path or URL with Gemini vision."""
93
+ import base64
94
+ from PIL import Image
95
+
96
+ if image_source.startswith("http"):
97
+ img = Image.open(BytesIO(requests.get(image_source, timeout=10).content))
98
+ else:
99
+ img = Image.open(image_source)
100
+
101
+ buffered = BytesIO()
102
+ img.convert("RGB").save(buffered, format="JPEG")
103
+ b64 = base64.b64encode(buffered.getvalue()).decode()
104
+
105
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
106
+ msg = HumanMessage(
107
+ content=[
108
+ {"type": "text", "text": "Describe this image in detail."},
109
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}},
110
+ ]
111
+ )
112
+ return llm.invoke([msg]).content
113
+
114
+ @tool
115
+ def web_search(query: str) -> str:
116
+ """Smart web search with 3 keyword variants, cached."""
117
+ from langchain_community.tools.tavily_search import TavilySearchResults
118
+
119
+ keywords = [query, query.replace(" ", " OR "), f'"{query}"']
120
+ seen = set()
121
+ results = []
122
+ for kw in keywords:
123
+ key = f"web:{kw}"
124
+ snippets = cached_get(
125
+ key,
126
+ lambda: TavilySearchResults(max_results=3, include_raw_content=True).invoke(kw),
127
+ )
128
+ for s in snippets:
129
+ if s["url"] not in seen:
130
+ seen.add(s["url"])
131
+ results.append(s["content"][:2000])
132
+ if len(results) >= 5:
133
+ break
134
+ return "\n\n---\n\n".join(results)
135
+
136
+ @tool
137
+ def wiki_search(query: str) -> str:
138
+ from langchain_community.document_loaders import WikipediaLoader
139
+ key = f"wiki:{query}"
140
+ docs = cached_get(
141
+ key,
142
+ lambda: WikipediaLoader(query=query, load_max_docs=2).load(),
143
+ )
144
+ return "\n\n---\n\n".join(
145
+ f'<Document source="{d.metadata.get("source", "")}">\n{d.page_content}\n</Document>'
146
+ for d in docs
147
+ )
148
+
149
+ @tool
150
+ def arxiv_search(query: str) -> str:
151
+ from langchain_community.document_loaders import ArxivLoader
152
+ key = f"arxiv:{query}"
153
+ docs = cached_get(
154
+ key,
155
+ lambda: ArxivLoader(query=query, load_max_docs=2).load(),
156
+ )
157
+ return "\n\n---\n\n".join(
158
+ f'<Document source="{d.metadata.get("source", "")}">\n{d.page_content[:2000]}...\n</Document>'
159
+ for d in docs
160
+ )
161
+
162
+ # ----------------------------------------------------------
163
+ # 4. System prompt
164
+ # ----------------------------------------------------------
165
+ SYSTEM_PROMPT = (
166
+ "You are a helpful assistant tasked with answering questions using a set of tools.
167
+
168
+ Your final answer must strictly follow this format:
169
+ FINAL ANSWER: [ANSWER]
170
+
171
+ Only write the answer in that exact format. Do not explain anything. Do not include any other text.
172
+
173
+ If you are provided with a similar question and its final answer, and the current question is **exactly the same**, then simply return the same final answer without using any tools.
174
+
175
+ Only use tools if the current question is different from the similar one".
176
+
177
+ Examples:
178
+ "- FINAL ANSWER: FunkMonk"
179
+ "- FINAL ANSWER: Paris""
180
+ "- FINAL ANSWER: 128"
181
+
182
+ "If you do not follow this format exactly, your response will be considered incorrect".
183
+ )
184
+
185
+ # ----------------------------------------------------------
186
+ # 5. Manual LangGraph construction
187
+ # ----------------------------------------------------------
188
+ tools_list = [python_repl, describe_image, web_search, wiki_search, arxiv_search]
189
+
190
+ # retriever tool
191
+ from langchain.tools.retriever import create_retriever_tool
192
+ tools_list.append(
193
+ create_retriever_tool(
194
+ retriever=retriever,
195
+ name="retrieve_examples",
196
+ description="Retrieve up to 5 solved questions similar to the user query.",
197
+ )
198
+ )
199
+
200
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
201
+ llm_with_tools = llm.bind_tools(tools_list)
202
+
203
+ def assistant(state: MessagesState):
204
+ """LLM node that can call tools."""
205
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
206
+
207
+ def retriever_node(state: MessagesState):
208
+ """First node: fetch examples and prepend them."""
209
+ user_query = state["messages"][-1].content
210
+ docs = retriever.invoke(user_query)
211
+ if docs:
212
+ example_text = "\n\n---\n\n".join(d.page_content for d in docs)
213
+ example_msg = HumanMessage(
214
+ content=f"Here are {len(docs)} similar solved examples:\n\n{example_text}"
215
+ )
216
+ return {"messages": [SYSTEM_PROMPT] + state["messages"] + [example_msg]}
217
+ return {"messages": [SYSTEM_PROMPT] + state["messages"]}
218
+
219
+ builder = StateGraph(MessagesState)
220
+ builder.add_node("retriever", retriever_node)
221
+ builder.add_node("assistant", assistant)
222
+ builder.add_node("tools", ToolNode(tools_list))
223
+ builder.add_edge(START, "retriever")
224
+ builder.add_edge("retriever", "assistant")
225
+ builder.add_conditional_edges("assistant", tools_condition)
226
+ builder.add_edge("tools", "assistant")
227
+
228
+ agent = builder.compile()
229
+
230
+ # ----------------------------------------------------------
231
+ # 6. Quick streaming test
232
+ # ----------------------------------------------------------
233
+ if __name__ == "__main__":
234
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
235
+ print("Agent thinking …")
236
+ for chunk in agent.stream({"messages": [("user", question)]}, stream_mode="values"):
237
+ last = chunk["messages"][-1]
238
+ if hasattr(last, "content"):
239
+ print(last.content, end="", flush=True)