Abbasid commited on
Commit
bf3da7a
·
verified ·
1 Parent(s): 6a09b39

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +254 -253
agent.py CHANGED
@@ -2,7 +2,7 @@
2
  agent.py
3
 
4
  This file defines the core logic for a sophisticated AI agent using LangGraph.
5
- This version includes proper multimodal support for images, YouTube videos, and audio files.
6
  """
7
 
8
  # ----------------------------------------------------------
@@ -14,18 +14,12 @@ import pickle
14
  import re
15
  import subprocess
16
  import textwrap
17
- import base64
18
  import functools
19
- from io import BytesIO
20
  from pathlib import Path
21
- import tempfile
22
- import yt_dlp
23
- from pydub import AudioSegment
24
- import speech_recognition as sr
25
 
26
  import requests
27
  from cachetools import TTLCache
28
- from PIL import Image
29
 
30
  from langchain.schema import Document
31
  from langchain.tools.retriever import create_retriever_tool
@@ -34,9 +28,8 @@ from langchain_community.tools.tavily_search import TavilySearchResults
34
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
35
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
36
  from langchain_core.tools import Tool, tool
37
- from langchain_google_genai import ChatGoogleGenerativeAI
38
  from langchain_groq import ChatGroq
39
- from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
40
  from langgraph.graph import START, StateGraph, MessagesState
41
  from langgraph.prebuilt import ToolNode, tools_condition
42
 
@@ -47,230 +40,205 @@ load_dotenv()
47
  JSONL_PATH, FAISS_CACHE, EMBED_MODEL = Path("metadata.jsonl"), Path("faiss_index.pkl"), "sentence-transformers/all-mpnet-base-v2"
48
  RETRIEVER_K, CACHE_TTL = 5, 600
49
  API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
 
50
  def cached_get(key: str, fetch_fn):
51
- if key in API_CACHE: return API_CACHE[key]
 
52
  val = fetch_fn()
53
  API_CACHE[key] = val
54
  return val
55
 
56
  # ----------------------------------------------------------
57
- # Section 2: Standalone Tool Functions
58
  # ----------------------------------------------------------
59
  @tool
60
  def python_repl(code: str) -> str:
61
  """Executes a string of Python code and returns the stdout/stderr."""
62
  code = textwrap.dedent(code).strip()
63
  try:
64
- result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False)
65
- if result.returncode == 0: return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```"
66
- else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
67
- except subprocess.TimeoutExpired: return "Execution timed out (>10s)."
 
 
 
 
 
 
 
 
 
68
 
69
- def describe_image_func(image_source: str, vision_llm_instance) -> str:
70
- """Describes an image from a local file path or a URL using a provided vision LLM."""
 
 
 
 
71
  try:
72
- print(f"Processing image: {image_source}")
73
-
74
- # Download and process image
75
- if image_source.startswith("http"):
76
- response = requests.get(image_source, timeout=10)
77
- response.raise_for_status()
78
- img = Image.open(BytesIO(response.content))
79
- else:
80
- img = Image.open(image_source)
81
-
82
- # Convert to base64
83
- buffered = BytesIO()
84
- img.convert("RGB").save(buffered, format="JPEG")
85
- b64_string = base64.b64encode(buffered.getvalue()).decode()
86
-
87
- # Create multimodal message
88
- msg = HumanMessage(content=[
89
- {"type": "text", "text": "Describe this image in detail. Include all objects, people, text, colors, setting, and any other relevant information you can see."},
90
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}
91
- ])
92
 
93
- result = vision_llm_instance.invoke([msg])
94
- return f"Image description: {result.content}"
 
 
95
 
96
- except Exception as e:
97
- print(f"Error in describe_image_func: {e}")
98
- return f"Error processing image: {e}"
99
 
100
- @tool
101
- def process_youtube_video(url: str) -> str:
102
- """Downloads and processes a YouTube video, extracting audio and converting to text."""
 
 
 
103
  try:
104
- print(f"Processing YouTube video: {url}")
 
 
 
 
 
 
 
105
 
106
- # Create temporary directory
107
- with tempfile.TemporaryDirectory() as temp_dir:
108
- # Download audio from YouTube video
109
- ydl_opts = {
110
- 'format': 'bestaudio/best',
111
- 'outtmpl': f'{temp_dir}/%(title)s.%(ext)s',
112
- 'postprocessors': [{
113
- 'key': 'FFmpegExtractAudio',
114
- 'preferredcodec': 'wav',
115
- }],
116
- }
117
-
118
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
119
- info = ydl.extract_info(url, download=True)
120
- title = info.get('title', 'Unknown')
121
-
122
- # Find the downloaded audio file
123
- audio_files = list(Path(temp_dir).glob("*.wav"))
124
- if not audio_files:
125
- return "Error: Could not download audio from YouTube video"
126
-
127
- audio_file = audio_files[0]
128
-
129
- # Convert audio to text using speech recognition
130
- r = sr.Recognizer()
131
-
132
- # Load audio file
133
- audio = AudioSegment.from_wav(str(audio_file))
134
-
135
- # Convert to mono and set sample rate
136
- audio = audio.set_channels(1)
137
- audio = audio.set_frame_rate(16000)
138
-
139
- # Convert to smaller chunks for processing (30 seconds each)
140
- chunk_length_ms = 30000
141
- chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
142
-
143
- transcript_parts = []
144
- for i, chunk in enumerate(chunks[:10]): # Limit to first 5 minutes
145
- chunk_file = Path(temp_dir) / f"chunk_{i}.wav"
146
- chunk.export(chunk_file, format="wav")
147
-
148
- try:
149
- with sr.AudioFile(str(chunk_file)) as source:
150
- audio_data = r.record(source)
151
- text = r.recognize_google(audio_data)
152
- transcript_parts.append(text)
153
- except sr.UnknownValueError:
154
- transcript_parts.append("[Unintelligible audio]")
155
- except sr.RequestError as e:
156
- transcript_parts.append(f"[Speech recognition error: {e}]")
157
-
158
- transcript = " ".join(transcript_parts)
159
-
160
- return f"YouTube Video: {title}\n\nTranscript (first 5 minutes):\n{transcript}"
161
-
162
  except Exception as e:
163
- print(f"Error processing YouTube video: {e}")
164
- return f"Error processing YouTube video: {e}"
165
 
166
- @tool
167
- def process_audio_file(file_url: str) -> str:
168
- """Downloads and processes an audio file (MP3, WAV, etc.) and converts to text."""
 
 
 
169
  try:
170
- print(f"Processing audio file: {file_url}")
 
 
 
171
 
172
- with tempfile.TemporaryDirectory() as temp_dir:
173
- # Download audio file
174
- response = requests.get(file_url, timeout=30)
175
- response.raise_for_status()
176
-
177
- # Determine file extension from URL or content type
178
- if file_url.lower().endswith('.mp3'):
179
- ext = 'mp3'
180
- elif file_url.lower().endswith('.wav'):
181
- ext = 'wav'
182
- else:
183
- content_type = response.headers.get('content-type', '')
184
- if 'mp3' in content_type:
185
- ext = 'mp3'
186
- elif 'wav' in content_type:
187
- ext = 'wav'
188
- else:
189
- ext = 'mp3' # Default assumption
190
-
191
- audio_file = Path(temp_dir) / f"audio.{ext}"
192
- with open(audio_file, 'wb') as f:
193
- f.write(response.content)
194
-
195
- # Convert to WAV if necessary
196
- if ext != 'wav':
197
- audio = AudioSegment.from_file(str(audio_file))
198
- wav_file = Path(temp_dir) / "audio.wav"
199
- audio.export(wav_file, format="wav")
200
- audio_file = wav_file
201
-
202
- # Convert audio to text
203
- r = sr.Recognizer()
204
-
205
- # Load and process audio
206
- audio = AudioSegment.from_wav(str(audio_file))
207
- audio = audio.set_channels(1).set_frame_rate(16000)
208
-
209
- # Process in chunks
210
- chunk_length_ms = 30000
211
- chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
212
-
213
- transcript_parts = []
214
- for i, chunk in enumerate(chunks[:20]): # Limit to first 10 minutes
215
- chunk_file = Path(temp_dir) / f"chunk_{i}.wav"
216
- chunk.export(chunk_file, format="wav")
217
-
218
- try:
219
- with sr.AudioFile(str(chunk_file)) as source:
220
- audio_data = r.record(source)
221
- text = r.recognize_google(audio_data)
222
- transcript_parts.append(text)
223
- except sr.UnknownValueError:
224
- transcript_parts.append("[Unintelligible audio]")
225
- except sr.RequestError as e:
226
- transcript_parts.append(f"[Speech recognition error: {e}]")
227
-
228
- transcript = " ".join(transcript_parts)
229
- return f"Audio file transcript:\n{transcript}"
230
-
231
  except Exception as e:
232
- print(f"Error processing audio file: {e}")
233
- return f"Error processing audio file: {e}"
234
 
235
- def web_search_func(query: str, cache_func) -> str:
236
- """Performs a web search using Tavily and returns a compilation of results."""
237
- key = f"web:{query}"
238
- results = cache_func(key, lambda: TavilySearchResults(max_results=5).invoke(query))
239
- return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- def wiki_search_func(query: str, cache_func) -> str:
242
- """Searches Wikipedia and returns the top 2 results."""
243
- key = f"wiki:{query}"
244
- docs = cache_func(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load())
245
- return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs])
246
 
247
- def arxiv_search_func(query: str, cache_func) -> str:
248
- """Searches Arxiv for scientific papers and returns the top 2 results."""
249
- key = f"arxiv:{query}"
250
- docs = cache_func(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
251
- return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
 
 
 
 
 
 
252
 
253
  # ----------------------------------------------------------
254
- # Section 3: DYNAMIC SYSTEM PROMPT
255
  # ----------------------------------------------------------
256
- SYSTEM_PROMPT_TEMPLATE = (
257
- """You are an expert-level multimodal research assistant. Your goal is to answer the user's question accurately using all available tools.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- **CRITICAL INSTRUCTIONS:**
260
- 1. **USE YOUR TOOLS:** You have been given a set of tools to find information. You MUST use them when the answer is not immediately known to you. Do not make up answers.
261
- 2. **MULTIMODAL PROCESSING:** When you encounter URLs or attachments:
262
- - For image URLs (jpg, png, gif, etc.): Use the `describe_image` tool
263
- - For YouTube URLs: Use the `process_youtube_video` tool
264
- - For audio files (mp3, wav, etc.): Use the `process_audio_file` tool
265
- - For other content: Use appropriate search tools
266
- 3. **AVAILABLE TOOLS:** Here is the exact list of tools you have access to:
267
- {tools}
268
- 4. **REASONING:** Think step-by-step. First, analyze the user's question and any attachments. Second, decide which tools are appropriate. Third, call the tools with correct parameters. Finally, synthesize the results.
269
- 5. **URL DETECTION:** Look for URLs in the user's message, especially in brackets like [Attachment URL: ...]. Process these appropriately.
270
- 6. **FINAL ANSWER FORMAT:** Your final response MUST strictly follow this format:
271
- `FINAL ANSWER: [Your comprehensive answer incorporating all tool results]`
 
 
272
  """
273
- )
274
 
275
  # ----------------------------------------------------------
276
  # Section 4: Factory Function for Agent Executor
@@ -281,80 +249,113 @@ def create_agent_executor(provider: str = "groq"):
281
  """
282
  print(f"Initializing agent with provider: {provider}")
283
 
284
- # Step 1: Build LLMs - Use Google for vision capabilities
285
- if provider == "google":
286
- main_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
287
- vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
288
- elif provider == "groq":
289
- main_llm = ChatGroq(model_name="llama-3.2-90b-vision-preview", temperature=0)
290
- # Use Google for vision since Groq's vision support may be limited
291
- vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
292
- elif provider == "huggingface":
293
- main_llm = ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.1))
294
- vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
295
- else:
296
- raise ValueError("Invalid provider selected")
 
 
297
 
298
- # Step 2: Build Retriever
299
  embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
300
  if FAISS_CACHE.exists():
301
- with open(FAISS_CACHE, "rb") as f: vector_store = pickle.load(f)
 
 
302
  else:
303
  if JSONL_PATH.exists():
304
- docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
 
 
 
 
 
 
 
305
  vector_store = FAISS.from_documents(docs, embeddings)
306
- with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
 
 
307
  else:
308
- # Create empty vector store if no metadata file exists
309
  docs = [Document(page_content="Sample document", metadata={"source": "sample"})]
310
  vector_store = FAISS.from_documents(docs, embeddings)
 
311
 
312
  retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
313
 
314
- # Step 3: Create the final list of tools
315
  tools_list = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  python_repl,
317
- Tool(name="describe_image", func=functools.partial(describe_image_func, vision_llm_instance=vision_llm), description="Describes an image from a local file path or a URL. Use this for any image files or image URLs."),
318
- process_youtube_video,
319
- process_audio_file,
320
- Tool(name="web_search", func=functools.partial(web_search_func, cache_func=cached_get), description="Performs a web search using Tavily."),
321
- Tool(name="wiki_search", func=functools.partial(wiki_search_func, cache_func=cached_get), description="Searches Wikipedia."),
322
- Tool(name="arxiv_search", func=functools.partial(arxiv_search_func, cache_func=cached_get), description="Searches Arxiv for scientific papers."),
323
- create_retriever_tool(retriever=retriever, name="retrieve_examples", description="Retrieve solved questions similar to the user's query."),
324
  ]
325
 
326
- # Step 4: Format the tool list into a string for the prompt
327
- tool_definitions = "\n".join([f"- `{tool.name}`: {tool.description}" for tool in tools_list])
328
- final_system_prompt = SYSTEM_PROMPT_TEMPLATE.format(tools=tool_definitions)
329
-
330
- llm_with_tools = main_llm.bind_tools(tools_list)
331
-
332
- # Step 5: Define Graph Nodes
333
- def retriever_node(state: MessagesState):
334
- user_query = state["messages"][-1].content
335
- docs = retriever.invoke(user_query)
336
- messages = [SystemMessage(content=final_system_prompt)]
337
- if docs:
338
- example_text = "\n\n---\n\n".join(d.page_content for d in docs)
339
- messages.append(AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever"))
340
- messages.extend(state["messages"])
341
- return {"messages": messages}
342
 
 
343
  def assistant_node(state: MessagesState):
344
- result = llm_with_tools.invoke(state["messages"])
345
- return {"messages": [result]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
- # Step 6: Build Graph
348
  builder = StateGraph(MessagesState)
349
- builder.add_node("retriever", retriever_node)
350
  builder.add_node("assistant", assistant_node)
351
- builder.add_node("tools", ToolNode(tools_list))
352
 
353
- builder.add_edge(START, "retriever")
354
- builder.add_edge("retriever", "assistant")
355
- builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"})
 
 
 
356
  builder.add_edge("tools", "assistant")
357
 
358
  agent_executor = builder.compile()
359
- print("Agent Executor created successfully.")
360
  return agent_executor
 
2
  agent.py
3
 
4
  This file defines the core logic for a sophisticated AI agent using LangGraph.
5
+ This version uses Groq's vision-capable models and includes proper reasoning steps.
6
  """
7
 
8
  # ----------------------------------------------------------
 
14
  import re
15
  import subprocess
16
  import textwrap
 
17
  import functools
 
18
  from pathlib import Path
19
+ from typing import Dict, Any
 
 
 
20
 
21
  import requests
22
  from cachetools import TTLCache
 
23
 
24
  from langchain.schema import Document
25
  from langchain.tools.retriever import create_retriever_tool
 
28
  from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
29
  from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
30
  from langchain_core.tools import Tool, tool
 
31
  from langchain_groq import ChatGroq
32
+ from langchain_huggingface import HuggingFaceEmbeddings
33
  from langgraph.graph import START, StateGraph, MessagesState
34
  from langgraph.prebuilt import ToolNode, tools_condition
35
 
 
40
  JSONL_PATH, FAISS_CACHE, EMBED_MODEL = Path("metadata.jsonl"), Path("faiss_index.pkl"), "sentence-transformers/all-mpnet-base-v2"
41
  RETRIEVER_K, CACHE_TTL = 5, 600
42
  API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
43
+
44
  def cached_get(key: str, fetch_fn):
45
+ if key in API_CACHE:
46
+ return API_CACHE[key]
47
  val = fetch_fn()
48
  API_CACHE[key] = val
49
  return val
50
 
51
  # ----------------------------------------------------------
52
+ # Section 2: Tool Functions
53
  # ----------------------------------------------------------
54
  @tool
55
  def python_repl(code: str) -> str:
56
  """Executes a string of Python code and returns the stdout/stderr."""
57
  code = textwrap.dedent(code).strip()
58
  try:
59
+ result = subprocess.run(
60
+ ["python", "-c", code],
61
+ capture_output=True,
62
+ text=True,
63
+ timeout=10,
64
+ check=False
65
+ )
66
+ if result.returncode == 0:
67
+ return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```"
68
+ else:
69
+ return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
70
+ except subprocess.TimeoutExpired:
71
+ return "Execution timed out (>10s)."
72
 
73
+ def web_search_func(query: str, cache_func) -> str:
74
+ """Performs a web search using Tavily and returns a compilation of results."""
75
+ if not query or not query.strip():
76
+ return "Error: Empty search query"
77
+
78
+ key = f"web:{query}"
79
  try:
80
+ results = cache_func(key, lambda: TavilySearchResults(max_results=5).invoke(query))
81
+ if not results:
82
+ return "No search results found"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ formatted_results = []
85
+ for res in results:
86
+ if isinstance(res, dict) and 'url' in res and 'content' in res:
87
+ formatted_results.append(f"Source: {res['url']}\nContent: {res['content']}")
88
 
89
+ return "\n\n---\n\n".join(formatted_results) if formatted_results else "No valid results found"
90
+ except Exception as e:
91
+ return f"Search error: {e}"
92
 
93
+ def wiki_search_func(query: str, cache_func) -> str:
94
+ """Searches Wikipedia and returns the top 2 results."""
95
+ if not query or not query.strip():
96
+ return "Error: Empty search query"
97
+
98
+ key = f"wiki:{query}"
99
  try:
100
+ docs = cache_func(key, lambda: WikipediaLoader(
101
+ query=query,
102
+ load_max_docs=2,
103
+ doc_content_chars_max=2000
104
+ ).load())
105
+
106
+ if not docs:
107
+ return "No Wikipedia articles found"
108
 
109
+ return "\n\n---\n\n".join([
110
+ f"Source: {d.metadata.get('source', 'Unknown')}\n\n{d.page_content}"
111
+ for d in docs
112
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  except Exception as e:
114
+ return f"Wikipedia search error: {e}"
 
115
 
116
+ def arxiv_search_func(query: str, cache_func) -> str:
117
+ """Searches Arxiv for scientific papers and returns the top 2 results."""
118
+ if not query or not query.strip():
119
+ return "Error: Empty search query"
120
+
121
+ key = f"arxiv:{query}"
122
  try:
123
+ docs = cache_func(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
124
+
125
+ if not docs:
126
+ return "No Arxiv papers found"
127
 
128
+ return "\n\n---\n\n".join([
129
+ f"Source: {d.metadata.get('source', 'Unknown')}\n"
130
+ f"Published: {d.metadata.get('Published', 'Unknown')}\n"
131
+ f"Title: {d.metadata.get('Title', 'Unknown')}\n\n"
132
+ f"Summary:\n{d.page_content}"
133
+ for d in docs
134
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  except Exception as e:
136
+ return f"Arxiv search error: {e}"
 
137
 
138
+ @tool
139
+ def analyze_task_and_reason(task_description: str) -> str:
140
+ """
141
+ Analyzes the task and provides reasoning about what approach to take.
142
+ This tool helps determine what other tools might be needed.
143
+ """
144
+ analysis = {
145
+ "task_type": "unknown",
146
+ "has_image": False,
147
+ "needs_search": False,
148
+ "needs_computation": False,
149
+ "approach": "Direct answer"
150
+ }
151
+
152
+ task_lower = task_description.lower()
153
+
154
+ # Check for image-related content
155
+ if any(keyword in task_lower for keyword in [
156
+ 'image', 'picture', 'photo', 'visual', 'see in', 'shown in',
157
+ 'attachment analysis', 'url:', 'http', '.jpg', '.png', '.gif'
158
+ ]):
159
+ analysis["has_image"] = True
160
+ analysis["task_type"] = "image_analysis"
161
+ analysis["approach"] = "Process image with vision model, then analyze content"
162
+
163
+ # Check for search needs
164
+ if any(keyword in task_lower for keyword in [
165
+ 'current', 'recent', 'latest', 'news', 'today', 'what is',
166
+ 'who is', 'when did', 'research', 'find information'
167
+ ]):
168
+ analysis["needs_search"] = True
169
+ if analysis["task_type"] == "unknown":
170
+ analysis["task_type"] = "information_search"
171
+ analysis["approach"] = "Search for current information"
172
+
173
+ # Check for computation needs
174
+ if any(keyword in task_lower for keyword in [
175
+ 'calculate', 'compute', 'math', 'formula', 'equation',
176
+ 'algorithm', 'code', 'program', 'python'
177
+ ]):
178
+ analysis["needs_computation"] = True
179
+ if analysis["task_type"] == "unknown":
180
+ analysis["task_type"] = "computation"
181
+ analysis["approach"] = "Use Python for calculations"
182
+
183
+ reasoning = f"""TASK ANALYSIS COMPLETE:
184
 
185
+ Task Type: {analysis['task_type']}
186
+ Has Image: {analysis['has_image']}
187
+ Needs Search: {analysis['needs_search']}
188
+ Needs Computation: {analysis['needs_computation']}
 
189
 
190
+ RECOMMENDED APPROACH: {analysis['approach']}
191
+
192
+ REASONING:
193
+ - If this involves an image, I should process it directly with my vision capabilities
194
+ - If this needs current information, I should use web search or Wikipedia
195
+ - If this needs calculations, I should use the Python tool
196
+ - I should always provide a comprehensive final answer
197
+
198
+ NEXT STEPS: Proceed with the identified approach and use appropriate tools."""
199
+
200
+ return reasoning
201
 
202
  # ----------------------------------------------------------
203
+ # Section 3: SYSTEM PROMPT
204
  # ----------------------------------------------------------
205
+ SYSTEM_PROMPT = """You are an expert multimodal AI assistant with vision capabilities and access to various tools.
206
+
207
+ **CORE CAPABILITIES:**
208
+ 1. **Vision Processing**: You can directly process and analyze images from URLs
209
+ 2. **Web Search**: Access current information via web search and Wikipedia
210
+ 3. **Computation**: Execute Python code for calculations and data processing
211
+ 4. **Research**: Search academic papers and retrieve similar examples
212
+
213
+ **CRITICAL WORKFLOW:**
214
+ 1. **ANALYZE FIRST**: Always start by using the 'analyze_task_and_reason' tool to understand what you're being asked to do
215
+ 2. **PROCESS IMAGES DIRECTLY**: When you encounter image URLs, process them directly with your vision model - DO NOT use separate image tools
216
+ 3. **USE TOOLS STRATEGICALLY**: Based on your analysis, use appropriate tools (web search, Python, etc.)
217
+ 4. **VALIDATE PARAMETERS**: Always check that you're passing correct parameters to tools
218
+ 5. **SYNTHESIZE**: Combine all information into a comprehensive answer
219
+
220
+ **IMAGE HANDLING:**
221
+ - You have native vision capabilities - process image URLs directly
222
+ - Look for image URLs in the task description
223
+ - When you see an image URL, examine it carefully and describe what you see
224
+ - Relate your visual observations to the question being asked
225
 
226
+ **TOOL USAGE RULES:**
227
+ - Always use 'analyze_task_and_reason' first to plan your approach
228
+ - Use web_search for current events, factual information, or research
229
+ - Use python_repl for calculations, data processing, or code execution
230
+ - Use wiki_search for encyclopedic information
231
+ - Use arxiv_search for academic/scientific papers
232
+ - Use retrieve_examples for similar solved problems
233
+
234
+ **OUTPUT FORMAT:**
235
+ Always end your response with: FINAL ANSWER: [Your comprehensive answer]
236
+
237
+ **PARAMETER VALIDATION:**
238
+ - Check that search queries are meaningful and specific
239
+ - Ensure Python code is safe and well-formed
240
+ - Verify image URLs are accessible before processing
241
  """
 
242
 
243
  # ----------------------------------------------------------
244
  # Section 4: Factory Function for Agent Executor
 
249
  """
250
  print(f"Initializing agent with provider: {provider}")
251
 
252
+ # Step 1: Initialize LLM with vision capabilities
253
+ if provider == "groq":
254
+ # Use Groq's vision-capable model
255
+ try:
256
+ llm = ChatGroq(
257
+ model_name="llama-3.2-90b-vision-preview", # Vision-capable model
258
+ temperature=0.1,
259
+ max_tokens=4000
260
+ )
261
+ print("Initialized Groq LLM with vision capabilities")
262
+ except Exception as e:
263
+ print(f"Error initializing Groq: {e}")
264
+ raise
265
+ else:
266
+ raise ValueError(f"Provider '{provider}' not supported in this version")
267
 
268
+ # Step 2: Build Retriever (if metadata exists)
269
  embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
270
  if FAISS_CACHE.exists():
271
+ with open(FAISS_CACHE, "rb") as f:
272
+ vector_store = pickle.load(f)
273
+ print("Loaded existing FAISS index")
274
  else:
275
  if JSONL_PATH.exists():
276
+ docs = []
277
+ with open(JSONL_PATH, "rt", encoding="utf-8") as f:
278
+ for line in f:
279
+ rec = json.loads(line)
280
+ docs.append(Document(
281
+ page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}",
282
+ metadata={"source": rec["task_id"]}
283
+ ))
284
  vector_store = FAISS.from_documents(docs, embeddings)
285
+ with open(FAISS_CACHE, "wb") as f:
286
+ pickle.dump(vector_store, f)
287
+ print(f"Created new FAISS index with {len(docs)} documents")
288
  else:
289
+ # Create minimal vector store
290
  docs = [Document(page_content="Sample document", metadata={"source": "sample"})]
291
  vector_store = FAISS.from_documents(docs, embeddings)
292
+ print("Created minimal FAISS index")
293
 
294
  retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
295
 
296
+ # Step 3: Create tools list
297
  tools_list = [
298
+ analyze_task_and_reason,
299
+ Tool(
300
+ name="web_search",
301
+ func=functools.partial(web_search_func, cache_func=cached_get),
302
+ description="Search the web for current information. Use specific, focused queries."
303
+ ),
304
+ Tool(
305
+ name="wiki_search",
306
+ func=functools.partial(wiki_search_func, cache_func=cached_get),
307
+ description="Search Wikipedia for encyclopedic information."
308
+ ),
309
+ Tool(
310
+ name="arxiv_search",
311
+ func=functools.partial(arxiv_search_func, cache_func=cached_get),
312
+ description="Search Arxiv for academic papers and research."
313
+ ),
314
  python_repl,
315
+ create_retriever_tool(
316
+ retriever=retriever,
317
+ name="retrieve_examples",
318
+ description="Retrieve similar solved examples from the knowledge base."
319
+ ),
 
 
320
  ]
321
 
322
+ llm_with_tools = llm.bind_tools(tools_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ # Step 4: Define Graph Nodes
325
  def assistant_node(state: MessagesState):
326
+ """Main assistant node that processes user input and tool responses."""
327
+ messages = [SystemMessage(content=SYSTEM_PROMPT)] + state["messages"]
328
+ try:
329
+ result = llm_with_tools.invoke(messages)
330
+ return {"messages": [result]}
331
+ except Exception as e:
332
+ error_msg = f"LLM Error: {e}"
333
+ print(error_msg)
334
+ return {"messages": [AIMessage(content=f"I encountered an error: {error_msg}")]}
335
+
336
+ def tools_node_wrapper(state: MessagesState):
337
+ """Wrapper for tool execution with error handling."""
338
+ try:
339
+ tool_node = ToolNode(tools_list)
340
+ return tool_node.invoke(state)
341
+ except Exception as e:
342
+ error_msg = f"Tool execution error: {e}"
343
+ print(error_msg)
344
+ return {"messages": [AIMessage(content=error_msg)]}
345
 
346
+ # Step 5: Build Graph
347
  builder = StateGraph(MessagesState)
 
348
  builder.add_node("assistant", assistant_node)
349
+ builder.add_node("tools", tools_node_wrapper)
350
 
351
+ builder.add_edge(START, "assistant")
352
+ builder.add_conditional_edges(
353
+ "assistant",
354
+ tools_condition,
355
+ {"tools": "tools", "__end__": "__end__"}
356
+ )
357
  builder.add_edge("tools", "assistant")
358
 
359
  agent_executor = builder.compile()
360
+ print("Agent Executor created successfully with vision capabilities")
361
  return agent_executor