Spaces:
Starting
Starting
import os | |
import gradio as gr | |
import requests | |
import aiohttp | |
import asyncio | |
import json | |
import nest_asyncio | |
from langgraph.graph import StateGraph, END | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_huggingface import HuggingFacePipeline | |
from transformers import pipeline | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool | |
from tools.search import initialize_search_tools | |
from state import JARVISState | |
import pandas as pd | |
from dotenv import load_dotenv | |
import logging | |
from langfuse.callback import CallbackHandler | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
# Apply nest_asyncio | |
nest_asyncio.apply() | |
# Load environment variables | |
load_dotenv() | |
# Verify environment variables | |
required_env_vars = ["SPACE_ID", "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"] | |
for var in required_env_vars: | |
if not os.getenv(var): | |
raise ValueError(f"Environment variable {var} is not set") | |
logger.info(f"Environment variables loaded: SPACE_ID={os.getenv('SPACE_ID')[:10]}..., LANGFUSE_HOST={os.getenv('LANGFUSE_HOST', 'https://cloud.langfuse.com')}") | |
# Initialize Hugging Face model | |
try: | |
hf_pipeline = pipeline( | |
"text-generation", | |
model="mistralai/Mixtral-7B-Instruct-v0.1", | |
device_map="auto", | |
max_new_tokens=512, | |
do_sample=True, | |
temperature=0.7 | |
) | |
llm = HuggingFacePipeline(pipeline=hf_pipeline) | |
logger.info("HuggingFace model initialized: mistralai/Mixtral-7B-Instruct-v0.1") | |
except Exception as e: | |
logger.error(f"Failed to initialize HuggingFace model: {e}") | |
llm = None | |
# Initialize search tools with LLM | |
try: | |
initialize_search_tools(llm) | |
logger.info("Search tools initialized") | |
except Exception as e: | |
logger.error(f"Failed to initialize search tools: {e}") | |
# Initialize Langfuse | |
try: | |
langfuse = CallbackHandler( | |
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), | |
secret_key=os.getenv("LANGFUSE_SECRET_KEY"), | |
host=os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com") | |
) | |
logger.info("Langfuse initialized successfully") | |
except Exception as e: | |
logger.warning(f"Failed to initialize Langfuse: {e}") | |
langfuse = None | |
# Initialize MemorySaver | |
memory = MemorySaver() | |
use_checkpointing = True | |
# --- Constants --- | |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space/api" | |
GAIA_FILE_URL = "https://api.gaia-benchmark.com/files/" | |
# --- Helper Functions --- | |
def log_state(task_id: str, state: JARVISState): | |
"""Log intermediate state to state_log.json""" | |
try: | |
log_entry = { | |
"task_id": task_id, | |
"question": state["question"], | |
"tools_needed": state["tools_needed"], | |
"web_results": state["web_results"], | |
"file_results": state["file_results"], | |
"image_results": state["image_results"], | |
"calculation_results": state["calculation_results"], | |
"document_results": state["document_results"], | |
"answer": state["answer"] | |
} | |
with open("state_log.json", "a") as f: | |
json.dump(log_entry, f, indent=2) | |
f.write("\n") | |
except Exception as e: | |
logger.error(f"Error logging state for task {task_id}: {e}") | |
async def test_gaia_api(task_id: str) -> bool: | |
"""Test connectivity to GAIA file API""" | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.head(f"{GAIA_FILE_URL}{task_id}", timeout=5) as resp: | |
return resp.status in [200, 403, 404] | |
except Exception as e: | |
logger.warning(f"GAIA API test failed: {e}") | |
return False | |
# --- Node Functions --- | |
async def parse_question(state: JARVISState) -> JARVISState: | |
try: | |
question = state["question"] | |
prompt = f"""Analyze this GAIA question: {question} | |
Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever). | |
Return a JSON list of tool names.""" | |
if llm: | |
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse] if langfuse else []}) | |
try: | |
tools_needed = json.loads(response.content) | |
except json.JSONDecodeError as je: | |
logger.warning(f"Invalid JSON in LLM response for task {state['task_id']}: {je}") | |
tools_needed = ["web_search"] | |
else: | |
logger.warning("No LLM available, using default tools") | |
tools_needed = ["web_search"] | |
state["tools_needed"] = tools_needed | |
log_state(state["task_id"], state) | |
return state | |
except Exception as e: | |
logger.error(f"Error parsing question for task {state['task_id']}: {e}") | |
state["tools_needed"] = [] | |
log_state(state["task_id"], state) | |
return state | |
async def tool_dispatcher(state: JARVISState) -> JARVISState: | |
try: | |
tools_needed = state["tools_needed"] | |
updated_state = state.copy() | |
can_download_files = await test_gaia_api(updated_state["task_id"]) | |
for tool in tools_needed: | |
try: | |
if tool == "web_search" or tool == "multi_hop_search": | |
result = await web_search_agent(updated_state) | |
updated_state["web_results"].extend(result["web_results"]) | |
elif tool == "file_parser" and can_download_files: | |
result = await file_parser_agent(updated_state) | |
updated_state["file_results"] = result["file_results"] | |
elif tool == "image_parser" and can_download_files: | |
result = await image_parser_agent(updated_state) | |
updated_state["image_results"] = result["image_results"] | |
elif tool == "calculator": | |
result = await calculator_agent(updated_state) | |
updated_state["calculation_results"] = result["calculation_results"] | |
elif tool == "document_retriever" and can_download_files: | |
result = await document_retriever_agent(updated_state) | |
updated_state["document_results"] = result["document_results"] | |
except Exception as e: | |
logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {e}") | |
log_state(updated_state["task_id"], updated_state) | |
return updated_state | |
except Exception as e: | |
logger.error(f"Error in tool dispatcher for task {state['task_id']}: {e}") | |
log_state(state["task_id"], state) | |
return state | |
async def web_search_agent(state: JARVISState) -> JARVISState: | |
try: | |
results = [] | |
if "web_search" in state["tools_needed"]: | |
result = await search_tool.invoke({"query": state["question"]}) | |
results.append(result) | |
if "multi_hop_search" in state["tools_needed"]: | |
result = await multi_hop_search_tool.invoke({"query": state["question"], "steps": 3}) | |
results.append(result) | |
return {"web_results": results} | |
except Exception as e: | |
logger.error(f"Error in web search for task {state['task_id']}: {e}") | |
return {"web_results": []} | |
async def file_parser_agent(state: JARVISState) -> JARVISState: | |
try: | |
if "file_parser" in state["tools_needed"]: | |
file_type = "csv" if "data" in state["question"].lower() else "txt" | |
result = await file_parser_tool.aparse(state["task_id"], file_type=file_type) | |
return {"file_results": result} | |
return {"file_results": ""} | |
except Exception as e: | |
logger.error(f"Error in file parser for task {state['task_id']}: {e}") | |
return {"file_results": "File parsing failed"} | |
async def image_parser_agent(state: JARVISState) -> JARVISState: | |
try: | |
if "image_parser" in state["tools_needed"]: | |
task = "match" if "fruits" in state["question"].lower() else "describe" | |
match_query = "fruits" if task == "match" else "" | |
file_path = f"temp_{state['task_id']}.jpg" | |
if not os.path.exists(file_path): | |
logger.warning(f"Image file not found for task {state['task_id']}") | |
return {"image_results": "Image file not found"} | |
result = await image_parser_tool.aparse( | |
file_path, task=task, match_query=match_query | |
) | |
return {"image_results": result} | |
return {"image_results": ""} | |
except Exception as e: | |
logger.error(f"Error in image parser for task {state['task_id']}: {e}") | |
return {"image_results": "Image parsing failed"} | |
async def calculator_agent(state: JARVISState) -> JARVISState: | |
try: | |
if "calculator" in state["tools_needed"]: | |
prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}" | |
if llm: | |
response = await llm.ainvoke(prompt, config={"callbacks": [langfuse] if langfuse else []}) | |
expression = response.content | |
else: | |
expression = "0" | |
result = await calculator_tool.aparse(expression) | |
return {"calculation_results": result} | |
return {"calculation_results": ""} | |
except Exception as e: | |
logger.error(f"Error in calculator for task {state['task_id']}: {e}") | |
return {"calculation_results": "Calculation failed"} | |
async def document_retriever_agent(state: JARVISState) -> JARVISState: | |
try: | |
if "document_retriever" in state["tools_needed"]: | |
file_type = "txt" if "menu" in state["question"].lower() else "csv" | |
if "report" in state["question"].lower() or "document" in state["question"].lower(): | |
file_type = "pdf" | |
result = await document_retriever_tool.aparse( | |
state["task_id"], state["question"], file_type=file_type | |
) | |
return {"document_results": result} | |
return {"document_results": ""} | |
except Exception as e: | |
logger.error(f"Error in document retriever for task {state['task_id']}: {e}") | |
return {"document_results": "Document retrieval failed"} | |
async def reasoning_agent(state: JARVISState) -> JARVISState: | |
try: | |
prompt = f"""Question: {state['question']} | |
Web Results: {state['web_results']} | |
File Results: {state['file_results']} | |
Image Results: {state['image_results']} | |
Calculation Results: {state['calculation_results']} | |
Document Results: {state['document_results']} | |
Synthesize an exact-match answer for the GAIA benchmark. | |
Output only the answer (e.g., '90', 'White;5876').""" | |
if llm: | |
response = await llm.ainvoke( | |
[ | |
SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."), | |
HumanMessage(content=prompt) | |
], | |
config={"callbacks": [langfuse] if langfuse else []} | |
) | |
answer = response.content.strip() | |
else: | |
answer = "Unknown" | |
state["answer"] = answer | |
log_state(state["task_id"], state) | |
return state | |
except Exception as e: | |
logger.error(f"Error in reasoning for task {state['task_id']}: {e}") | |
state["answer"] = "Error in reasoning" | |
log_state(state["task_id"], state) | |
return state | |
def router(state: JARVISState) -> str: | |
if state["tools_needed"]: | |
return "tool_dispatcher" | |
return "reasoning" | |
# --- Define StateGraph --- | |
workflow = StateGraph(JARVISState) | |
workflow.add_node("parse", parse_question) | |
workflow.add_node("tool_dispatcher", tool_dispatcher) | |
workflow.add_node("reasoning", reasoning_agent) | |
workflow.set_entry_point("parse") | |
workflow.add_conditional_edges( | |
"parse", | |
router, | |
{ | |
"tool_dispatcher": "tool_dispatcher", | |
"reasoning": "reasoning" | |
} | |
) | |
workflow.add_edge("tool_dispatcher", "reasoning") | |
workflow.add_edge("reasoning", END) | |
# Compile graph | |
graph = workflow.compile(checkpointer=memory if use_checkpointing else None) | |
# --- Basic Agent Definition --- | |
class BasicAgent: | |
def __init__(self): | |
logger.info("BasicAgent initialized.") | |
async def process_question(self, task_id: str, question: str) -> str: | |
file_type = "jpg" if "image" in question.lower() else "txt" | |
if "menu" in question.lower() or "report" in question.lower() or "document" in question.lower(): | |
file_type = "pdf" | |
elif "data" in question.lower(): | |
file_type = "csv" | |
file_path = f"temp_{task_id}.{file_type}" | |
if await test_gaia_api(task_id): | |
try: | |
async with aiohttp.ClientSession() as session: | |
async with session.get(f"{GAIA_FILE_URL}{task_id}") as resp: | |
if resp.status == 200: | |
with open(file_path, "wb") as f: | |
f.write(await resp.read()) | |
else: | |
logger.warning(f"Failed to download file for task {task_id}: HTTP {resp.status}") | |
except Exception as e: | |
logger.error(f"Error downloading file for task {task_id}: {e}") | |
state = JARVISState( | |
task_id=task_id, | |
question=question, | |
tools_needed=[], | |
web_results=[], | |
file_results="", | |
image_results="", | |
calculation_results="", | |
document_results="", | |
messages=[], | |
answer="" | |
) | |
try: | |
config = {"configurable": {"thread_id": task_id}} if use_checkpointing else {} | |
result = await graph.ainvoke(state, config=config) | |
return result["answer"] or "No answer generated" | |
except Exception as e: | |
logger.error(f"Error processing task {task_id}: {e}") | |
return f"Error: {str(e)}" | |
finally: | |
if os.path.exists(file_path): | |
try: | |
os.remove(file_path) | |
except Exception as e: | |
logger.error(f"Error removing file {file_path}: {e}") | |
async def async_call(self, question: str, task_id: str) -> str: | |
return await self.process_question(task_id, question) | |
def __call__(self, question: str, task_id: str = None) -> str: | |
logger.info(f"Agent received question (first 50 chars): {question[:50]}...") | |
if task_id is None: | |
logger.warning("task_id not provided, using placeholder") | |
task_id = "placeholder_task_id" | |
try: | |
try: | |
loop = asyncio.get_event_loop() | |
except RuntimeError: | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
return loop.run_until_complete(self.async_call(question, task_id)) | |
finally: | |
pass | |
# --- Main Function --- | |
def run_and_submit_all(profile: gr.OAuthProfile | None): | |
space_id = os.getenv("SPACE_ID") | |
if not profile: | |
logger.error("User not logged in.") | |
return "Please Login to Hugging Face with the button.", None | |
username = f"{profile.username}" | |
logger.info(f"User logged in: {username}") | |
api_url = DEFAULT_API_URL | |
questions_url = f"{api_url}/questions" | |
submit_url = f"{api_url}/submit" | |
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
try: | |
agent = BasicAgent() | |
except Exception as e: | |
logger.error(f"Error instantiating agent: {e}") | |
return f"Error initializing agent: {e}", None | |
logger.info(f"Fetching questions from: {questions_url}") | |
try: | |
response = requests.get(questions_url, timeout=15) | |
response.raise_for_status() | |
questions_data = response.json() | |
if not questions_data: | |
logger.error("Fetched questions list is empty.") | |
return "Fetched questions list is empty or invalid format.", None | |
logger.info(f"Fetched {len(questions_data)} questions.") | |
except Exception as e: | |
logger.error(f"Error fetching questions: {e}") | |
return f"Error fetching questions: {e}", None | |
results_log = [] | |
answers_payload = [] | |
logger.info(f"Running agent on {len(questions_data)} questions...") | |
for item in questions_data: | |
task_id = item.get("task_id") | |
question_text = item.get("question") | |
if not task_id or question_text is None: | |
logger.warning(f"Skipping item with missing task_id or question: {item}") | |
continue | |
try: | |
submitted_answer = agent(question_text, task_id) | |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) | |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer}) | |
except Exception as e: | |
logger.error(f"Error running agent on task {task_id}: {e}") | |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"}) | |
if not answers_payload: | |
logger.error("Agent did not produce any answers to submit.") | |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) | |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} | |
logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}") | |
try: | |
response = requests.post(submit_url, json=submission_data, timeout=120) | |
response.raise_for_status() | |
result_data = response.json() | |
logger.info(f"Server response: {result_data}") | |
final_status = ( | |
f"Submission Successful!\n" | |
f"User: {result_data.get('username')}\n" | |
f"Overall Score: {result_data.get('score', 'N/A')}% " | |
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" | |
f"Message: {result_data.get('message', 'No message received.')}" | |
) | |
results_df = pd.DataFrame(results_log) | |
return final_status, results_df | |
except Exception as e: | |
logger.error(f"Submission failed: {e}") | |
results_df = pd.DataFrame(results_log) | |
return f"Submission Failed: {e}", results_df | |
# --- Build Gradio Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# JARVIS Agent Evaluation Runner") | |
gr.Markdown( | |
""" | |
**Instructions:** | |
1. Log in to your Hugging Face account using the button below. | |
2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run the JARVIS agent, and submit answers. | |
--- | |
**Disclaimers:** | |
The agent uses a local Hugging Face model (Mixtral-7B) and async tools for the GAIA benchmark. | |
""" | |
) | |
gr.LoginButton() | |
run_button = gr.Button("Run Evaluation & Submit All Answers") | |
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) | |
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) | |
run_button.click( | |
fn=run_and_submit_all, | |
outputs=[status_output, results_table] | |
) | |
if __name__ == "__main__": | |
logger.info("\n" + "-"*30 + " App Starting " + "-"*30) | |
space_id = os.getenv("SPACE_ID") | |
logger.info(f"SPACE_ID: {space_id}") | |
logger.info("Launching Gradio Interface...") | |
demo.launch(debug=True, share=False) |