Update app.py
Browse files
app.py
CHANGED
@@ -53,11 +53,30 @@ def tighten(q: str) -> str:
|
|
53 |
# Multimodal helpers
|
54 |
# -------------------------
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def image_qa_bytes(data: bytes, prompt: str) -> str:
|
57 |
"""Query MiniGPT-4-V for image-based QA using bytes."""
|
58 |
headers = {"Content-Type": "application/octet-stream"}
|
59 |
return client.post("Vision-CAIR/MiniGPT4-V", data=data, headers=headers)
|
60 |
|
|
|
61 |
def video_label_bytes(data: bytes) -> str:
|
62 |
"""Get video classification using VideoMAE-Base-Short from bytes."""
|
63 |
headers = {"Content-Type": "application/octet-stream"}
|
@@ -94,7 +113,7 @@ class AgentState(TypedDict):
|
|
94 |
search_query: Annotated[str, override]
|
95 |
task_id: Annotated[str, override]
|
96 |
logs: Annotated[Dict[str, Any], merge_dicts]
|
97 |
-
|
98 |
|
99 |
# -------------------------
|
100 |
# BasicAgent implementation
|
@@ -128,10 +147,8 @@ class BasicAgent:
|
|
128 |
try:
|
129 |
question_data = json.loads(state["question"])
|
130 |
if "file_url" in question_data:
|
131 |
-
file_url = question_data["file_url"]
|
132 |
-
|
133 |
-
state["attachment_data"] = file_data
|
134 |
-
file_type = self._detect_file_type(file_url)
|
135 |
if file_type == "video":
|
136 |
state["current_step"] = "video"
|
137 |
elif file_type == "image":
|
@@ -180,43 +197,34 @@ class BasicAgent:
|
|
180 |
def _image_node(self, state: AgentState) -> AgentState:
|
181 |
"""Handle image-based questions."""
|
182 |
try:
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
else:
|
187 |
-
raise ValueError("No image data found in state")
|
188 |
-
state["current_step"] = "answer"
|
189 |
except Exception as e:
|
190 |
state["logs"]["image_error"] = str(e)
|
191 |
-
|
192 |
return state
|
193 |
|
194 |
def _video_node(self, state: AgentState) -> AgentState:
|
195 |
"""Handle video-based questions."""
|
196 |
try:
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
else:
|
201 |
-
raise ValueError("No video data found in state")
|
202 |
-
state["current_step"] = "answer"
|
203 |
except Exception as e:
|
204 |
state["logs"]["video_error"] = str(e)
|
205 |
-
|
206 |
return state
|
207 |
|
208 |
def _sheet_node(self, state: AgentState) -> AgentState:
|
209 |
"""Handle spreadsheet-based questions."""
|
210 |
try:
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
else:
|
215 |
-
raise ValueError("No spreadsheet data found in state")
|
216 |
-
state["current_step"] = "answer"
|
217 |
except Exception as e:
|
218 |
state["logs"]["sheet_error"] = str(e)
|
219 |
-
|
220 |
return state
|
221 |
|
222 |
def _perform_search(self, state: AgentState) -> AgentState:
|
@@ -243,13 +251,24 @@ class BasicAgent:
|
|
243 |
return state
|
244 |
|
245 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
246 |
-
# Collect all
|
247 |
materials = []
|
248 |
for h in state["history"]:
|
249 |
-
if h["step"]
|
250 |
-
materials.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
-
search_block = "\n".join(materials) if materials else "No
|
253 |
|
254 |
prompt = f"""
|
255 |
Answer this question using ONLY the materials provided.
|
@@ -325,7 +344,7 @@ Write ANSWER: <answer> on its own line.
|
|
325 |
"search_query": "",
|
326 |
"task_id": task_id,
|
327 |
"logs": {},
|
328 |
-
"
|
329 |
}
|
330 |
final_state = self.workflow.invoke(state)
|
331 |
return final_state["final_answer"]
|
@@ -418,7 +437,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
418 |
"search_query": "",
|
419 |
"task_id": task_id,
|
420 |
"logs": {},
|
421 |
-
"
|
422 |
}
|
423 |
|
424 |
# Run the workflow
|
|
|
53 |
# Multimodal helpers
|
54 |
# -------------------------
|
55 |
|
56 |
+
def retry_hf_inference(func):
|
57 |
+
"""Decorator to retry HF Inference API calls with backoff."""
|
58 |
+
def wrapper(*args, **kwargs):
|
59 |
+
max_retries = 2
|
60 |
+
base_delay = 7
|
61 |
+
|
62 |
+
for attempt in range(max_retries + 1):
|
63 |
+
try:
|
64 |
+
return func(*args, **kwargs)
|
65 |
+
except Exception as e:
|
66 |
+
if attempt == max_retries:
|
67 |
+
raise
|
68 |
+
delay = base_delay * (attempt + 1)
|
69 |
+
print(f"HF API error: {str(e)}. Retrying in {delay}s...")
|
70 |
+
time.sleep(delay)
|
71 |
+
return wrapper
|
72 |
+
|
73 |
+
@retry_hf_inference
|
74 |
def image_qa_bytes(data: bytes, prompt: str) -> str:
|
75 |
"""Query MiniGPT-4-V for image-based QA using bytes."""
|
76 |
headers = {"Content-Type": "application/octet-stream"}
|
77 |
return client.post("Vision-CAIR/MiniGPT4-V", data=data, headers=headers)
|
78 |
|
79 |
+
@retry_hf_inference
|
80 |
def video_label_bytes(data: bytes) -> str:
|
81 |
"""Get video classification using VideoMAE-Base-Short from bytes."""
|
82 |
headers = {"Content-Type": "application/octet-stream"}
|
|
|
113 |
search_query: Annotated[str, override]
|
114 |
task_id: Annotated[str, override]
|
115 |
logs: Annotated[Dict[str, Any], merge_dicts]
|
116 |
+
file_url: Annotated[str, override]
|
117 |
|
118 |
# -------------------------
|
119 |
# BasicAgent implementation
|
|
|
147 |
try:
|
148 |
question_data = json.loads(state["question"])
|
149 |
if "file_url" in question_data:
|
150 |
+
state["file_url"] = question_data["file_url"]
|
151 |
+
file_type = self._detect_file_type(state["file_url"])
|
|
|
|
|
152 |
if file_type == "video":
|
153 |
state["current_step"] = "video"
|
154 |
elif file_type == "image":
|
|
|
197 |
def _image_node(self, state: AgentState) -> AgentState:
|
198 |
"""Handle image-based questions."""
|
199 |
try:
|
200 |
+
data = self._download_file(state["file_url"])
|
201 |
+
answer = image_qa_bytes(data, "What is shown in this image?")
|
202 |
+
state["history"].append({"step": "image", "output": answer})
|
|
|
|
|
|
|
203 |
except Exception as e:
|
204 |
state["logs"]["image_error"] = str(e)
|
205 |
+
state["current_step"] = "answer"
|
206 |
return state
|
207 |
|
208 |
def _video_node(self, state: AgentState) -> AgentState:
|
209 |
"""Handle video-based questions."""
|
210 |
try:
|
211 |
+
data = self._download_file(state["file_url"])
|
212 |
+
label = video_label_bytes(data)
|
213 |
+
state["history"].append({"step": "video", "output": label})
|
|
|
|
|
|
|
214 |
except Exception as e:
|
215 |
state["logs"]["video_error"] = str(e)
|
216 |
+
state["current_step"] = "answer"
|
217 |
return state
|
218 |
|
219 |
def _sheet_node(self, state: AgentState) -> AgentState:
|
220 |
"""Handle spreadsheet-based questions."""
|
221 |
try:
|
222 |
+
data = self._download_file(state["file_url"])
|
223 |
+
answer = sheet_answer_bytes(data, state["file_url"])
|
224 |
+
state["history"].append({"step": "sheet", "output": answer})
|
|
|
|
|
|
|
225 |
except Exception as e:
|
226 |
state["logs"]["sheet_error"] = str(e)
|
227 |
+
state["current_step"] = "answer"
|
228 |
return state
|
229 |
|
230 |
def _perform_search(self, state: AgentState) -> AgentState:
|
|
|
251 |
return state
|
252 |
|
253 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
254 |
+
# Collect all tool outputs with clear section headers
|
255 |
materials = []
|
256 |
for h in state["history"]:
|
257 |
+
if h["step"] == "search":
|
258 |
+
materials.append("=== Search Results ===")
|
259 |
+
for result in h.get("results", []):
|
260 |
+
materials.append(result)
|
261 |
+
elif h["step"] == "image":
|
262 |
+
materials.append("=== Image Analysis ===")
|
263 |
+
materials.append(h.get("output", ""))
|
264 |
+
elif h["step"] == "video":
|
265 |
+
materials.append("=== Video Analysis ===")
|
266 |
+
materials.append(h.get("output", ""))
|
267 |
+
elif h["step"] == "sheet":
|
268 |
+
materials.append("=== Spreadsheet Analysis ===")
|
269 |
+
materials.append(h.get("output", ""))
|
270 |
|
271 |
+
search_block = "\n\n".join(materials) if materials else "No materials available."
|
272 |
|
273 |
prompt = f"""
|
274 |
Answer this question using ONLY the materials provided.
|
|
|
344 |
"search_query": "",
|
345 |
"task_id": task_id,
|
346 |
"logs": {},
|
347 |
+
"file_url": ""
|
348 |
}
|
349 |
final_state = self.workflow.invoke(state)
|
350 |
return final_state["final_answer"]
|
|
|
437 |
"search_query": "",
|
438 |
"task_id": task_id,
|
439 |
"logs": {},
|
440 |
+
"file_url": ""
|
441 |
}
|
442 |
|
443 |
# Run the workflow
|