Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,6 @@ from huggingface_hub import InferenceClient
|
|
16 |
import io
|
17 |
import mimetypes
|
18 |
import base64
|
19 |
-
from decision_maker import DecisionMaker, ToolType
|
20 |
|
21 |
# -------------------------
|
22 |
# Environment & constants
|
@@ -95,98 +94,43 @@ class AgentState(TypedDict):
|
|
95 |
search_query: Annotated[str, override]
|
96 |
task_id: Annotated[str, override]
|
97 |
logs: Annotated[Dict[str, Any], merge_dicts]
|
98 |
-
|
99 |
-
attachment_data: Annotated[Dict[str, bytes], merge_dicts] # Store downloaded file data
|
100 |
|
101 |
# -------------------------
|
102 |
# BasicAgent implementation
|
103 |
# -------------------------
|
104 |
|
105 |
-
def quick_code_stats(src: str) -> Dict[str, Any]:
|
106 |
-
"""Lightweight code analysis using AST."""
|
107 |
-
try:
|
108 |
-
tree = ast.parse(src)
|
109 |
-
funcs = [n.name for n in tree.body if isinstance(n, ast.FunctionDef)]
|
110 |
-
classes = [n.name for n in tree.body if isinstance(n, ast.ClassDef)]
|
111 |
-
imports = []
|
112 |
-
for node in ast.walk(tree):
|
113 |
-
if isinstance(node, ast.Import):
|
114 |
-
imports.extend(n.name for n in node.names)
|
115 |
-
elif isinstance(node, ast.ImportFrom):
|
116 |
-
imports.append(f"{node.module}.{node.names[0].name}")
|
117 |
-
|
118 |
-
return {
|
119 |
-
"functions": funcs,
|
120 |
-
"classes": classes,
|
121 |
-
"imports": imports,
|
122 |
-
"lines": len(src.splitlines())
|
123 |
-
}
|
124 |
-
except Exception as e:
|
125 |
-
return {
|
126 |
-
"error": str(e),
|
127 |
-
"lines": len(src.splitlines())
|
128 |
-
}
|
129 |
-
|
130 |
class BasicAgent:
|
131 |
def __init__(self):
|
132 |
if not OPENAI_API_KEY:
|
133 |
raise EnvironmentError("OPENAI_API_KEY not set")
|
134 |
self.llm = OpenAI(api_key=OPENAI_API_KEY)
|
135 |
-
self.decision_maker = DecisionMaker()
|
136 |
self.workflow = self._build_workflow()
|
137 |
|
138 |
-
# ---- Low‑level LLM call
|
139 |
def _call_llm(self, prompt: str, max_tokens: int = 256) -> str:
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
return resp.choices[0].message.content.strip()
|
156 |
-
except Exception as e:
|
157 |
-
print(f"Error with {model}: {str(e)}")
|
158 |
-
if model == models[-1]: # If this was the last model
|
159 |
-
print(f"All models failed. Last error: {str(e)}")
|
160 |
-
print(f"Prompt that caused error:\n{prompt}")
|
161 |
-
raise
|
162 |
-
print(f"Falling back to next model...")
|
163 |
-
continue
|
164 |
-
|
165 |
-
# ---- Workflow nodes
|
166 |
def _analyze_question(self, state: AgentState) -> AgentState:
|
167 |
-
# First, analyze the request using the decision maker
|
168 |
-
request_analysis = self.decision_maker.analyze_request(state["question"])
|
169 |
-
state["logs"]["request_analysis"] = request_analysis
|
170 |
-
|
171 |
-
# Check for code-related content
|
172 |
-
if "code" in request_analysis["intent"]:
|
173 |
-
# Extract code blocks from the question
|
174 |
-
code_blocks = self._extract_code_blocks(state["question"])
|
175 |
-
if code_blocks:
|
176 |
-
state["current_step"] = "code_analysis"
|
177 |
-
state["code_blocks"] = code_blocks
|
178 |
-
return state
|
179 |
-
|
180 |
# Check for file attachments in the question
|
181 |
try:
|
182 |
question_data = json.loads(state["question"])
|
183 |
if "file_url" in question_data:
|
184 |
file_url = question_data["file_url"]
|
185 |
-
# Download the file
|
186 |
file_data = self._download_file(file_url)
|
187 |
-
# Store in state
|
188 |
state["attachment_data"] = file_data
|
189 |
-
# Detect type and set appropriate step
|
190 |
file_type = self._detect_file_type(file_url)
|
191 |
if file_type == "video":
|
192 |
state["current_step"] = "video"
|
@@ -196,12 +140,12 @@ class BasicAgent:
|
|
196 |
state["current_step"] = "sheet"
|
197 |
return state
|
198 |
except (json.JSONDecodeError, KeyError):
|
199 |
-
pass
|
200 |
except Exception as e:
|
201 |
print(f"\nFile handling error: {str(e)}")
|
202 |
state["current_step"] = "answer"
|
203 |
return state
|
204 |
-
|
205 |
# Regular text question analysis
|
206 |
prompt = (
|
207 |
"Decide if this question needs web search. Respond with a Python dict:\n"
|
@@ -215,53 +159,10 @@ class BasicAgent:
|
|
215 |
state["search_query"] = decision.get("search_query", state["question"])
|
216 |
except Exception as e:
|
217 |
print(f"\nLLM Error in question analysis: {str(e)}")
|
218 |
-
print(f"Raw response: {raw}")
|
219 |
state["needs_search"] = True
|
220 |
state["search_query"] = state["question"]
|
221 |
-
decision = {"parse_error": raw}
|
222 |
|
223 |
state["current_step"] = "search" if state["needs_search"] else "answer"
|
224 |
-
state["history"].append({"step": "analyze", "output": decision})
|
225 |
-
return state
|
226 |
-
|
227 |
-
def _extract_code_blocks(self, text: str) -> List[Dict[str, str]]:
|
228 |
-
"""Extract code blocks from text using markdown-style code blocks."""
|
229 |
-
code_blocks = []
|
230 |
-
pattern = r"```(\w+)?\n(.*?)```"
|
231 |
-
matches = re.finditer(pattern, text, re.DOTALL)
|
232 |
-
|
233 |
-
for match in matches:
|
234 |
-
language = match.group(1) or "python"
|
235 |
-
code = match.group(2).strip()
|
236 |
-
code_blocks.append({
|
237 |
-
"language": language,
|
238 |
-
"code": code
|
239 |
-
})
|
240 |
-
|
241 |
-
return code_blocks
|
242 |
-
|
243 |
-
def _code_analysis_node(self, state: AgentState) -> AgentState:
|
244 |
-
"""Handle code analysis requests."""
|
245 |
-
try:
|
246 |
-
results = []
|
247 |
-
for block in state["code_blocks"]:
|
248 |
-
# Analyze code using the lightweight analyzer
|
249 |
-
analysis = quick_code_stats(block["code"])
|
250 |
-
|
251 |
-
# Format the results
|
252 |
-
result = {
|
253 |
-
"language": block["language"],
|
254 |
-
"analysis": analysis
|
255 |
-
}
|
256 |
-
results.append(result)
|
257 |
-
|
258 |
-
state["history"].append({"step": "code_analysis", "output": results})
|
259 |
-
state["current_step"] = "answer"
|
260 |
-
|
261 |
-
except Exception as e:
|
262 |
-
state["logs"]["code_analysis_error"] = str(e)
|
263 |
-
state["current_step"] = "answer"
|
264 |
-
|
265 |
return state
|
266 |
|
267 |
def _detect_file_type(self, url: str) -> str:
|
@@ -320,14 +221,14 @@ class BasicAgent:
|
|
320 |
|
321 |
def _perform_search(self, state: AgentState) -> AgentState:
|
322 |
try:
|
323 |
-
results = simple_search(state["search_query"], max_results=6)
|
324 |
print("\nSearch Results:")
|
325 |
for i, s in enumerate(results, 1):
|
326 |
print(f"[{i}] {s[:120]}…")
|
327 |
|
328 |
if not results:
|
329 |
print("Warning: No search results found")
|
330 |
-
state["needs_search"] = True
|
331 |
else:
|
332 |
state["needs_search"] = False
|
333 |
|
@@ -338,34 +239,16 @@ class BasicAgent:
|
|
338 |
state["needs_search"] = True
|
339 |
state["history"].append({"step": "search", "error": str(e)})
|
340 |
|
341 |
-
state["current_step"] = "
|
342 |
return state
|
343 |
|
344 |
-
def _re_evaluate(self, state: AgentState) -> AgentState:
|
345 |
-
"""If search returned nothing, reformulate a shorter query."""
|
346 |
-
if state["needs_search"]:
|
347 |
-
state["search_query"] = tighten(state["question"])
|
348 |
-
state["current_step"] = "search"
|
349 |
-
else:
|
350 |
-
state["current_step"] = "answer"
|
351 |
-
return state
|
352 |
-
|
353 |
-
def _extract_boxed_answer(self, text: str) -> str:
|
354 |
-
"""Extract answer from boxed format or return original text if no box found."""
|
355 |
-
# Look for text between [box] and [/box] tags
|
356 |
-
box_match = re.search(r'\[box\](.*?)\[/box\]', text, re.DOTALL)
|
357 |
-
if box_match:
|
358 |
-
return box_match.group(1).strip()
|
359 |
-
return text.strip()
|
360 |
-
|
361 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
362 |
# Collect all relevant tool outputs
|
363 |
materials = []
|
364 |
for h in state["history"]:
|
365 |
-
if h["step"] in {"search", "video", "image", "sheet"
|
366 |
materials.append(json.dumps(h.get("output") or h.get("results"), indent=2))
|
367 |
|
368 |
-
# Join all materials with proper formatting
|
369 |
search_block = "\n".join(materials) if materials else "No artefacts available."
|
370 |
|
371 |
prompt = f"""
|
@@ -383,55 +266,39 @@ Write ANSWER: <answer> on its own line.
|
|
383 |
raw = self._call_llm(prompt, 300)
|
384 |
answer = raw.split("ANSWER:")[-1].strip()
|
385 |
|
386 |
-
# Validate answer
|
387 |
if not answer:
|
388 |
-
print("\nLLM Warning: Empty answer received")
|
389 |
-
print(f"Raw response: {raw}")
|
390 |
answer = "I cannot provide a definitive answer at this time."
|
391 |
-
elif any(k in answer.lower() for k in ["i cannot find", "sorry"]):
|
392 |
-
print("\nLLM Warning: LLM indicated it couldn't find an answer")
|
393 |
-
print(f"Raw response: {raw}")
|
394 |
-
answer = "Based on the available information, I cannot provide a complete answer."
|
395 |
elif "ANSWER:" not in raw:
|
396 |
-
print("\nLLM Warning: Response missing ANSWER: prefix")
|
397 |
-
print(f"Raw response: {raw}")
|
398 |
answer = "I cannot provide a definitive answer at this time."
|
399 |
|
400 |
state["final_answer"] = answer
|
401 |
-
state["history"].append({"step": "answer", "output": raw})
|
402 |
state["current_step"] = "done"
|
403 |
|
404 |
except Exception as e:
|
405 |
print(f"\nLLM Error in answer generation: {str(e)}")
|
406 |
-
print(f"Question: {state['question']}")
|
407 |
-
print(f"Materials:\n{search_block}")
|
408 |
state["final_answer"] = "I encountered an error while generating the answer."
|
409 |
state["current_step"] = "done"
|
410 |
|
411 |
return state
|
412 |
|
413 |
-
# ---- Build LangGraph workflow
|
414 |
def _build_workflow(self) -> Graph:
|
415 |
sg = StateGraph(state_schema=AgentState)
|
416 |
|
417 |
-
# Add
|
418 |
sg.add_node("analyze", self._analyze_question)
|
419 |
sg.add_node("search", self._perform_search)
|
420 |
-
sg.add_node("recheck", self._re_evaluate)
|
421 |
sg.add_node("answer", self._generate_answer)
|
422 |
sg.add_node("image", self._image_node)
|
423 |
sg.add_node("video", self._video_node)
|
424 |
sg.add_node("sheet", self._sheet_node)
|
425 |
-
sg.add_node("code_analysis", self._code_analysis_node)
|
426 |
|
427 |
# Add edges
|
428 |
sg.add_edge("analyze", "search")
|
429 |
sg.add_edge("analyze", "answer")
|
430 |
-
sg.add_edge("search", "
|
431 |
sg.add_edge("image", "answer")
|
432 |
sg.add_edge("video", "answer")
|
433 |
sg.add_edge("sheet", "answer")
|
434 |
-
sg.add_edge("code_analysis", "answer")
|
435 |
|
436 |
def router(state: AgentState):
|
437 |
return state["current_step"]
|
@@ -441,19 +308,13 @@ Write ANSWER: <answer> on its own line.
|
|
441 |
"answer": "answer",
|
442 |
"image": "image",
|
443 |
"video": "video",
|
444 |
-
"sheet": "sheet"
|
445 |
-
"code_analysis": "code_analysis"
|
446 |
-
})
|
447 |
-
sg.add_conditional_edges("recheck", router, {
|
448 |
-
"search": "search",
|
449 |
-
"answer": "answer"
|
450 |
})
|
451 |
|
452 |
sg.set_entry_point("analyze")
|
453 |
sg.set_finish_point("answer")
|
454 |
return sg.compile()
|
455 |
|
456 |
-
# ---- Public call
|
457 |
def __call__(self, question: str, task_id: str = "unknown") -> str:
|
458 |
state: AgentState = {
|
459 |
"question": question,
|
@@ -464,7 +325,6 @@ Write ANSWER: <answer> on its own line.
|
|
464 |
"search_query": "",
|
465 |
"task_id": task_id,
|
466 |
"logs": {},
|
467 |
-
"code_blocks": [],
|
468 |
"attachment_data": {}
|
469 |
}
|
470 |
final_state = self.workflow.invoke(state)
|
@@ -558,7 +418,6 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
558 |
"search_query": "",
|
559 |
"task_id": task_id,
|
560 |
"logs": {},
|
561 |
-
"code_blocks": [],
|
562 |
"attachment_data": {}
|
563 |
}
|
564 |
|
|
|
16 |
import io
|
17 |
import mimetypes
|
18 |
import base64
|
|
|
19 |
|
20 |
# -------------------------
|
21 |
# Environment & constants
|
|
|
94 |
search_query: Annotated[str, override]
|
95 |
task_id: Annotated[str, override]
|
96 |
logs: Annotated[Dict[str, Any], merge_dicts]
|
97 |
+
attachment_data: Annotated[Dict[str, bytes], merge_dicts]
|
|
|
98 |
|
99 |
# -------------------------
|
100 |
# BasicAgent implementation
|
101 |
# -------------------------
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
class BasicAgent:
|
104 |
def __init__(self):
|
105 |
if not OPENAI_API_KEY:
|
106 |
raise EnvironmentError("OPENAI_API_KEY not set")
|
107 |
self.llm = OpenAI(api_key=OPENAI_API_KEY)
|
|
|
108 |
self.workflow = self._build_workflow()
|
109 |
|
|
|
110 |
def _call_llm(self, prompt: str, max_tokens: int = 256) -> str:
|
111 |
+
try:
|
112 |
+
resp = self.llm.chat.completions.create(
|
113 |
+
model="gpt-4o-mini",
|
114 |
+
messages=[
|
115 |
+
{"role": "system", "content": "You are a careful reasoning assistant."},
|
116 |
+
{"role": "user", "content": prompt},
|
117 |
+
],
|
118 |
+
temperature=0.3,
|
119 |
+
max_tokens=max_tokens,
|
120 |
+
)
|
121 |
+
return resp.choices[0].message.content.strip()
|
122 |
+
except Exception as e:
|
123 |
+
print(f"\nLLM Error: {str(e)}")
|
124 |
+
raise
|
125 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
def _analyze_question(self, state: AgentState) -> AgentState:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
# Check for file attachments in the question
|
128 |
try:
|
129 |
question_data = json.loads(state["question"])
|
130 |
if "file_url" in question_data:
|
131 |
file_url = question_data["file_url"]
|
|
|
132 |
file_data = self._download_file(file_url)
|
|
|
133 |
state["attachment_data"] = file_data
|
|
|
134 |
file_type = self._detect_file_type(file_url)
|
135 |
if file_type == "video":
|
136 |
state["current_step"] = "video"
|
|
|
140 |
state["current_step"] = "sheet"
|
141 |
return state
|
142 |
except (json.JSONDecodeError, KeyError):
|
143 |
+
pass
|
144 |
except Exception as e:
|
145 |
print(f"\nFile handling error: {str(e)}")
|
146 |
state["current_step"] = "answer"
|
147 |
return state
|
148 |
+
|
149 |
# Regular text question analysis
|
150 |
prompt = (
|
151 |
"Decide if this question needs web search. Respond with a Python dict:\n"
|
|
|
159 |
state["search_query"] = decision.get("search_query", state["question"])
|
160 |
except Exception as e:
|
161 |
print(f"\nLLM Error in question analysis: {str(e)}")
|
|
|
162 |
state["needs_search"] = True
|
163 |
state["search_query"] = state["question"]
|
|
|
164 |
|
165 |
state["current_step"] = "search" if state["needs_search"] else "answer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
return state
|
167 |
|
168 |
def _detect_file_type(self, url: str) -> str:
|
|
|
221 |
|
222 |
def _perform_search(self, state: AgentState) -> AgentState:
|
223 |
try:
|
224 |
+
results = simple_search(state["search_query"], max_results=6)
|
225 |
print("\nSearch Results:")
|
226 |
for i, s in enumerate(results, 1):
|
227 |
print(f"[{i}] {s[:120]}…")
|
228 |
|
229 |
if not results:
|
230 |
print("Warning: No search results found")
|
231 |
+
state["needs_search"] = True
|
232 |
else:
|
233 |
state["needs_search"] = False
|
234 |
|
|
|
239 |
state["needs_search"] = True
|
240 |
state["history"].append({"step": "search", "error": str(e)})
|
241 |
|
242 |
+
state["current_step"] = "answer"
|
243 |
return state
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
def _generate_answer(self, state: AgentState) -> AgentState:
|
246 |
# Collect all relevant tool outputs
|
247 |
materials = []
|
248 |
for h in state["history"]:
|
249 |
+
if h["step"] in {"search", "video", "image", "sheet"}:
|
250 |
materials.append(json.dumps(h.get("output") or h.get("results"), indent=2))
|
251 |
|
|
|
252 |
search_block = "\n".join(materials) if materials else "No artefacts available."
|
253 |
|
254 |
prompt = f"""
|
|
|
266 |
raw = self._call_llm(prompt, 300)
|
267 |
answer = raw.split("ANSWER:")[-1].strip()
|
268 |
|
|
|
269 |
if not answer:
|
|
|
|
|
270 |
answer = "I cannot provide a definitive answer at this time."
|
|
|
|
|
|
|
|
|
271 |
elif "ANSWER:" not in raw:
|
|
|
|
|
272 |
answer = "I cannot provide a definitive answer at this time."
|
273 |
|
274 |
state["final_answer"] = answer
|
|
|
275 |
state["current_step"] = "done"
|
276 |
|
277 |
except Exception as e:
|
278 |
print(f"\nLLM Error in answer generation: {str(e)}")
|
|
|
|
|
279 |
state["final_answer"] = "I encountered an error while generating the answer."
|
280 |
state["current_step"] = "done"
|
281 |
|
282 |
return state
|
283 |
|
|
|
284 |
def _build_workflow(self) -> Graph:
|
285 |
sg = StateGraph(state_schema=AgentState)
|
286 |
|
287 |
+
# Add nodes
|
288 |
sg.add_node("analyze", self._analyze_question)
|
289 |
sg.add_node("search", self._perform_search)
|
|
|
290 |
sg.add_node("answer", self._generate_answer)
|
291 |
sg.add_node("image", self._image_node)
|
292 |
sg.add_node("video", self._video_node)
|
293 |
sg.add_node("sheet", self._sheet_node)
|
|
|
294 |
|
295 |
# Add edges
|
296 |
sg.add_edge("analyze", "search")
|
297 |
sg.add_edge("analyze", "answer")
|
298 |
+
sg.add_edge("search", "answer")
|
299 |
sg.add_edge("image", "answer")
|
300 |
sg.add_edge("video", "answer")
|
301 |
sg.add_edge("sheet", "answer")
|
|
|
302 |
|
303 |
def router(state: AgentState):
|
304 |
return state["current_step"]
|
|
|
308 |
"answer": "answer",
|
309 |
"image": "image",
|
310 |
"video": "video",
|
311 |
+
"sheet": "sheet"
|
|
|
|
|
|
|
|
|
|
|
312 |
})
|
313 |
|
314 |
sg.set_entry_point("analyze")
|
315 |
sg.set_finish_point("answer")
|
316 |
return sg.compile()
|
317 |
|
|
|
318 |
def __call__(self, question: str, task_id: str = "unknown") -> str:
|
319 |
state: AgentState = {
|
320 |
"question": question,
|
|
|
325 |
"search_query": "",
|
326 |
"task_id": task_id,
|
327 |
"logs": {},
|
|
|
328 |
"attachment_data": {}
|
329 |
}
|
330 |
final_state = self.workflow.invoke(state)
|
|
|
418 |
"search_query": "",
|
419 |
"task_id": task_id,
|
420 |
"logs": {},
|
|
|
421 |
"attachment_data": {}
|
422 |
}
|
423 |
|