Update app.py
Browse files
app.py
CHANGED
@@ -59,13 +59,17 @@ def image_qa(image_path: str, prompt: str) -> str:
|
|
59 |
"""Query LLaVA model for image-based QA."""
|
60 |
with open(image_path, "rb") as f:
|
61 |
data = {"prompt": prompt, "image": f.read()}
|
62 |
-
|
|
|
63 |
|
64 |
def video_label(video_path: str, topk: int = 1) -> str:
|
65 |
"""Get video classification using VideoMAE."""
|
66 |
with open(video_path, "rb") as f:
|
|
|
67 |
preds = client.post(
|
68 |
-
"MCG-NJU/videomae-base-finetuned-ucf101",
|
|
|
|
|
69 |
)
|
70 |
preds = sorted(preds, key=lambda x: x["score"], reverse=True)[:topk]
|
71 |
return preds[0]["label"]
|
@@ -97,6 +101,7 @@ class AgentState(TypedDict):
|
|
97 |
task_id: Annotated[str, override]
|
98 |
logs: Annotated[Dict[str, Any], merge_dicts]
|
99 |
code_blocks: Annotated[List[Dict[str, str]], list.__add__]
|
|
|
100 |
|
101 |
# -------------------------
|
102 |
# BasicAgent implementation
|
@@ -139,36 +144,54 @@ class BasicAgent:
|
|
139 |
state["code_blocks"] = code_blocks
|
140 |
return state
|
141 |
|
142 |
-
# Check for
|
143 |
-
|
144 |
-
if "video" in q or q.endswith(".mp4"):
|
145 |
-
state["current_step"] = "video"
|
146 |
-
elif q.endswith((".jpg", ".png", ".jpeg")):
|
147 |
-
state["current_step"] = "image"
|
148 |
-
elif q.endswith((".xlsx", ".csv")):
|
149 |
-
state["current_step"] = "sheet"
|
150 |
-
else:
|
151 |
-
# Regular text question analysis
|
152 |
-
prompt = (
|
153 |
-
"You will receive a user question. Think step‑by‑step to decide whether external web search is required. "
|
154 |
-
"Respond ONLY with a valid Python dict literal in the following format and NOTHING else:\n"
|
155 |
-
"{\n 'needs_search': bool,\n 'search_query': str\n} \n\n"
|
156 |
-
f"Question: {state['question']}"
|
157 |
-
)
|
158 |
-
raw = self._call_llm(prompt)
|
159 |
try:
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
return state
|
173 |
|
174 |
def _extract_code_blocks(self, text: str) -> List[Dict[str, str]]:
|
@@ -218,11 +241,54 @@ class BasicAgent:
|
|
218 |
|
219 |
return state
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
def _image_node(self, state: AgentState) -> AgentState:
|
222 |
"""Handle image-based questions."""
|
223 |
try:
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
226 |
state["current_step"] = "answer"
|
227 |
except Exception as e:
|
228 |
state["logs"]["image_error"] = str(e)
|
@@ -232,8 +298,13 @@ class BasicAgent:
|
|
232 |
def _video_node(self, state: AgentState) -> AgentState:
|
233 |
"""Handle video-based questions."""
|
234 |
try:
|
235 |
-
|
236 |
-
|
|
|
|
|
|
|
|
|
|
|
237 |
state["current_step"] = "answer"
|
238 |
except Exception as e:
|
239 |
state["logs"]["video_error"] = str(e)
|
@@ -243,9 +314,13 @@ class BasicAgent:
|
|
243 |
def _sheet_node(self, state: AgentState) -> AgentState:
|
244 |
"""Handle spreadsheet-based questions."""
|
245 |
try:
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
|
|
|
|
|
|
249 |
state["current_step"] = "answer"
|
250 |
except Exception as e:
|
251 |
state["logs"]["sheet_error"] = str(e)
|
@@ -299,16 +374,21 @@ class BasicAgent:
|
|
299 |
return text.strip()
|
300 |
|
301 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
302 |
-
#
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
prompt = f"""
|
314 |
You are an expert assistant. Use ONLY the materials below to answer.
|
@@ -392,10 +472,21 @@ Think step-by-step. Write ANSWER: <answer> on its own line.
|
|
392 |
"task_id": task_id,
|
393 |
"logs": {},
|
394 |
"code_blocks": [],
|
|
|
395 |
}
|
396 |
final_state = self.workflow.invoke(state)
|
397 |
return final_state["final_answer"]
|
398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
# ----------------------------------------------------------------------------------
|
400 |
# Gradio Interface & Submission Routines
|
401 |
# ----------------------------------------------------------------------------------
|
@@ -479,6 +570,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
479 |
"task_id": task_id,
|
480 |
"logs": {},
|
481 |
"code_blocks": [],
|
|
|
482 |
}
|
483 |
|
484 |
# Run the workflow
|
|
|
59 |
"""Query LLaVA model for image-based QA."""
|
60 |
with open(image_path, "rb") as f:
|
61 |
data = {"prompt": prompt, "image": f.read()}
|
62 |
+
headers = {"Content-Type": "application/octet-stream"}
|
63 |
+
return client.post("llava-hf/llava-v1.6-mistral-7b-hf", data=data, headers=headers)
|
64 |
|
65 |
def video_label(video_path: str, topk: int = 1) -> str:
|
66 |
"""Get video classification using VideoMAE."""
|
67 |
with open(video_path, "rb") as f:
|
68 |
+
headers = {"Content-Type": "application/octet-stream"}
|
69 |
preds = client.post(
|
70 |
+
"MCG-NJU/videomae-base-finetuned-ucf101",
|
71 |
+
data=f.read(),
|
72 |
+
headers=headers
|
73 |
)
|
74 |
preds = sorted(preds, key=lambda x: x["score"], reverse=True)[:topk]
|
75 |
return preds[0]["label"]
|
|
|
101 |
task_id: Annotated[str, override]
|
102 |
logs: Annotated[Dict[str, Any], merge_dicts]
|
103 |
code_blocks: Annotated[List[Dict[str, str]], list.__add__]
|
104 |
+
attachment_data: Annotated[Dict[str, bytes], merge_dicts] # Store downloaded file data
|
105 |
|
106 |
# -------------------------
|
107 |
# BasicAgent implementation
|
|
|
144 |
state["code_blocks"] = code_blocks
|
145 |
return state
|
146 |
|
147 |
+
# Check for file attachments in the question
|
148 |
+
if "file_url" in state["question"]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
try:
|
150 |
+
# Parse the question to get file URL
|
151 |
+
question_data = json.loads(state["question"])
|
152 |
+
file_url = question_data.get("file_url")
|
153 |
+
if file_url:
|
154 |
+
# Download the file
|
155 |
+
file_data = self._download_file(file_url)
|
156 |
+
# Store in state
|
157 |
+
state["attachment_data"] = {
|
158 |
+
"content": file_data,
|
159 |
+
"type": self._detect_file_type(file_data, file_url)
|
160 |
+
}
|
161 |
+
# Set appropriate step based on file type
|
162 |
+
if state["attachment_data"]["type"] == "video":
|
163 |
+
state["current_step"] = "video"
|
164 |
+
elif state["attachment_data"]["type"] == "image":
|
165 |
+
state["current_step"] = "image"
|
166 |
+
elif state["attachment_data"]["type"] in ["excel", "csv"]:
|
167 |
+
state["current_step"] = "sheet"
|
168 |
+
return state
|
169 |
+
except Exception as e:
|
170 |
+
state["logs"]["file_download_error"] = str(e)
|
171 |
+
state["current_step"] = "answer"
|
172 |
+
return state
|
173 |
+
|
174 |
+
# Regular text question analysis
|
175 |
+
prompt = (
|
176 |
+
"You will receive a user question. Think step‑by‑step to decide whether external web search is required. "
|
177 |
+
"Respond ONLY with a valid Python dict literal in the following format and NOTHING else:\n"
|
178 |
+
"{\n 'needs_search': bool,\n 'search_query': str\n} \n\n"
|
179 |
+
f"Question: {state['question']}"
|
180 |
+
)
|
181 |
+
raw = self._call_llm(prompt)
|
182 |
+
try:
|
183 |
+
decision = ast.literal_eval(raw)
|
184 |
+
state["needs_search"] = bool(decision.get("needs_search", False))
|
185 |
+
state["search_query"] = decision.get("search_query", state["question"])
|
186 |
+
except Exception:
|
187 |
+
state["needs_search"] = True
|
188 |
+
state["search_query"] = state["question"]
|
189 |
+
decision = {"parse_error": raw}
|
190 |
+
state["logs"] = {
|
191 |
+
"analyze": {"prompt": prompt, "llm_response": raw, "decision": decision}
|
192 |
+
}
|
193 |
+
state["current_step"] = "search" if state["needs_search"] else "answer"
|
194 |
+
state["history"].append({"step": "analyze", "output": decision})
|
195 |
return state
|
196 |
|
197 |
def _extract_code_blocks(self, text: str) -> List[Dict[str, str]]:
|
|
|
241 |
|
242 |
return state
|
243 |
|
244 |
+
def _detect_file_type(self, data: bytes, url: str) -> str:
|
245 |
+
"""Detect file type from content and URL."""
|
246 |
+
# Check URL extension first
|
247 |
+
url_lower = url.lower()
|
248 |
+
if url_lower.endswith((".mp4", ".avi", ".mov")):
|
249 |
+
return "video"
|
250 |
+
elif url_lower.endswith((".jpg", ".jpeg", ".png", ".gif")):
|
251 |
+
return "image"
|
252 |
+
elif url_lower.endswith(".xlsx"):
|
253 |
+
return "excel"
|
254 |
+
elif url_lower.endswith(".csv"):
|
255 |
+
return "csv"
|
256 |
+
|
257 |
+
# If URL check fails, try content-based detection
|
258 |
+
try:
|
259 |
+
# Try to detect image
|
260 |
+
from PIL import Image
|
261 |
+
Image.open(io.BytesIO(data))
|
262 |
+
return "image"
|
263 |
+
except:
|
264 |
+
pass
|
265 |
+
|
266 |
+
try:
|
267 |
+
# Try to detect Excel
|
268 |
+
pd.read_excel(io.BytesIO(data))
|
269 |
+
return "excel"
|
270 |
+
except:
|
271 |
+
pass
|
272 |
+
|
273 |
+
try:
|
274 |
+
# Try to detect CSV
|
275 |
+
pd.read_csv(io.BytesIO(data))
|
276 |
+
return "csv"
|
277 |
+
except:
|
278 |
+
pass
|
279 |
+
|
280 |
+
return "unknown"
|
281 |
+
|
282 |
def _image_node(self, state: AgentState) -> AgentState:
|
283 |
"""Handle image-based questions."""
|
284 |
try:
|
285 |
+
if "attachment_data" in state and "content" in state["attachment_data"]:
|
286 |
+
# Use the downloaded image data
|
287 |
+
image_data = state["attachment_data"]["content"]
|
288 |
+
answer = image_qa(image_data, "What is shown in this image?")
|
289 |
+
state["history"].append({"step": "image", "output": answer})
|
290 |
+
else:
|
291 |
+
raise ValueError("No image data found in state")
|
292 |
state["current_step"] = "answer"
|
293 |
except Exception as e:
|
294 |
state["logs"]["image_error"] = str(e)
|
|
|
298 |
def _video_node(self, state: AgentState) -> AgentState:
|
299 |
"""Handle video-based questions."""
|
300 |
try:
|
301 |
+
if "attachment_data" in state and "content" in state["attachment_data"]:
|
302 |
+
# Use the downloaded video data
|
303 |
+
video_data = state["attachment_data"]["content"]
|
304 |
+
label = video_label(video_data)
|
305 |
+
state["history"].append({"step": "video", "output": label})
|
306 |
+
else:
|
307 |
+
raise ValueError("No video data found in state")
|
308 |
state["current_step"] = "answer"
|
309 |
except Exception as e:
|
310 |
state["logs"]["video_error"] = str(e)
|
|
|
314 |
def _sheet_node(self, state: AgentState) -> AgentState:
|
315 |
"""Handle spreadsheet-based questions."""
|
316 |
try:
|
317 |
+
if "attachment_data" in state and "content" in state["attachment_data"]:
|
318 |
+
# Use the downloaded spreadsheet data
|
319 |
+
sheet_data = state["attachment_data"]["content"]
|
320 |
+
answer = sheet_answer(sheet_data, state["question"])
|
321 |
+
state["history"].append({"step": "sheet", "output": answer})
|
322 |
+
else:
|
323 |
+
raise ValueError("No spreadsheet data found in state")
|
324 |
state["current_step"] = "answer"
|
325 |
except Exception as e:
|
326 |
state["logs"]["sheet_error"] = str(e)
|
|
|
374 |
return text.strip()
|
375 |
|
376 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
377 |
+
# Collect all relevant tool outputs
|
378 |
+
materials = []
|
379 |
+
for item in state["history"]:
|
380 |
+
if item["step"] in ("search", "image", "video", "sheet", "code_analysis"):
|
381 |
+
# Handle different output formats
|
382 |
+
if item["step"] == "search":
|
383 |
+
output = item.get("results", [])
|
384 |
+
if isinstance(output, list):
|
385 |
+
output = "\n".join(output)
|
386 |
+
else:
|
387 |
+
output = item.get("output", "")
|
388 |
+
materials.append(str(output))
|
389 |
+
|
390 |
+
# Join all materials with proper formatting
|
391 |
+
search_block = "\n".join(materials) if materials else "No artefacts available."
|
392 |
|
393 |
prompt = f"""
|
394 |
You are an expert assistant. Use ONLY the materials below to answer.
|
|
|
472 |
"task_id": task_id,
|
473 |
"logs": {},
|
474 |
"code_blocks": [],
|
475 |
+
"attachment_data": {}
|
476 |
}
|
477 |
final_state = self.workflow.invoke(state)
|
478 |
return final_state["final_answer"]
|
479 |
|
480 |
+
def _download_file(self, url: str) -> bytes:
|
481 |
+
"""Download a file from a URL."""
|
482 |
+
try:
|
483 |
+
response = requests.get(url, timeout=20)
|
484 |
+
response.raise_for_status()
|
485 |
+
return response.content
|
486 |
+
except Exception as e:
|
487 |
+
print(f"Error downloading file from {url}: {e}")
|
488 |
+
raise
|
489 |
+
|
490 |
# ----------------------------------------------------------------------------------
|
491 |
# Gradio Interface & Submission Routines
|
492 |
# ----------------------------------------------------------------------------------
|
|
|
570 |
"task_id": task_id,
|
571 |
"logs": {},
|
572 |
"code_blocks": [],
|
573 |
+
"attachment_data": {}
|
574 |
}
|
575 |
|
576 |
# Run the workflow
|