naman1102 commited on
Commit
8286288
·
1 Parent(s): 8476091

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -48
app.py CHANGED
@@ -59,13 +59,17 @@ def image_qa(image_path: str, prompt: str) -> str:
59
  """Query LLaVA model for image-based QA."""
60
  with open(image_path, "rb") as f:
61
  data = {"prompt": prompt, "image": f.read()}
62
- return client.post("llava-hf/llava-v1.6-mistral-7b-hf", data=data)
 
63
 
64
  def video_label(video_path: str, topk: int = 1) -> str:
65
  """Get video classification using VideoMAE."""
66
  with open(video_path, "rb") as f:
 
67
  preds = client.post(
68
- "MCG-NJU/videomae-base-finetuned-ucf101", data=f.read()
 
 
69
  )
70
  preds = sorted(preds, key=lambda x: x["score"], reverse=True)[:topk]
71
  return preds[0]["label"]
@@ -97,6 +101,7 @@ class AgentState(TypedDict):
97
  task_id: Annotated[str, override]
98
  logs: Annotated[Dict[str, Any], merge_dicts]
99
  code_blocks: Annotated[List[Dict[str, str]], list.__add__]
 
100
 
101
  # -------------------------
102
  # BasicAgent implementation
@@ -139,36 +144,54 @@ class BasicAgent:
139
  state["code_blocks"] = code_blocks
140
  return state
141
 
142
- # Check for multimodal content
143
- q = state["question"].lower()
144
- if "video" in q or q.endswith(".mp4"):
145
- state["current_step"] = "video"
146
- elif q.endswith((".jpg", ".png", ".jpeg")):
147
- state["current_step"] = "image"
148
- elif q.endswith((".xlsx", ".csv")):
149
- state["current_step"] = "sheet"
150
- else:
151
- # Regular text question analysis
152
- prompt = (
153
- "You will receive a user question. Think step‑by‑step to decide whether external web search is required. "
154
- "Respond ONLY with a valid Python dict literal in the following format and NOTHING else:\n"
155
- "{\n 'needs_search': bool,\n 'search_query': str\n} \n\n"
156
- f"Question: {state['question']}"
157
- )
158
- raw = self._call_llm(prompt)
159
  try:
160
- decision = ast.literal_eval(raw)
161
- state["needs_search"] = bool(decision.get("needs_search", False))
162
- state["search_query"] = decision.get("search_query", state["question"])
163
- except Exception:
164
- state["needs_search"] = True
165
- state["search_query"] = state["question"]
166
- decision = {"parse_error": raw}
167
- state["logs"] = {
168
- "analyze": {"prompt": prompt, "llm_response": raw, "decision": decision}
169
- }
170
- state["current_step"] = "search" if state["needs_search"] else "answer"
171
- state["history"].append({"step": "analyze", "output": decision})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  return state
173
 
174
  def _extract_code_blocks(self, text: str) -> List[Dict[str, str]]:
@@ -218,11 +241,54 @@ class BasicAgent:
218
 
219
  return state
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def _image_node(self, state: AgentState) -> AgentState:
222
  """Handle image-based questions."""
223
  try:
224
- answer = image_qa(state["question"], "What is shown in this image?")
225
- state["history"].append({"step": "image", "output": answer})
 
 
 
 
 
226
  state["current_step"] = "answer"
227
  except Exception as e:
228
  state["logs"]["image_error"] = str(e)
@@ -232,8 +298,13 @@ class BasicAgent:
232
  def _video_node(self, state: AgentState) -> AgentState:
233
  """Handle video-based questions."""
234
  try:
235
- label = video_label(state["question"])
236
- state["history"].append({"step": "video", "output": label})
 
 
 
 
 
237
  state["current_step"] = "answer"
238
  except Exception as e:
239
  state["logs"]["video_error"] = str(e)
@@ -243,9 +314,13 @@ class BasicAgent:
243
  def _sheet_node(self, state: AgentState) -> AgentState:
244
  """Handle spreadsheet-based questions."""
245
  try:
246
- with open(state["question"], "rb") as f:
247
- answer = sheet_answer(f.read(), state["question"])
248
- state["history"].append({"step": "sheet", "output": answer})
 
 
 
 
249
  state["current_step"] = "answer"
250
  except Exception as e:
251
  state["logs"]["sheet_error"] = str(e)
@@ -299,16 +374,21 @@ class BasicAgent:
299
  return text.strip()
300
 
301
  def _generate_answer(self, state: AgentState) -> AgentState:
302
- # Get the last search results with error handling
303
- search_block = "No search results available."
304
- try:
305
- # Find the last search step in history
306
- search_steps = [item for item in state["history"] if item.get("step") == "search"]
307
- if search_steps and "results" in search_steps[-1]:
308
- search_block = "\n".join(search_steps[-1]["results"])
309
- except Exception as e:
310
- print(f"Error accessing search results: {e}")
311
- search_block = "Error retrieving search results."
 
 
 
 
 
312
 
313
  prompt = f"""
314
  You are an expert assistant. Use ONLY the materials below to answer.
@@ -392,10 +472,21 @@ Think step-by-step. Write ANSWER: <answer> on its own line.
392
  "task_id": task_id,
393
  "logs": {},
394
  "code_blocks": [],
 
395
  }
396
  final_state = self.workflow.invoke(state)
397
  return final_state["final_answer"]
398
 
 
 
 
 
 
 
 
 
 
 
399
  # ----------------------------------------------------------------------------------
400
  # Gradio Interface & Submission Routines
401
  # ----------------------------------------------------------------------------------
@@ -479,6 +570,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
479
  "task_id": task_id,
480
  "logs": {},
481
  "code_blocks": [],
 
482
  }
483
 
484
  # Run the workflow
 
59
  """Query LLaVA model for image-based QA."""
60
  with open(image_path, "rb") as f:
61
  data = {"prompt": prompt, "image": f.read()}
62
+ headers = {"Content-Type": "application/octet-stream"}
63
+ return client.post("llava-hf/llava-v1.6-mistral-7b-hf", data=data, headers=headers)
64
 
65
  def video_label(video_path: str, topk: int = 1) -> str:
66
  """Get video classification using VideoMAE."""
67
  with open(video_path, "rb") as f:
68
+ headers = {"Content-Type": "application/octet-stream"}
69
  preds = client.post(
70
+ "MCG-NJU/videomae-base-finetuned-ucf101",
71
+ data=f.read(),
72
+ headers=headers
73
  )
74
  preds = sorted(preds, key=lambda x: x["score"], reverse=True)[:topk]
75
  return preds[0]["label"]
 
101
  task_id: Annotated[str, override]
102
  logs: Annotated[Dict[str, Any], merge_dicts]
103
  code_blocks: Annotated[List[Dict[str, str]], list.__add__]
104
+ attachment_data: Annotated[Dict[str, bytes], merge_dicts] # Store downloaded file data
105
 
106
  # -------------------------
107
  # BasicAgent implementation
 
144
  state["code_blocks"] = code_blocks
145
  return state
146
 
147
+ # Check for file attachments in the question
148
+ if "file_url" in state["question"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  try:
150
+ # Parse the question to get file URL
151
+ question_data = json.loads(state["question"])
152
+ file_url = question_data.get("file_url")
153
+ if file_url:
154
+ # Download the file
155
+ file_data = self._download_file(file_url)
156
+ # Store in state
157
+ state["attachment_data"] = {
158
+ "content": file_data,
159
+ "type": self._detect_file_type(file_data, file_url)
160
+ }
161
+ # Set appropriate step based on file type
162
+ if state["attachment_data"]["type"] == "video":
163
+ state["current_step"] = "video"
164
+ elif state["attachment_data"]["type"] == "image":
165
+ state["current_step"] = "image"
166
+ elif state["attachment_data"]["type"] in ["excel", "csv"]:
167
+ state["current_step"] = "sheet"
168
+ return state
169
+ except Exception as e:
170
+ state["logs"]["file_download_error"] = str(e)
171
+ state["current_step"] = "answer"
172
+ return state
173
+
174
+ # Regular text question analysis
175
+ prompt = (
176
+ "You will receive a user question. Think step‑by‑step to decide whether external web search is required. "
177
+ "Respond ONLY with a valid Python dict literal in the following format and NOTHING else:\n"
178
+ "{\n 'needs_search': bool,\n 'search_query': str\n} \n\n"
179
+ f"Question: {state['question']}"
180
+ )
181
+ raw = self._call_llm(prompt)
182
+ try:
183
+ decision = ast.literal_eval(raw)
184
+ state["needs_search"] = bool(decision.get("needs_search", False))
185
+ state["search_query"] = decision.get("search_query", state["question"])
186
+ except Exception:
187
+ state["needs_search"] = True
188
+ state["search_query"] = state["question"]
189
+ decision = {"parse_error": raw}
190
+ state["logs"] = {
191
+ "analyze": {"prompt": prompt, "llm_response": raw, "decision": decision}
192
+ }
193
+ state["current_step"] = "search" if state["needs_search"] else "answer"
194
+ state["history"].append({"step": "analyze", "output": decision})
195
  return state
196
 
197
  def _extract_code_blocks(self, text: str) -> List[Dict[str, str]]:
 
241
 
242
  return state
243
 
244
+ def _detect_file_type(self, data: bytes, url: str) -> str:
245
+ """Detect file type from content and URL."""
246
+ # Check URL extension first
247
+ url_lower = url.lower()
248
+ if url_lower.endswith((".mp4", ".avi", ".mov")):
249
+ return "video"
250
+ elif url_lower.endswith((".jpg", ".jpeg", ".png", ".gif")):
251
+ return "image"
252
+ elif url_lower.endswith(".xlsx"):
253
+ return "excel"
254
+ elif url_lower.endswith(".csv"):
255
+ return "csv"
256
+
257
+ # If URL check fails, try content-based detection
258
+ try:
259
+ # Try to detect image
260
+ from PIL import Image
261
+ Image.open(io.BytesIO(data))
262
+ return "image"
263
+ except:
264
+ pass
265
+
266
+ try:
267
+ # Try to detect Excel
268
+ pd.read_excel(io.BytesIO(data))
269
+ return "excel"
270
+ except:
271
+ pass
272
+
273
+ try:
274
+ # Try to detect CSV
275
+ pd.read_csv(io.BytesIO(data))
276
+ return "csv"
277
+ except:
278
+ pass
279
+
280
+ return "unknown"
281
+
282
  def _image_node(self, state: AgentState) -> AgentState:
283
  """Handle image-based questions."""
284
  try:
285
+ if "attachment_data" in state and "content" in state["attachment_data"]:
286
+ # Use the downloaded image data
287
+ image_data = state["attachment_data"]["content"]
288
+ answer = image_qa(image_data, "What is shown in this image?")
289
+ state["history"].append({"step": "image", "output": answer})
290
+ else:
291
+ raise ValueError("No image data found in state")
292
  state["current_step"] = "answer"
293
  except Exception as e:
294
  state["logs"]["image_error"] = str(e)
 
298
  def _video_node(self, state: AgentState) -> AgentState:
299
  """Handle video-based questions."""
300
  try:
301
+ if "attachment_data" in state and "content" in state["attachment_data"]:
302
+ # Use the downloaded video data
303
+ video_data = state["attachment_data"]["content"]
304
+ label = video_label(video_data)
305
+ state["history"].append({"step": "video", "output": label})
306
+ else:
307
+ raise ValueError("No video data found in state")
308
  state["current_step"] = "answer"
309
  except Exception as e:
310
  state["logs"]["video_error"] = str(e)
 
314
  def _sheet_node(self, state: AgentState) -> AgentState:
315
  """Handle spreadsheet-based questions."""
316
  try:
317
+ if "attachment_data" in state and "content" in state["attachment_data"]:
318
+ # Use the downloaded spreadsheet data
319
+ sheet_data = state["attachment_data"]["content"]
320
+ answer = sheet_answer(sheet_data, state["question"])
321
+ state["history"].append({"step": "sheet", "output": answer})
322
+ else:
323
+ raise ValueError("No spreadsheet data found in state")
324
  state["current_step"] = "answer"
325
  except Exception as e:
326
  state["logs"]["sheet_error"] = str(e)
 
374
  return text.strip()
375
 
376
  def _generate_answer(self, state: AgentState) -> AgentState:
377
+ # Collect all relevant tool outputs
378
+ materials = []
379
+ for item in state["history"]:
380
+ if item["step"] in ("search", "image", "video", "sheet", "code_analysis"):
381
+ # Handle different output formats
382
+ if item["step"] == "search":
383
+ output = item.get("results", [])
384
+ if isinstance(output, list):
385
+ output = "\n".join(output)
386
+ else:
387
+ output = item.get("output", "")
388
+ materials.append(str(output))
389
+
390
+ # Join all materials with proper formatting
391
+ search_block = "\n".join(materials) if materials else "No artefacts available."
392
 
393
  prompt = f"""
394
  You are an expert assistant. Use ONLY the materials below to answer.
 
472
  "task_id": task_id,
473
  "logs": {},
474
  "code_blocks": [],
475
+ "attachment_data": {}
476
  }
477
  final_state = self.workflow.invoke(state)
478
  return final_state["final_answer"]
479
 
480
+ def _download_file(self, url: str) -> bytes:
481
+ """Download a file from a URL."""
482
+ try:
483
+ response = requests.get(url, timeout=20)
484
+ response.raise_for_status()
485
+ return response.content
486
+ except Exception as e:
487
+ print(f"Error downloading file from {url}: {e}")
488
+ raise
489
+
490
  # ----------------------------------------------------------------------------------
491
  # Gradio Interface & Submission Routines
492
  # ----------------------------------------------------------------------------------
 
570
  "task_id": task_id,
571
  "logs": {},
572
  "code_blocks": [],
573
+ "attachment_data": {}
574
  }
575
 
576
  # Run the workflow