onisj commited on
Commit
4701375
·
1 Parent(s): 488dc3e

feat(advance): Deploy corrected app.py and tools fo advance functions

Browse files
README.md CHANGED
@@ -1,39 +1,79 @@
1
  ---
2
- title: Jarvis Gaia Agent
3
  emoji: 🐢
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
- short_description: The JARVIS (Just A Rather Very Intelligent System) project
10
  ---
11
 
12
- # Jarvis Gaia Agent
13
 
14
- A Python-based AI agent leveraging `langchain`, `duckduckgo-search`, and `pytesseract` to perform web searches, document parsing, and multi-hop query refinement. Deployed as a Hugging Face Space for interactive use.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  ## Features
17
 
18
- - **Web Search**: Performs asynchronous searches using DuckDuckGo.
19
- - **Multi-Hop Search**: Refines complex queries iteratively with OpenAI's GPT-4o.
20
- - **Document Parsing**: Extracts text from PDFs and images using `PyPDF2` and `pytesseract`.
21
- - **Modular Tools**: Includes calculator, file parser, and document retriever.
22
- - **Observability**: Integrated with Langfuse for monitoring.
 
 
23
 
24
  ## Prerequisites
25
 
26
  - Python 3.11
27
  - Tesseract OCR (`brew install tesseract` on macOS)
28
- - API keys for:
29
- - OpenAI (`OPENAI_API_KEY`)
30
- - Hugging Face (`HUGGINGFACEHUB_API_TOKEN`)
31
- - Groq (`GROQ_API_KEY`)
32
- - Langfuse (`LANGFUSE_PUBLIC_KEY`, `LANGFUSE_SECRET_KEY`, `LANGFUSE_HOST`)
33
 
34
  ## Setup
35
 
36
  1. **Clone the Repository**:
37
  ```bash
38
- git clone https://github.com/your-username/jarvis_gaia_agent.git
39
- cd jarvis_gaia_agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: JARVIS Gaia Agent
3
  emoji: 🐢
4
  colorFrom: indigo
5
  colorTo: green
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
+ short_description: Enhanced JARVIS AI agent for GAIA benchmark
10
  ---
11
 
12
+ # Evolved JARVIS Gaia Agent
13
 
14
+ An advanced Python-based AI agent combining `langchain`, `smolagents`, SERPAPI, and OCR for web searches, file parsing, and data retrieval. Deployed as a Hugging Face Space for GAIA benchmark evaluation.
15
+
16
+ #### Directory Structure
17
+ ```
18
+ jarvis_gaia_agent/
19
+ ├── app.py # Main application with Gradio interface and agent logic
20
+ ├── state.py # Defines JARVISState for state management
21
+ ├── retriever.py # Guest info retriever tool
22
+ ├── tools/ # Directory for all tools
23
+ │ ├── __init__.py # Exports all tools
24
+ │ ├── search.py # Web search tools (SERPAPI-based)
25
+ │ ├── file_parser.py # File parsing tool (CSV, TXT, PDF, Excel)
26
+ │ ├── image_parser.py # Image parsing tool (OCR)
27
+ │ ├── calculator.py # Calculator tool
28
+ │ ├── document_retriever.py # Document retrieval tool
29
+ │ ├── duckduckgo_search.py # DuckDuckGo search tool (from smolagents)
30
+ │ ├── weather_info.py # Weather info tool (OpenWeatherMap)
31
+ │ ├── hub_stats.py # Hugging Face Hub stats tool
32
+ │ ├── guest_info.py # Guest info retriever tool (moved from retriever.py)
33
+ ├── requirements.txt # Python dependencies
34
+ ├── Dockerfile # Docker configuration
35
+ ├── README.md # Project documentation
36
+ ├── .env # Environment variables (not committed)
37
+ ```
38
 
39
  ## Features
40
 
41
+ - **Web Search**: SERPAPI and DuckDuckGo for robust searches.
42
+ - **File Parsing**: Handles CSV, TXT, PDF, and Excel files.
43
+ - **Image Parsing**: OCR with `easyocr` for image-based questions.
44
+ - **Data Retrieval**: Guest info retriever for structured data.
45
+ - **External APIs**: Weather (OpenWeatherMap), Hugging Face Hub stats.
46
+ - **State Management**: `langgraph` for multi-step reasoning.
47
+ - **Exact-Match Answers**: Optimized for GAIA Level 1 questions.
48
 
49
  ## Prerequisites
50
 
51
  - Python 3.11
52
  - Tesseract OCR (`brew install tesseract` on macOS)
53
+ - API keys in `.env`:
54
+ - `HUGGINGFACEHUB_API_TOKEN`
55
+ - `SERPAPI_API_KEY`
56
+ - `OPENWEATHERMAP_API_KEY`
57
+ - `SPACE_ID`
58
 
59
  ## Setup
60
 
61
  1. **Clone the Repository**:
62
  ```bash
63
+ git clone https://huggingface.co/spaces/onisj/jarvis_gaia_agent
64
+ cd jarvis_gaia_agent
65
+ ```
66
+
67
+ 2. **Set Up Environment Variables**:
68
+ Create a `.env` file with your API keys.
69
+
70
+ 3. **Run Locally**:
71
+ ```bash
72
+ pip install -r requirements.txt
73
+ python app.py
74
+ ```
75
+
76
+ 4. **Deploy to Hugging Face Space**:
77
+ - Push code to your Space.
78
+ - Set environment variables in Space settings.
79
+ - Run evaluation via Gradio interface.
__init__.py ADDED
File without changes
app.py CHANGED
@@ -1,25 +1,28 @@
1
  import os
2
- import gradio as gr
3
- import requests
4
- import aiohttp
5
- import asyncio
6
  import json
 
 
 
7
  import nest_asyncio
8
- from langgraph.graph import StateGraph, END
9
- from langgraph.checkpoint.memory import MemorySaver
10
- from langchain_huggingface import HuggingFacePipeline
11
- from transformers import pipeline
12
- from langchain_core.messages import SystemMessage, HumanMessage
13
- from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool
14
- from tools.search import initialize_search_tools
15
- from state import JARVISState
16
  import pandas as pd
 
 
 
 
 
 
17
  from dotenv import load_dotenv
18
- import logging
19
- from langfuse.callback import CallbackHandler
 
 
 
 
 
20
 
21
- # Set up logging
22
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
23
  logger = logging.getLogger(__name__)
24
 
25
  # Apply nest_asyncio
@@ -27,252 +30,253 @@ nest_asyncio.apply()
27
 
28
  # Load environment variables
29
  load_dotenv()
 
 
 
 
30
 
31
  # Verify environment variables
32
- required_env_vars = ["SPACE_ID", "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"]
33
- for var in required_env_vars:
34
- if not os.getenv(var):
35
- raise ValueError(f"Environment variable {var} is not set")
36
- logger.info(f"Environment variables loaded: SPACE_ID={os.getenv('SPACE_ID')[:10]}..., LANGFUSE_HOST={os.getenv('LANGFUSE_HOST', 'https://cloud.langfuse.com')}")
37
 
38
- # Initialize Hugging Face model
39
  try:
40
- hf_pipeline = pipeline(
41
- "text-generation",
42
- model="mistralai/Mixtral-7B-Instruct-v0.1",
43
- device_map="auto",
44
- max_new_tokens=512,
45
- do_sample=True,
46
- temperature=0.7
47
  )
48
- llm = HuggingFacePipeline(pipeline=hf_pipeline)
49
- logger.info("HuggingFace model initialized: mistralai/Mixtral-7B-Instruct-v0.1")
50
  except Exception as e:
51
- logger.error(f"Failed to initialize HuggingFace model: {e}")
52
  llm = None
53
 
54
- # Initialize search tools with LLM
55
  try:
56
- initialize_search_tools(llm)
57
- logger.info("Search tools initialized")
58
  except Exception as e:
59
- logger.error(f"Failed to initialize search tools: {e}")
60
-
61
- # Initialize Langfuse
62
- try:
63
- langfuse = CallbackHandler(
64
- public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
65
- secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
66
- host=os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com")
67
- )
68
- logger.info("Langfuse initialized successfully")
69
- except Exception as e:
70
- logger.warning(f"Failed to initialize Langfuse: {e}")
71
- langfuse = None
72
-
73
- # Initialize MemorySaver
74
- memory = MemorySaver()
75
- use_checkpointing = True
76
-
77
- # --- Constants ---
78
- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space/api"
79
- GAIA_FILE_URL = "https://api.gaia-benchmark.com/files/"
80
 
81
  # --- Helper Functions ---
82
- def log_state(task_id: str, state: JARVISState):
83
- """Log intermediate state to state_log.json"""
84
- try:
85
- log_entry = {
86
- "task_id": task_id,
87
- "question": state["question"],
88
- "tools_needed": state["tools_needed"],
89
- "web_results": state["web_results"],
90
- "file_results": state["file_results"],
91
- "image_results": state["image_results"],
92
- "calculation_results": state["calculation_results"],
93
- "document_results": state["document_results"],
94
- "answer": state["answer"]
95
- }
96
- with open("state_log.json", "a") as f:
97
- json.dump(log_entry, f, indent=2)
98
- f.write("\n")
99
- except Exception as e:
100
- logger.error(f"Error logging state for task {task_id}: {e}")
101
-
102
- async def test_gaia_api(task_id: str) -> bool:
103
- """Test connectivity to GAIA file API"""
104
  try:
105
- async with aiohttp.ClientSession() as session:
106
- async with session.head(f"{GAIA_FILE_URL}{task_id}", timeout=5) as resp:
107
- return resp.status in [200, 403, 404]
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
- logger.warning(f"GAIA API test failed: {e}")
110
- return False
111
 
112
  # --- Node Functions ---
113
- async def parse_question(state: JARVISState) -> JARVISState:
 
114
  try:
115
  question = state["question"]
116
- prompt = f"""Analyze this GAIA question: {question}
117
- Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever).
118
- Return a JSON list of tool names."""
119
  if llm:
120
- response = await llm.ainvoke(prompt, config={"callbacks": [langfuse] if langfuse else []})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  try:
122
- tools_needed = json.loads(response.content)
123
- except json.JSONDecodeError as je:
124
- logger.warning(f"Invalid JSON in LLM response for task {state['task_id']}: {je}")
125
- tools_needed = ["web_search"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  else:
127
- logger.warning("No LLM available, using default tools")
128
- tools_needed = ["web_search"]
129
- state["tools_needed"] = tools_needed
130
- log_state(state["task_id"], state)
131
  return state
132
  except Exception as e:
133
- logger.error(f"Error parsing question for task {state['task_id']}: {e}")
134
- state["tools_needed"] = []
135
- log_state(state["task_id"], state)
136
  return state
137
 
138
  async def tool_dispatcher(state: JARVISState) -> JARVISState:
 
139
  try:
140
- tools_needed = state["tools_needed"]
141
  updated_state = state.copy()
142
- can_download_files = await test_gaia_api(updated_state["task_id"])
 
 
 
 
 
 
143
 
144
- for tool in tools_needed:
145
  try:
146
- if tool == "web_search" or tool == "multi_hop_search":
147
- result = await web_search_agent(updated_state)
148
- updated_state["web_results"].extend(result["web_results"])
149
- elif tool == "file_parser" and can_download_files:
150
- result = await file_parser_agent(updated_state)
151
- updated_state["file_results"] = result["file_results"]
152
- elif tool == "image_parser" and can_download_files:
153
- result = await image_parser_agent(updated_state)
154
- updated_state["image_results"] = result["image_results"]
155
- elif tool == "calculator":
156
- result = await calculator_agent(updated_state)
157
- updated_state["calculation_results"] = result["calculation_results"]
158
- elif tool == "document_retriever" and can_download_files:
159
- result = await document_retriever_agent(updated_state)
160
- updated_state["document_results"] = result["document_results"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  except Exception as e:
162
- logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {e}")
163
-
164
- log_state(updated_state["task_id"], updated_state)
 
165
  return updated_state
166
  except Exception as e:
167
- logger.error(f"Error in tool dispatcher for task {state['task_id']}: {e}")
168
- log_state(state["task_id"], state)
169
  return state
170
 
171
- async def web_search_agent(state: JARVISState) -> JARVISState:
172
- try:
173
- results = []
174
- if "web_search" in state["tools_needed"]:
175
- result = await search_tool.invoke({"query": state["question"]})
176
- results.append(result)
177
- if "multi_hop_search" in state["tools_needed"]:
178
- result = await multi_hop_search_tool.invoke({"query": state["question"], "steps": 3})
179
- results.append(result)
180
- return {"web_results": results}
181
- except Exception as e:
182
- logger.error(f"Error in web search for task {state['task_id']}: {e}")
183
- return {"web_results": []}
184
-
185
- async def file_parser_agent(state: JARVISState) -> JARVISState:
186
- try:
187
- if "file_parser" in state["tools_needed"]:
188
- file_type = "csv" if "data" in state["question"].lower() else "txt"
189
- result = await file_parser_tool.aparse(state["task_id"], file_type=file_type)
190
- return {"file_results": result}
191
- return {"file_results": ""}
192
- except Exception as e:
193
- logger.error(f"Error in file parser for task {state['task_id']}: {e}")
194
- return {"file_results": "File parsing failed"}
195
-
196
- async def image_parser_agent(state: JARVISState) -> JARVISState:
197
- try:
198
- if "image_parser" in state["tools_needed"]:
199
- task = "match" if "fruits" in state["question"].lower() else "describe"
200
- match_query = "fruits" if task == "match" else ""
201
- file_path = f"temp_{state['task_id']}.jpg"
202
- if not os.path.exists(file_path):
203
- logger.warning(f"Image file not found for task {state['task_id']}")
204
- return {"image_results": "Image file not found"}
205
- result = await image_parser_tool.aparse(
206
- file_path, task=task, match_query=match_query
207
- )
208
- return {"image_results": result}
209
- return {"image_results": ""}
210
- except Exception as e:
211
- logger.error(f"Error in image parser for task {state['task_id']}: {e}")
212
- return {"image_results": "Image parsing failed"}
213
-
214
- async def calculator_agent(state: JARVISState) -> JARVISState:
215
- try:
216
- if "calculator" in state["tools_needed"]:
217
- prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}"
218
- if llm:
219
- response = await llm.ainvoke(prompt, config={"callbacks": [langfuse] if langfuse else []})
220
- expression = response.content
221
- else:
222
- expression = "0"
223
- result = await calculator_tool.aparse(expression)
224
- return {"calculation_results": result}
225
- return {"calculation_results": ""}
226
- except Exception as e:
227
- logger.error(f"Error in calculator for task {state['task_id']}: {e}")
228
- return {"calculation_results": "Calculation failed"}
229
-
230
- async def document_retriever_agent(state: JARVISState) -> JARVISState:
231
- try:
232
- if "document_retriever" in state["tools_needed"]:
233
- file_type = "txt" if "menu" in state["question"].lower() else "csv"
234
- if "report" in state["question"].lower() or "document" in state["question"].lower():
235
- file_type = "pdf"
236
- result = await document_retriever_tool.aparse(
237
- state["task_id"], state["question"], file_type=file_type
238
- )
239
- return {"document_results": result}
240
- return {"document_results": ""}
241
- except Exception as e:
242
- logger.error(f"Error in document retriever for task {state['task_id']}: {e}")
243
- return {"document_results": "Document retrieval failed"}
244
-
245
- async def reasoning_agent(state: JARVISState) -> JARVISState:
246
  try:
247
- prompt = f"""Question: {state['question']}
248
- Web Results: {state['web_results']}
249
- File Results: {state['file_results']}
250
- Image Results: {state['image_results']}
251
- Calculation Results: {state['calculation_results']}
252
- Document Results: {state['document_results']}
253
- Synthesize an exact-match answer for the GAIA benchmark.
254
- Output only the answer (e.g., '90', 'White;5876')."""
255
- if llm:
256
- response = await llm.ainvoke(
257
- [
258
- SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."),
259
- HumanMessage(content=prompt)
260
- ],
261
- config={"callbacks": [langfuse] if langfuse else []}
262
- )
263
- answer = response.content.strip()
264
- else:
265
- answer = "Unknown"
266
- state["answer"] = answer
267
- log_state(state["task_id"], state)
268
- return state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  except Exception as e:
270
- logger.error(f"Error in reasoning for task {state['task_id']}: {e}")
271
- state["answer"] = "Error in reasoning"
272
- log_state(state["task_id"], state)
273
- return state
274
 
275
  def router(state: JARVISState) -> str:
 
276
  if state["tools_needed"]:
277
  return "tool_dispatcher"
278
  return "reasoning"
@@ -281,8 +285,7 @@ def router(state: JARVISState) -> str:
281
  workflow = StateGraph(JARVISState)
282
  workflow.add_node("parse", parse_question)
283
  workflow.add_node("tool_dispatcher", tool_dispatcher)
284
- workflow.add_node("reasoning", reasoning_agent)
285
-
286
  workflow.set_entry_point("parse")
287
  workflow.add_conditional_edges(
288
  "parse",
@@ -294,97 +297,95 @@ workflow.add_conditional_edges(
294
  )
295
  workflow.add_edge("tool_dispatcher", "reasoning")
296
  workflow.add_edge("reasoning", END)
 
297
 
298
- # Compile graph
299
- graph = workflow.compile(checkpointer=memory if use_checkpointing else None)
300
-
301
- # --- Basic Agent Definition ---
302
  class BasicAgent:
303
  def __init__(self):
304
  logger.info("BasicAgent initialized.")
305
 
306
  async def process_question(self, task_id: str, question: str) -> str:
 
307
  file_type = "jpg" if "image" in question.lower() else "txt"
308
- if "menu" in question.lower() or "report" in question.lower() or "document" in question.lower():
309
  file_type = "pdf"
310
  elif "data" in question.lower():
311
- file_type = "csv"
312
 
313
  file_path = f"temp_{task_id}.{file_type}"
314
- if await test_gaia_api(task_id):
 
315
  try:
316
  async with aiohttp.ClientSession() as session:
317
- async with session.get(f"{GAIA_FILE_URL}{task_id}") as resp:
318
  if resp.status == 200:
319
  with open(file_path, "wb") as f:
320
  f.write(await resp.read())
321
  else:
322
- logger.warning(f"Failed to download file for task {task_id}: HTTP {resp.status}")
323
  except Exception as e:
324
- logger.error(f"Error downloading file for task {task_id}: {e}")
325
 
326
  state = JARVISState(
327
  task_id=task_id,
328
  question=question,
329
- tools_needed=[],
330
  web_results=[],
331
  file_results="",
332
  image_results="",
333
  calculation_results="",
334
  document_results="",
335
- messages=[],
336
  answer=""
337
  )
338
  try:
339
- config = {"configurable": {"thread_id": task_id}} if use_checkpointing else {}
340
- result = await graph.ainvoke(state, config=config)
341
- return result["answer"] or "No answer generated"
 
342
  except Exception as e:
343
  logger.error(f"Error processing task {task_id}: {e}")
344
  return f"Error: {str(e)}"
345
  finally:
346
- if os.path.exists(file_path):
347
- try:
348
- os.remove(file_path)
349
- except Exception as e:
350
- logger.error(f"Error removing file {file_path}: {e}")
 
 
351
 
352
  async def async_call(self, question: str, task_id: str) -> str:
353
- return await self.process_question(task_id, question)
354
 
355
  def __call__(self, question: str, task_id: str = None) -> str:
356
- logger.info(f"Agent received question (first 50 chars): {question[:50]}...")
357
  if task_id is None:
358
- logger.warning("task_id not provided, using placeholder")
359
- task_id = "placeholder_task_id"
360
  try:
361
- try:
362
- loop = asyncio.get_event_loop()
363
- except RuntimeError:
364
- loop = asyncio.new_event_loop()
365
- asyncio.set_event_loop(loop)
366
- return loop.run_until_complete(self.async_call(question, task_id))
367
- finally:
368
- pass
369
 
370
- # --- Main Function ---
371
  def run_and_submit_all(profile: gr.OAuthProfile | None):
372
- space_id = os.getenv("SPACE_ID")
373
  if not profile:
374
  logger.error("User not logged in.")
375
- return "Please Login to Hugging Face with the button.", None
376
  username = f"{profile.username}"
377
  logger.info(f"User logged in: {username}")
378
 
379
- api_url = DEFAULT_API_URL
380
- questions_url = f"{api_url}/questions"
381
- submit_url = f"{api_url}/submit"
382
- agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
383
 
384
  try:
385
  agent = BasicAgent()
386
  except Exception as e:
387
- logger.error(f"Error instantiating agent: {e}")
388
  return f"Error initializing agent: {e}", None
389
 
390
  logger.info(f"Fetching questions from: {questions_url}")
@@ -393,8 +394,8 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
393
  response.raise_for_status()
394
  questions_data = response.json()
395
  if not questions_data:
396
- logger.error("Fetched questions list is empty.")
397
- return "Fetched questions list is empty or invalid format.", None
398
  logger.info(f"Fetched {len(questions_data)} questions.")
399
  except Exception as e:
400
  logger.error(f"Error fetching questions: {e}")
@@ -402,24 +403,24 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
402
 
403
  results_log = []
404
  answers_payload = []
405
- logger.info(f"Running agent on {len(questions_data)} questions...")
406
  for item in questions_data:
407
  task_id = item.get("task_id")
408
  question_text = item.get("question")
409
  if not task_id or question_text is None:
410
- logger.warning(f"Skipping item with missing task_id or question: {item}")
411
  continue
412
  try:
413
  submitted_answer = agent(question_text, task_id)
414
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
415
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
416
  except Exception as e:
417
- logger.error(f"Error running agent on task {task_id}: {e}")
418
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
419
 
420
  if not answers_payload:
421
- logger.error("Agent did not produce any answers to submit.")
422
- return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
423
 
424
  submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
425
  logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}")
@@ -427,7 +428,6 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
427
  response = requests.post(submit_url, json=submission_data, timeout=120)
428
  response.raise_for_status()
429
  result_data = response.json()
430
- logger.info(f"Server response: {result_data}")
431
  final_status = (
432
  f"Submission Successful!\n"
433
  f"User: {result_data.get('username')}\n"
@@ -442,19 +442,19 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
442
  results_df = pd.DataFrame(results_log)
443
  return f"Submission Failed: {e}", results_df
444
 
445
- # --- Build Gradio Interface ---
446
  with gr.Blocks() as demo:
447
- gr.Markdown("# JARVIS Agent Evaluation Runner")
448
  gr.Markdown(
449
  """
450
  **Instructions:**
451
 
452
- 1. Log in to your Hugging Face account using the button below.
453
- 2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run the JARVIS agent, and submit answers.
454
 
455
  ---
456
  **Disclaimers:**
457
- The agent uses a local Hugging Face model (Mixtral-7B) and async tools for the GAIA benchmark.
458
  """
459
  )
460
 
@@ -463,16 +463,16 @@ with gr.Blocks() as demo:
463
  run_button = gr.Button("Run Evaluation & Submit All Answers")
464
 
465
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
466
- results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
467
 
468
  run_button.click(
469
  fn=run_and_submit_all,
470
  outputs=[status_output, results_table]
471
  )
472
 
 
473
  if __name__ == "__main__":
474
  logger.info("\n" + "-"*30 + " App Starting " + "-"*30)
475
- space_id = os.getenv("SPACE_ID")
476
- logger.info(f"SPACE_ID: {space_id}")
477
  logger.info("Launching Gradio Interface...")
478
  demo.launch(debug=True, share=False)
 
1
  import os
 
 
 
 
2
  import json
3
+ import logging
4
+ import asyncio
5
+ import aiohttp
6
  import nest_asyncio
7
+ import requests
 
 
 
 
 
 
 
8
  import pandas as pd
9
+ from typing import Dict, Any, List
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain_core.messages import SystemMessage, HumanMessage
12
+ from langgraph.graph import StateGraph, END
13
+ from sentence_transformers import SentenceTransformer
14
+ import gradio as gr
15
  from dotenv import load_dotenv
16
+ from huggingface_hub import InferenceClient
17
+ from state import JARVISState
18
+ from tools import (
19
+ search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool,
20
+ calculator_tool, document_retriever_tool, duckduckgo_search_tool,
21
+ weather_info_tool, hub_stats_tool, guest_info_retriever_tool
22
+ )
23
 
24
+ # Setup logging
25
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
26
  logger = logging.getLogger(__name__)
27
 
28
  # Apply nest_asyncio
 
30
 
31
  # Load environment variables
32
  load_dotenv()
33
+ SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
34
+ GAIA_API_URL = "https://agents-course-unit4-scoring.hf.space"
35
+ GAIA_FILE_URL = f"{GAIA_API_URL}/files/"
36
+ HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
37
 
38
  # Verify environment variables
39
+ if not SPACE_ID:
40
+ raise ValueError("SPACE_ID not set")
41
+ if not HF_TOKEN:
42
+ raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
43
+ logger.info(f"SPACE_ID: {SPACE_ID}")
44
 
45
+ # Initialize models
46
  try:
47
+ llm = InferenceClient(
48
+ model="meta-llama/Meta-Llama-3-8B-Instruct",
49
+ token=HF_TOKEN,
50
+ timeout=30
 
 
 
51
  )
52
+ logger.info("Hugging Face Inference LLM initialized")
 
53
  except Exception as e:
54
+ logger.error(f"Failed to initialize LLM: {e}")
55
  llm = None
56
 
 
57
  try:
58
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
59
+ logger.info("Sentence transformer initialized")
60
  except Exception as e:
61
+ logger.error(f"Failed to initialize embedder: {e}")
62
+ embedder = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  # --- Helper Functions ---
65
+ async def test_gaia_api(task_id: str, file_type: str = "txt") -> tuple[bool, str | None]:
66
+ """Test if a file exists for the task ID."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
+ for ext in [file_type, "txt", "csv", "xlsx", "jpg", "pdf"]:
69
+ async with aiohttp.ClientSession() as session:
70
+ async with session.get(f"{GAIA_FILE_URL}{task_id}.{ext}", timeout=5) as resp:
71
+ logger.info(f"GAIA API test for task {task_id} with .{ext}: HTTP {resp.status}")
72
+ if resp.status == 200:
73
+ file_path = f"temp_{task_id}.{ext}"
74
+ with open(file_path, "wb") as f:
75
+ f.write(await resp.read())
76
+ return True, ext
77
+ logger.info(f"No file found for task {task_id}")
78
+ return False, None
79
  except Exception as e:
80
+ logger.warning(f"GAIA API test failed: {str(e)}")
81
+ return False, None
82
 
83
  # --- Node Functions ---
84
+ async def parse_question(state: Dict[str, Any]) -> Dict[str, Any]:
85
+ """Parse the question to select appropriate tools."""
86
  try:
87
  question = state["question"]
88
+ task_id = state["task_id"]
89
+ tools_needed = ["search_tool"]
90
+
91
  if llm:
92
+ prompt = ChatPromptTemplate.from_messages([
93
+ SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
94
+ Return JSON list, e.g., ["search_tool", "file_parser_tool"].
95
+ Rules:
96
+ - Always include "search_tool" unless purely computational.
97
+ - Use "multi_hop_search_tool" for complex queries (over 20 words).
98
+ - Use "file_parser_tool" for data, tables, or Excel.
99
+ - Use "image_parser_tool" for images/videos.
100
+ - Use "calculator_tool" for math calculations.
101
+ - Use "document_retriever_tool" for documents/PDFs.
102
+ - Use "duckduckgo_search_tool" for additional search capability.
103
+ - Use "weather_info_tool" for weather-related queries.
104
+ - Use "hub_stats_tool" for Hugging Face Hub queries.
105
+ - Use "guest_info_retriever_tool" for guest-related queries.
106
+ - Output ONLY valid JSON."""),
107
+ HumanMessage(content=f"Query: {question}")
108
+ ])
109
  try:
110
+ response = llm.chat_completion(
111
+ messages=[
112
+ {"role": "system", "content": prompt[0].content},
113
+ {"role": "user", "content": prompt[1].content}
114
+ ],
115
+ max_tokens=512,
116
+ temperature=0.7
117
+ )
118
+ tools_needed = json.loads(response["choices"][0]["message"]["content"].strip())
119
+ valid_tools = {
120
+ "search_tool", "multi_hop_search_tool", "file_parser_tool", "image_parser_tool",
121
+ "calculator_tool", "document_retriever_tool", "duckduckgo_search_tool",
122
+ "weather_info_tool", "hub_stats_tool", "guest_info_retriever_tool"
123
+ }
124
+ tools_needed = [tool for tool in tools_needed if tool in valid_tools]
125
+ except Exception as e:
126
+ logger.warning(f"Task {task_id} failed: JSON parse error: {e}")
127
+ tools_needed = ["search_tool"]
128
+
129
+ # Keyword-based fallback
130
+ question_lower = question.lower()
131
+ if any(word in question_lower for word in ["image", "video"]):
132
+ tools_needed.append("image_parser_tool")
133
+ if any(word in question_lower for word in ["data", "table", "excel"]):
134
+ tools_needed.append("file_parser_tool")
135
+ if any(word in question_lower for word in ["calculate", "math"]):
136
+ tools_needed.append("calculator_tool")
137
+ if any(word in question_lower for word in ["document", "pdf"]):
138
+ tools_needed.append("document_retriever_tool")
139
+ if any(word in question_lower for word in ["weather"]):
140
+ tools_needed.append("weather_info_tool")
141
+ if any(word in question_lower for word in ["model", "huggingface"]):
142
+ tools_needed.append("hub_stats_tool")
143
+ if any(word in question_lower for word in ["guest", "name", "relation"]):
144
+ tools_needed.append("guest_info_retriever_tool")
145
+ if len(question.split()) > 20:
146
+ tools_needed.append("multi_hop_search_tool")
147
+
148
+ file_available, file_ext = await test_gaia_api(task_id)
149
+ if file_available:
150
+ if "file_parser_tool" not in tools_needed and any(word in question_lower for word in ["data", "table", "excel"]):
151
+ tools_needed.append("file_parser_tool")
152
+ if "image_parser_tool" not in tools_needed and "image" in question_lower:
153
+ tools_needed.append("image_parser_tool")
154
+ if "document_retriever_tool" not in tools_needed and file_ext == "pdf":
155
+ tools_needed.append("document_retriever_tool")
156
  else:
157
+ tools_needed = [tool for tool in tools_needed if tool not in ["file_parser_tool", "image_parser_tool", "document_retriever_tool"]]
158
+
159
+ state["tools_needed"] = list(set(tools_needed)) # Remove duplicates
160
+ logger.info(f"Task {task_id}: Selected tools: {tools_needed}")
161
  return state
162
  except Exception as e:
163
+ logger.error(f"Error parsing task {task_id}: {e}")
164
+ state["tools_needed"] = ["search_tool"]
 
165
  return state
166
 
167
  async def tool_dispatcher(state: JARVISState) -> JARVISState:
168
+ """Dispatch selected tools to process the state."""
169
  try:
 
170
  updated_state = state.copy()
171
+ file_type = "jpg" if "image" in state["question"].lower() else "txt"
172
+ if "menu" in state["question"].lower() or "report" in state["question"].lower():
173
+ file_type = "pdf"
174
+ elif "data" in state["question"].lower():
175
+ file_type = "xlsx"
176
+
177
+ can_download, file_ext = await test_gaia_api(updated_state["task_id"], file_type)
178
 
179
+ for tool in updated_state["tools_needed"]:
180
  try:
181
+ if tool == "search_tool":
182
+ result = await search_tool.ainvoke({"query": updated_state["question"]})
183
+ updated_state["web_results"].extend([r["content"] for r in result])
184
+ elif tool == "multi_hop_search_tool":
185
+ result = await multi_hop_search_tool.ainvoke({"query": updated_state["question"], "steps": 3})
186
+ updated_state["web_results"].extend([r["content"] for r in result])
187
+ await asyncio.sleep(2) # Rate limit
188
+ elif tool == "file_parser_tool" and can_download:
189
+ result = await file_parser_tool.ainvoke({"task_id": updated_state["task_id"], "file_type": file_ext})
190
+ updated_state["file_results"] = str(result)
191
+ elif tool == "image_parser_tool" and can_download:
192
+ result = await image_parser_tool.ainvoke({
193
+ "file_path": f"temp_{updated_state['task_id']}.{file_ext}",
194
+ "task": "describe"
195
+ })
196
+ updated_state["image_results"] = str(result)
197
+ elif tool == "calculator_tool":
198
+ result = await calculator_tool.ainvoke({"expression": updated_state.get("question", "")})
199
+ updated_state["calculation_results"] = str(result)
200
+ elif tool == "document_retriever_tool" and can_download:
201
+ result = await document_retriever_tool.ainvoke({
202
+ "task_id": updated_state["task_id"],
203
+ "query": updated_state["question"],
204
+ "file_type": file_ext
205
+ })
206
+ updated_state["document_results"] = str(result)
207
+ elif tool == "duckduckgo_search_tool":
208
+ result = await duckduckgo_search_tool.run(updated_state["question"])
209
+ updated_state["web_results"].append(str(result))
210
+ elif tool == "weather_info_tool":
211
+ location = updated_state["question"].split("weather in ")[1].split()[0] if "weather in" in updated_state["question"].lower() else "Unknown"
212
+ result = await weather_info_tool.ainvoke({"location": location})
213
+ updated_state["web_results"].append(str(result))
214
+ elif tool == "hub_stats_tool":
215
+ author = updated_state["question"].split("by ")[1].split()[0] if "by" in updated_state["question"].lower() else "Unknown"
216
+ result = await hub_stats_tool.ainvoke({"author": author})
217
+ updated_state["web_results"].append(str(result))
218
+ elif tool == "guest_info_retriever_tool":
219
+ query = updated_state["question"].split("about ")[1] if "about" in updated_state["question"].lower() else updated_state["question"]
220
+ result = await guest_info_retriever_tool.ainvoke({"query": query})
221
+ updated_state["web_results"].append(str(result))
222
  except Exception as e:
223
+ logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {str(e)}")
224
+ updated_state[f"{tool}_results"] = f"Error: {str(e)}"
225
+
226
+ logger.info(f"Task {updated_state['task_id']}: Tool results: {updated_state}")
227
  return updated_state
228
  except Exception as e:
229
+ logger.error(f"Tool dispatch failed for task {state['task_id']}: {e}")
 
230
  return state
231
 
232
+ async def reasoning(state: JARVISState) -> Dict[str, Any]:
233
+ """Generate exact-match answer with specific formatting."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  try:
235
+ if not llm:
236
+ return {"answer": "LLM unavailable"}
237
+ prompt = ChatPromptTemplate.from_messages([
238
+ SystemMessage(content="""Provide ONLY the exact answer (e.g., '90', 'HUE'). For USD, use two decimal places (e.g., '1234.00'). For lists, use comma-separated values (e.g., 'Smith, Lee'). For IOC codes, use three-letter codes (e.g., 'ARG'). No explanations or conversational text."""),
239
+ HumanMessage(content="""Question: {question}
240
+ Web results: {web_results}
241
+ File results: {file_results}
242
+ Image results: {image_results}
243
+ Calculation results: {calculation_results}
244
+ Document results: {document_results}""")
245
+ ])
246
+ response = llm.chat_completion(
247
+ messages=[
248
+ {"role": "system", "content": prompt[0].content},
249
+ {"role": "user", "content": prompt[1].content.format(
250
+ question=state["question"],
251
+ web_results="\n".join(state["web_results"]),
252
+ file_results=state["file_results"],
253
+ image_results=state["image_results"],
254
+ calculation_results=state["calculation_results"],
255
+ document_results=state["document_results"]
256
+ )}
257
+ ],
258
+ max_tokens=512,
259
+ temperature=0.7
260
+ )
261
+ answer = response["choices"][0]["message"]["content"].strip()
262
+ # Clean answer for specific formats
263
+ if "USD" in state["question"].lower():
264
+ try:
265
+ answer = f"{float(answer):.2f}"
266
+ except ValueError:
267
+ pass
268
+ if "before and after" in state["question"].lower():
269
+ answer = answer.replace(" and ", ", ")
270
+ elif "IOC code" in state["question"].lower():
271
+ answer = answer.upper()[:3]
272
+ logger.info(f"Task {state['task_id']}: Answer: {answer}")
273
+ return {"answer": answer}
274
  except Exception as e:
275
+ logger.error(f"Reasoning failed for task {state['task_id']}: {e}")
276
+ return {"answer": f"Error: {str(e)}"}
 
 
277
 
278
  def router(state: JARVISState) -> str:
279
+ """Route based on tools needed."""
280
  if state["tools_needed"]:
281
  return "tool_dispatcher"
282
  return "reasoning"
 
285
  workflow = StateGraph(JARVISState)
286
  workflow.add_node("parse", parse_question)
287
  workflow.add_node("tool_dispatcher", tool_dispatcher)
288
+ workflow.add_node("reasoning", reasoning)
 
289
  workflow.set_entry_point("parse")
290
  workflow.add_conditional_edges(
291
  "parse",
 
297
  )
298
  workflow.add_edge("tool_dispatcher", "reasoning")
299
  workflow.add_edge("reasoning", END)
300
+ graph = workflow.compile()
301
 
302
+ # --- Basic Agent ---
 
 
 
303
  class BasicAgent:
304
  def __init__(self):
305
  logger.info("BasicAgent initialized.")
306
 
307
  async def process_question(self, task_id: str, question: str) -> str:
308
+ """Process a single question with file handling."""
309
  file_type = "jpg" if "image" in question.lower() else "txt"
310
+ if "menu" in question.lower() or "report" in question.lower():
311
  file_type = "pdf"
312
  elif "data" in question.lower():
313
+ file_type = "xlsx"
314
 
315
  file_path = f"temp_{task_id}.{file_type}"
316
+ file_available, file_ext = await test_gaia_api(task_id, file_type)
317
+ if file_available:
318
  try:
319
  async with aiohttp.ClientSession() as session:
320
+ async with session.get(f"{GAIA_FILE_URL}{task_id}.{file_ext}") as resp:
321
  if resp.status == 200:
322
  with open(file_path, "wb") as f:
323
  f.write(await resp.read())
324
  else:
325
+ logger.warning(f"Failed to fetch file for {task_id}: HTTP {resp.status}")
326
  except Exception as e:
327
+ logger.error(f"Error downloading file for task {task_id}: {str(e)}")
328
 
329
  state = JARVISState(
330
  task_id=task_id,
331
  question=question,
332
+ tools_needed=["search_tool"],
333
  web_results=[],
334
  file_results="",
335
  image_results="",
336
  calculation_results="",
337
  document_results="",
338
+ messages=[HumanMessage(content=question)],
339
  answer=""
340
  )
341
  try:
342
+ result = await graph.ainvoke(state)
343
+ answer = result["answer"] or "Unknown"
344
+ logger.info(f"Task {task_id}: Final answer generated: {answer}")
345
+ return answer
346
  except Exception as e:
347
  logger.error(f"Error processing task {task_id}: {e}")
348
  return f"Error: {str(e)}"
349
  finally:
350
+ for ext in ["txt", "csv", "xlsx", "jpg", "pdf"]:
351
+ file_path = f"temp_{task_id}.{ext}"
352
+ if os.path.exists(file_path):
353
+ try:
354
+ os.remove(file_path)
355
+ except Exception as e:
356
+ logger.error(f"Error removing file {file_path}: {e}")
357
 
358
  async def async_call(self, question: str, task_id: str) -> str:
359
+ return await self.process_question(question, task_id)
360
 
361
  def __call__(self, question: str, task_id: str = None) -> str:
362
+ logger.info(f"Processing question: {question[:50]}...")
363
  if task_id is None:
364
+ task_id = "unknown_task_id"
 
365
  try:
366
+ loop = asyncio.get_event_loop()
367
+ except RuntimeError:
368
+ loop = asyncio.new_event_loop()
369
+ asyncio.set_event_loop(loop)
370
+ return loop.run_until_complete(self.async_call(question, task_id))
 
 
 
371
 
372
+ # --- Evaluation and Submission ---
373
  def run_and_submit_all(profile: gr.OAuthProfile | None):
374
+ """Run evaluation and submit answers to GAIA API."""
375
  if not profile:
376
  logger.error("User not logged in.")
377
+ return "Please Login to Hugging Face.", None
378
  username = f"{profile.username}"
379
  logger.info(f"User logged in: {username}")
380
 
381
+ questions_url = f"{GAIA_API_URL}/questions"
382
+ submit_url = f"{GAIA_API_URL}/submit"
383
+ agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
 
384
 
385
  try:
386
  agent = BasicAgent()
387
  except Exception as e:
388
+ logger.error(f"Agent initialization failed: {e}")
389
  return f"Error initializing agent: {e}", None
390
 
391
  logger.info(f"Fetching questions from: {questions_url}")
 
394
  response.raise_for_status()
395
  questions_data = response.json()
396
  if not questions_data:
397
+ logger.error("Empty questions list.")
398
+ return "No questions fetched.", None
399
  logger.info(f"Fetched {len(questions_data)} questions.")
400
  except Exception as e:
401
  logger.error(f"Error fetching questions: {e}")
 
403
 
404
  results_log = []
405
  answers_payload = []
406
+ logger.info(f"Processing {len(questions_data)} questions...")
407
  for item in questions_data:
408
  task_id = item.get("task_id")
409
  question_text = item.get("question")
410
  if not task_id or question_text is None:
411
+ logger.warning(f"Skipping invalid item: {item}")
412
  continue
413
  try:
414
  submitted_answer = agent(question_text, task_id)
415
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
416
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
417
  except Exception as e:
418
+ logger.error(f"Error for task {task_id}: {e}")
419
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
420
 
421
  if not answers_payload:
422
+ logger.error("No answers generated.")
423
+ return "No answers to submit.", pd.DataFrame(results_log)
424
 
425
  submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
426
  logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}")
 
428
  response = requests.post(submit_url, json=submission_data, timeout=120)
429
  response.raise_for_status()
430
  result_data = response.json()
 
431
  final_status = (
432
  f"Submission Successful!\n"
433
  f"User: {result_data.get('username')}\n"
 
442
  results_df = pd.DataFrame(results_log)
443
  return f"Submission Failed: {e}", results_df
444
 
445
+ # --- Gradio Interface ---
446
  with gr.Blocks() as demo:
447
+ gr.Markdown("# Evolved JARVIS Agent Evaluation")
448
  gr.Markdown(
449
  """
450
  **Instructions:**
451
 
452
+ 1. Log in to Hugging Face using the button below.
453
+ 2. Click 'Run Evaluation & Submit All Answers' to process GAIA questions and submit.
454
 
455
  ---
456
  **Disclaimers:**
457
+ Uses Hugging Face Inference, SERPAPI, and OpenWeatherMap for GAIA benchmark.
458
  """
459
  )
460
 
 
463
  run_button = gr.Button("Run Evaluation & Submit All Answers")
464
 
465
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
466
+ results_table = gr.DataFrame(label="Questions and Answers", wrap=True)
467
 
468
  run_button.click(
469
  fn=run_and_submit_all,
470
  outputs=[status_output, results_table]
471
  )
472
 
473
+ # --- Main ---
474
  if __name__ == "__main__":
475
  logger.info("\n" + "-"*30 + " App Starting " + "-"*30)
476
+ logger.info(f"SPACE_ID: {SPACE_ID}")
 
477
  logger.info("Launching Gradio Interface...")
478
  demo.launch(debug=True, share=False)
dockerfile CHANGED
@@ -2,23 +2,20 @@ FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
 
5
- # Install system dependencies
6
  RUN apt-get update && apt-get install -y \
7
  libgl1-mesa-glx \
8
  libglib2.0-0 \
 
9
  tesseract-ocr \
10
  libtesseract-dev \
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
- # Copy project files
14
  COPY requirements.txt .
15
  COPY app.py .
16
- COPY graph.py .
17
  COPY state.py .
 
18
  COPY tools/ tools/
19
 
20
- # Install Python dependencies
21
  RUN pip install --no-cache-dir -r requirements.txt
22
 
23
- # Run the application
24
- CMD ["python", "app.py"]
 
2
 
3
  WORKDIR /app
4
 
 
5
  RUN apt-get update && apt-get install -y \
6
  libgl1-mesa-glx \
7
  libglib2.0-0 \
8
+ python3-dev \
9
  tesseract-ocr \
10
  libtesseract-dev \
11
  && rm -rf /var/lib/apt/lists/*
12
 
 
13
  COPY requirements.txt .
14
  COPY app.py .
 
15
  COPY state.py .
16
+ COPY retriever.py .
17
  COPY tools/ tools/
18
 
 
19
  RUN pip install --no-cache-dir -r requirements.txt
20
 
21
+ CMD ["python3", "app.py"]
 
graph.py DELETED
@@ -1,143 +0,0 @@
1
- from langgraph.graph import StateGraph, END
2
- from langgraph.checkpoint.memory import MemorySaver
3
- from state import JARVISState
4
- from langchain_openai import ChatOpenAI
5
- from langchain_core.messages import SystemMessage, HumanMessage
6
- from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool
7
- from langfuse.callback import LangfuseCallbackHandler
8
- import json
9
- import os
10
- from dotenv import load_dotenv
11
-
12
- # Load environment variables
13
- load_dotenv()
14
- # Debug: Verify environment variables
15
- print(f"OPENAI_API_KEY loaded in graph.py: {'set' if os.getenv('OPENAI_API_KEY') else 'not set'}")
16
- print(f"LANGFUSE_PUBLIC_KEY loaded in graph.py: {'set' if os.getenv('LANGFUSE_PUBLIC_KEY') else 'not set'}")
17
-
18
- # Initialize LLM and Langfuse
19
- api_key = os.getenv("OPENAI_API_KEY")
20
- if not api_key:
21
- raise ValueError("OPENAI_API_KEY environment variable not set")
22
- llm = ChatOpenAI(model="gpt-4o", api_key=api_key)
23
- langfuse = LangfuseCallbackHandler(
24
- public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
25
- secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
26
- host=os.getenv("LANGFUSE_HOST")
27
- )
28
- memory = MemorySaver()
29
-
30
- # Question Parser Node
31
- async def parse_question(state: JARVISState) -> JARVISState:
32
- question = state["question"]
33
- prompt = f"""Analyze this GAIA question: {question}
34
- Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever).
35
- Return a JSON list of tool names."""
36
- response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]})
37
- tools_needed = json.loads(response.content)
38
- return {"messages": state["messages"] + [response], "tools_needed": tools_needed}
39
-
40
- # Web Search Agent Node
41
- async def web_search_agent(state: JARVISState) -> JARVISState:
42
- results = []
43
- if "web_search" in state["tools_needed"]:
44
- result = await search_tool.arun(state["question"])
45
- results.append(result)
46
- if "multi_hop_search" in state["tools_needed"]:
47
- result = await multi_hop_search_tool.aparse(state["question"], steps=3)
48
- results.append(result)
49
- return {"web_results": results}
50
-
51
- # File Parser Agent Node
52
- async def file_parser_agent(state: JARVISState) -> JARVISState:
53
- if "file_parser" in state["tools_needed"]:
54
- result = await file_parser_tool.aparse(state["task_id"])
55
- return {"file_results": result}
56
- return {"file_results": ""}
57
-
58
- # Image Parser Agent Node
59
- async def image_parser_agent(state: JARVISState) -> JARVISState:
60
- if "image_parser" in state["tools_needed"]:
61
- task = "match" if "fruits" in state["question"].lower() else "describe"
62
- match_query = "fruits" if task == "match" else ""
63
- result = await image_parser_tool.aparse(
64
- f"temp_{state['task_id']}.jpg", task=task, match_query=match_query
65
- )
66
- return {"image_results": result}
67
- return {"image_results": ""}
68
-
69
- # Calculator Agent Node
70
- async def calculator_agent(state: JARVISState) -> JARVISState:
71
- if "calculator" in state["tools_needed"]:
72
- prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}"
73
- response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]})
74
- expression = response.content
75
- result = await calculator_tool.aparse(expression)
76
- return {"calculation_results": result}
77
- return {"calculation_results": ""}
78
-
79
- # Document Retriever Agent Node
80
- async def document_retriever_agent(state: JARVISState) -> JARVISState:
81
- if "document_retriever" in state["tools_needed"]:
82
- file_type = "txt" if "menu" in state["question"].lower() else "csv"
83
- if "report" in state["question"].lower() or "document" in state["question"].lower():
84
- file_type = "pdf"
85
- result = await document_retriever_tool.aparse(
86
- state["task_id"], state["question"], file_type=file_type
87
- )
88
- return {"document_results": result}
89
- return {"document_results": ""}
90
-
91
- # Reasoning Agent Node
92
- async def reasoning_agent(state: JARVISState) -> JARVISState:
93
- prompt = f"""Question: {state['question']}
94
- Web Results: {state['web_results']}
95
- File Results: {state['file_results']}
96
- Image Results: {state['image_results']}
97
- Calculation Results: {state['calculation_results']}
98
- Document Results: {state['document_results']}
99
-
100
- Synthesize an exact-match answer for the GAIA benchmark.
101
- Output only the answer (e.g., '90', 'White;5876')."""
102
- response = await llm.ainvoke(
103
- [
104
- SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."),
105
- HumanMessage(content=prompt)
106
- ],
107
- config={"callbacks": [langfuse]}
108
- )
109
- return {"answer": response.content, "messages": state["messages"] + [response]}
110
-
111
- # Conditional Edge Router
112
- def router(state: JARVISState) -> str:
113
- if state["tools_needed"]:
114
- return "tools"
115
- return "reasoning"
116
-
117
- # Build Graph
118
- workflow = StateGraph(JARVISState)
119
- workflow.add_node("parse", parse_question)
120
- workflow.add_node("web_search", web_search_agent)
121
- workflow.add_node("file_parser", file_parser_agent)
122
- workflow.add_node("image_parser", image_parser_agent)
123
- workflow.add_node("calculator", calculator_agent)
124
- workflow.add_node("document_retriever", document_retriever_agent)
125
- workflow.add_node("reasoning", reasoning_agent)
126
-
127
- workflow.set_entry_point("parse")
128
- workflow.add_conditional_edges(
129
- "parse",
130
- router,
131
- {
132
- "tools": ["web_search", "file_parser", "image_parser", "calculator", "document_retriever"],
133
- "reasoning": "reasoning"
134
- }
135
- )
136
- workflow.add_edge("web_search", "reasoning")
137
- workflow.add_edge("file_parser", "reasoning")
138
- workflow.add_edge("image_parser", "reasoning")
139
- workflow.add_edge("calculator", "reasoning")
140
- workflow.add_edge("document_retriever", "reasoning")
141
- workflow.add_edge("reasoning", END)
142
-
143
- graph = workflow.compile(checkpointer=memory)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
project_struct.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ jarvis_gaia_agent/
2
+ ├── app.py # Main application with Gradio interface and agent logic
3
+ ├── state.py # Defines JARVISState for state management
4
+ ├── retriever.py # Guest info retriever tool
5
+ ├── tools/ # Directory for all tools
6
+ │ ├── __init__.py # Exports all tools
7
+ │ ├── search.py # Web search tools (SERPAPI-based)
8
+ │ ├── file_parser.py # File parsing tool (CSV, TXT, PDF, Excel)
9
+ │ ├── image_parser.py # Image parsing tool (OCR)
10
+ │ ├── calculator.py # Calculator tool
11
+ │ ├── document_retriever.py # Document retrieval tool
12
+ │ ├── duckduckgo_search.py # DuckDuckGo search tool (from smolagents)
13
+ │ ├── weather_info.py # Weather info tool (OpenWeatherMap)
14
+ │ ├── hub_stats.py # Hugging Face Hub stats tool
15
+ │ ├── guest_info.py # Guest info retriever tool (moved from retriever.py)
16
+ ├── requirements.txt # Python dependencies
17
+ ├── Dockerfile # Docker configuration
18
+ ├── README.md # Project documentation
19
+ ├── .env # Environment variables (not committed)
20
+
21
+ 2 directories, 17 files
requirements.txt CHANGED
@@ -1,89 +1,18 @@
1
- aiohttp==3.8.6
2
- aiosignal==1.3.1
3
- annotated-types==0.7.0
4
- anyio==4.4.0
5
- attrs==23.2.0
6
- backoff==2.2.1
7
- certifi==2024.7.4
8
- charset-normalizer==3.3.2
9
- click==8.1.7
10
- dataclasses-json==0.6.7
11
- distro==1.9.0
12
- duckduckgo_search==6.2.4
13
- filelock==3.15.4
14
- frozenlist==1.4.1
15
- fsspec==2024.6.1
16
- greenlet==3.0.3
17
- h11==0.14.0
18
- httpcore==1.0.5
19
- httpx==0.27.0
20
- httpx-sse==0.4.0
21
- huggingface-hub==0.23.4
22
- idna==3.7
23
- Jinja2==3.1.4
24
- jiter==0.5.0
25
- joblib==1.4.2
26
- jsonpatch==1.33
27
- jsonpointer==3.0.0
28
- langchain==0.2.11
29
- langchain-community==0.2.10
30
- langchain-core==0.2.23
31
- langchain-openai==0.1.17
32
- langchain-text-splitters==0.2.2
33
- langfuse==2.36.1
34
- langgraph==0.1.15
35
- langgraph-checkpoint==1.0.2
36
- langsmith==0.1.93
37
- lxml==5.2.2
38
- markdown-it-py==3.0.0
39
- MarkupSafe==2.1.5
40
- marshmallow==3.21.3
41
- mdurl==0.1.2
42
- mpmath==1.3.0
43
- msgpack==1.0.8
44
- multidict==6.0.5
45
- mypy_extensions==1.0.0
46
- networkx==3.3
47
- numpy==1.26.4
48
- openai==1.35.13
49
- orjson==3.10.6
50
- packaging==23.2
51
- pandas==2.2.2
52
- pillow==10.4.0
53
- primp==0.15.0
54
- pydantic==2.8.2
55
- pydantic_core==2.20.1
56
- Pygments==2.18.0
57
- PyPDF2==3.0.1
58
- pytesseract==0.3.10
59
- python-dateutil==2.9.0.post0
60
- python-dotenv==1.0.1
61
- pytz==2024.1
62
- PyYAML==6.0.1
63
- regex==2024.7.24
64
- requests==2.32.3
65
- requests-toolbelt==1.0.0
66
- rich==13.7.1
67
- safetensors==0.4.3
68
- scikit-learn==1.5.1
69
- scipy==1.14.0
70
- sentence-transformers==3.0.1
71
- six==1.16.0
72
- sniffio==1.3.1
73
- SQLAlchemy==2.0.31
74
- sympy==1.13.1
75
- tenacity==8.5.0
76
- threadpoolctl==3.5.0
77
- tiktoken==0.7.0
78
- tokenizers==0.19.1
79
- torch==2.2.2
80
- tqdm==4.66.4
81
- transformers==4.42.4
82
- typing-inspect==0.9.0
83
- typing_extensions==4.12.2
84
- tzdata==2024.1
85
- urllib3==2.2.2
86
- wrapt==1.16.0
87
- xxhash==3.4.1
88
- yarl==1.9.4
89
- gradio[oauth]==4.44.1
 
1
+ gradio
2
+ requests
3
+ pandas
4
+ PyPDF2
5
+ easyocr
6
+ langchain
7
+ langchain-community
8
+ langgraph
9
+ sentence-transformers
10
+ huggingface_hub
11
+ python-dotenv
12
+ aiohttp
13
+ nest-asyncio
14
+ sympy
15
+ openpyxl
16
+ smolagents
17
+ datasets
18
+ asyncio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retriever.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ from langchain.docstore.document import Document
3
+ from langchain_community.retrievers import BM25Retriever
4
+ from smolagents import Tool
5
+
6
+ def load_guest_dataset():
7
+ try:
8
+ guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
9
+ docs = [
10
+ Document(
11
+ page_content="\n".join([
12
+ f"Name: {guest['name']}",
13
+ f"Relation: {guest['relation']}",
14
+ f"Description: {guest['description']}",
15
+ f"Email: {guest['email']}"
16
+ ]),
17
+ metadata={"name": guest["name"]}
18
+ )
19
+ for guest in guest_dataset
20
+ ]
21
+ except Exception as e:
22
+ # Fallback mock dataset
23
+ docs = [
24
+ Document(
25
+ page_content="\n".join([
26
+ "Name: Dr. Nikola Tesla",
27
+ "Relation: old friend from university days",
28
+ "Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.",
29
+ "Email: nikola.tesla@gmail.com"
30
+ ]),
31
+ metadata={"name": "Dr. Nikola Tesla"}
32
+ )
33
+ ]
34
+ return docs
state_log.json ADDED
The diff for this file is too large to render. See raw diff
 
tools/__init__.py CHANGED
@@ -2,4 +2,8 @@ from .search import search_tool, multi_hop_search_tool
2
  from .file_parser import file_parser_tool
3
  from .image_parser import image_parser_tool
4
  from .calculator import calculator_tool
5
- from .document_retriever import document_retriever_tool
 
 
 
 
 
2
  from .file_parser import file_parser_tool
3
  from .image_parser import image_parser_tool
4
  from .calculator import calculator_tool
5
+ from .document_retriever import document_retriever_tool
6
+ from .duckduckgo_search import duckduckgo_search_tool
7
+ from .weather_info import weather_info_tool
8
+ from .hub_stats import hub_stats_tool
9
+ from .guest_info import guest_info_retriever_tool
tools/calculator.py CHANGED
@@ -6,7 +6,7 @@ logger = logging.getLogger(__name__)
6
 
7
  @tool
8
  async def calculator_tool(expression: str) -> str:
9
- """Evaluate a mathematical expression"""
10
  try:
11
  result = sympify(expression)
12
  return str(result)
 
6
 
7
  @tool
8
  async def calculator_tool(expression: str) -> str:
9
+ """Evaluate a mathematical expression."""
10
  try:
11
  result = sympify(expression)
12
  return str(result)
tools/document_retriever.py CHANGED
@@ -7,7 +7,7 @@ logger = logging.getLogger(__name__)
7
 
8
  @tool
9
  async def document_retriever_tool(task_id: str, query: str, file_type: str) -> str:
10
- """Retrieve content from a document"""
11
  try:
12
  file_path = f"temp_{task_id}.{file_type}"
13
  if not os.path.exists(file_path):
 
7
 
8
  @tool
9
  async def document_retriever_tool(task_id: str, query: str, file_type: str) -> str:
10
+ """Retrieve content from a document."""
11
  try:
12
  file_path = f"temp_{task_id}.{file_type}"
13
  if not os.path.exists(file_path):
tools/duckduckgo_search.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from smolagents import Tool, DuckDuckGoSearchTool
2
+ import logging
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+ duckduckgo_search_tool = DuckDuckGoSearchTool()
tools/file_parser.py CHANGED
@@ -8,7 +8,7 @@ logger = logging.getLogger(__name__)
8
 
9
  @tool
10
  async def file_parser_tool(task_id: str, file_type: str) -> str:
11
- """Parse a file based on task_id and file_type"""
12
  try:
13
  file_path = f"temp_{task_id}.{file_type}"
14
  if not os.path.exists(file_path):
@@ -26,6 +26,9 @@ async def file_parser_tool(task_id: str, file_type: str) -> str:
26
  reader = PyPDF2.PdfReader(f)
27
  text = "".join(page.extract_text() for page in reader.pages)
28
  return text
 
 
 
29
  else:
30
  return f"Unsupported file type: {file_type}"
31
  except Exception as e:
 
8
 
9
  @tool
10
  async def file_parser_tool(task_id: str, file_type: str) -> str:
11
+ """Parse a file based on task_id and file_type."""
12
  try:
13
  file_path = f"temp_{task_id}.{file_type}"
14
  if not os.path.exists(file_path):
 
26
  reader = PyPDF2.PdfReader(f)
27
  text = "".join(page.extract_text() for page in reader.pages)
28
  return text
29
+ elif file_type in ["xlsx", "xls"]:
30
+ df = pd.read_excel(file_path, engine="openpyxl")
31
+ return df.to_string()
32
  else:
33
  return f"Unsupported file type: {file_type}"
34
  except Exception as e:
tools/guest_info.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ from retriever import load_guest_dataset
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ @tool
8
+ async def guest_info_retriever_tool(query: str) -> str:
9
+ """Retrieve detailed information about gala guests based on their name or relation."""
10
+ try:
11
+ docs = load_guest_dataset()
12
+ from langchain_community.retrievers import BM25Retriever
13
+ retriever = BM25Retriever.from_documents(docs)
14
+ results = retriever.get_relevant_documents(query)
15
+ if results:
16
+ return "\n\n".join([doc.page_content for doc in results[:3]])
17
+ return "No matching guest information found."
18
+ except Exception as e:
19
+ logger.error(f"Error retrieving guest info for query '{query}': {e}")
20
+ return f"Error: {str(e)}"
tools/hub_stats.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ from huggingface_hub import list_models
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ @tool
8
+ async def hub_stats_tool(author: str) -> str:
9
+ """Fetch the most downloaded model from a specific author on Hugging Face Hub."""
10
+ try:
11
+ models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
12
+ if models:
13
+ model = models[0]
14
+ return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
15
+ return f"No models found for author {author}."
16
+ except Exception as e:
17
+ logger.error(f"Error fetching models for {author}: {e}")
18
+ return f"Error: {str(e)}"
tools/image_parser.py CHANGED
@@ -4,12 +4,11 @@ import logging
4
  import os
5
 
6
  logger = logging.getLogger(__name__)
7
-
8
  reader = easyocr.Reader(['en'])
9
 
10
  @tool
11
  async def image_parser_tool(file_path: str, task: str = "describe", match_query: str = "") -> str:
12
- """Parse text from an image"""
13
  try:
14
  if not os.path.exists(file_path):
15
  logger.warning(f"Image not found: {file_path}")
 
4
  import os
5
 
6
  logger = logging.getLogger(__name__)
 
7
  reader = easyocr.Reader(['en'])
8
 
9
  @tool
10
  async def image_parser_tool(file_path: str, task: str = "describe", match_query: str = "") -> str:
11
+ """Parse text from an image."""
12
  try:
13
  if not os.path.exists(file_path):
14
  logger.warning(f"Image not found: {file_path}")
tools/retriever.py DELETED
@@ -1,80 +0,0 @@
1
- from langchain.text_splitter import RecursiveCharacterTextSplitter
2
- from sentence_transformers import SentenceTransformer
3
- import numpy as np
4
- import pandas as pd
5
- import PyPDF2
6
- import os
7
- from typing import List, Dict
8
-
9
- class DocumentRetrieverTool:
10
- def __init__(self):
11
- self.name = "document_retriever"
12
- self.description = "Retrieves relevant text from GAIA text-heavy files (CSV, TXT, PDF) using semantic search."
13
- self.inputs = {
14
- "task_id": {"type": "string", "description": "GAIA task ID for the file"},
15
- "query": {"type": "string", "description": "Question or query to search for"},
16
- "file_type": {"type": "string", "description": "File type (csv, txt, pdf, default: txt)"}
17
- }
18
- self.output_type = str
19
- self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
20
- self.text_splitter = RecursiveCharacterTextSplitter(
21
- chunk_size=500,
22
- chunk_overlap=50,
23
- length_function=len
24
- )
25
- self.chunks: List[str] = []
26
- self.embeddings: np.ndarray = None
27
-
28
- async def aparse(self, task_id: str, query: str, file_type: str = "txt") -> str:
29
- """
30
- Loads a GAIA file, splits it into chunks, embeds them, and retrieves relevant text for the query.
31
- Supports CSV, TXT, and PDF files.
32
- """
33
- try:
34
- file_path = f"temp_{task_id}.{file_type}"
35
- if not os.path.exists(file_path):
36
- return f"File not found for task ID {task_id}"
37
-
38
- # Load and preprocess file
39
- text = ""
40
- if file_type == "csv":
41
- df = pd.read_csv(file_path)
42
- text = df.to_string()
43
- elif file_type == "txt":
44
- with open(file_path, "r", encoding="utf-8") as f:
45
- text = f.read()
46
- elif file_type == "pdf":
47
- with open(file_path, "rb") as f:
48
- pdf = PyPDF2.PdfReader(f)
49
- text = "".join(page.extract_text() or "" for page in pdf.pages)
50
- else:
51
- return f"Unsupported file type: {file_type}"
52
-
53
- # Check if text was extracted
54
- if not text.strip():
55
- return "No extractable text found in file."
56
-
57
- # Split text into chunks
58
- self.chunks = self.text_splitter.split_text(text)
59
- if not self.chunks:
60
- return "No content found in file."
61
-
62
- # Embed chunks and query
63
- self.embeddings = self.embedder.encode(self.chunks, convert_to_tensor=True)
64
- query_embedding = self.embedder.encode(query, convert_to_tensor=True)
65
-
66
- # Compute cosine similarities
67
- from sentence_transformers import util
68
- similarities = util.cos_sim(query_embedding, self.embeddings)[0]
69
-
70
- # Get top 3 most relevant chunks
71
- top_k = min(3, len(self.chunks))
72
- top_indices = similarities.argsort(descending=True)[:top_k]
73
- relevant_chunks = [self.chunks[idx] for idx in top_indices]
74
-
75
- # Combine results
76
- return "\n\n".join(relevant_chunks)
77
- except Exception as e:
78
- return f"Error retrieving documents: {str(e)}"
79
-
80
- document_retriever_tool = DocumentRetrieverTool()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/search.py CHANGED
@@ -1,91 +1,46 @@
1
  from langchain_core.tools import tool
2
- from langchain_huggingface import HuggingFacePipeline
3
- from sentence_transformers import SentenceTransformer
4
  import logging
5
- from typing import List, Dict, Any
6
  import requests
7
  import os
 
 
8
 
9
  logger = logging.getLogger(__name__)
10
-
11
- # Initialize embedding model (free, open-source)
12
- try:
13
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
14
- except Exception as e:
15
- logger.error(f"Failed to initialize embedding model: {e}")
16
- embedder = None
17
-
18
- # Global LLM instance
19
- search_llm = None
20
-
21
- def initialize_search_tools(llm: HuggingFacePipeline) -> None:
22
- """Initialize search tools with the provided LLM"""
23
- global search_llm
24
- search_llm = llm
25
- logger.info("Search tools initialized with HuggingFace LLM")
26
 
27
  @tool
28
  async def search_tool(query: str) -> List[Dict[str, Any]]:
29
- """Perform a web search using the query"""
30
  try:
31
- if not search_llm:
32
- logger.warning("Search LLM not initialized")
33
- return [{"content": "Search unavailable", "url": ""}]
34
-
35
- # Refine query using LLM
36
- prompt = f"Refine this search query for better results: {query}"
37
- response = await search_llm.ainvoke(prompt)
38
- refined_query = response.content.strip()
39
-
40
- # Check for SerpAPI key (free tier available)
41
  serpapi_key = os.getenv("SERPAPI_API_KEY")
42
- if serpapi_key:
43
- try:
44
- params = {"q": refined_query, "api_key": serpapi_key}
45
- response = requests.get("https://serpapi.com/search", params=params)
46
- response.raise_for_status()
47
- results = response.json().get("organic_results", [])
48
- return [{"content": r.get("snippet", ""), "url": r.get("link", "")} for r in results]
49
- except Exception as e:
50
- logger.warning(f"SerpAPI failed: {e}, falling back to mock search")
51
-
52
- # Mock search if no API key or API fails
53
- if embedder:
54
- query_embedding = embedder.encode(refined_query)
55
- results = [
56
- {"content": f"Mock result for {refined_query}", "url": "https://example.com"},
57
- {"content": f"Another mock result for {refined_query}", "url": "https://example.org"}
58
- ]
59
- else:
60
- results = [{"content": "Embedding model unavailable", "url": ""}]
61
-
62
- logger.info(f"Search results for query '{refined_query}': {len(results)} items")
63
- return results
64
  except Exception as e:
65
  logger.error(f"Error in search_tool: {e}")
66
  return [{"content": f"Search failed: {str(e)}", "url": ""}]
67
 
68
  @tool
69
  async def multi_hop_search_tool(query: str, steps: int = 3) -> List[Dict[str, Any]]:
70
- """Perform a multi-hop search by iteratively refining the query"""
71
  try:
72
- if not search_llm:
73
- logger.warning("Search LLM not initialized")
74
- return [{"content": "Multi-hop search unavailable", "url": ""}]
75
-
76
  results = []
77
  current_query = query
78
  for step in range(steps):
79
- prompt = f"Based on the query '{current_query}', generate a follow-up question to deepen the search."
80
- response = await search_llm.ainvoke(prompt)
81
- next_query = response.content.strip()
82
-
83
- step_results = await search_tool.invoke({"query": next_query})
84
  results.extend(step_results)
85
- current_query = next_query
86
- logger.info(f"Multi-hop step {step + 1}: {next_query}")
87
-
88
- return results
89
  except Exception as e:
90
  logger.error(f"Error in multi_hop_search_tool: {e}")
91
  return [{"content": f"Multi-hop search failed: {str(e)}", "url": ""}]
 
1
  from langchain_core.tools import tool
 
 
2
  import logging
 
3
  import requests
4
  import os
5
+ from typing import List, Dict, Any
6
+ from dotenv import load_dotenv
7
 
8
  logger = logging.getLogger(__name__)
9
+ load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  @tool
12
  async def search_tool(query: str) -> List[Dict[str, Any]]:
13
+ """Perform a web search using SERPAPI."""
14
  try:
 
 
 
 
 
 
 
 
 
 
15
  serpapi_key = os.getenv("SERPAPI_API_KEY")
16
+ if not serpapi_key:
17
+ logger.error("SERPAPI_API_KEY not set")
18
+ return [{"content": "Search unavailable: API key missing", "url": ""}]
19
+
20
+ params = {"q": query, "api_key": serpapi_key}
21
+ response = requests.get("https://serpapi.com/search", params=params, timeout=10)
22
+ response.raise_for_status()
23
+ results = response.json().get("organic_results", [])
24
+ logger.info(f"Search results for query '{query}': {len(results)} items")
25
+ search_results = [{"content": r.get("snippet", ""), "url": r.get("link", "")} for r in results]
26
+ return search_results or [{"content": "No search results", "url": ""}]
 
 
 
 
 
 
 
 
 
 
 
27
  except Exception as e:
28
  logger.error(f"Error in search_tool: {e}")
29
  return [{"content": f"Search failed: {str(e)}", "url": ""}]
30
 
31
  @tool
32
  async def multi_hop_search_tool(query: str, steps: int = 3) -> List[Dict[str, Any]]:
33
+ """Perform a multi-hop search."""
34
  try:
 
 
 
 
35
  results = []
36
  current_query = query
37
  for step in range(steps):
38
+ step_results = await search_tool.invoke({"query": current_query})
 
 
 
 
39
  results.extend(step_results)
40
+ current_query = f"{current_query} more details"
41
+ logger.info(f"Multi-hop step {step + 1}: {current_query}")
42
+ await asyncio.sleep(2) # Avoid rate limits
43
+ return results or [{"content": "No multi-hop results", "url": ""}]
44
  except Exception as e:
45
  logger.error(f"Error in multi_hop_search_tool: {e}")
46
  return [{"content": f"Multi-hop search failed: {str(e)}", "url": ""}]
tools/weather_info.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ import requests
3
+ import logging
4
+ import os
5
+ from dotenv import load_dotenv
6
+
7
+ logger = logging.getLogger(__name__)
8
+ load_dotenv()
9
+
10
+ @tool
11
+ async def weather_info_tool(location: str) -> str:
12
+ """Fetch real weather information for a given location."""
13
+ try:
14
+ api_key = os.getenv("OPENWEATHERMAP_API_KEY")
15
+ if not api_key:
16
+ logger.error("OPENWEATHERMAP_API_KEY not set")
17
+ return "Weather unavailable: API key missing"
18
+
19
+ url = f"http://api.openweathermap.org/data/2.5/weather?q={location}&appid={api_key}&units=metric"
20
+ response = requests.get(url).json()
21
+ if response.get("cod") == 200:
22
+ condition = response["weather"][0]["description"]
23
+ temp = response["main"]["temp"]
24
+ return f"Weather in {location}: {condition}, {temp}°C"
25
+ return f"Unable to fetch weather for {location}."
26
+ except Exception as e:
27
+ logger.error(f"Error fetching weather for {location}: {e}")
28
+ return f"Error: {str(e)}"