|
import os |
|
import gradio as gr |
|
import requests |
|
import ast |
|
import json |
|
import time |
|
import pandas as pd |
|
from datetime import datetime |
|
from typing import List, Dict, Any, Annotated |
|
from langgraph.graph import Graph, StateGraph |
|
from typing_extensions import TypedDict |
|
from openai import OpenAI |
|
from tools import simple_search |
|
import re |
|
from huggingface_hub import InferenceClient |
|
import io |
|
import mimetypes |
|
import base64 |
|
import cv2 |
|
import numpy as np |
|
from io import BytesIO |
|
import tempfile |
|
import subprocess |
|
import sys |
|
import textwrap |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
client = InferenceClient(token=HF_TOKEN) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def override(_, new): |
|
return new |
|
|
|
def merge_dicts(old: Dict, new: Dict) -> Dict: |
|
"""Merge two dictionaries, with *new* values taking precedence.""" |
|
return {**old, **new} |
|
|
|
def tighten(q: str) -> str: |
|
""" |
|
Strip long GAIA questions down to quoted phrases and capitalised words. |
|
Falls back to the original text if we strip too much. |
|
""" |
|
quoted = re.findall(r'"([^"]+)"', q) |
|
caps = re.findall(r'\b([A-Z0-9][\w-]{2,})', q) |
|
short = " ".join(quoted + caps) |
|
return short or q |
|
|
|
|
|
|
|
|
|
|
|
def retry_hf_inference(func): |
|
"""Decorator to retry HF Inference API calls with backoff.""" |
|
def wrapper(*args, **kwargs): |
|
max_retries = 2 |
|
base_delay = 7 |
|
|
|
for attempt in range(max_retries + 1): |
|
try: |
|
return func(*args, **kwargs) |
|
except Exception as e: |
|
if attempt == max_retries: |
|
raise |
|
delay = base_delay * (attempt + 1) |
|
print(f"HF API error: {str(e)}. Retrying in {delay}s...") |
|
time.sleep(delay) |
|
return wrapper |
|
|
|
@retry_hf_inference |
|
def image_qa_bytes(data: bytes, prompt: str) -> str: |
|
"""Query LLaVA for image-based QA using bytes.""" |
|
headers = {"Content-Type": "application/octet-stream"} |
|
return client.post("llava-hf/llava-v1.6-mistral-7b-hf", data=data, headers=headers) |
|
|
|
@retry_hf_inference |
|
def video_label_bytes(data: bytes) -> str: |
|
"""Get video classification using VideoMAE-Base from bytes.""" |
|
|
|
|
|
|
|
video_bytes = BytesIO(data) |
|
cap = cv2.VideoCapture() |
|
cap.open(video_bytes) |
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
target_frames = 16 |
|
target_duration = 8 |
|
frame_interval = max(1, int(frame_count / (fps * target_duration))) |
|
|
|
frames = [] |
|
frame_idx = 0 |
|
|
|
while len(frames) < target_frames and frame_idx < frame_count: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
if frame_idx % frame_interval == 0: |
|
|
|
frame = cv2.resize(frame, (224, 224)) |
|
frames.append(frame) |
|
|
|
frame_idx += 1 |
|
|
|
cap.release() |
|
|
|
|
|
while len(frames) < target_frames: |
|
frames.append(frames[-1]) |
|
|
|
|
|
video_array = np.stack(frames) |
|
_, buffer = cv2.imencode('.mp4', video_array) |
|
processed_bytes = buffer.tobytes() |
|
|
|
|
|
headers = {"Content-Type": "application/octet-stream"} |
|
preds = client.post( |
|
"MCG-NJU/videomae-base-finetuned-ucf101", |
|
data=processed_bytes, |
|
headers=headers |
|
) |
|
return sorted(preds, key=lambda x: x["score"], reverse=True)[0]["label"] |
|
|
|
def sheet_answer_bytes(data: bytes, question: str) -> str: |
|
"""Process spreadsheet data from bytes and return numeric answer.""" |
|
if mimetypes.guess_type("x.xlsx")[0] == "text/csv" or question.endswith(".csv"): |
|
df = pd.read_csv(io.BytesIO(data)) |
|
else: |
|
df = pd.read_excel(io.BytesIO(data)) |
|
|
|
|
|
total = df[df["Category"] == "Food"]["Sales"].sum() |
|
return f"{total:.2f}" |
|
|
|
|
|
|
|
|
|
|
|
def run_python(code: str) -> str: |
|
"""Quick & dirty evaluator for Python code.""" |
|
with tempfile.NamedTemporaryFile("w+", suffix=".py", delete=False) as f: |
|
f.write(textwrap.dedent(code)) |
|
f.flush() |
|
out = subprocess.check_output([sys.executable, f.name], timeout=10) |
|
return out.decode().strip() |
|
|
|
|
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
question: Annotated[str, override] |
|
current_step: Annotated[str, override] |
|
final_answer: Annotated[str, override] |
|
history: Annotated[List[Dict[str, str]], list.__add__] |
|
needs_search: Annotated[bool, override] |
|
search_query: Annotated[str, override] |
|
task_id: Annotated[str, override] |
|
logs: Annotated[Dict[str, Any], merge_dicts] |
|
file_url: Annotated[str, override] |
|
code_blocks: Annotated[List[Dict[str, str]], list.__add__] |
|
|
|
|
|
|
|
|
|
|
|
class BasicAgent: |
|
def __init__(self, session: requests.Session): |
|
if not OPENAI_API_KEY: |
|
raise EnvironmentError("OPENAI_API_KEY not set") |
|
self.llm = OpenAI(api_key=OPENAI_API_KEY) |
|
self.workflow = self._build_workflow() |
|
self.session = session |
|
|
|
def _call_llm(self, prompt: str, max_tokens: int = 256) -> str: |
|
try: |
|
resp = self.llm.chat.completions.create( |
|
model="gpt-4.1", |
|
messages=[ |
|
{"role": "user", "content": prompt}, |
|
], |
|
temperature=0, |
|
top_p=0.1, |
|
max_tokens=max_tokens, |
|
) |
|
return resp.choices[0].message.content.strip() |
|
except Exception as e: |
|
print(f"\nLLM Error: {str(e)}") |
|
raise |
|
|
|
def _safe_parse(self, raw: str) -> str: |
|
try: |
|
return json.loads(raw)["ANSWER"] |
|
except Exception: |
|
|
|
match = re.search(r'\{.*?\}', raw, re.S) |
|
if match: |
|
try: |
|
return json.loads(match.group())["ANSWER"] |
|
except Exception: |
|
pass |
|
|
|
return raw.split(':', 1)[-1].strip() |
|
|
|
def __call__(self, question: str, task_id: str = "unknown", file_url: str = "") -> str: |
|
state: AgentState = { |
|
"question": question, |
|
"current_step": "answer", |
|
"final_answer": "", |
|
"history": [], |
|
"needs_search": False, |
|
"search_query": "", |
|
"task_id": task_id, |
|
"logs": {}, |
|
"file_url": file_url, |
|
"code_blocks": [] |
|
} |
|
|
|
print(f"\nProcessing task {task_id}") |
|
print(f"Question: {state['question']}") |
|
print(f"File URL: {state['file_url']}") |
|
|
|
final_state = self.workflow.invoke(state) |
|
return final_state["final_answer"] |
|
|
|
def _generate_answer(self, state: AgentState) -> AgentState: |
|
if state["file_url"]: |
|
try: |
|
print(f"Downloading {state['file_url']} …") |
|
response = self.session.get(state["file_url"], timeout=30) |
|
response.raise_for_status() |
|
data = response.content |
|
print(f"Successfully downloaded file, size: {len(data)} bytes") |
|
|
|
kind = mimetypes.guess_type(state["file_url"])[0] or "" |
|
print(f"Detected file type: {kind}") |
|
|
|
if "image" in kind: |
|
print("Processing as image...") |
|
answer = image_qa_bytes(data, state["question"]) |
|
elif "video" in kind: |
|
print("Processing as video...") |
|
answer = video_label_bytes(data) |
|
elif kind.endswith("spreadsheet") or state["file_url"].endswith((".xlsx", ".csv")): |
|
print("Processing as spreadsheet...") |
|
answer = sheet_answer_bytes(data, state["question"]) |
|
elif state["file_url"].endswith(".py"): |
|
print("Processing as Python file...") |
|
answer = run_python(data.decode()) |
|
else: |
|
print(f"Unsupported file type: {kind}") |
|
answer = f"Unsupported file type: {kind}" |
|
|
|
print(f"Generated answer: {answer}") |
|
state["final_answer"] = answer |
|
state["current_step"] = "done" |
|
return state |
|
except requests.exceptions.RequestException as e: |
|
print(f"Error downloading file: {e}") |
|
state["final_answer"] = f"Error downloading file: {str(e)}" |
|
state["current_step"] = "done" |
|
return state |
|
except Exception as e: |
|
print(f"\nError processing file {state['file_url']}: {str(e)}") |
|
state["final_answer"] = f"Error processing file: {str(e)}" |
|
state["current_step"] = "done" |
|
return state |
|
|
|
|
|
print("\nProcessing as text-only question...") |
|
prompt = f""" |
|
Answer this question using the materials provided. |
|
|
|
QUESTION: |
|
{state['question']} |
|
|
|
Return ONLY this exact JSON object: |
|
{{"ANSWER": "<answer text>"}} |
|
""" |
|
try: |
|
raw = self._call_llm(prompt, 300) |
|
answer = self._safe_parse(raw) |
|
print(f"Generated answer: {answer}") |
|
state["final_answer"] = answer |
|
except Exception as e: |
|
print(f"\nLLM Error in answer generation: {str(e)}") |
|
state["final_answer"] = "I encountered an error while generating the answer." |
|
|
|
state["current_step"] = "done" |
|
return state |
|
|
|
def _build_workflow(self) -> Graph: |
|
sg = StateGraph(state_schema=AgentState) |
|
sg.add_node("answer", self._generate_answer) |
|
sg.set_entry_point("answer") |
|
sg.set_finish_point("answer") |
|
return sg.compile() |
|
|
|
|
|
|
|
|
|
|
|
def run_and_submit_all(profile: gr.OAuthProfile | None): |
|
""" |
|
Fetches all questions, runs the BasicAgent on them, submits all answers, |
|
and displays the results. |
|
""" |
|
|
|
space_id = os.getenv("SPACE_ID") |
|
print("Space ID: ", space_id) |
|
if profile: |
|
username = f"{profile.username}" |
|
print(f"User logged in: {username}") |
|
else: |
|
print("User not logged in.") |
|
return "Please Login to Hugging Face with the button.", None |
|
|
|
api_url = DEFAULT_API_URL |
|
questions_url = f"{api_url}/questions" |
|
submit_url = f"{api_url}/submit" |
|
|
|
|
|
sess = requests.Session() |
|
|
|
|
|
try: |
|
print("Initializing agent...") |
|
agent = BasicAgent(session=sess) |
|
print("Agent initialized successfully.") |
|
except Exception as e: |
|
print(f"Error instantiating agent: {e}") |
|
return f"Error initializing agent: {e}", None |
|
|
|
|
|
print(f"Fetching questions from: {questions_url}") |
|
try: |
|
response = sess.get(questions_url, timeout=30) |
|
response.raise_for_status() |
|
questions_data = response.json() |
|
if not questions_data: |
|
print("Fetched questions list is empty.") |
|
return "Fetched questions list is empty or invalid format.", None |
|
print(f"Fetched {len(questions_data)} questions.") |
|
except Exception as e: |
|
print(f"Error fetching questions: {e}") |
|
return f"Error fetching questions: {e}", None |
|
|
|
|
|
results_log = [] |
|
answers_payload = [] |
|
|
|
for item in questions_data: |
|
task_id = item.get("task_id") |
|
if not task_id: |
|
continue |
|
|
|
try: |
|
print(f"\nProcessing question {task_id}...") |
|
|
|
|
|
raw_url = item.get("file_url") or "" |
|
if not raw_url: |
|
raw_url = f"/files/{task_id}" |
|
file_url = f"{api_url}{raw_url}" |
|
|
|
answer = agent( |
|
question=item.get("question", ""), |
|
task_id=task_id, |
|
file_url=file_url |
|
) |
|
|
|
|
|
answers_payload.append({ |
|
"task_id": task_id, |
|
"submitted_answer": answer |
|
}) |
|
results_log.append({ |
|
"Task ID": task_id, |
|
"Question": item.get("question", ""), |
|
"Submitted Answer": answer |
|
}) |
|
|
|
except Exception as e: |
|
print(f"Error processing task {task_id}: {e}") |
|
results_log.append({ |
|
"Task ID": task_id, |
|
"Question": item.get("question", ""), |
|
"Submitted Answer": f"ERROR: {e}" |
|
}) |
|
|
|
if not answers_payload: |
|
return "No answers were generated.", pd.DataFrame(results_log) |
|
|
|
|
|
submission_data = { |
|
"username": username.strip(), |
|
"agent_code": f"https://huggingface.co/spaces/{space_id}/tree/main", |
|
"answers": answers_payload |
|
} |
|
|
|
try: |
|
response = sess.post(submit_url, json=submission_data, timeout=60) |
|
response.raise_for_status() |
|
result_data = response.json() |
|
final_status = ( |
|
f"Submission Successful!\n" |
|
f"User: {result_data.get('username')}\n" |
|
f"Overall Score: {result_data.get('score', 'N/A')}% " |
|
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" |
|
f"Message: {result_data.get('message', 'No message received.')}" |
|
) |
|
return final_status, pd.DataFrame(results_log) |
|
except Exception as e: |
|
return f"Submission Failed: {str(e)}", pd.DataFrame(results_log) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Basic Agent Evaluation Runner") |
|
gr.Markdown( |
|
""" |
|
**Instructions:** |
|
|
|
1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ... |
|
2. Log in to your Hugging Face account using the button below. This uses your HF username for submission. |
|
3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score. |
|
|
|
--- |
|
**Disclaimers:** |
|
Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions). |
|
This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async. |
|
""" |
|
) |
|
|
|
gr.LoginButton() |
|
|
|
run_button = gr.Button("Run Evaluation & Submit All Answers") |
|
|
|
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) |
|
results_table = gr.DataFrame( |
|
label="Questions and Agent Answers", |
|
wrap=True, |
|
column_widths=["10%", "30%", "30%", "30%"] |
|
) |
|
|
|
run_button.click( |
|
fn=run_and_submit_all, |
|
outputs=[status_output, results_table] |
|
) |
|
|
|
if __name__ == "__main__": |
|
print("\n" + "-"*30 + " App Starting " + "-"*30) |
|
|
|
space_host_startup = os.getenv("SPACE_HOST") |
|
space_id_startup = os.getenv("SPACE_ID") |
|
|
|
if space_host_startup: |
|
print(f"✅ SPACE_HOST found: {space_host_startup}") |
|
print(f" Runtime URL should be: https://{space_host_startup}.hf.space") |
|
else: |
|
print("ℹ️ SPACE_HOST environment variable not found (running locally?).") |
|
|
|
if space_id_startup: |
|
print(f"✅ SPACE_ID found: {space_id_startup}") |
|
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}") |
|
print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main") |
|
else: |
|
print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.") |
|
|
|
print("-"*(60 + len(" App Starting ")) + "\n") |
|
|
|
print("Launching Gradio Interface for Basic Agent Evaluation...") |
|
demo.launch(debug=True, share=False) |
|
|