Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,6 @@ from huggingface_hub import InferenceClient
|
|
16 |
import io
|
17 |
import mimetypes
|
18 |
import base64
|
19 |
-
from code_interpreter import CodeInterpreter
|
20 |
from decision_maker import DecisionMaker, ToolType
|
21 |
|
22 |
# -------------------------
|
@@ -56,18 +55,18 @@ def tighten(q: str) -> str:
|
|
56 |
# -------------------------
|
57 |
|
58 |
def image_qa(image_path: str, prompt: str) -> str:
|
59 |
-
"""Query
|
60 |
with open(image_path, "rb") as f:
|
61 |
data = {"prompt": prompt, "image": f.read()}
|
62 |
headers = {"Content-Type": "application/octet-stream"}
|
63 |
-
return client.post("
|
64 |
|
65 |
def video_label(video_path: str, topk: int = 1) -> str:
|
66 |
-
"""Get video classification using VideoMAE."""
|
67 |
with open(video_path, "rb") as f:
|
68 |
headers = {"Content-Type": "application/octet-stream"}
|
69 |
preds = client.post(
|
70 |
-
"MCG-NJU/videomae-base-finetuned-ucf101",
|
71 |
data=f.read(),
|
72 |
headers=headers
|
73 |
)
|
@@ -107,12 +106,36 @@ class AgentState(TypedDict):
|
|
107 |
# BasicAgent implementation
|
108 |
# -------------------------
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
class BasicAgent:
|
111 |
def __init__(self):
|
112 |
if not OPENAI_API_KEY:
|
113 |
raise EnvironmentError("OPENAI_API_KEY not set")
|
114 |
self.llm = OpenAI(api_key=OPENAI_API_KEY)
|
115 |
-
self.code_interpreter = CodeInterpreter()
|
116 |
self.decision_maker = DecisionMaker()
|
117 |
self.workflow = self._build_workflow()
|
118 |
|
@@ -145,31 +168,31 @@ class BasicAgent:
|
|
145 |
return state
|
146 |
|
147 |
# Check for file attachments in the question
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
elif state["attachment_data"]["type"] in ["excel", "csv"]:
|
167 |
-
state["current_step"] = "sheet"
|
168 |
-
return state
|
169 |
-
except Exception as e:
|
170 |
-
state["logs"]["file_download_error"] = str(e)
|
171 |
-
state["current_step"] = "answer"
|
172 |
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
# Regular text question analysis
|
175 |
prompt = (
|
@@ -215,20 +238,13 @@ class BasicAgent:
|
|
215 |
try:
|
216 |
results = []
|
217 |
for block in state["code_blocks"]:
|
218 |
-
# Analyze code using the
|
219 |
-
analysis =
|
220 |
-
block["code"],
|
221 |
-
language=block["language"]
|
222 |
-
)
|
223 |
-
|
224 |
-
# Get improvement suggestions
|
225 |
-
suggestions = self.code_interpreter.suggest_improvements(analysis)
|
226 |
|
227 |
# Format the results
|
228 |
result = {
|
229 |
"language": block["language"],
|
230 |
-
"analysis": analysis
|
231 |
-
"suggestions": suggestions
|
232 |
}
|
233 |
results.append(result)
|
234 |
|
@@ -242,41 +258,12 @@ class BasicAgent:
|
|
242 |
return state
|
243 |
|
244 |
def _detect_file_type(self, data: bytes, url: str) -> str:
|
245 |
-
"""Detect file type from
|
246 |
-
|
247 |
-
|
248 |
-
if
|
249 |
-
|
250 |
-
|
251 |
-
return "image"
|
252 |
-
elif url_lower.endswith(".xlsx"):
|
253 |
-
return "excel"
|
254 |
-
elif url_lower.endswith(".csv"):
|
255 |
-
return "csv"
|
256 |
-
|
257 |
-
# If URL check fails, try content-based detection
|
258 |
-
try:
|
259 |
-
# Try to detect image
|
260 |
-
from PIL import Image
|
261 |
-
Image.open(io.BytesIO(data))
|
262 |
-
return "image"
|
263 |
-
except:
|
264 |
-
pass
|
265 |
-
|
266 |
-
try:
|
267 |
-
# Try to detect Excel
|
268 |
-
pd.read_excel(io.BytesIO(data))
|
269 |
-
return "excel"
|
270 |
-
except:
|
271 |
-
pass
|
272 |
-
|
273 |
-
try:
|
274 |
-
# Try to detect CSV
|
275 |
-
pd.read_csv(io.BytesIO(data))
|
276 |
-
return "csv"
|
277 |
-
except:
|
278 |
-
pass
|
279 |
-
|
280 |
return "unknown"
|
281 |
|
282 |
def _image_node(self, state: AgentState) -> AgentState:
|
@@ -329,7 +316,7 @@ class BasicAgent:
|
|
329 |
|
330 |
def _perform_search(self, state: AgentState) -> AgentState:
|
331 |
try:
|
332 |
-
results = simple_search(state["search_query"], max_results=
|
333 |
print("\nSearch Results:")
|
334 |
for i, s in enumerate(results, 1):
|
335 |
print(f"[{i}] {s[:120]}…")
|
@@ -376,16 +363,17 @@ class BasicAgent:
|
|
376 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
377 |
# Collect all relevant tool outputs
|
378 |
materials = []
|
379 |
-
for
|
380 |
-
if
|
381 |
# Handle different output formats
|
382 |
-
if
|
383 |
-
output =
|
384 |
if isinstance(output, list):
|
385 |
output = "\n".join(output)
|
386 |
else:
|
387 |
-
output =
|
388 |
-
|
|
|
389 |
|
390 |
# Join all materials with proper formatting
|
391 |
search_block = "\n".join(materials) if materials else "No artefacts available."
|
@@ -479,13 +467,9 @@ Think step-by-step. Write ANSWER: <answer> on its own line.
|
|
479 |
|
480 |
def _download_file(self, url: str) -> bytes:
|
481 |
"""Download a file from a URL."""
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
return response.content
|
486 |
-
except Exception as e:
|
487 |
-
print(f"Error downloading file from {url}: {e}")
|
488 |
-
raise
|
489 |
|
490 |
# ----------------------------------------------------------------------------------
|
491 |
# Gradio Interface & Submission Routines
|
|
|
16 |
import io
|
17 |
import mimetypes
|
18 |
import base64
|
|
|
19 |
from decision_maker import DecisionMaker, ToolType
|
20 |
|
21 |
# -------------------------
|
|
|
55 |
# -------------------------
|
56 |
|
57 |
def image_qa(image_path: str, prompt: str) -> str:
|
58 |
+
"""Query MiniGPT-4-V for image-based QA."""
|
59 |
with open(image_path, "rb") as f:
|
60 |
data = {"prompt": prompt, "image": f.read()}
|
61 |
headers = {"Content-Type": "application/octet-stream"}
|
62 |
+
return client.post("Vision-CAIR/MiniGPT4-V", data=data, headers=headers)
|
63 |
|
64 |
def video_label(video_path: str, topk: int = 1) -> str:
|
65 |
+
"""Get video classification using VideoMAE-Base-Short."""
|
66 |
with open(video_path, "rb") as f:
|
67 |
headers = {"Content-Type": "application/octet-stream"}
|
68 |
preds = client.post(
|
69 |
+
"MCG-NJU/videomae-base-short-finetuned-ucf101",
|
70 |
data=f.read(),
|
71 |
headers=headers
|
72 |
)
|
|
|
106 |
# BasicAgent implementation
|
107 |
# -------------------------
|
108 |
|
109 |
+
def quick_code_stats(src: str) -> Dict[str, Any]:
|
110 |
+
"""Lightweight code analysis using AST."""
|
111 |
+
try:
|
112 |
+
tree = ast.parse(src)
|
113 |
+
funcs = [n.name for n in tree.body if isinstance(n, ast.FunctionDef)]
|
114 |
+
classes = [n.name for n in tree.body if isinstance(n, ast.ClassDef)]
|
115 |
+
imports = []
|
116 |
+
for node in ast.walk(tree):
|
117 |
+
if isinstance(node, ast.Import):
|
118 |
+
imports.extend(n.name for n in node.names)
|
119 |
+
elif isinstance(node, ast.ImportFrom):
|
120 |
+
imports.append(f"{node.module}.{node.names[0].name}")
|
121 |
+
|
122 |
+
return {
|
123 |
+
"functions": funcs,
|
124 |
+
"classes": classes,
|
125 |
+
"imports": imports,
|
126 |
+
"lines": len(src.splitlines())
|
127 |
+
}
|
128 |
+
except Exception as e:
|
129 |
+
return {
|
130 |
+
"error": str(e),
|
131 |
+
"lines": len(src.splitlines())
|
132 |
+
}
|
133 |
+
|
134 |
class BasicAgent:
|
135 |
def __init__(self):
|
136 |
if not OPENAI_API_KEY:
|
137 |
raise EnvironmentError("OPENAI_API_KEY not set")
|
138 |
self.llm = OpenAI(api_key=OPENAI_API_KEY)
|
|
|
139 |
self.decision_maker = DecisionMaker()
|
140 |
self.workflow = self._build_workflow()
|
141 |
|
|
|
168 |
return state
|
169 |
|
170 |
# Check for file attachments in the question
|
171 |
+
try:
|
172 |
+
question_data = json.loads(state["question"])
|
173 |
+
if "file_url" in question_data:
|
174 |
+
file_url = question_data["file_url"]
|
175 |
+
# Download the file
|
176 |
+
file_data = self._download_file(file_url)
|
177 |
+
# Store in state
|
178 |
+
state["attachment_data"] = {
|
179 |
+
"content": file_data,
|
180 |
+
"type": self._detect_file_type(file_data, file_url)
|
181 |
+
}
|
182 |
+
# Set appropriate step based on file type
|
183 |
+
if state["attachment_data"]["type"] == "video":
|
184 |
+
state["current_step"] = "video"
|
185 |
+
elif state["attachment_data"]["type"] == "image":
|
186 |
+
state["current_step"] = "image"
|
187 |
+
elif state["attachment_data"]["type"] in ["excel", "csv"]:
|
188 |
+
state["current_step"] = "sheet"
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
return state
|
190 |
+
except (json.JSONDecodeError, KeyError):
|
191 |
+
pass # Not a JSON question or no file_url
|
192 |
+
except Exception as e:
|
193 |
+
state["logs"]["file_download_error"] = str(e)
|
194 |
+
state["current_step"] = "answer"
|
195 |
+
return state
|
196 |
|
197 |
# Regular text question analysis
|
198 |
prompt = (
|
|
|
238 |
try:
|
239 |
results = []
|
240 |
for block in state["code_blocks"]:
|
241 |
+
# Analyze code using the lightweight analyzer
|
242 |
+
analysis = quick_code_stats(block["code"])
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
# Format the results
|
245 |
result = {
|
246 |
"language": block["language"],
|
247 |
+
"analysis": analysis
|
|
|
248 |
}
|
249 |
results.append(result)
|
250 |
|
|
|
258 |
return state
|
259 |
|
260 |
def _detect_file_type(self, data: bytes, url: str) -> str:
|
261 |
+
"""Detect file type from URL extension."""
|
262 |
+
ext = url.split(".")[-1].lower()
|
263 |
+
if ext in {"mp4"}: return "video"
|
264 |
+
if ext in {"jpg", "jpeg", "png"}: return "image"
|
265 |
+
if ext in {"xlsx"}: return "excel"
|
266 |
+
if ext in {"csv"}: return "csv"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
return "unknown"
|
268 |
|
269 |
def _image_node(self, state: AgentState) -> AgentState:
|
|
|
316 |
|
317 |
def _perform_search(self, state: AgentState) -> AgentState:
|
318 |
try:
|
319 |
+
results = simple_search(state["search_query"], max_results=10)
|
320 |
print("\nSearch Results:")
|
321 |
for i, s in enumerate(results, 1):
|
322 |
print(f"[{i}] {s[:120]}…")
|
|
|
363 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
364 |
# Collect all relevant tool outputs
|
365 |
materials = []
|
366 |
+
for step in state["history"]:
|
367 |
+
if step["step"] in {"search", "image", "video", "sheet", "code_analysis"}:
|
368 |
# Handle different output formats
|
369 |
+
if step["step"] == "search":
|
370 |
+
output = step.get("results", [])
|
371 |
if isinstance(output, list):
|
372 |
output = "\n".join(output)
|
373 |
else:
|
374 |
+
output = step.get("output", "")
|
375 |
+
# Format the output as JSON for better readability
|
376 |
+
materials.append(json.dumps(output, indent=2))
|
377 |
|
378 |
# Join all materials with proper formatting
|
379 |
search_block = "\n".join(materials) if materials else "No artefacts available."
|
|
|
467 |
|
468 |
def _download_file(self, url: str) -> bytes:
|
469 |
"""Download a file from a URL."""
|
470 |
+
r = requests.get(url, timeout=30)
|
471 |
+
r.raise_for_status()
|
472 |
+
return r.content
|
|
|
|
|
|
|
|
|
473 |
|
474 |
# ----------------------------------------------------------------------------------
|
475 |
# Gradio Interface & Submission Routines
|