Abbasid commited on
Commit
1376719
·
verified ·
1 Parent(s): d59e1c2

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +186 -127
agent.py CHANGED
@@ -2,9 +2,7 @@
2
  agent.py
3
 
4
  This file defines the core logic for a sophisticated AI agent using LangGraph.
5
- ## MODIFICATION: This version introduces a 'multimodal_router' node.
6
- This node intelligently inspects user input to identify, classify (using HEAD requests),
7
- and pre-process URLs for images, audio, and video before the main LLM reasoning step.
8
  """
9
 
10
  # ----------------------------------------------------------
@@ -38,14 +36,14 @@ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
38
  from langchain_core.tools import Tool, tool
39
  from langchain_google_genai import ChatGoogleGenerativeAI
40
  from langchain_groq import ChatGroq
41
- from langchain_huggingface import HuggingFaceEmbeddings
42
  from langgraph.graph import START, StateGraph, MessagesState
43
  from langgraph.prebuilt import ToolNode, tools_condition
44
 
45
  from dotenv import load_dotenv
46
  load_dotenv()
47
 
48
- # --- Configuration and Caching (remains the same) ---
49
  JSONL_PATH, FAISS_CACHE, EMBED_MODEL = Path("metadata.jsonl"), Path("faiss_index.pkl"), "sentence-transformers/all-mpnet-base-v2"
50
  RETRIEVER_K, CACHE_TTL = 5, 600
51
  API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
@@ -56,12 +54,11 @@ def cached_get(key: str, fetch_fn):
56
  return val
57
 
58
  # ----------------------------------------------------------
59
- # Section 2: Standalone Tool Functions (remains the same)
60
  # ----------------------------------------------------------
61
  @tool
62
  def python_repl(code: str) -> str:
63
  """Executes a string of Python code and returns the stdout/stderr."""
64
- # ... (implementation unchanged)
65
  code = textwrap.dedent(code).strip()
66
  try:
67
  result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False)
@@ -69,36 +66,99 @@ def python_repl(code: str) -> str:
69
  else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
70
  except subprocess.TimeoutExpired: return "Execution timed out (>10s)."
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  @tool
74
  def process_youtube_video(url: str) -> str:
75
  """Downloads and processes a YouTube video, extracting audio and converting to text."""
76
- # ... (implementation unchanged)
77
  try:
78
  print(f"Processing YouTube video: {url}")
 
 
79
  with tempfile.TemporaryDirectory() as temp_dir:
 
80
  ydl_opts = {
81
- 'format': 'bestaudio/best', 'outtmpl': f'{temp_dir}/%(title)s.%(ext)s',
82
- 'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'wav'}],
 
 
 
 
83
  }
 
84
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
85
  info = ydl.extract_info(url, download=True)
86
  title = info.get('title', 'Unknown')
 
 
87
  audio_files = list(Path(temp_dir).glob("*.wav"))
88
- if not audio_files: return "Error: Could not download audio from YouTube video"
89
- r, transcript_parts = sr.Recognizer(), []
90
- audio = AudioSegment.from_wav(str(audio_files[0])).set_channels(1).set_frame_rate(16000)
91
- chunks = [audio[i:i + 30000] for i in range(0, len(audio), 30000)]
92
- for i, chunk in enumerate(chunks[:10]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  chunk_file = Path(temp_dir) / f"chunk_{i}.wav"
94
  chunk.export(chunk_file, format="wav")
 
95
  try:
96
  with sr.AudioFile(str(chunk_file)) as source:
97
- text = r.recognize_google(r.record(source))
 
98
  transcript_parts.append(text)
99
- except (sr.UnknownValueError, sr.RequestError) as e:
100
- transcript_parts.append(f"[Speech recognition error or unintelligible audio: {e}]")
101
- return f"YouTube Video: {title}\n\nTranscript (first 5 minutes):\n{' '.join(transcript_parts)}"
 
 
 
 
 
 
102
  except Exception as e:
103
  print(f"Error processing YouTube video: {e}")
104
  return f"Error processing YouTube video: {e}"
@@ -106,61 +166,110 @@ def process_youtube_video(url: str) -> str:
106
  @tool
107
  def process_audio_file(file_url: str) -> str:
108
  """Downloads and processes an audio file (MP3, WAV, etc.) and converts to text."""
109
- # ... (implementation unchanged)
110
  try:
111
  print(f"Processing audio file: {file_url}")
 
112
  with tempfile.TemporaryDirectory() as temp_dir:
 
113
  response = requests.get(file_url, timeout=30)
114
  response.raise_for_status()
115
- ext = os.path.splitext(file_url)[1][1:] or 'mp3'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  audio_file = Path(temp_dir) / f"audio.{ext}"
117
- with open(audio_file, 'wb') as f: f.write(response.content)
118
- wav_file = Path(temp_dir) / "audio.wav"
119
- AudioSegment.from_file(str(audio_file)).export(wav_file, format="wav")
120
- r, transcript_parts = sr.Recognizer(), []
121
- audio = AudioSegment.from_wav(str(wav_file)).set_channels(1).set_frame_rate(16000)
122
- chunks = [audio[i:i + 30000] for i in range(0, len(audio), 30000)]
123
- for i, chunk in enumerate(chunks[:20]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  chunk_file = Path(temp_dir) / f"chunk_{i}.wav"
125
  chunk.export(chunk_file, format="wav")
 
126
  try:
127
  with sr.AudioFile(str(chunk_file)) as source:
128
- text = r.recognize_google(r.record(source))
 
129
  transcript_parts.append(text)
130
- except (sr.UnknownValueError, sr.RequestError) as e:
131
- transcript_parts.append(f"[Speech recognition error or unintelligible audio: {e}]")
132
- return f"Audio file transcript:\n{' '.join(transcript_parts)}"
 
 
 
 
 
133
  except Exception as e:
134
  print(f"Error processing audio file: {e}")
135
  return f"Error processing audio file: {e}"
136
 
137
-
138
  def web_search_func(query: str, cache_func) -> str:
139
  """Performs a web search using Tavily and returns a compilation of results."""
140
- # ... (implementation unchanged)
141
  key = f"web:{query}"
142
  results = cache_func(key, lambda: TavilySearchResults(max_results=5).invoke(query))
143
  return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results])
144
 
145
  def wiki_search_func(query: str, cache_func) -> str:
146
  """Searches Wikipedia and returns the top 2 results."""
147
- # ... (implementation unchanged)
148
  key = f"wiki:{query}"
149
  docs = cache_func(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load())
150
  return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs])
151
 
152
  def arxiv_search_func(query: str, cache_func) -> str:
153
  """Searches Arxiv for scientific papers and returns the top 2 results."""
154
- # ... (implementation unchanged)
155
  key = f"arxiv:{query}"
156
  docs = cache_func(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
157
  return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
158
 
159
  # ----------------------------------------------------------
160
- # Section 3: DYNAMIC SYSTEM PROMPT (remains the same)
161
  # ----------------------------------------------------------
162
  SYSTEM_PROMPT_TEMPLATE = (
163
- """You are an expert-level multimodal research assistant...""" # Unchanged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  )
165
 
166
  # ----------------------------------------------------------
@@ -172,130 +281,80 @@ def create_agent_executor(provider: str = "groq"):
172
  """
173
  print(f"Initializing agent with provider: {provider}")
174
 
175
- # Step 1: Build LLM (remains the same)
176
- if provider == "groq":
177
- llm = ChatGroq(model_name="llama-3.1-70b-vision-preview", temperature=0)
 
 
 
 
 
 
 
 
178
  else:
179
- raise ValueError(f"Provider '{provider}' not currently configured for this router.")
180
-
181
- # Step 2: Build Retriever (remains the same, but will be called inside the router)
182
  embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
183
  if FAISS_CACHE.exists():
184
  with open(FAISS_CACHE, "rb") as f: vector_store = pickle.load(f)
185
  else:
186
- # ... logic to build vector_store from JSONL or create empty ...
187
- docs = []
188
  if JSONL_PATH.exists():
189
  docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
190
- if not docs:
 
 
 
191
  docs = [Document(page_content="Sample document", metadata={"source": "sample"})]
192
- vector_store = FAISS.from_documents(docs, embeddings)
193
- with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
194
  retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
195
 
196
- # Step 3: Create the final list of tools (remains the same)
197
  tools_list = [
198
- python_repl, process_youtube_video, process_audio_file,
 
 
 
199
  Tool(name="web_search", func=functools.partial(web_search_func, cache_func=cached_get), description="Performs a web search using Tavily."),
200
  Tool(name="wiki_search", func=functools.partial(wiki_search_func, cache_func=cached_get), description="Searches Wikipedia."),
201
  Tool(name="arxiv_search", func=functools.partial(arxiv_search_func, cache_func=cached_get), description="Searches Arxiv for scientific papers."),
202
  create_retriever_tool(retriever=retriever, name="retrieve_examples", description="Retrieve solved questions similar to the user's query."),
203
  ]
204
 
205
- # Step 4: Format prompt and bind tools (remains the same)
206
  tool_definitions = "\n".join([f"- `{tool.name}`: {tool.description}" for tool in tools_list])
207
  final_system_prompt = SYSTEM_PROMPT_TEMPLATE.format(tools=tool_definitions)
208
- llm_with_tools = llm.bind_tools(tools_list)
209
 
210
- # Step 5: Define Graph Nodes
211
-
212
- ## MODIFICATION: A new, powerful router node that replaces the previous pre-processing.
213
- def multimodal_router(state: MessagesState):
214
- """
215
- Inspects the user's message, classifies URLs, and prepares the state for the LLM.
216
- This node acts as a central dispatcher.
217
- """
218
- print("--- Entering Multimodal Router ---")
219
- messages = state["messages"]
220
- last_message = messages[-1]
221
-
222
- # 1. Perform knowledge base retrieval first
223
- # We consolidate this logic here from the old retriever_node
224
- user_query_text = ""
225
- if isinstance(last_message.content, str):
226
- user_query_text = last_message.content
227
- elif isinstance(last_message.content, list): # For multimodal messages
228
- user_query_text = " ".join(item['text'] for item in last_message.content if item['type'] == 'text')
229
 
230
- docs = retriever.invoke(user_query_text)
231
- system_messages = [SystemMessage(content=final_system_prompt)]
 
 
 
232
  if docs:
233
  example_text = "\n\n---\n\n".join(d.page_content for d in docs)
234
- system_messages.append(AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever"))
235
-
236
- # 2. Extract and classify URLs
237
- urls = re.findall(r'(https?://[^\s]+)', user_query_text)
238
- image_processed = False
239
-
240
- for url in urls:
241
- try:
242
- print(f"Routing URL: {url}")
243
- # Simple classification first
244
- if "youtube.com" in url or "youtu.be" in url:
245
- system_messages.append(SystemMessage(content=f"[System Note: A YouTube URL has been detected. Use the 'process_youtube_video' tool if the user asks about it.]"))
246
- continue
247
-
248
- # Use a HEAD request for robust classification
249
- headers = requests.head(url, timeout=5, allow_redirects=True).headers
250
- content_type = headers.get('Content-Type', '')
251
-
252
- if 'image/' in content_type and not image_processed:
253
- print(f" -> Classified as Image. Processing for vision model.")
254
- response = requests.get(url, timeout=10)
255
- response.raise_for_status()
256
- img = Image.open(BytesIO(response.content))
257
- buffered = BytesIO()
258
- img.convert("RGB").save(buffered, format="JPEG")
259
- b64_string = base64.b64encode(buffered.getvalue()).decode()
260
-
261
- # Embed the image into the last message
262
- new_content = [
263
- {"type": "text", "text": user_query_text},
264
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}
265
- ]
266
- messages[-1] = HumanMessage(content=new_content)
267
- image_processed = True # Process only the first image for now
268
-
269
- elif 'audio/' in content_type:
270
- print(f" -> Classified as Audio.")
271
- system_messages.append(SystemMessage(content=f"[System Note: An audio URL has been detected. Use the 'process_audio_file' tool if the user asks about it.]"))
272
-
273
- else:
274
- print(f" -> Classified as Web Page/Other.")
275
-
276
- except Exception as e:
277
- print(f" -> Could not process URL {url}: {e}")
278
-
279
- # Rebuild the final state
280
- final_messages = system_messages + messages
281
- return {"messages": final_messages}
282
 
283
  def assistant_node(state: MessagesState):
284
  result = llm_with_tools.invoke(state["messages"])
285
  return {"messages": [result]}
286
 
287
  # Step 6: Build Graph
288
- ## MODIFICATION: The graph is now simpler and more robust.
289
  builder = StateGraph(MessagesState)
290
- builder.add_node("multimodal_router", multimodal_router) # The new, powerful starting node
291
  builder.add_node("assistant", assistant_node)
292
  builder.add_node("tools", ToolNode(tools_list))
293
 
294
- builder.add_edge(START, "multimodal_router")
295
- builder.add_edge("multimodal_router", "assistant")
296
  builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"})
297
  builder.add_edge("tools", "assistant")
298
 
299
  agent_executor = builder.compile()
300
- print("Agent Executor with Multimodal Router created successfully.")
301
  return agent_executor
 
2
  agent.py
3
 
4
  This file defines the core logic for a sophisticated AI agent using LangGraph.
5
+ This version includes proper multimodal support for images, YouTube videos, and audio files.
 
 
6
  """
7
 
8
  # ----------------------------------------------------------
 
36
  from langchain_core.tools import Tool, tool
37
  from langchain_google_genai import ChatGoogleGenerativeAI
38
  from langchain_groq import ChatGroq
39
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint, ChatHuggingFace
40
  from langgraph.graph import START, StateGraph, MessagesState
41
  from langgraph.prebuilt import ToolNode, tools_condition
42
 
43
  from dotenv import load_dotenv
44
  load_dotenv()
45
 
46
+ # --- Configuration and Caching ---
47
  JSONL_PATH, FAISS_CACHE, EMBED_MODEL = Path("metadata.jsonl"), Path("faiss_index.pkl"), "sentence-transformers/all-mpnet-base-v2"
48
  RETRIEVER_K, CACHE_TTL = 5, 600
49
  API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
 
54
  return val
55
 
56
  # ----------------------------------------------------------
57
+ # Section 2: Standalone Tool Functions
58
  # ----------------------------------------------------------
59
  @tool
60
  def python_repl(code: str) -> str:
61
  """Executes a string of Python code and returns the stdout/stderr."""
 
62
  code = textwrap.dedent(code).strip()
63
  try:
64
  result = subprocess.run(["python", "-c", code], capture_output=True, text=True, timeout=10, check=False)
 
66
  else: return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```"
67
  except subprocess.TimeoutExpired: return "Execution timed out (>10s)."
68
 
69
+ def describe_image_func(image_source: str, vision_llm_instance) -> str:
70
+ """Describes an image from a local file path or a URL using a provided vision LLM."""
71
+ try:
72
+ print(f"Processing image: {image_source}")
73
+
74
+ # Download and process image
75
+ if image_source.startswith("http"):
76
+ response = requests.get(image_source, timeout=10)
77
+ response.raise_for_status()
78
+ img = Image.open(BytesIO(response.content))
79
+ else:
80
+ img = Image.open(image_source)
81
+
82
+ # Convert to base64
83
+ buffered = BytesIO()
84
+ img.convert("RGB").save(buffered, format="JPEG")
85
+ b64_string = base64.b64encode(buffered.getvalue()).decode()
86
+
87
+ # Create multimodal message
88
+ msg = HumanMessage(content=[
89
+ {"type": "text", "text": "Describe this image in detail. Include all objects, people, text, colors, setting, and any other relevant information you can see."},
90
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}
91
+ ])
92
+
93
+ result = vision_llm_instance.invoke([msg])
94
+ return f"Image description: {result.content}"
95
+
96
+ except Exception as e:
97
+ print(f"Error in describe_image_func: {e}")
98
+ return f"Error processing image: {e}"
99
 
100
  @tool
101
  def process_youtube_video(url: str) -> str:
102
  """Downloads and processes a YouTube video, extracting audio and converting to text."""
 
103
  try:
104
  print(f"Processing YouTube video: {url}")
105
+
106
+ # Create temporary directory
107
  with tempfile.TemporaryDirectory() as temp_dir:
108
+ # Download audio from YouTube video
109
  ydl_opts = {
110
+ 'format': 'bestaudio/best',
111
+ 'outtmpl': f'{temp_dir}/%(title)s.%(ext)s',
112
+ 'postprocessors': [{
113
+ 'key': 'FFmpegExtractAudio',
114
+ 'preferredcodec': 'wav',
115
+ }],
116
  }
117
+
118
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
119
  info = ydl.extract_info(url, download=True)
120
  title = info.get('title', 'Unknown')
121
+
122
+ # Find the downloaded audio file
123
  audio_files = list(Path(temp_dir).glob("*.wav"))
124
+ if not audio_files:
125
+ return "Error: Could not download audio from YouTube video"
126
+
127
+ audio_file = audio_files[0]
128
+
129
+ # Convert audio to text using speech recognition
130
+ r = sr.Recognizer()
131
+
132
+ # Load audio file
133
+ audio = AudioSegment.from_wav(str(audio_file))
134
+
135
+ # Convert to mono and set sample rate
136
+ audio = audio.set_channels(1)
137
+ audio = audio.set_frame_rate(16000)
138
+
139
+ # Convert to smaller chunks for processing (30 seconds each)
140
+ chunk_length_ms = 30000
141
+ chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
142
+
143
+ transcript_parts = []
144
+ for i, chunk in enumerate(chunks[:10]): # Limit to first 5 minutes
145
  chunk_file = Path(temp_dir) / f"chunk_{i}.wav"
146
  chunk.export(chunk_file, format="wav")
147
+
148
  try:
149
  with sr.AudioFile(str(chunk_file)) as source:
150
+ audio_data = r.record(source)
151
+ text = r.recognize_google(audio_data)
152
  transcript_parts.append(text)
153
+ except sr.UnknownValueError:
154
+ transcript_parts.append("[Unintelligible audio]")
155
+ except sr.RequestError as e:
156
+ transcript_parts.append(f"[Speech recognition error: {e}]")
157
+
158
+ transcript = " ".join(transcript_parts)
159
+
160
+ return f"YouTube Video: {title}\n\nTranscript (first 5 minutes):\n{transcript}"
161
+
162
  except Exception as e:
163
  print(f"Error processing YouTube video: {e}")
164
  return f"Error processing YouTube video: {e}"
 
166
  @tool
167
  def process_audio_file(file_url: str) -> str:
168
  """Downloads and processes an audio file (MP3, WAV, etc.) and converts to text."""
 
169
  try:
170
  print(f"Processing audio file: {file_url}")
171
+
172
  with tempfile.TemporaryDirectory() as temp_dir:
173
+ # Download audio file
174
  response = requests.get(file_url, timeout=30)
175
  response.raise_for_status()
176
+
177
+ # Determine file extension from URL or content type
178
+ if file_url.lower().endswith('.mp3'):
179
+ ext = 'mp3'
180
+ elif file_url.lower().endswith('.wav'):
181
+ ext = 'wav'
182
+ else:
183
+ content_type = response.headers.get('content-type', '')
184
+ if 'mp3' in content_type:
185
+ ext = 'mp3'
186
+ elif 'wav' in content_type:
187
+ ext = 'wav'
188
+ else:
189
+ ext = 'mp3' # Default assumption
190
+
191
  audio_file = Path(temp_dir) / f"audio.{ext}"
192
+ with open(audio_file, 'wb') as f:
193
+ f.write(response.content)
194
+
195
+ # Convert to WAV if necessary
196
+ if ext != 'wav':
197
+ audio = AudioSegment.from_file(str(audio_file))
198
+ wav_file = Path(temp_dir) / "audio.wav"
199
+ audio.export(wav_file, format="wav")
200
+ audio_file = wav_file
201
+
202
+ # Convert audio to text
203
+ r = sr.Recognizer()
204
+
205
+ # Load and process audio
206
+ audio = AudioSegment.from_wav(str(audio_file))
207
+ audio = audio.set_channels(1).set_frame_rate(16000)
208
+
209
+ # Process in chunks
210
+ chunk_length_ms = 30000
211
+ chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
212
+
213
+ transcript_parts = []
214
+ for i, chunk in enumerate(chunks[:20]): # Limit to first 10 minutes
215
  chunk_file = Path(temp_dir) / f"chunk_{i}.wav"
216
  chunk.export(chunk_file, format="wav")
217
+
218
  try:
219
  with sr.AudioFile(str(chunk_file)) as source:
220
+ audio_data = r.record(source)
221
+ text = r.recognize_google(audio_data)
222
  transcript_parts.append(text)
223
+ except sr.UnknownValueError:
224
+ transcript_parts.append("[Unintelligible audio]")
225
+ except sr.RequestError as e:
226
+ transcript_parts.append(f"[Speech recognition error: {e}]")
227
+
228
+ transcript = " ".join(transcript_parts)
229
+ return f"Audio file transcript:\n{transcript}"
230
+
231
  except Exception as e:
232
  print(f"Error processing audio file: {e}")
233
  return f"Error processing audio file: {e}"
234
 
 
235
  def web_search_func(query: str, cache_func) -> str:
236
  """Performs a web search using Tavily and returns a compilation of results."""
 
237
  key = f"web:{query}"
238
  results = cache_func(key, lambda: TavilySearchResults(max_results=5).invoke(query))
239
  return "\n\n---\n\n".join([f"Source: {res['url']}\nContent: {res['content']}" for res in results])
240
 
241
  def wiki_search_func(query: str, cache_func) -> str:
242
  """Searches Wikipedia and returns the top 2 results."""
 
243
  key = f"wiki:{query}"
244
  docs = cache_func(key, lambda: WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=2000).load())
245
  return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\n\n{d.page_content}" for d in docs])
246
 
247
  def arxiv_search_func(query: str, cache_func) -> str:
248
  """Searches Arxiv for scientific papers and returns the top 2 results."""
 
249
  key = f"arxiv:{query}"
250
  docs = cache_func(key, lambda: ArxivLoader(query=query, load_max_docs=2).load())
251
  return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
252
 
253
  # ----------------------------------------------------------
254
+ # Section 3: DYNAMIC SYSTEM PROMPT
255
  # ----------------------------------------------------------
256
  SYSTEM_PROMPT_TEMPLATE = (
257
+ """You are an expert-level multimodal research assistant. Your goal is to answer the user's question accurately using all available tools.
258
+
259
+ **CRITICAL INSTRUCTIONS:**
260
+ 1. **USE YOUR TOOLS:** You have been given a set of tools to find information. You MUST use them when the answer is not immediately known to you. Do not make up answers.
261
+ 2. **MULTIMODAL PROCESSING:** When you encounter URLs or attachments:
262
+ - For image URLs (jpg, png, gif, etc.): Use the `describe_image` tool
263
+ - For YouTube URLs: Use the `process_youtube_video` tool
264
+ - For audio files (mp3, wav, etc.): Use the `process_audio_file` tool
265
+ - For other content: Use appropriate search tools
266
+ 3. **AVAILABLE TOOLS:** Here is the exact list of tools you have access to:
267
+ {tools}
268
+ 4. **REASONING:** Think step-by-step. First, analyze the user's question and any attachments. Second, decide which tools are appropriate. Third, call the tools with correct parameters. Finally, synthesize the results.
269
+ 5. **URL DETECTION:** Look for URLs in the user's message, especially in brackets like [Attachment URL: ...]. Process these appropriately.
270
+ 6. **FINAL ANSWER FORMAT:** Your final response MUST strictly follow this format:
271
+ `FINAL ANSWER: [Your comprehensive answer incorporating all tool results]`
272
+ """
273
  )
274
 
275
  # ----------------------------------------------------------
 
281
  """
282
  print(f"Initializing agent with provider: {provider}")
283
 
284
+ # Step 1: Build LLMs - Use Google for vision capabilities
285
+ if provider == "google":
286
+ main_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
287
+ vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
288
+ elif provider == "groq":
289
+ main_llm = ChatGroq(model_name="llama-3.2-90b-vision-preview", temperature=0)
290
+ # Use Google for vision since Groq's vision support may be limited
291
+ vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
292
+ elif provider == "huggingface":
293
+ main_llm = ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.1))
294
+ vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
295
  else:
296
+ raise ValueError("Invalid provider selected")
297
+
298
+ # Step 2: Build Retriever
299
  embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
300
  if FAISS_CACHE.exists():
301
  with open(FAISS_CACHE, "rb") as f: vector_store = pickle.load(f)
302
  else:
 
 
303
  if JSONL_PATH.exists():
304
  docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
305
+ vector_store = FAISS.from_documents(docs, embeddings)
306
+ with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
307
+ else:
308
+ # Create empty vector store if no metadata file exists
309
  docs = [Document(page_content="Sample document", metadata={"source": "sample"})]
310
+ vector_store = FAISS.from_documents(docs, embeddings)
311
+
312
  retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
313
 
314
+ # Step 3: Create the final list of tools
315
  tools_list = [
316
+ python_repl,
317
+ Tool(name="describe_image", func=functools.partial(describe_image_func, vision_llm_instance=vision_llm), description="Describes an image from a local file path or a URL. Use this for any image files or image URLs."),
318
+ process_youtube_video,
319
+ process_audio_file,
320
  Tool(name="web_search", func=functools.partial(web_search_func, cache_func=cached_get), description="Performs a web search using Tavily."),
321
  Tool(name="wiki_search", func=functools.partial(wiki_search_func, cache_func=cached_get), description="Searches Wikipedia."),
322
  Tool(name="arxiv_search", func=functools.partial(arxiv_search_func, cache_func=cached_get), description="Searches Arxiv for scientific papers."),
323
  create_retriever_tool(retriever=retriever, name="retrieve_examples", description="Retrieve solved questions similar to the user's query."),
324
  ]
325
 
326
+ # Step 4: Format the tool list into a string for the prompt
327
  tool_definitions = "\n".join([f"- `{tool.name}`: {tool.description}" for tool in tools_list])
328
  final_system_prompt = SYSTEM_PROMPT_TEMPLATE.format(tools=tool_definitions)
 
329
 
330
+ llm_with_tools = main_llm.bind_tools(tools_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ # Step 5: Define Graph Nodes
333
+ def retriever_node(state: MessagesState):
334
+ user_query = state["messages"][-1].content
335
+ docs = retriever.invoke(user_query)
336
+ messages = [SystemMessage(content=final_system_prompt)]
337
  if docs:
338
  example_text = "\n\n---\n\n".join(d.page_content for d in docs)
339
+ messages.append(AIMessage(content=f"I have found {len(docs)} similar solved examples:\n\n{example_text}", name="ExampleRetriever"))
340
+ messages.extend(state["messages"])
341
+ return {"messages": messages}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
  def assistant_node(state: MessagesState):
344
  result = llm_with_tools.invoke(state["messages"])
345
  return {"messages": [result]}
346
 
347
  # Step 6: Build Graph
 
348
  builder = StateGraph(MessagesState)
349
+ builder.add_node("retriever", retriever_node)
350
  builder.add_node("assistant", assistant_node)
351
  builder.add_node("tools", ToolNode(tools_list))
352
 
353
+ builder.add_edge(START, "retriever")
354
+ builder.add_edge("retriever", "assistant")
355
  builder.add_conditional_edges("assistant", tools_condition, {"tools": "tools", "__end__": "__end__"})
356
  builder.add_edge("tools", "assistant")
357
 
358
  agent_executor = builder.compile()
359
+ print("Agent Executor created successfully.")
360
  return agent_executor