Update app.py
Browse files
app.py
CHANGED
@@ -54,27 +54,23 @@ def tighten(q: str) -> str:
|
|
54 |
# Multimodal helpers
|
55 |
# -------------------------
|
56 |
|
57 |
-
def
|
58 |
-
"""Query MiniGPT-4-V for image-based QA."""
|
59 |
-
|
60 |
-
data = {"prompt": prompt, "image": f.read()}
|
61 |
-
headers = {"Content-Type": "application/octet-stream"}
|
62 |
return client.post("Vision-CAIR/MiniGPT4-V", data=data, headers=headers)
|
63 |
|
64 |
-
def
|
65 |
-
"""Get video classification using VideoMAE-Base-Short."""
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
preds = sorted(preds, key=lambda x: x["score"], reverse=True)[:topk]
|
74 |
-
return preds[0]["label"]
|
75 |
|
76 |
-
def
|
77 |
-
"""Process spreadsheet data and answer questions."""
|
78 |
if mimetypes.guess_type("x.xlsx")[0] == "text/csv" or question.endswith(".csv"):
|
79 |
df = pd.read_csv(io.BytesIO(data))
|
80 |
else:
|
@@ -175,16 +171,14 @@ class BasicAgent:
|
|
175 |
# Download the file
|
176 |
file_data = self._download_file(file_url)
|
177 |
# Store in state
|
178 |
-
state["attachment_data"] =
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
# Set appropriate step based on file type
|
183 |
-
if state["attachment_data"]["type"] == "video":
|
184 |
state["current_step"] = "video"
|
185 |
-
elif
|
186 |
state["current_step"] = "image"
|
187 |
-
elif
|
188 |
state["current_step"] = "sheet"
|
189 |
return state
|
190 |
except (json.JSONDecodeError, KeyError):
|
@@ -257,22 +251,23 @@ class BasicAgent:
|
|
257 |
|
258 |
return state
|
259 |
|
260 |
-
def _detect_file_type(self,
|
261 |
"""Detect file type from URL extension."""
|
262 |
ext = url.split(".")[-1].lower()
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
268 |
|
269 |
def _image_node(self, state: AgentState) -> AgentState:
|
270 |
"""Handle image-based questions."""
|
271 |
try:
|
272 |
-
if "attachment_data" in state
|
273 |
-
|
274 |
-
image_data = state["attachment_data"]["content"]
|
275 |
-
answer = image_qa(image_data, "What is shown in this image?")
|
276 |
state["history"].append({"step": "image", "output": answer})
|
277 |
else:
|
278 |
raise ValueError("No image data found in state")
|
@@ -285,10 +280,8 @@ class BasicAgent:
|
|
285 |
def _video_node(self, state: AgentState) -> AgentState:
|
286 |
"""Handle video-based questions."""
|
287 |
try:
|
288 |
-
if "attachment_data" in state
|
289 |
-
|
290 |
-
video_data = state["attachment_data"]["content"]
|
291 |
-
label = video_label(video_data)
|
292 |
state["history"].append({"step": "video", "output": label})
|
293 |
else:
|
294 |
raise ValueError("No video data found in state")
|
@@ -301,10 +294,8 @@ class BasicAgent:
|
|
301 |
def _sheet_node(self, state: AgentState) -> AgentState:
|
302 |
"""Handle spreadsheet-based questions."""
|
303 |
try:
|
304 |
-
if "attachment_data" in state
|
305 |
-
|
306 |
-
sheet_data = state["attachment_data"]["content"]
|
307 |
-
answer = sheet_answer(sheet_data, state["question"])
|
308 |
state["history"].append({"step": "sheet", "output": answer})
|
309 |
else:
|
310 |
raise ValueError("No spreadsheet data found in state")
|
@@ -357,17 +348,9 @@ class BasicAgent:
|
|
357 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
358 |
# Collect all relevant tool outputs
|
359 |
materials = []
|
360 |
-
for
|
361 |
-
if
|
362 |
-
|
363 |
-
if step["step"] == "search":
|
364 |
-
output = step.get("results", [])
|
365 |
-
if isinstance(output, list):
|
366 |
-
output = "\n".join(output)
|
367 |
-
else:
|
368 |
-
output = step.get("output", "")
|
369 |
-
# Format the output as JSON for better readability
|
370 |
-
materials.append(json.dumps(output, indent=2))
|
371 |
|
372 |
# Join all materials with proper formatting
|
373 |
search_block = "\n".join(materials) if materials else "No artefacts available."
|
|
|
54 |
# Multimodal helpers
|
55 |
# -------------------------
|
56 |
|
57 |
+
def image_qa_bytes(data: bytes, prompt: str) -> str:
|
58 |
+
"""Query MiniGPT-4-V for image-based QA using bytes."""
|
59 |
+
headers = {"Content-Type": "application/octet-stream"}
|
|
|
|
|
60 |
return client.post("Vision-CAIR/MiniGPT4-V", data=data, headers=headers)
|
61 |
|
62 |
+
def video_label_bytes(data: bytes) -> str:
|
63 |
+
"""Get video classification using VideoMAE-Base-Short from bytes."""
|
64 |
+
headers = {"Content-Type": "application/octet-stream"}
|
65 |
+
preds = client.post(
|
66 |
+
"MCG-NJU/videomae-base-short-finetuned-ucf101",
|
67 |
+
data=data,
|
68 |
+
headers=headers
|
69 |
+
)
|
70 |
+
return sorted(preds, key=lambda x: x["score"], reverse=True)[0]["label"]
|
|
|
|
|
71 |
|
72 |
+
def sheet_answer_bytes(data: bytes, question: str) -> str:
|
73 |
+
"""Process spreadsheet data from bytes and answer questions."""
|
74 |
if mimetypes.guess_type("x.xlsx")[0] == "text/csv" or question.endswith(".csv"):
|
75 |
df = pd.read_csv(io.BytesIO(data))
|
76 |
else:
|
|
|
171 |
# Download the file
|
172 |
file_data = self._download_file(file_url)
|
173 |
# Store in state
|
174 |
+
state["attachment_data"] = file_data
|
175 |
+
# Detect type and set appropriate step
|
176 |
+
file_type = self._detect_file_type(file_url)
|
177 |
+
if file_type == "video":
|
|
|
|
|
178 |
state["current_step"] = "video"
|
179 |
+
elif file_type == "image":
|
180 |
state["current_step"] = "image"
|
181 |
+
elif file_type in ["excel", "csv"]:
|
182 |
state["current_step"] = "sheet"
|
183 |
return state
|
184 |
except (json.JSONDecodeError, KeyError):
|
|
|
251 |
|
252 |
return state
|
253 |
|
254 |
+
def _detect_file_type(self, url: str) -> str:
|
255 |
"""Detect file type from URL extension."""
|
256 |
ext = url.split(".")[-1].lower()
|
257 |
+
return {
|
258 |
+
"mp4": "video",
|
259 |
+
"jpg": "image",
|
260 |
+
"jpeg": "image",
|
261 |
+
"png": "image",
|
262 |
+
"xlsx": "excel",
|
263 |
+
"csv": "csv"
|
264 |
+
}.get(ext, "unknown")
|
265 |
|
266 |
def _image_node(self, state: AgentState) -> AgentState:
|
267 |
"""Handle image-based questions."""
|
268 |
try:
|
269 |
+
if "attachment_data" in state:
|
270 |
+
answer = image_qa_bytes(state["attachment_data"], "What is shown in this image?")
|
|
|
|
|
271 |
state["history"].append({"step": "image", "output": answer})
|
272 |
else:
|
273 |
raise ValueError("No image data found in state")
|
|
|
280 |
def _video_node(self, state: AgentState) -> AgentState:
|
281 |
"""Handle video-based questions."""
|
282 |
try:
|
283 |
+
if "attachment_data" in state:
|
284 |
+
label = video_label_bytes(state["attachment_data"])
|
|
|
|
|
285 |
state["history"].append({"step": "video", "output": label})
|
286 |
else:
|
287 |
raise ValueError("No video data found in state")
|
|
|
294 |
def _sheet_node(self, state: AgentState) -> AgentState:
|
295 |
"""Handle spreadsheet-based questions."""
|
296 |
try:
|
297 |
+
if "attachment_data" in state:
|
298 |
+
answer = sheet_answer_bytes(state["attachment_data"], state["question"])
|
|
|
|
|
299 |
state["history"].append({"step": "sheet", "output": answer})
|
300 |
else:
|
301 |
raise ValueError("No spreadsheet data found in state")
|
|
|
348 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
349 |
# Collect all relevant tool outputs
|
350 |
materials = []
|
351 |
+
for h in state["history"]:
|
352 |
+
if h["step"] in {"search", "video", "image", "sheet", "code_analysis"}:
|
353 |
+
materials.append(json.dumps(h.get("output") or h.get("results"), indent=2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
|
355 |
# Join all materials with proper formatting
|
356 |
search_block = "\n".join(materials) if materials else "No artefacts available."
|