naman1102 commited on
Commit
4e8e7db
·
1 Parent(s): d07b7a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -28
app.py CHANGED
@@ -16,6 +16,12 @@ from huggingface_hub import InferenceClient
16
  import io
17
  import mimetypes
18
  import base64
 
 
 
 
 
 
19
 
20
  # -------------------------
21
  # Environment & constants
@@ -72,17 +78,60 @@ def retry_hf_inference(func):
72
 
73
  @retry_hf_inference
74
  def image_qa_bytes(data: bytes, prompt: str) -> str:
75
- """Query MiniGPT-4-V for image-based QA using bytes."""
76
  headers = {"Content-Type": "application/octet-stream"}
77
- return client.post("Vision-CAIR/MiniGPT4-V", data=data, headers=headers)
78
 
79
  @retry_hf_inference
80
  def video_label_bytes(data: bytes) -> str:
81
- """Get video classification using VideoMAE-Base-Short from bytes."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  headers = {"Content-Type": "application/octet-stream"}
83
  preds = client.post(
84
- "MCG-NJU/videomae-base-short-finetuned-ucf101",
85
- data=data,
86
  headers=headers
87
  )
88
  return sorted(preds, key=lambda x: x["score"], reverse=True)[0]["label"]
@@ -100,6 +149,18 @@ def sheet_answer_bytes(data: bytes, question: str) -> str:
100
  label = df.columns[col]
101
  return f"{label}: {value}"
102
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  # -------------------------
104
  # State definition
105
  # -------------------------
@@ -241,6 +302,47 @@ class BasicAgent:
241
  state["current_step"] = "answer"
242
  return state
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  def _generate_answer(self, state: AgentState) -> AgentState:
245
  # Collect all tool outputs with clear section headers
246
  materials = []
@@ -349,6 +451,7 @@ Write ANSWER: <answer> on its own line.
349
  sg.add_node("image", self._image_node)
350
  sg.add_node("video", self._video_node)
351
  sg.add_node("sheet", self._sheet_node)
 
352
 
353
  # Add edges
354
  sg.add_edge("analyze", "search")
@@ -357,6 +460,7 @@ Write ANSWER: <answer> on its own line.
357
  sg.add_edge("image", "answer")
358
  sg.add_edge("video", "answer")
359
  sg.add_edge("sheet", "answer")
 
360
 
361
  def router(state: AgentState):
362
  return state["current_step"]
@@ -366,7 +470,8 @@ Write ANSWER: <answer> on its own line.
366
  "answer": "answer",
367
  "image": "image",
368
  "video": "video",
369
- "sheet": "sheet"
 
370
  })
371
 
372
  sg.set_entry_point("analyze")
@@ -465,36 +570,22 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
465
 
466
  for item in questions_data:
467
  task_id = item.get("task_id")
468
- question_text = item.get("question")
469
- if not task_id or question_text is None:
470
- print(f"Skipping item with missing task_id or question: {item}")
471
  continue
472
 
473
  try:
474
- print(f"\nProcessing question {task_id}: {question_text[:50]}...")
475
-
476
- # Initialize state for this question
477
- state: AgentState = {
478
- "question": question_text,
479
- "current_step": "analyze",
480
- "final_answer": "",
481
- "history": [],
482
- "needs_search": False,
483
- "search_query": "",
484
- "task_id": task_id,
485
- "logs": {},
486
- "file_url": ""
487
- }
488
 
489
- # Run the workflow
490
- final_state = agent.workflow.invoke(state)
491
- answer = final_state["final_answer"]
492
 
493
  # Add to results
494
  answers_payload.append({"task_id": task_id, "submitted_answer": answer})
495
  results_log.append({
496
  "Task ID": task_id,
497
- "Question": question_text,
498
  "Submitted Answer": answer
499
  })
500
 
@@ -504,7 +595,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
504
  print(f"Error running agent on task {task_id}: {e}")
505
  results_log.append({
506
  "Task ID": task_id,
507
- "Question": question_text,
508
  "Submitted Answer": f"ERROR: {e}"
509
  })
510
 
 
16
  import io
17
  import mimetypes
18
  import base64
19
+ import cv2
20
+ import numpy as np
21
+ from io import BytesIO
22
+ import tempfile
23
+ import subprocess
24
+ import sys
25
 
26
  # -------------------------
27
  # Environment & constants
 
78
 
79
  @retry_hf_inference
80
  def image_qa_bytes(data: bytes, prompt: str) -> str:
81
+ """Query LLaVA for image-based QA using bytes."""
82
  headers = {"Content-Type": "application/octet-stream"}
83
+ return client.post("llava-hf/llava-v1.6-mistral-7b-hf", data=data, headers=headers)
84
 
85
  @retry_hf_inference
86
  def video_label_bytes(data: bytes) -> str:
87
+ """Get video classification using VideoMAE-Base from bytes."""
88
+ # Process video to get first 8 seconds, 16 frames
89
+
90
+ # Read video from bytes
91
+ video_bytes = BytesIO(data)
92
+ cap = cv2.VideoCapture()
93
+ cap.open(video_bytes)
94
+
95
+ # Get video properties
96
+ fps = cap.get(cv2.CAP_PROP_FPS)
97
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
98
+
99
+ # Calculate frames to extract (16 frames over 8 seconds)
100
+ target_frames = 16
101
+ target_duration = 8 # seconds
102
+ frame_interval = max(1, int(frame_count / (fps * target_duration)))
103
+
104
+ frames = []
105
+ frame_idx = 0
106
+
107
+ while len(frames) < target_frames and frame_idx < frame_count:
108
+ ret, frame = cap.read()
109
+ if not ret:
110
+ break
111
+
112
+ if frame_idx % frame_interval == 0:
113
+ # Resize frame to match VideoMAE's expected input
114
+ frame = cv2.resize(frame, (224, 224))
115
+ frames.append(frame)
116
+
117
+ frame_idx += 1
118
+
119
+ cap.release()
120
+
121
+ # If we don't have enough frames, duplicate the last frame
122
+ while len(frames) < target_frames:
123
+ frames.append(frames[-1])
124
+
125
+ # Stack frames and convert to bytes
126
+ video_array = np.stack(frames)
127
+ _, buffer = cv2.imencode('.mp4', video_array)
128
+ processed_bytes = buffer.tobytes()
129
+
130
+ # Send to VideoMAE
131
  headers = {"Content-Type": "application/octet-stream"}
132
  preds = client.post(
133
+ "MCG-NJU/videomae-base-finetuned-ucf101",
134
+ data=processed_bytes,
135
  headers=headers
136
  )
137
  return sorted(preds, key=lambda x: x["score"], reverse=True)[0]["label"]
 
149
  label = df.columns[col]
150
  return f"{label}: {value}"
151
 
152
+ # -------------------------
153
+ # Code Analysis helpers
154
+ # -------------------------
155
+
156
+ def run_python(code: str) -> str:
157
+ """Quick & dirty evaluator for Python code."""
158
+ with tempfile.NamedTemporaryFile("w+", suffix=".py", delete=False) as f:
159
+ f.write(code)
160
+ f.flush()
161
+ out = subprocess.check_output([sys.executable, f.name], timeout=10)
162
+ return out.decode().strip()
163
+
164
  # -------------------------
165
  # State definition
166
  # -------------------------
 
302
  state["current_step"] = "answer"
303
  return state
304
 
305
+ def _code_analysis_node(self, state: AgentState) -> AgentState:
306
+ """Handle code analysis questions."""
307
+ try:
308
+ # Extract code from the question
309
+ code_match = re.search(r'```python\n(.*?)\n```', state["question"], re.DOTALL)
310
+ if not code_match:
311
+ state["logs"]["code_error"] = "No Python code found in question"
312
+ state["current_step"] = "answer"
313
+ return state
314
+
315
+ code = code_match.group(1)
316
+
317
+ # Run the code and get output
318
+ try:
319
+ output = run_python(code)
320
+ state["history"].append({
321
+ "step": "code",
322
+ "output": f"Code execution result:\n{output}"
323
+ })
324
+ except subprocess.TimeoutExpired:
325
+ state["history"].append({
326
+ "step": "code",
327
+ "output": "Code execution timed out after 10 seconds"
328
+ })
329
+ except subprocess.CalledProcessError as e:
330
+ state["history"].append({
331
+ "step": "code",
332
+ "output": f"Code execution failed with error:\n{e.output.decode()}"
333
+ })
334
+ except Exception as e:
335
+ state["history"].append({
336
+ "step": "code",
337
+ "output": f"Unexpected error during code execution:\n{str(e)}"
338
+ })
339
+
340
+ except Exception as e:
341
+ state["logs"]["code_error"] = str(e)
342
+
343
+ state["current_step"] = "answer"
344
+ return state
345
+
346
  def _generate_answer(self, state: AgentState) -> AgentState:
347
  # Collect all tool outputs with clear section headers
348
  materials = []
 
451
  sg.add_node("image", self._image_node)
452
  sg.add_node("video", self._video_node)
453
  sg.add_node("sheet", self._sheet_node)
454
+ sg.add_node("code", self._code_analysis_node)
455
 
456
  # Add edges
457
  sg.add_edge("analyze", "search")
 
460
  sg.add_edge("image", "answer")
461
  sg.add_edge("video", "answer")
462
  sg.add_edge("sheet", "answer")
463
+ sg.add_edge("code", "answer")
464
 
465
  def router(state: AgentState):
466
  return state["current_step"]
 
470
  "answer": "answer",
471
  "image": "image",
472
  "video": "video",
473
+ "sheet": "sheet",
474
+ "code": "code"
475
  })
476
 
477
  sg.set_entry_point("analyze")
 
570
 
571
  for item in questions_data:
572
  task_id = item.get("task_id")
573
+ if not task_id:
574
+ print(f"Skipping item with missing task_id: {item}")
 
575
  continue
576
 
577
  try:
578
+ print(f"\nProcessing question {task_id}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
579
 
580
+ # Pass the entire item as JSON string
581
+ state_question = json.dumps(item)
582
+ answer = agent(state_question, task_id)
583
 
584
  # Add to results
585
  answers_payload.append({"task_id": task_id, "submitted_answer": answer})
586
  results_log.append({
587
  "Task ID": task_id,
588
+ "Question": item.get("question", ""),
589
  "Submitted Answer": answer
590
  })
591
 
 
595
  print(f"Error running agent on task {task_id}: {e}")
596
  results_log.append({
597
  "Task ID": task_id,
598
+ "Question": item.get("question", ""),
599
  "Submitted Answer": f"ERROR: {e}"
600
  })
601