naman1102 commited on
Commit
ebc7e51
·
1 Parent(s): 8286288

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -89
app.py CHANGED
@@ -16,7 +16,6 @@ from huggingface_hub import InferenceClient
16
  import io
17
  import mimetypes
18
  import base64
19
- from code_interpreter import CodeInterpreter
20
  from decision_maker import DecisionMaker, ToolType
21
 
22
  # -------------------------
@@ -56,18 +55,18 @@ def tighten(q: str) -> str:
56
  # -------------------------
57
 
58
  def image_qa(image_path: str, prompt: str) -> str:
59
- """Query LLaVA model for image-based QA."""
60
  with open(image_path, "rb") as f:
61
  data = {"prompt": prompt, "image": f.read()}
62
  headers = {"Content-Type": "application/octet-stream"}
63
- return client.post("llava-hf/llava-v1.6-mistral-7b-hf", data=data, headers=headers)
64
 
65
  def video_label(video_path: str, topk: int = 1) -> str:
66
- """Get video classification using VideoMAE."""
67
  with open(video_path, "rb") as f:
68
  headers = {"Content-Type": "application/octet-stream"}
69
  preds = client.post(
70
- "MCG-NJU/videomae-base-finetuned-ucf101",
71
  data=f.read(),
72
  headers=headers
73
  )
@@ -107,12 +106,36 @@ class AgentState(TypedDict):
107
  # BasicAgent implementation
108
  # -------------------------
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  class BasicAgent:
111
  def __init__(self):
112
  if not OPENAI_API_KEY:
113
  raise EnvironmentError("OPENAI_API_KEY not set")
114
  self.llm = OpenAI(api_key=OPENAI_API_KEY)
115
- self.code_interpreter = CodeInterpreter()
116
  self.decision_maker = DecisionMaker()
117
  self.workflow = self._build_workflow()
118
 
@@ -145,31 +168,31 @@ class BasicAgent:
145
  return state
146
 
147
  # Check for file attachments in the question
148
- if "file_url" in state["question"]:
149
- try:
150
- # Parse the question to get file URL
151
- question_data = json.loads(state["question"])
152
- file_url = question_data.get("file_url")
153
- if file_url:
154
- # Download the file
155
- file_data = self._download_file(file_url)
156
- # Store in state
157
- state["attachment_data"] = {
158
- "content": file_data,
159
- "type": self._detect_file_type(file_data, file_url)
160
- }
161
- # Set appropriate step based on file type
162
- if state["attachment_data"]["type"] == "video":
163
- state["current_step"] = "video"
164
- elif state["attachment_data"]["type"] == "image":
165
- state["current_step"] = "image"
166
- elif state["attachment_data"]["type"] in ["excel", "csv"]:
167
- state["current_step"] = "sheet"
168
- return state
169
- except Exception as e:
170
- state["logs"]["file_download_error"] = str(e)
171
- state["current_step"] = "answer"
172
  return state
 
 
 
 
 
 
173
 
174
  # Regular text question analysis
175
  prompt = (
@@ -215,20 +238,13 @@ class BasicAgent:
215
  try:
216
  results = []
217
  for block in state["code_blocks"]:
218
- # Analyze code using the code interpreter
219
- analysis = self.code_interpreter.analyze_code(
220
- block["code"],
221
- language=block["language"]
222
- )
223
-
224
- # Get improvement suggestions
225
- suggestions = self.code_interpreter.suggest_improvements(analysis)
226
 
227
  # Format the results
228
  result = {
229
  "language": block["language"],
230
- "analysis": analysis,
231
- "suggestions": suggestions
232
  }
233
  results.append(result)
234
 
@@ -242,41 +258,12 @@ class BasicAgent:
242
  return state
243
 
244
  def _detect_file_type(self, data: bytes, url: str) -> str:
245
- """Detect file type from content and URL."""
246
- # Check URL extension first
247
- url_lower = url.lower()
248
- if url_lower.endswith((".mp4", ".avi", ".mov")):
249
- return "video"
250
- elif url_lower.endswith((".jpg", ".jpeg", ".png", ".gif")):
251
- return "image"
252
- elif url_lower.endswith(".xlsx"):
253
- return "excel"
254
- elif url_lower.endswith(".csv"):
255
- return "csv"
256
-
257
- # If URL check fails, try content-based detection
258
- try:
259
- # Try to detect image
260
- from PIL import Image
261
- Image.open(io.BytesIO(data))
262
- return "image"
263
- except:
264
- pass
265
-
266
- try:
267
- # Try to detect Excel
268
- pd.read_excel(io.BytesIO(data))
269
- return "excel"
270
- except:
271
- pass
272
-
273
- try:
274
- # Try to detect CSV
275
- pd.read_csv(io.BytesIO(data))
276
- return "csv"
277
- except:
278
- pass
279
-
280
  return "unknown"
281
 
282
  def _image_node(self, state: AgentState) -> AgentState:
@@ -329,7 +316,7 @@ class BasicAgent:
329
 
330
  def _perform_search(self, state: AgentState) -> AgentState:
331
  try:
332
- results = simple_search(state["search_query"], max_results=5)
333
  print("\nSearch Results:")
334
  for i, s in enumerate(results, 1):
335
  print(f"[{i}] {s[:120]}…")
@@ -376,16 +363,17 @@ class BasicAgent:
376
  def _generate_answer(self, state: AgentState) -> AgentState:
377
  # Collect all relevant tool outputs
378
  materials = []
379
- for item in state["history"]:
380
- if item["step"] in ("search", "image", "video", "sheet", "code_analysis"):
381
  # Handle different output formats
382
- if item["step"] == "search":
383
- output = item.get("results", [])
384
  if isinstance(output, list):
385
  output = "\n".join(output)
386
  else:
387
- output = item.get("output", "")
388
- materials.append(str(output))
 
389
 
390
  # Join all materials with proper formatting
391
  search_block = "\n".join(materials) if materials else "No artefacts available."
@@ -479,13 +467,9 @@ Think step-by-step. Write ANSWER: <answer> on its own line.
479
 
480
  def _download_file(self, url: str) -> bytes:
481
  """Download a file from a URL."""
482
- try:
483
- response = requests.get(url, timeout=20)
484
- response.raise_for_status()
485
- return response.content
486
- except Exception as e:
487
- print(f"Error downloading file from {url}: {e}")
488
- raise
489
 
490
  # ----------------------------------------------------------------------------------
491
  # Gradio Interface & Submission Routines
 
16
  import io
17
  import mimetypes
18
  import base64
 
19
  from decision_maker import DecisionMaker, ToolType
20
 
21
  # -------------------------
 
55
  # -------------------------
56
 
57
  def image_qa(image_path: str, prompt: str) -> str:
58
+ """Query MiniGPT-4-V for image-based QA."""
59
  with open(image_path, "rb") as f:
60
  data = {"prompt": prompt, "image": f.read()}
61
  headers = {"Content-Type": "application/octet-stream"}
62
+ return client.post("Vision-CAIR/MiniGPT4-V", data=data, headers=headers)
63
 
64
  def video_label(video_path: str, topk: int = 1) -> str:
65
+ """Get video classification using VideoMAE-Base-Short."""
66
  with open(video_path, "rb") as f:
67
  headers = {"Content-Type": "application/octet-stream"}
68
  preds = client.post(
69
+ "MCG-NJU/videomae-base-short-finetuned-ucf101",
70
  data=f.read(),
71
  headers=headers
72
  )
 
106
  # BasicAgent implementation
107
  # -------------------------
108
 
109
+ def quick_code_stats(src: str) -> Dict[str, Any]:
110
+ """Lightweight code analysis using AST."""
111
+ try:
112
+ tree = ast.parse(src)
113
+ funcs = [n.name for n in tree.body if isinstance(n, ast.FunctionDef)]
114
+ classes = [n.name for n in tree.body if isinstance(n, ast.ClassDef)]
115
+ imports = []
116
+ for node in ast.walk(tree):
117
+ if isinstance(node, ast.Import):
118
+ imports.extend(n.name for n in node.names)
119
+ elif isinstance(node, ast.ImportFrom):
120
+ imports.append(f"{node.module}.{node.names[0].name}")
121
+
122
+ return {
123
+ "functions": funcs,
124
+ "classes": classes,
125
+ "imports": imports,
126
+ "lines": len(src.splitlines())
127
+ }
128
+ except Exception as e:
129
+ return {
130
+ "error": str(e),
131
+ "lines": len(src.splitlines())
132
+ }
133
+
134
  class BasicAgent:
135
  def __init__(self):
136
  if not OPENAI_API_KEY:
137
  raise EnvironmentError("OPENAI_API_KEY not set")
138
  self.llm = OpenAI(api_key=OPENAI_API_KEY)
 
139
  self.decision_maker = DecisionMaker()
140
  self.workflow = self._build_workflow()
141
 
 
168
  return state
169
 
170
  # Check for file attachments in the question
171
+ try:
172
+ question_data = json.loads(state["question"])
173
+ if "file_url" in question_data:
174
+ file_url = question_data["file_url"]
175
+ # Download the file
176
+ file_data = self._download_file(file_url)
177
+ # Store in state
178
+ state["attachment_data"] = {
179
+ "content": file_data,
180
+ "type": self._detect_file_type(file_data, file_url)
181
+ }
182
+ # Set appropriate step based on file type
183
+ if state["attachment_data"]["type"] == "video":
184
+ state["current_step"] = "video"
185
+ elif state["attachment_data"]["type"] == "image":
186
+ state["current_step"] = "image"
187
+ elif state["attachment_data"]["type"] in ["excel", "csv"]:
188
+ state["current_step"] = "sheet"
 
 
 
 
 
 
189
  return state
190
+ except (json.JSONDecodeError, KeyError):
191
+ pass # Not a JSON question or no file_url
192
+ except Exception as e:
193
+ state["logs"]["file_download_error"] = str(e)
194
+ state["current_step"] = "answer"
195
+ return state
196
 
197
  # Regular text question analysis
198
  prompt = (
 
238
  try:
239
  results = []
240
  for block in state["code_blocks"]:
241
+ # Analyze code using the lightweight analyzer
242
+ analysis = quick_code_stats(block["code"])
 
 
 
 
 
 
243
 
244
  # Format the results
245
  result = {
246
  "language": block["language"],
247
+ "analysis": analysis
 
248
  }
249
  results.append(result)
250
 
 
258
  return state
259
 
260
  def _detect_file_type(self, data: bytes, url: str) -> str:
261
+ """Detect file type from URL extension."""
262
+ ext = url.split(".")[-1].lower()
263
+ if ext in {"mp4"}: return "video"
264
+ if ext in {"jpg", "jpeg", "png"}: return "image"
265
+ if ext in {"xlsx"}: return "excel"
266
+ if ext in {"csv"}: return "csv"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  return "unknown"
268
 
269
  def _image_node(self, state: AgentState) -> AgentState:
 
316
 
317
  def _perform_search(self, state: AgentState) -> AgentState:
318
  try:
319
+ results = simple_search(state["search_query"], max_results=10)
320
  print("\nSearch Results:")
321
  for i, s in enumerate(results, 1):
322
  print(f"[{i}] {s[:120]}…")
 
363
  def _generate_answer(self, state: AgentState) -> AgentState:
364
  # Collect all relevant tool outputs
365
  materials = []
366
+ for step in state["history"]:
367
+ if step["step"] in {"search", "image", "video", "sheet", "code_analysis"}:
368
  # Handle different output formats
369
+ if step["step"] == "search":
370
+ output = step.get("results", [])
371
  if isinstance(output, list):
372
  output = "\n".join(output)
373
  else:
374
+ output = step.get("output", "")
375
+ # Format the output as JSON for better readability
376
+ materials.append(json.dumps(output, indent=2))
377
 
378
  # Join all materials with proper formatting
379
  search_block = "\n".join(materials) if materials else "No artefacts available."
 
467
 
468
  def _download_file(self, url: str) -> bytes:
469
  """Download a file from a URL."""
470
+ r = requests.get(url, timeout=30)
471
+ r.raise_for_status()
472
+ return r.content
 
 
 
 
473
 
474
  # ----------------------------------------------------------------------------------
475
  # Gradio Interface & Submission Routines