Update app.py
Browse files
app.py
CHANGED
@@ -16,6 +16,12 @@ from huggingface_hub import InferenceClient
|
|
16 |
import io
|
17 |
import mimetypes
|
18 |
import base64
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
# -------------------------
|
21 |
# Environment & constants
|
@@ -72,17 +78,60 @@ def retry_hf_inference(func):
|
|
72 |
|
73 |
@retry_hf_inference
|
74 |
def image_qa_bytes(data: bytes, prompt: str) -> str:
|
75 |
-
"""Query
|
76 |
headers = {"Content-Type": "application/octet-stream"}
|
77 |
-
return client.post("
|
78 |
|
79 |
@retry_hf_inference
|
80 |
def video_label_bytes(data: bytes) -> str:
|
81 |
-
"""Get video classification using VideoMAE-Base
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
headers = {"Content-Type": "application/octet-stream"}
|
83 |
preds = client.post(
|
84 |
-
"MCG-NJU/videomae-base-
|
85 |
-
data=
|
86 |
headers=headers
|
87 |
)
|
88 |
return sorted(preds, key=lambda x: x["score"], reverse=True)[0]["label"]
|
@@ -100,6 +149,18 @@ def sheet_answer_bytes(data: bytes, question: str) -> str:
|
|
100 |
label = df.columns[col]
|
101 |
return f"{label}: {value}"
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
# -------------------------
|
104 |
# State definition
|
105 |
# -------------------------
|
@@ -241,6 +302,47 @@ class BasicAgent:
|
|
241 |
state["current_step"] = "answer"
|
242 |
return state
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
245 |
# Collect all tool outputs with clear section headers
|
246 |
materials = []
|
@@ -349,6 +451,7 @@ Write ANSWER: <answer> on its own line.
|
|
349 |
sg.add_node("image", self._image_node)
|
350 |
sg.add_node("video", self._video_node)
|
351 |
sg.add_node("sheet", self._sheet_node)
|
|
|
352 |
|
353 |
# Add edges
|
354 |
sg.add_edge("analyze", "search")
|
@@ -357,6 +460,7 @@ Write ANSWER: <answer> on its own line.
|
|
357 |
sg.add_edge("image", "answer")
|
358 |
sg.add_edge("video", "answer")
|
359 |
sg.add_edge("sheet", "answer")
|
|
|
360 |
|
361 |
def router(state: AgentState):
|
362 |
return state["current_step"]
|
@@ -366,7 +470,8 @@ Write ANSWER: <answer> on its own line.
|
|
366 |
"answer": "answer",
|
367 |
"image": "image",
|
368 |
"video": "video",
|
369 |
-
"sheet": "sheet"
|
|
|
370 |
})
|
371 |
|
372 |
sg.set_entry_point("analyze")
|
@@ -465,36 +570,22 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
465 |
|
466 |
for item in questions_data:
|
467 |
task_id = item.get("task_id")
|
468 |
-
|
469 |
-
|
470 |
-
print(f"Skipping item with missing task_id or question: {item}")
|
471 |
continue
|
472 |
|
473 |
try:
|
474 |
-
print(f"\nProcessing question {task_id}
|
475 |
-
|
476 |
-
# Initialize state for this question
|
477 |
-
state: AgentState = {
|
478 |
-
"question": question_text,
|
479 |
-
"current_step": "analyze",
|
480 |
-
"final_answer": "",
|
481 |
-
"history": [],
|
482 |
-
"needs_search": False,
|
483 |
-
"search_query": "",
|
484 |
-
"task_id": task_id,
|
485 |
-
"logs": {},
|
486 |
-
"file_url": ""
|
487 |
-
}
|
488 |
|
489 |
-
#
|
490 |
-
|
491 |
-
answer =
|
492 |
|
493 |
# Add to results
|
494 |
answers_payload.append({"task_id": task_id, "submitted_answer": answer})
|
495 |
results_log.append({
|
496 |
"Task ID": task_id,
|
497 |
-
"Question":
|
498 |
"Submitted Answer": answer
|
499 |
})
|
500 |
|
@@ -504,7 +595,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
504 |
print(f"Error running agent on task {task_id}: {e}")
|
505 |
results_log.append({
|
506 |
"Task ID": task_id,
|
507 |
-
"Question":
|
508 |
"Submitted Answer": f"ERROR: {e}"
|
509 |
})
|
510 |
|
|
|
16 |
import io
|
17 |
import mimetypes
|
18 |
import base64
|
19 |
+
import cv2
|
20 |
+
import numpy as np
|
21 |
+
from io import BytesIO
|
22 |
+
import tempfile
|
23 |
+
import subprocess
|
24 |
+
import sys
|
25 |
|
26 |
# -------------------------
|
27 |
# Environment & constants
|
|
|
78 |
|
79 |
@retry_hf_inference
|
80 |
def image_qa_bytes(data: bytes, prompt: str) -> str:
|
81 |
+
"""Query LLaVA for image-based QA using bytes."""
|
82 |
headers = {"Content-Type": "application/octet-stream"}
|
83 |
+
return client.post("llava-hf/llava-v1.6-mistral-7b-hf", data=data, headers=headers)
|
84 |
|
85 |
@retry_hf_inference
|
86 |
def video_label_bytes(data: bytes) -> str:
|
87 |
+
"""Get video classification using VideoMAE-Base from bytes."""
|
88 |
+
# Process video to get first 8 seconds, 16 frames
|
89 |
+
|
90 |
+
# Read video from bytes
|
91 |
+
video_bytes = BytesIO(data)
|
92 |
+
cap = cv2.VideoCapture()
|
93 |
+
cap.open(video_bytes)
|
94 |
+
|
95 |
+
# Get video properties
|
96 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
97 |
+
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
98 |
+
|
99 |
+
# Calculate frames to extract (16 frames over 8 seconds)
|
100 |
+
target_frames = 16
|
101 |
+
target_duration = 8 # seconds
|
102 |
+
frame_interval = max(1, int(frame_count / (fps * target_duration)))
|
103 |
+
|
104 |
+
frames = []
|
105 |
+
frame_idx = 0
|
106 |
+
|
107 |
+
while len(frames) < target_frames and frame_idx < frame_count:
|
108 |
+
ret, frame = cap.read()
|
109 |
+
if not ret:
|
110 |
+
break
|
111 |
+
|
112 |
+
if frame_idx % frame_interval == 0:
|
113 |
+
# Resize frame to match VideoMAE's expected input
|
114 |
+
frame = cv2.resize(frame, (224, 224))
|
115 |
+
frames.append(frame)
|
116 |
+
|
117 |
+
frame_idx += 1
|
118 |
+
|
119 |
+
cap.release()
|
120 |
+
|
121 |
+
# If we don't have enough frames, duplicate the last frame
|
122 |
+
while len(frames) < target_frames:
|
123 |
+
frames.append(frames[-1])
|
124 |
+
|
125 |
+
# Stack frames and convert to bytes
|
126 |
+
video_array = np.stack(frames)
|
127 |
+
_, buffer = cv2.imencode('.mp4', video_array)
|
128 |
+
processed_bytes = buffer.tobytes()
|
129 |
+
|
130 |
+
# Send to VideoMAE
|
131 |
headers = {"Content-Type": "application/octet-stream"}
|
132 |
preds = client.post(
|
133 |
+
"MCG-NJU/videomae-base-finetuned-ucf101",
|
134 |
+
data=processed_bytes,
|
135 |
headers=headers
|
136 |
)
|
137 |
return sorted(preds, key=lambda x: x["score"], reverse=True)[0]["label"]
|
|
|
149 |
label = df.columns[col]
|
150 |
return f"{label}: {value}"
|
151 |
|
152 |
+
# -------------------------
|
153 |
+
# Code Analysis helpers
|
154 |
+
# -------------------------
|
155 |
+
|
156 |
+
def run_python(code: str) -> str:
|
157 |
+
"""Quick & dirty evaluator for Python code."""
|
158 |
+
with tempfile.NamedTemporaryFile("w+", suffix=".py", delete=False) as f:
|
159 |
+
f.write(code)
|
160 |
+
f.flush()
|
161 |
+
out = subprocess.check_output([sys.executable, f.name], timeout=10)
|
162 |
+
return out.decode().strip()
|
163 |
+
|
164 |
# -------------------------
|
165 |
# State definition
|
166 |
# -------------------------
|
|
|
302 |
state["current_step"] = "answer"
|
303 |
return state
|
304 |
|
305 |
+
def _code_analysis_node(self, state: AgentState) -> AgentState:
|
306 |
+
"""Handle code analysis questions."""
|
307 |
+
try:
|
308 |
+
# Extract code from the question
|
309 |
+
code_match = re.search(r'```python\n(.*?)\n```', state["question"], re.DOTALL)
|
310 |
+
if not code_match:
|
311 |
+
state["logs"]["code_error"] = "No Python code found in question"
|
312 |
+
state["current_step"] = "answer"
|
313 |
+
return state
|
314 |
+
|
315 |
+
code = code_match.group(1)
|
316 |
+
|
317 |
+
# Run the code and get output
|
318 |
+
try:
|
319 |
+
output = run_python(code)
|
320 |
+
state["history"].append({
|
321 |
+
"step": "code",
|
322 |
+
"output": f"Code execution result:\n{output}"
|
323 |
+
})
|
324 |
+
except subprocess.TimeoutExpired:
|
325 |
+
state["history"].append({
|
326 |
+
"step": "code",
|
327 |
+
"output": "Code execution timed out after 10 seconds"
|
328 |
+
})
|
329 |
+
except subprocess.CalledProcessError as e:
|
330 |
+
state["history"].append({
|
331 |
+
"step": "code",
|
332 |
+
"output": f"Code execution failed with error:\n{e.output.decode()}"
|
333 |
+
})
|
334 |
+
except Exception as e:
|
335 |
+
state["history"].append({
|
336 |
+
"step": "code",
|
337 |
+
"output": f"Unexpected error during code execution:\n{str(e)}"
|
338 |
+
})
|
339 |
+
|
340 |
+
except Exception as e:
|
341 |
+
state["logs"]["code_error"] = str(e)
|
342 |
+
|
343 |
+
state["current_step"] = "answer"
|
344 |
+
return state
|
345 |
+
|
346 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
347 |
# Collect all tool outputs with clear section headers
|
348 |
materials = []
|
|
|
451 |
sg.add_node("image", self._image_node)
|
452 |
sg.add_node("video", self._video_node)
|
453 |
sg.add_node("sheet", self._sheet_node)
|
454 |
+
sg.add_node("code", self._code_analysis_node)
|
455 |
|
456 |
# Add edges
|
457 |
sg.add_edge("analyze", "search")
|
|
|
460 |
sg.add_edge("image", "answer")
|
461 |
sg.add_edge("video", "answer")
|
462 |
sg.add_edge("sheet", "answer")
|
463 |
+
sg.add_edge("code", "answer")
|
464 |
|
465 |
def router(state: AgentState):
|
466 |
return state["current_step"]
|
|
|
470 |
"answer": "answer",
|
471 |
"image": "image",
|
472 |
"video": "video",
|
473 |
+
"sheet": "sheet",
|
474 |
+
"code": "code"
|
475 |
})
|
476 |
|
477 |
sg.set_entry_point("analyze")
|
|
|
570 |
|
571 |
for item in questions_data:
|
572 |
task_id = item.get("task_id")
|
573 |
+
if not task_id:
|
574 |
+
print(f"Skipping item with missing task_id: {item}")
|
|
|
575 |
continue
|
576 |
|
577 |
try:
|
578 |
+
print(f"\nProcessing question {task_id}...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
579 |
|
580 |
+
# Pass the entire item as JSON string
|
581 |
+
state_question = json.dumps(item)
|
582 |
+
answer = agent(state_question, task_id)
|
583 |
|
584 |
# Add to results
|
585 |
answers_payload.append({"task_id": task_id, "submitted_answer": answer})
|
586 |
results_log.append({
|
587 |
"Task ID": task_id,
|
588 |
+
"Question": item.get("question", ""),
|
589 |
"Submitted Answer": answer
|
590 |
})
|
591 |
|
|
|
595 |
print(f"Error running agent on task {task_id}: {e}")
|
596 |
results_log.append({
|
597 |
"Task ID": task_id,
|
598 |
+
"Question": item.get("question", ""),
|
599 |
"Submitted Answer": f"ERROR: {e}"
|
600 |
})
|
601 |
|