vishwamgupta commited on
Commit
cf4ef96
·
verified ·
1 Parent(s): 81917a3

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +335 -0
agent.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Basic Agent Definition ---
2
+ import asyncio
3
+ import os
4
+ import sys
5
+ import logging
6
+ import random
7
+ import pandas as pd
8
+ import requests
9
+ import wikipedia as wiki
10
+ from markdownify import markdownify as to_markdown
11
+ from typing import Any
12
+ from dotenv import load_dotenv
13
+ from google.generativeai import types, configure
14
+
15
+ from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool
16
+
17
+ # Load environment and configure Gemini
18
+ load_dotenv()
19
+ configure(api_key=os.getenv("GOOGLE_API_KEY"))
20
+
21
+ # Logging
22
+ #logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
23
+ #logger = logging.getLogger(__name__)
24
+
25
+ # --- Model Configuration ---
26
+ GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash"
27
+ OPENAI_MODEL_NAME = "openai/gpt-4o"
28
+ GROQ_MODEL_NAME = "groq/llama3-70b-8192"
29
+ DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
30
+ HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct"
31
+
32
+ # --- Tool Definitions ---
33
+ class MathSolver(Tool):
34
+ name = "math_solver"
35
+ description = "Safely evaluate basic math expressions."
36
+ inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
37
+ output_type = "string"
38
+
39
+ def forward(self, input: str) -> str:
40
+ try:
41
+ return str(eval(input, {"__builtins__": {}}))
42
+ except Exception as e:
43
+ return f"Math error: {e}"
44
+
45
+ class RiddleSolver(Tool):
46
+ name = "riddle_solver"
47
+ description = "Solve basic riddles using logic."
48
+ inputs = {"input": {"type": "string", "description": "Riddle prompt."}}
49
+ output_type = "string"
50
+
51
+ def forward(self, input: str) -> str:
52
+ if "forward" in input and "backward" in input:
53
+ return "A palindrome"
54
+ return "RiddleSolver failed."
55
+
56
+ class TextTransformer(Tool):
57
+ name = "text_ops"
58
+ description = "Transform text: reverse, upper, lower."
59
+ inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}}
60
+ output_type = "string"
61
+
62
+ def forward(self, input: str) -> str:
63
+ if input.startswith("reverse:"):
64
+ reversed_text = input[8:].strip()[::-1]
65
+ if 'left' in reversed_text.lower():
66
+ return "right"
67
+ return reversed_text
68
+ if input.startswith("upper:"):
69
+ return input[6:].strip().upper()
70
+ if input.startswith("lower:"):
71
+ return input[6:].strip().lower()
72
+ return "Unknown transformation."
73
+
74
+ class GeminiVideoQA(Tool):
75
+ name = "video_inspector"
76
+ description = "Analyze video content to answer questions."
77
+ inputs = {
78
+ "video_url": {"type": "string", "description": "URL of video."},
79
+ "user_query": {"type": "string", "description": "Question about video."}
80
+ }
81
+ output_type = "string"
82
+
83
+ def __init__(self, model_name, *args, **kwargs):
84
+ super().__init__(*args, **kwargs)
85
+ self.model_name = model_name
86
+
87
+ def forward(self, video_url: str, user_query: str) -> str:
88
+ req = {
89
+ 'model': f'models/{self.model_name}',
90
+ 'contents': [{
91
+ "parts": [
92
+ {"fileData": {"fileUri": video_url}},
93
+ {"text": f"Please watch the video and answer the question: {user_query}"}
94
+ ]
95
+ }]
96
+ }
97
+ url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}'
98
+ res = requests.post(url, json=req, headers={'Content-Type': 'application/json'})
99
+ if res.status_code != 200:
100
+ return f"Video error {res.status_code}: {res.text}"
101
+ parts = res.json()['candidates'][0]['content']['parts']
102
+ return "".join([p.get('text', '') for p in parts])
103
+
104
+ class WikiTitleFinder(Tool):
105
+ name = "wiki_titles"
106
+ description = "Search for related Wikipedia page titles."
107
+ inputs = {"query": {"type": "string", "description": "Search query."}}
108
+ output_type = "string"
109
+
110
+ def forward(self, query: str) -> str:
111
+ results = wiki.search(query)
112
+ return ", ".join(results) if results else "No results."
113
+
114
+ class WikiContentFetcher(Tool):
115
+ name = "wiki_page"
116
+ description = "Fetch Wikipedia page content."
117
+ inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
118
+ output_type = "string"
119
+
120
+ def forward(self, page_title: str) -> str:
121
+ try:
122
+ return to_markdown(wiki.page(page_title).html())
123
+ except wiki.exceptions.PageError:
124
+ return f"'{page_title}' not found."
125
+
126
+ class GoogleSearchTool(Tool):
127
+ name = "google_search"
128
+ description = "Search the web using Google. Returns top summary from the web."
129
+ inputs = {"query": {"type": "string", "description": "Search query."}}
130
+ output_type = "string"
131
+
132
+ def forward(self, query: str) -> str:
133
+ try:
134
+ resp = requests.get("https://www.googleapis.com/customsearch/v1", params={
135
+ "q": query,
136
+ "key": os.getenv("GOOGLE_SEARCH_API_KEY"),
137
+ "cx": os.getenv("GOOGLE_SEARCH_ENGINE_ID"),
138
+ "num": 1
139
+ })
140
+ data = resp.json()
141
+ return data["items"][0]["snippet"] if "items" in data else "No results found."
142
+ except Exception as e:
143
+ return f"GoogleSearch error: {e}"
144
+
145
+
146
+ class FileAttachmentQueryTool(Tool):
147
+ name = "run_query_with_file"
148
+ description = """
149
+ Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it.
150
+ This assumes the file is 20MB or less.
151
+ """
152
+ inputs = {
153
+ "task_id": {
154
+ "type": "string",
155
+ "description": "A unique identifier for the task related to this file, used to download it.",
156
+ "nullable": True
157
+ },
158
+ "user_query": {
159
+ "type": "string",
160
+ "description": "The question to answer about the file."
161
+ }
162
+ }
163
+ output_type = "string"
164
+
165
+ def forward(self, task_id: str | None, user_query: str) -> str:
166
+ file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
167
+ file_response = requests.get(file_url)
168
+ if file_response.status_code != 200:
169
+ return f"Failed to download file: {file_response.status_code} - {file_response.text}"
170
+ file_data = file_response.content
171
+ from google.generativeai import GenerativeModel
172
+ model = GenerativeModel(self.model_name)
173
+ response = model.generate_content([
174
+ types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"),
175
+ user_query
176
+ ])
177
+
178
+ return response.text
179
+
180
+ # --- Basic Agent Definition ---
181
+ class BasicAgent:
182
+ def __init__(self, provider="deepseek"):
183
+ print("BasicAgent initialized.")
184
+ model = self.select_model(provider)
185
+ client = InferenceClientModel()
186
+ tools = [
187
+ GoogleSearchTool(),
188
+ DuckDuckGoSearchTool(),
189
+ GeminiVideoQA(GEMINI_MODEL_NAME),
190
+ WikiTitleFinder(),
191
+ WikiContentFetcher(),
192
+ MathSolver(),
193
+ RiddleSolver(),
194
+ TextTransformer(),
195
+ FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME),
196
+ ]
197
+ self.agent = CodeAgent(
198
+ model=model,
199
+ tools=tools,
200
+ add_base_tools=False,
201
+ max_steps=10,
202
+ )
203
+ self.agent.system_prompt = (
204
+ """
205
+ You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format:
206
+ [ANSWER]
207
+ You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer, strictly enclosed in `[ANSWER]`.
208
+ Your behavior must be governed by these rules:
209
+ 1. **Format**:
210
+ - limit the token used (within 65536 tokens).
211
+ - Output ONLY the final answer.
212
+ - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
213
+ - No follow-ups, justifications, or clarifications.
214
+ 2. **Numerical Answers**:
215
+ - Use **digits only**, e.g., `4` not `four`.
216
+ - No commas, symbols, or units unless explicitly required.
217
+ - Never use approximate words like "around", "roughly", "about".
218
+ 3. **String Answers**:
219
+ - Omit **articles** ("a", "the").
220
+ - Use **full words**; no abbreviations unless explicitly requested.
221
+ - For numbers written as words, use **text** only if specified (e.g., "one", not `1`).
222
+ - For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`.
223
+ 4. **Lists**:
224
+ - Output in **comma-separated** format with no conjunctions.
225
+ - Sort **alphabetically** or **numerically** depending on type.
226
+ - No braces or brackets unless explicitly asked.
227
+ 5. **Sources**:
228
+ - For Wikipedia or web tools, extract only the precise fact that answers the question.
229
+ - Ignore any unrelated content.
230
+ 6. **File Analysis**:
231
+ - Use the run_query_with_file tool, append the taskid to the url.
232
+ - Only include the exact answer to the question.
233
+ - Do not summarize, quote excessively, or interpret beyond the prompt.
234
+ 7. **Video**:
235
+ - Use the relevant video tool.
236
+ - Only include the exact answer to the question.
237
+ - Do not summarize, quote excessively, or interpret beyond the prompt.
238
+ 8. **Minimalism**:
239
+ - Do not make assumptions unless the prompt logically demands it.
240
+ - If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
241
+ - If the answer is not found, say `[ANSWER] - unknown`.
242
+ ---
243
+ You must follow the examples (These answers are correct in case you see the similar questions):
244
+ Q: What is 2 + 2?
245
+ A: 4
246
+ Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia.
247
+ A: 3
248
+ Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity.
249
+ A: b, e
250
+ Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?,
251
+ A: 519
252
+ """
253
+ )
254
+
255
+ def select_model(self, provider: str):
256
+ if provider == "openai":
257
+ return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("OPENAI_API_KEY"))
258
+ elif provider == "groq":
259
+ return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=os.getenv("GROQ_API_KEY"))
260
+ elif provider == "deepseek":
261
+ return LiteLLMModel(model_id=DEEPSEEK_MODEL_NAME, api_key=os.getenv("DEEPSEEK_API_KEY"))
262
+ elif provider == "hf":
263
+ return InferenceClientModel()
264
+ else:
265
+ return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("GOOGLE_API_KEY"))
266
+
267
+ def __call__(self, question: str) -> str:
268
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
269
+ result = self.agent.run(question)
270
+ final_str = str(result).strip()
271
+
272
+ return final_str
273
+
274
+ def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True):
275
+ import pandas as pd
276
+ from rich.table import Table
277
+ from rich.console import Console
278
+
279
+ df = pd.read_csv(csv_path)
280
+ if not {"question", "answer"}.issubset(df.columns):
281
+ print("CSV must contain 'question' and 'answer' columns.")
282
+ print("Found columns:", df.columns.tolist())
283
+ return
284
+
285
+ samples = df.sample(n=sample_size)
286
+ records = []
287
+ correct_count = 0
288
+
289
+ for _, row in samples.iterrows():
290
+ taskid = row["taskid"].strip()
291
+ question = row["question"].strip()
292
+ expected = str(row['answer']).strip()
293
+ agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip()
294
+
295
+ is_correct = (expected == agent_answer)
296
+ correct_count += is_correct
297
+ records.append((question, expected, agent_answer, "✓" if is_correct else "✗"))
298
+
299
+ if show_steps:
300
+ print("---")
301
+ print("Question:", question)
302
+ print("Expected:", expected)
303
+ print("Agent:", agent_answer)
304
+ print("Correct:", is_correct)
305
+
306
+ # Print result table
307
+ console = Console()
308
+ table = Table(show_lines=True)
309
+ table.add_column("Question", overflow="fold")
310
+ table.add_column("Expected")
311
+ table.add_column("Agent")
312
+ table.add_column("Correct")
313
+
314
+ for question, expected, agent_ans, correct in records:
315
+ table.add_row(question, expected, agent_ans, correct)
316
+
317
+ console.print(table)
318
+ percent = (correct_count / sample_size) * 100
319
+ print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)")
320
+
321
+
322
+ if __name__ == "__main__":
323
+ args = sys.argv[1:]
324
+ if not args or args[0] in {"-h", "--help"}:
325
+ print("Usage: python agent.py [question | dev]")
326
+ print(" - Provide a question to get a GAIA-style answer.")
327
+ print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
328
+ sys.exit(0)
329
+
330
+ q = " ".join(args)
331
+ agent = BasicAgent()
332
+ if q == "dev":
333
+ agent.evaluate_random_questions()
334
+ else:
335
+ print(agent(q))