Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,21 @@ from typing_extensions import TypedDict
|
|
12 |
from openai import OpenAI
|
13 |
from tools import simple_search
|
14 |
import re
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# -------------------------
|
17 |
# Utility helpers
|
@@ -35,13 +50,36 @@ def tighten(q: str) -> str:
|
|
35 |
return short or q
|
36 |
|
37 |
# -------------------------
|
38 |
-
#
|
39 |
# -------------------------
|
40 |
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
# -------------------------
|
47 |
# State definition
|
@@ -83,27 +121,70 @@ class BasicAgent:
|
|
83 |
|
84 |
# ---- Workflow nodes
|
85 |
def _analyze_question(self, state: AgentState) -> AgentState:
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
"
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
try:
|
94 |
-
|
95 |
-
state["
|
96 |
-
state["
|
97 |
-
except Exception:
|
98 |
-
|
99 |
-
state["
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
return state
|
108 |
|
109 |
def _perform_search(self, state: AgentState) -> AgentState:
|
@@ -147,26 +228,28 @@ class BasicAgent:
|
|
147 |
search_block = "Error retrieving search results."
|
148 |
|
149 |
prompt = f"""
|
150 |
-
You are an expert
|
151 |
|
152 |
-
|
153 |
{state['question']}
|
154 |
|
155 |
-
|
156 |
{search_block}
|
157 |
|
158 |
-
Think step-by-step.
|
159 |
-
END EACH STEP with ➤. After reasoning, output:
|
160 |
-
|
161 |
-
ANSWER: <the short answer here>
|
162 |
-
|
163 |
-
No other text.
|
164 |
"""
|
165 |
raw = self._call_llm(prompt, 300)
|
166 |
-
answer = raw.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
state["final_answer"] = answer
|
169 |
-
state["history"].append({"step": "answer", "output": raw})
|
170 |
state["logs"]["final_answer"] = {"prompt": prompt, "response": raw}
|
171 |
state["current_step"] = "done"
|
172 |
return state
|
@@ -174,21 +257,39 @@ No other text.
|
|
174 |
# ---- Build LangGraph workflow
|
175 |
def _build_workflow(self) -> Graph:
|
176 |
sg = StateGraph(state_schema=AgentState)
|
|
|
|
|
177 |
sg.add_node("analyze", self._analyze_question)
|
178 |
sg.add_node("search", self._perform_search)
|
179 |
sg.add_node("recheck", self._re_evaluate)
|
180 |
sg.add_node("answer", self._generate_answer)
|
|
|
|
|
|
|
181 |
|
182 |
-
#
|
183 |
sg.add_edge("analyze", "search")
|
184 |
sg.add_edge("analyze", "answer")
|
185 |
sg.add_edge("search", "recheck")
|
|
|
|
|
|
|
186 |
|
187 |
def router(state: AgentState):
|
188 |
return state["current_step"]
|
189 |
|
190 |
-
sg.add_conditional_edges("analyze", router, {
|
191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
sg.set_entry_point("analyze")
|
193 |
sg.set_finish_point("answer")
|
194 |
return sg.compile()
|
|
|
12 |
from openai import OpenAI
|
13 |
from tools import simple_search
|
14 |
import re
|
15 |
+
from huggingface_hub import InferenceClient
|
16 |
+
import io
|
17 |
+
import mimetypes
|
18 |
+
import base64
|
19 |
+
|
20 |
+
# -------------------------
|
21 |
+
# Environment & constants
|
22 |
+
# -------------------------
|
23 |
+
|
24 |
+
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
25 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
26 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
27 |
+
|
28 |
+
# Initialize HF client
|
29 |
+
client = InferenceClient(token=HF_TOKEN)
|
30 |
|
31 |
# -------------------------
|
32 |
# Utility helpers
|
|
|
50 |
return short or q
|
51 |
|
52 |
# -------------------------
|
53 |
+
# Multimodal helpers
|
54 |
# -------------------------
|
55 |
|
56 |
+
def image_qa(image_path: str, prompt: str) -> str:
|
57 |
+
"""Query LLaVA model for image-based QA."""
|
58 |
+
with open(image_path, "rb") as f:
|
59 |
+
data = {"prompt": prompt, "image": f.read()}
|
60 |
+
return client.post("llava-hf/llava-v1.6-mistral-7b-hf", data=data)
|
61 |
+
|
62 |
+
def video_label(video_path: str, topk: int = 1) -> str:
|
63 |
+
"""Get video classification using VideoMAE."""
|
64 |
+
with open(video_path, "rb") as f:
|
65 |
+
preds = client.post(
|
66 |
+
"MCG-NJU/videomae-base-finetuned-ucf101", data=f.read()
|
67 |
+
)
|
68 |
+
preds = sorted(preds, key=lambda x: x["score"], reverse=True)[:topk]
|
69 |
+
return preds[0]["label"]
|
70 |
|
71 |
+
def sheet_answer(data: bytes, question: str) -> str:
|
72 |
+
"""Process spreadsheet data and answer questions."""
|
73 |
+
if mimetypes.guess_type("x.xlsx")[0] == "text/csv" or question.endswith(".csv"):
|
74 |
+
df = pd.read_csv(io.BytesIO(data))
|
75 |
+
else:
|
76 |
+
df = pd.read_excel(io.BytesIO(data))
|
77 |
+
numeric_cols = df.select_dtypes("number")
|
78 |
+
col = numeric_cols.max().idxmax()
|
79 |
+
row = numeric_cols[col].idxmax()
|
80 |
+
value = df.loc[row, col]
|
81 |
+
label = df.columns[col]
|
82 |
+
return f"{label}: {value}"
|
83 |
|
84 |
# -------------------------
|
85 |
# State definition
|
|
|
121 |
|
122 |
# ---- Workflow nodes
|
123 |
def _analyze_question(self, state: AgentState) -> AgentState:
|
124 |
+
# Check for multimodal content
|
125 |
+
q = state["question"].lower()
|
126 |
+
if "video" in q or q.endswith(".mp4"):
|
127 |
+
state["current_step"] = "video"
|
128 |
+
elif q.endswith((".jpg", ".png", ".jpeg")):
|
129 |
+
state["current_step"] = "image"
|
130 |
+
elif q.endswith((".xlsx", ".csv")):
|
131 |
+
state["current_step"] = "sheet"
|
132 |
+
else:
|
133 |
+
# Regular text question analysis
|
134 |
+
prompt = (
|
135 |
+
"You will receive a user question. Think step‑by‑step to decide whether external web search is required. "
|
136 |
+
"Respond ONLY with a valid Python dict literal in the following format and NOTHING else:\n"
|
137 |
+
"{\n 'needs_search': bool,\n 'search_query': str\n} \n\n"
|
138 |
+
f"Question: {state['question']}"
|
139 |
+
)
|
140 |
+
raw = self._call_llm(prompt)
|
141 |
+
try:
|
142 |
+
decision = ast.literal_eval(raw)
|
143 |
+
state["needs_search"] = bool(decision.get("needs_search", False))
|
144 |
+
state["search_query"] = decision.get("search_query", state["question"])
|
145 |
+
except Exception:
|
146 |
+
state["needs_search"] = True
|
147 |
+
state["search_query"] = state["question"]
|
148 |
+
decision = {"parse_error": raw}
|
149 |
+
state["logs"] = {
|
150 |
+
"analyze": {"prompt": prompt, "llm_response": raw, "decision": decision}
|
151 |
+
}
|
152 |
+
state["current_step"] = "search" if state["needs_search"] else "answer"
|
153 |
+
state["history"].append({"step": "analyze", "output": decision})
|
154 |
+
return state
|
155 |
+
|
156 |
+
def _image_node(self, state: AgentState) -> AgentState:
|
157 |
+
"""Handle image-based questions."""
|
158 |
try:
|
159 |
+
answer = image_qa(state["question"], "What is shown in this image?")
|
160 |
+
state["history"].append({"step": "image", "output": answer})
|
161 |
+
state["current_step"] = "answer"
|
162 |
+
except Exception as e:
|
163 |
+
state["logs"]["image_error"] = str(e)
|
164 |
+
state["current_step"] = "answer"
|
165 |
+
return state
|
166 |
+
|
167 |
+
def _video_node(self, state: AgentState) -> AgentState:
|
168 |
+
"""Handle video-based questions."""
|
169 |
+
try:
|
170 |
+
label = video_label(state["question"])
|
171 |
+
state["history"].append({"step": "video", "output": label})
|
172 |
+
state["current_step"] = "answer"
|
173 |
+
except Exception as e:
|
174 |
+
state["logs"]["video_error"] = str(e)
|
175 |
+
state["current_step"] = "answer"
|
176 |
+
return state
|
177 |
+
|
178 |
+
def _sheet_node(self, state: AgentState) -> AgentState:
|
179 |
+
"""Handle spreadsheet-based questions."""
|
180 |
+
try:
|
181 |
+
with open(state["question"], "rb") as f:
|
182 |
+
answer = sheet_answer(f.read(), state["question"])
|
183 |
+
state["history"].append({"step": "sheet", "output": answer})
|
184 |
+
state["current_step"] = "answer"
|
185 |
+
except Exception as e:
|
186 |
+
state["logs"]["sheet_error"] = str(e)
|
187 |
+
state["current_step"] = "answer"
|
188 |
return state
|
189 |
|
190 |
def _perform_search(self, state: AgentState) -> AgentState:
|
|
|
228 |
search_block = "Error retrieving search results."
|
229 |
|
230 |
prompt = f"""
|
231 |
+
You are an expert assistant. Use ONLY the materials below to answer.
|
232 |
|
233 |
+
QUESTION:
|
234 |
{state['question']}
|
235 |
|
236 |
+
MATERIALS:
|
237 |
{search_block}
|
238 |
|
239 |
+
Think step-by-step. Write ANSWER: <answer> on its own line.
|
|
|
|
|
|
|
|
|
|
|
240 |
"""
|
241 |
raw = self._call_llm(prompt, 300)
|
242 |
+
answer = raw.split("ANSWER:")[-1].strip()
|
243 |
+
|
244 |
+
# Validate answer
|
245 |
+
if not answer:
|
246 |
+
answer = "I cannot provide a definitive answer at this time."
|
247 |
+
elif any(k in answer.lower() for k in ["i cannot find", "sorry"]):
|
248 |
+
# Fall back to a more general response
|
249 |
+
answer = "Based on the available information, I cannot provide a complete answer."
|
250 |
|
251 |
state["final_answer"] = answer
|
252 |
+
state["history"].append({"step": "answer", "output": raw})
|
253 |
state["logs"]["final_answer"] = {"prompt": prompt, "response": raw}
|
254 |
state["current_step"] = "done"
|
255 |
return state
|
|
|
257 |
# ---- Build LangGraph workflow
|
258 |
def _build_workflow(self) -> Graph:
|
259 |
sg = StateGraph(state_schema=AgentState)
|
260 |
+
|
261 |
+
# Add all nodes
|
262 |
sg.add_node("analyze", self._analyze_question)
|
263 |
sg.add_node("search", self._perform_search)
|
264 |
sg.add_node("recheck", self._re_evaluate)
|
265 |
sg.add_node("answer", self._generate_answer)
|
266 |
+
sg.add_node("image", self._image_node)
|
267 |
+
sg.add_node("video", self._video_node)
|
268 |
+
sg.add_node("sheet", self._sheet_node)
|
269 |
|
270 |
+
# Add edges
|
271 |
sg.add_edge("analyze", "search")
|
272 |
sg.add_edge("analyze", "answer")
|
273 |
sg.add_edge("search", "recheck")
|
274 |
+
sg.add_edge("image", "answer")
|
275 |
+
sg.add_edge("video", "answer")
|
276 |
+
sg.add_edge("sheet", "answer")
|
277 |
|
278 |
def router(state: AgentState):
|
279 |
return state["current_step"]
|
280 |
|
281 |
+
sg.add_conditional_edges("analyze", router, {
|
282 |
+
"search": "search",
|
283 |
+
"answer": "answer",
|
284 |
+
"image": "image",
|
285 |
+
"video": "video",
|
286 |
+
"sheet": "sheet"
|
287 |
+
})
|
288 |
+
sg.add_conditional_edges("recheck", router, {
|
289 |
+
"search": "search",
|
290 |
+
"answer": "answer"
|
291 |
+
})
|
292 |
+
|
293 |
sg.set_entry_point("analyze")
|
294 |
sg.set_finish_point("answer")
|
295 |
return sg.compile()
|