Abbasid commited on
Commit
bd93d23
·
verified ·
1 Parent(s): f9b5dc1

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +212 -45
agent.py CHANGED
@@ -2,12 +2,11 @@
2
  agent.py
3
 
4
  This file defines the core logic for a sophisticated AI agent using LangGraph.
5
- This version uses a dynamic system prompt to explicitly list the available tools
6
- for the LLM on every run, designed to combat "tool refusal".
7
  """
8
 
9
  # ----------------------------------------------------------
10
- # Section 0: Imports and Configuration (Identical to before)
11
  # ----------------------------------------------------------
12
  import json
13
  import os
@@ -19,6 +18,10 @@ import base64
19
  import functools
20
  from io import BytesIO
21
  from pathlib import Path
 
 
 
 
22
 
23
  import requests
24
  from cachetools import TTLCache
@@ -40,7 +43,7 @@ from langgraph.prebuilt import ToolNode, tools_condition
40
  from dotenv import load_dotenv
41
  load_dotenv()
42
 
43
- # --- Configuration and Caching (Identical) ---
44
  JSONL_PATH, FAISS_CACHE, EMBED_MODEL = Path("metadata.jsonl"), Path("faiss_index.pkl"), "sentence-transformers/all-mpnet-base-v2"
45
  RETRIEVER_K, CACHE_TTL = 5, 600
46
  API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
@@ -51,7 +54,7 @@ def cached_get(key: str, fetch_fn):
51
  return val
52
 
53
  # ----------------------------------------------------------
54
- # Section 2: Standalone Tool Functions (Identical to before)
55
  # ----------------------------------------------------------
56
  @tool
57
  def python_repl(code: str) -> str:
@@ -66,14 +69,168 @@ def python_repl(code: str) -> str:
66
  def describe_image_func(image_source: str, vision_llm_instance) -> str:
67
  """Describes an image from a local file path or a URL using a provided vision LLM."""
68
  try:
69
- if image_source.startswith("http"): img = Image.open(BytesIO(requests.get(image_source, timeout=10).content))
70
- else: img = Image.open(image_source)
 
 
 
 
 
 
 
 
 
71
  buffered = BytesIO()
72
  img.convert("RGB").save(buffered, format="JPEG")
73
  b64_string = base64.b64encode(buffered.getvalue()).decode()
74
- msg = HumanMessage(content=[{"type": "text", "text": "Describe this image in detail."}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}])
75
- return vision_llm_instance.invoke([msg]).content
76
- except Exception as e: return f"Error processing image: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  def web_search_func(query: str, cache_func) -> str:
79
  """Performs a web search using Tavily and returns a compilation of results."""
@@ -94,75 +251,88 @@ def arxiv_search_func(query: str, cache_func) -> str:
94
  return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
95
 
96
  # ----------------------------------------------------------
97
- # Section 3: NEW DYNAMIC SYSTEM PROMPT
98
  # ----------------------------------------------------------
99
- # This is now a template string. The {tools} section will be filled in dynamically.
100
  SYSTEM_PROMPT_TEMPLATE = (
101
- """You are an expert-level research assistant. Your goal is to answer the user's question accurately.
102
 
103
  **CRITICAL INSTRUCTIONS:**
104
- 1. **USE YOUR TOOLS:** You have been given a set of tools to find information. You MUST use them when the answer is not immediately known to you. Do not make up answers. Do not apologize or refuse to use a tool. You must try.
105
- 2. **AVAILABLE TOOLS:** Here is the exact list of tools you have access to:
 
 
 
 
 
106
  {tools}
107
- 3. **REASONING:** Think step-by-step. First, analyze the user's question. Second, decide which tool is appropriate. Third, call the tool with the correct parameters. Finally, analyze the tool's output to formulate your answer.
108
- 4. **LIMITATIONS:** If a question requires a capability you absolutely do not have (e.g., watching a video, listening to audio), you must state that limitation clearly.
109
- 5. **FINAL ANSWER FORMAT:** Your final response MUST strictly follow this format and nothing else:
110
- `FINAL ANSWER: [Your concise and accurate answer here]`
111
  """
112
  )
113
 
114
  # ----------------------------------------------------------
115
- # Section 4: Factory Function for Agent Executor (MODIFIED)
116
  # ----------------------------------------------------------
117
  def create_agent_executor(provider: str = "groq"):
118
  """
119
  Factory function to create and compile the LangGraph agent executor.
120
- This version dynamically builds the system prompt with the list of tools.
121
  """
122
  print(f"Initializing agent with provider: {provider}")
123
 
124
- # Step 1: Build LLMs (Identical)
125
- if provider == "google": main_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
126
- elif provider == "groq": main_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
127
- elif provider == "huggingface": main_llm = ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.1))
128
- else: raise ValueError("Invalid provider selected")
129
- vision_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
 
 
 
 
 
 
 
130
 
131
- # Step 2: Build Retriever (Identical)
132
  embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
133
  if FAISS_CACHE.exists():
134
  with open(FAISS_CACHE, "rb") as f: vector_store = pickle.load(f)
135
  else:
136
- docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
137
- vector_store = FAISS.from_documents(docs, embeddings)
138
- with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
 
 
 
 
 
 
139
  retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
140
 
141
- # Step 3: Create the final list of tools (Identical)
142
  tools_list = [
143
  python_repl,
144
- Tool(name="describe_image", func=functools.partial(describe_image_func, vision_llm_instance=vision_llm), description="Describes an image from a local file path or a URL."),
 
 
145
  Tool(name="web_search", func=functools.partial(web_search_func, cache_func=cached_get), description="Performs a web search using Tavily."),
146
  Tool(name="wiki_search", func=functools.partial(wiki_search_func, cache_func=cached_get), description="Searches Wikipedia."),
147
  Tool(name="arxiv_search", func=functools.partial(arxiv_search_func, cache_func=cached_get), description="Searches Arxiv for scientific papers."),
148
  create_retriever_tool(retriever=retriever, name="retrieve_examples", description="Retrieve solved questions similar to the user's query."),
149
  ]
150
 
151
- # --- THIS PART IS NEW ---
152
- # 4a. Format the tool list into a string for the prompt
153
  tool_definitions = "\n".join([f"- `{tool.name}`: {tool.description}" for tool in tools_list])
154
-
155
- # 4b. Create the final, dynamic system prompt
156
  final_system_prompt = SYSTEM_PROMPT_TEMPLATE.format(tools=tool_definitions)
157
- # --- END NEW PART ---
158
 
159
  llm_with_tools = main_llm.bind_tools(tools_list)
160
 
161
- # Step 5: Define Graph Nodes (Modified to use the new prompt)
162
  def retriever_node(state: MessagesState):
163
  user_query = state["messages"][-1].content
164
  docs = retriever.invoke(user_query)
165
- # Use the new, dynamic prompt here
166
  messages = [SystemMessage(content=final_system_prompt)]
167
  if docs:
168
  example_text = "\n\n---\n\n".join(d.page_content for d in docs)
@@ -174,7 +344,7 @@ def create_agent_executor(provider: str = "groq"):
174
  result = llm_with_tools.invoke(state["messages"])
175
  return {"messages": [result]}
176
 
177
- # Step 6: Build Graph (Identical)
178
  builder = StateGraph(MessagesState)
179
  builder.add_node("retriever", retriever_node)
180
  builder.add_node("assistant", assistant_node)
@@ -187,7 +357,4 @@ def create_agent_executor(provider: str = "groq"):
187
 
188
  agent_executor = builder.compile()
189
  print("Agent Executor created successfully.")
190
- return agent_executor
191
-
192
- # --- Section 5 (Testing functions) remains the same ---
193
- # ... (test_llm_connection and __main__ block)
 
2
  agent.py
3
 
4
  This file defines the core logic for a sophisticated AI agent using LangGraph.
5
+ This version includes proper multimodal support for images, YouTube videos, and audio files.
 
6
  """
7
 
8
  # ----------------------------------------------------------
9
+ # Section 0: Imports and Configuration
10
  # ----------------------------------------------------------
11
  import json
12
  import os
 
18
  import functools
19
  from io import BytesIO
20
  from pathlib import Path
21
+ import tempfile
22
+ import yt_dlp
23
+ from pydub import AudioSegment
24
+ import speech_recognition as sr
25
 
26
  import requests
27
  from cachetools import TTLCache
 
43
  from dotenv import load_dotenv
44
  load_dotenv()
45
 
46
+ # --- Configuration and Caching ---
47
  JSONL_PATH, FAISS_CACHE, EMBED_MODEL = Path("metadata.jsonl"), Path("faiss_index.pkl"), "sentence-transformers/all-mpnet-base-v2"
48
  RETRIEVER_K, CACHE_TTL = 5, 600
49
  API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL)
 
54
  return val
55
 
56
  # ----------------------------------------------------------
57
+ # Section 2: Standalone Tool Functions
58
  # ----------------------------------------------------------
59
  @tool
60
  def python_repl(code: str) -> str:
 
69
  def describe_image_func(image_source: str, vision_llm_instance) -> str:
70
  """Describes an image from a local file path or a URL using a provided vision LLM."""
71
  try:
72
+ print(f"Processing image: {image_source}")
73
+
74
+ # Download and process image
75
+ if image_source.startswith("http"):
76
+ response = requests.get(image_source, timeout=10)
77
+ response.raise_for_status()
78
+ img = Image.open(BytesIO(response.content))
79
+ else:
80
+ img = Image.open(image_source)
81
+
82
+ # Convert to base64
83
  buffered = BytesIO()
84
  img.convert("RGB").save(buffered, format="JPEG")
85
  b64_string = base64.b64encode(buffered.getvalue()).decode()
86
+
87
+ # Create multimodal message
88
+ msg = HumanMessage(content=[
89
+ {"type": "text", "text": "Describe this image in detail. Include all objects, people, text, colors, setting, and any other relevant information you can see."},
90
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_string}"}}
91
+ ])
92
+
93
+ result = vision_llm_instance.invoke([msg])
94
+ return f"Image description: {result.content}"
95
+
96
+ except Exception as e:
97
+ print(f"Error in describe_image_func: {e}")
98
+ return f"Error processing image: {e}"
99
+
100
+ @tool
101
+ def process_youtube_video(url: str) -> str:
102
+ """Downloads and processes a YouTube video, extracting audio and converting to text."""
103
+ try:
104
+ print(f"Processing YouTube video: {url}")
105
+
106
+ # Create temporary directory
107
+ with tempfile.TemporaryDirectory() as temp_dir:
108
+ # Download audio from YouTube video
109
+ ydl_opts = {
110
+ 'format': 'bestaudio/best',
111
+ 'outtmpl': f'{temp_dir}/%(title)s.%(ext)s',
112
+ 'postprocessors': [{
113
+ 'key': 'FFmpegExtractAudio',
114
+ 'preferredcodec': 'wav',
115
+ }],
116
+ }
117
+
118
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
119
+ info = ydl.extract_info(url, download=True)
120
+ title = info.get('title', 'Unknown')
121
+
122
+ # Find the downloaded audio file
123
+ audio_files = list(Path(temp_dir).glob("*.wav"))
124
+ if not audio_files:
125
+ return "Error: Could not download audio from YouTube video"
126
+
127
+ audio_file = audio_files[0]
128
+
129
+ # Convert audio to text using speech recognition
130
+ r = sr.Recognizer()
131
+
132
+ # Load audio file
133
+ audio = AudioSegment.from_wav(str(audio_file))
134
+
135
+ # Convert to mono and set sample rate
136
+ audio = audio.set_channels(1)
137
+ audio = audio.set_frame_rate(16000)
138
+
139
+ # Convert to smaller chunks for processing (30 seconds each)
140
+ chunk_length_ms = 30000
141
+ chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
142
+
143
+ transcript_parts = []
144
+ for i, chunk in enumerate(chunks[:10]): # Limit to first 5 minutes
145
+ chunk_file = Path(temp_dir) / f"chunk_{i}.wav"
146
+ chunk.export(chunk_file, format="wav")
147
+
148
+ try:
149
+ with sr.AudioFile(str(chunk_file)) as source:
150
+ audio_data = r.record(source)
151
+ text = r.recognize_google(audio_data)
152
+ transcript_parts.append(text)
153
+ except sr.UnknownValueError:
154
+ transcript_parts.append("[Unintelligible audio]")
155
+ except sr.RequestError as e:
156
+ transcript_parts.append(f"[Speech recognition error: {e}]")
157
+
158
+ transcript = " ".join(transcript_parts)
159
+
160
+ return f"YouTube Video: {title}\n\nTranscript (first 5 minutes):\n{transcript}"
161
+
162
+ except Exception as e:
163
+ print(f"Error processing YouTube video: {e}")
164
+ return f"Error processing YouTube video: {e}"
165
+
166
+ @tool
167
+ def process_audio_file(file_url: str) -> str:
168
+ """Downloads and processes an audio file (MP3, WAV, etc.) and converts to text."""
169
+ try:
170
+ print(f"Processing audio file: {file_url}")
171
+
172
+ with tempfile.TemporaryDirectory() as temp_dir:
173
+ # Download audio file
174
+ response = requests.get(file_url, timeout=30)
175
+ response.raise_for_status()
176
+
177
+ # Determine file extension from URL or content type
178
+ if file_url.lower().endswith('.mp3'):
179
+ ext = 'mp3'
180
+ elif file_url.lower().endswith('.wav'):
181
+ ext = 'wav'
182
+ else:
183
+ content_type = response.headers.get('content-type', '')
184
+ if 'mp3' in content_type:
185
+ ext = 'mp3'
186
+ elif 'wav' in content_type:
187
+ ext = 'wav'
188
+ else:
189
+ ext = 'mp3' # Default assumption
190
+
191
+ audio_file = Path(temp_dir) / f"audio.{ext}"
192
+ with open(audio_file, 'wb') as f:
193
+ f.write(response.content)
194
+
195
+ # Convert to WAV if necessary
196
+ if ext != 'wav':
197
+ audio = AudioSegment.from_file(str(audio_file))
198
+ wav_file = Path(temp_dir) / "audio.wav"
199
+ audio.export(wav_file, format="wav")
200
+ audio_file = wav_file
201
+
202
+ # Convert audio to text
203
+ r = sr.Recognizer()
204
+
205
+ # Load and process audio
206
+ audio = AudioSegment.from_wav(str(audio_file))
207
+ audio = audio.set_channels(1).set_frame_rate(16000)
208
+
209
+ # Process in chunks
210
+ chunk_length_ms = 30000
211
+ chunks = [audio[i:i + chunk_length_ms] for i in range(0, len(audio), chunk_length_ms)]
212
+
213
+ transcript_parts = []
214
+ for i, chunk in enumerate(chunks[:20]): # Limit to first 10 minutes
215
+ chunk_file = Path(temp_dir) / f"chunk_{i}.wav"
216
+ chunk.export(chunk_file, format="wav")
217
+
218
+ try:
219
+ with sr.AudioFile(str(chunk_file)) as source:
220
+ audio_data = r.record(source)
221
+ text = r.recognize_google(audio_data)
222
+ transcript_parts.append(text)
223
+ except sr.UnknownValueError:
224
+ transcript_parts.append("[Unintelligible audio]")
225
+ except sr.RequestError as e:
226
+ transcript_parts.append(f"[Speech recognition error: {e}]")
227
+
228
+ transcript = " ".join(transcript_parts)
229
+ return f"Audio file transcript:\n{transcript}"
230
+
231
+ except Exception as e:
232
+ print(f"Error processing audio file: {e}")
233
+ return f"Error processing audio file: {e}"
234
 
235
  def web_search_func(query: str, cache_func) -> str:
236
  """Performs a web search using Tavily and returns a compilation of results."""
 
251
  return "\n\n---\n\n".join([f"Source: {d.metadata['source']}\nPublished: {d.metadata['Published']}\nTitle: {d.metadata['Title']}\n\nSummary:\n{d.page_content}" for d in docs])
252
 
253
  # ----------------------------------------------------------
254
+ # Section 3: DYNAMIC SYSTEM PROMPT
255
  # ----------------------------------------------------------
 
256
  SYSTEM_PROMPT_TEMPLATE = (
257
+ """You are an expert-level multimodal research assistant. Your goal is to answer the user's question accurately using all available tools.
258
 
259
  **CRITICAL INSTRUCTIONS:**
260
+ 1. **USE YOUR TOOLS:** You have been given a set of tools to find information. You MUST use them when the answer is not immediately known to you. Do not make up answers.
261
+ 2. **MULTIMODAL PROCESSING:** When you encounter URLs or attachments:
262
+ - For image URLs (jpg, png, gif, etc.): Use the `describe_image` tool
263
+ - For YouTube URLs: Use the `process_youtube_video` tool
264
+ - For audio files (mp3, wav, etc.): Use the `process_audio_file` tool
265
+ - For other content: Use appropriate search tools
266
+ 3. **AVAILABLE TOOLS:** Here is the exact list of tools you have access to:
267
  {tools}
268
+ 4. **REASONING:** Think step-by-step. First, analyze the user's question and any attachments. Second, decide which tools are appropriate. Third, call the tools with correct parameters. Finally, synthesize the results.
269
+ 5. **URL DETECTION:** Look for URLs in the user's message, especially in brackets like [Attachment URL: ...]. Process these appropriately.
270
+ 6. **FINAL ANSWER FORMAT:** Your final response MUST strictly follow this format:
271
+ `FINAL ANSWER: [Your comprehensive answer incorporating all tool results]`
272
  """
273
  )
274
 
275
  # ----------------------------------------------------------
276
+ # Section 4: Factory Function for Agent Executor
277
  # ----------------------------------------------------------
278
  def create_agent_executor(provider: str = "groq"):
279
  """
280
  Factory function to create and compile the LangGraph agent executor.
 
281
  """
282
  print(f"Initializing agent with provider: {provider}")
283
 
284
+ # Step 1: Build LLMs - Use Google for vision capabilities
285
+ if provider == "google":
286
+ main_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
287
+ vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
288
+ elif provider == "groq":
289
+ main_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
290
+ # Use Google for vision since Groq's vision support may be limited
291
+ main_llm = ChatGroq(model_name="meta-llama/llama-4-maverick-17b-128e-instruct", temperature=0)
292
+ elif provider == "huggingface":
293
+ main_llm = ChatHuggingFace(llm=HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", temperature=0.1))
294
+ vision_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0)
295
+ else:
296
+ raise ValueError("Invalid provider selected")
297
 
298
+ # Step 2: Build Retriever
299
  embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
300
  if FAISS_CACHE.exists():
301
  with open(FAISS_CACHE, "rb") as f: vector_store = pickle.load(f)
302
  else:
303
+ if JSONL_PATH.exists():
304
+ docs = [Document(page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", metadata={"source": rec["task_id"]}) for rec in (json.loads(line) for line in open(JSONL_PATH, "rt", encoding="utf-8"))]
305
+ vector_store = FAISS.from_documents(docs, embeddings)
306
+ with open(FAISS_CACHE, "wb") as f: pickle.dump(vector_store, f)
307
+ else:
308
+ # Create empty vector store if no metadata file exists
309
+ docs = [Document(page_content="Sample document", metadata={"source": "sample"})]
310
+ vector_store = FAISS.from_documents(docs, embeddings)
311
+
312
  retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K})
313
 
314
+ # Step 3: Create the final list of tools
315
  tools_list = [
316
  python_repl,
317
+ Tool(name="describe_image", func=functools.partial(describe_image_func, vision_llm_instance=vision_llm), description="Describes an image from a local file path or a URL. Use this for any image files or image URLs."),
318
+ process_youtube_video,
319
+ process_audio_file,
320
  Tool(name="web_search", func=functools.partial(web_search_func, cache_func=cached_get), description="Performs a web search using Tavily."),
321
  Tool(name="wiki_search", func=functools.partial(wiki_search_func, cache_func=cached_get), description="Searches Wikipedia."),
322
  Tool(name="arxiv_search", func=functools.partial(arxiv_search_func, cache_func=cached_get), description="Searches Arxiv for scientific papers."),
323
  create_retriever_tool(retriever=retriever, name="retrieve_examples", description="Retrieve solved questions similar to the user's query."),
324
  ]
325
 
326
+ # Step 4: Format the tool list into a string for the prompt
 
327
  tool_definitions = "\n".join([f"- `{tool.name}`: {tool.description}" for tool in tools_list])
 
 
328
  final_system_prompt = SYSTEM_PROMPT_TEMPLATE.format(tools=tool_definitions)
 
329
 
330
  llm_with_tools = main_llm.bind_tools(tools_list)
331
 
332
+ # Step 5: Define Graph Nodes
333
  def retriever_node(state: MessagesState):
334
  user_query = state["messages"][-1].content
335
  docs = retriever.invoke(user_query)
 
336
  messages = [SystemMessage(content=final_system_prompt)]
337
  if docs:
338
  example_text = "\n\n---\n\n".join(d.page_content for d in docs)
 
344
  result = llm_with_tools.invoke(state["messages"])
345
  return {"messages": [result]}
346
 
347
+ # Step 6: Build Graph
348
  builder = StateGraph(MessagesState)
349
  builder.add_node("retriever", retriever_node)
350
  builder.add_node("assistant", assistant_node)
 
357
 
358
  agent_executor = builder.compile()
359
  print("Agent Executor created successfully.")
360
+ return agent_executor