import os import gradio as gr import requests import aiohttp import asyncio import json import nest_asyncio from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver from langchain_huggingface import HuggingFacePipeline from transformers import pipeline from langchain_core.messages import SystemMessage, HumanMessage from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool from tools.search import initialize_search_tools from state import JARVISState import pandas as pd from dotenv import load_dotenv import logging from langfuse.callback import CallbackHandler # Set up logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Apply nest_asyncio nest_asyncio.apply() # Load environment variables load_dotenv() # Verify environment variables required_env_vars = ["SPACE_ID", "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"] for var in required_env_vars: if not os.getenv(var): raise ValueError(f"Environment variable {var} is not set") logger.info(f"Environment variables loaded: SPACE_ID={os.getenv('SPACE_ID')[:10]}..., LANGFUSE_HOST={os.getenv('LANGFUSE_HOST', 'https://cloud.langfuse.com')}") # Initialize Hugging Face model try: hf_pipeline = pipeline( "text-generation", model="mistralai/Mixtral-7B-Instruct-v0.1", device_map="auto", max_new_tokens=512, do_sample=True, temperature=0.7 ) llm = HuggingFacePipeline(pipeline=hf_pipeline) logger.info("HuggingFace model initialized: mistralai/Mixtral-7B-Instruct-v0.1") except Exception as e: logger.error(f"Failed to initialize HuggingFace model: {e}") llm = None # Initialize search tools with LLM try: initialize_search_tools(llm) logger.info("Search tools initialized") except Exception as e: logger.error(f"Failed to initialize search tools: {e}") # Initialize Langfuse try: langfuse = CallbackHandler( public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), secret_key=os.getenv("LANGFUSE_SECRET_KEY"), host=os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com") ) logger.info("Langfuse initialized successfully") except Exception as e: logger.warning(f"Failed to initialize Langfuse: {e}") langfuse = None # Initialize MemorySaver memory = MemorySaver() use_checkpointing = True # --- Constants --- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space/api" GAIA_FILE_URL = "https://api.gaia-benchmark.com/files/" # --- Helper Functions --- def log_state(task_id: str, state: JARVISState): """Log intermediate state to state_log.json""" try: log_entry = { "task_id": task_id, "question": state["question"], "tools_needed": state["tools_needed"], "web_results": state["web_results"], "file_results": state["file_results"], "image_results": state["image_results"], "calculation_results": state["calculation_results"], "document_results": state["document_results"], "answer": state["answer"] } with open("state_log.json", "a") as f: json.dump(log_entry, f, indent=2) f.write("\n") except Exception as e: logger.error(f"Error logging state for task {task_id}: {e}") async def test_gaia_api(task_id: str) -> bool: """Test connectivity to GAIA file API""" try: async with aiohttp.ClientSession() as session: async with session.head(f"{GAIA_FILE_URL}{task_id}", timeout=5) as resp: return resp.status in [200, 403, 404] except Exception as e: logger.warning(f"GAIA API test failed: {e}") return False # --- Node Functions --- async def parse_question(state: JARVISState) -> JARVISState: try: question = state["question"] prompt = f"""Analyze this GAIA question: {question} Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever). Return a JSON list of tool names.""" if llm: response = await llm.ainvoke(prompt, config={"callbacks": [langfuse] if langfuse else []}) try: tools_needed = json.loads(response.content) except json.JSONDecodeError as je: logger.warning(f"Invalid JSON in LLM response for task {state['task_id']}: {je}") tools_needed = ["web_search"] else: logger.warning("No LLM available, using default tools") tools_needed = ["web_search"] state["tools_needed"] = tools_needed log_state(state["task_id"], state) return state except Exception as e: logger.error(f"Error parsing question for task {state['task_id']}: {e}") state["tools_needed"] = [] log_state(state["task_id"], state) return state async def tool_dispatcher(state: JARVISState) -> JARVISState: try: tools_needed = state["tools_needed"] updated_state = state.copy() can_download_files = await test_gaia_api(updated_state["task_id"]) for tool in tools_needed: try: if tool == "web_search" or tool == "multi_hop_search": result = await web_search_agent(updated_state) updated_state["web_results"].extend(result["web_results"]) elif tool == "file_parser" and can_download_files: result = await file_parser_agent(updated_state) updated_state["file_results"] = result["file_results"] elif tool == "image_parser" and can_download_files: result = await image_parser_agent(updated_state) updated_state["image_results"] = result["image_results"] elif tool == "calculator": result = await calculator_agent(updated_state) updated_state["calculation_results"] = result["calculation_results"] elif tool == "document_retriever" and can_download_files: result = await document_retriever_agent(updated_state) updated_state["document_results"] = result["document_results"] except Exception as e: logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {e}") log_state(updated_state["task_id"], updated_state) return updated_state except Exception as e: logger.error(f"Error in tool dispatcher for task {state['task_id']}: {e}") log_state(state["task_id"], state) return state async def web_search_agent(state: JARVISState) -> JARVISState: try: results = [] if "web_search" in state["tools_needed"]: result = await search_tool.invoke({"query": state["question"]}) results.append(result) if "multi_hop_search" in state["tools_needed"]: result = await multi_hop_search_tool.invoke({"query": state["question"], "steps": 3}) results.append(result) return {"web_results": results} except Exception as e: logger.error(f"Error in web search for task {state['task_id']}: {e}") return {"web_results": []} async def file_parser_agent(state: JARVISState) -> JARVISState: try: if "file_parser" in state["tools_needed"]: file_type = "csv" if "data" in state["question"].lower() else "txt" result = await file_parser_tool.aparse(state["task_id"], file_type=file_type) return {"file_results": result} return {"file_results": ""} except Exception as e: logger.error(f"Error in file parser for task {state['task_id']}: {e}") return {"file_results": "File parsing failed"} async def image_parser_agent(state: JARVISState) -> JARVISState: try: if "image_parser" in state["tools_needed"]: task = "match" if "fruits" in state["question"].lower() else "describe" match_query = "fruits" if task == "match" else "" file_path = f"temp_{state['task_id']}.jpg" if not os.path.exists(file_path): logger.warning(f"Image file not found for task {state['task_id']}") return {"image_results": "Image file not found"} result = await image_parser_tool.aparse( file_path, task=task, match_query=match_query ) return {"image_results": result} return {"image_results": ""} except Exception as e: logger.error(f"Error in image parser for task {state['task_id']}: {e}") return {"image_results": "Image parsing failed"} async def calculator_agent(state: JARVISState) -> JARVISState: try: if "calculator" in state["tools_needed"]: prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}" if llm: response = await llm.ainvoke(prompt, config={"callbacks": [langfuse] if langfuse else []}) expression = response.content else: expression = "0" result = await calculator_tool.aparse(expression) return {"calculation_results": result} return {"calculation_results": ""} except Exception as e: logger.error(f"Error in calculator for task {state['task_id']}: {e}") return {"calculation_results": "Calculation failed"} async def document_retriever_agent(state: JARVISState) -> JARVISState: try: if "document_retriever" in state["tools_needed"]: file_type = "txt" if "menu" in state["question"].lower() else "csv" if "report" in state["question"].lower() or "document" in state["question"].lower(): file_type = "pdf" result = await document_retriever_tool.aparse( state["task_id"], state["question"], file_type=file_type ) return {"document_results": result} return {"document_results": ""} except Exception as e: logger.error(f"Error in document retriever for task {state['task_id']}: {e}") return {"document_results": "Document retrieval failed"} async def reasoning_agent(state: JARVISState) -> JARVISState: try: prompt = f"""Question: {state['question']} Web Results: {state['web_results']} File Results: {state['file_results']} Image Results: {state['image_results']} Calculation Results: {state['calculation_results']} Document Results: {state['document_results']} Synthesize an exact-match answer for the GAIA benchmark. Output only the answer (e.g., '90', 'White;5876').""" if llm: response = await llm.ainvoke( [ SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."), HumanMessage(content=prompt) ], config={"callbacks": [langfuse] if langfuse else []} ) answer = response.content.strip() else: answer = "Unknown" state["answer"] = answer log_state(state["task_id"], state) return state except Exception as e: logger.error(f"Error in reasoning for task {state['task_id']}: {e}") state["answer"] = "Error in reasoning" log_state(state["task_id"], state) return state def router(state: JARVISState) -> str: if state["tools_needed"]: return "tool_dispatcher" return "reasoning" # --- Define StateGraph --- workflow = StateGraph(JARVISState) workflow.add_node("parse", parse_question) workflow.add_node("tool_dispatcher", tool_dispatcher) workflow.add_node("reasoning", reasoning_agent) workflow.set_entry_point("parse") workflow.add_conditional_edges( "parse", router, { "tool_dispatcher": "tool_dispatcher", "reasoning": "reasoning" } ) workflow.add_edge("tool_dispatcher", "reasoning") workflow.add_edge("reasoning", END) # Compile graph graph = workflow.compile(checkpointer=memory if use_checkpointing else None) # --- Basic Agent Definition --- class BasicAgent: def __init__(self): logger.info("BasicAgent initialized.") async def process_question(self, task_id: str, question: str) -> str: file_type = "jpg" if "image" in question.lower() else "txt" if "menu" in question.lower() or "report" in question.lower() or "document" in question.lower(): file_type = "pdf" elif "data" in question.lower(): file_type = "csv" file_path = f"temp_{task_id}.{file_type}" if await test_gaia_api(task_id): try: async with aiohttp.ClientSession() as session: async with session.get(f"{GAIA_FILE_URL}{task_id}") as resp: if resp.status == 200: with open(file_path, "wb") as f: f.write(await resp.read()) else: logger.warning(f"Failed to download file for task {task_id}: HTTP {resp.status}") except Exception as e: logger.error(f"Error downloading file for task {task_id}: {e}") state = JARVISState( task_id=task_id, question=question, tools_needed=[], web_results=[], file_results="", image_results="", calculation_results="", document_results="", messages=[], answer="" ) try: config = {"configurable": {"thread_id": task_id}} if use_checkpointing else {} result = await graph.ainvoke(state, config=config) return result["answer"] or "No answer generated" except Exception as e: logger.error(f"Error processing task {task_id}: {e}") return f"Error: {str(e)}" finally: if os.path.exists(file_path): try: os.remove(file_path) except Exception as e: logger.error(f"Error removing file {file_path}: {e}") async def async_call(self, question: str, task_id: str) -> str: return await self.process_question(task_id, question) def __call__(self, question: str, task_id: str = None) -> str: logger.info(f"Agent received question (first 50 chars): {question[:50]}...") if task_id is None: logger.warning("task_id not provided, using placeholder") task_id = "placeholder_task_id" try: try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop.run_until_complete(self.async_call(question, task_id)) finally: pass # --- Main Function --- def run_and_submit_all(profile: gr.OAuthProfile | None): space_id = os.getenv("SPACE_ID") if not profile: logger.error("User not logged in.") return "Please Login to Hugging Face with the button.", None username = f"{profile.username}" logger.info(f"User logged in: {username}") api_url = DEFAULT_API_URL questions_url = f"{api_url}/questions" submit_url = f"{api_url}/submit" agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" try: agent = BasicAgent() except Exception as e: logger.error(f"Error instantiating agent: {e}") return f"Error initializing agent: {e}", None logger.info(f"Fetching questions from: {questions_url}") try: response = requests.get(questions_url, timeout=15) response.raise_for_status() questions_data = response.json() if not questions_data: logger.error("Fetched questions list is empty.") return "Fetched questions list is empty or invalid format.", None logger.info(f"Fetched {len(questions_data)} questions.") except Exception as e: logger.error(f"Error fetching questions: {e}") return f"Error fetching questions: {e}", None results_log = [] answers_payload = [] logger.info(f"Running agent on {len(questions_data)} questions...") for item in questions_data: task_id = item.get("task_id") question_text = item.get("question") if not task_id or question_text is None: logger.warning(f"Skipping item with missing task_id or question: {item}") continue try: submitted_answer = agent(question_text, task_id) answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer}) except Exception as e: logger.error(f"Error running agent on task {task_id}: {e}") results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"}) if not answers_payload: logger.error("Agent did not produce any answers to submit.") return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}") try: response = requests.post(submit_url, json=submission_data, timeout=120) response.raise_for_status() result_data = response.json() logger.info(f"Server response: {result_data}") final_status = ( f"Submission Successful!\n" f"User: {result_data.get('username')}\n" f"Overall Score: {result_data.get('score', 'N/A')}% " f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" f"Message: {result_data.get('message', 'No message received.')}" ) results_df = pd.DataFrame(results_log) return final_status, results_df except Exception as e: logger.error(f"Submission failed: {e}") results_df = pd.DataFrame(results_log) return f"Submission Failed: {e}", results_df # --- Build Gradio Interface --- with gr.Blocks() as demo: gr.Markdown("# JARVIS Agent Evaluation Runner") gr.Markdown( """ **Instructions:** 1. Log in to your Hugging Face account using the button below. 2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run the JARVIS agent, and submit answers. --- **Disclaimers:** The agent uses a local Hugging Face model (Mixtral-7B) and async tools for the GAIA benchmark. """ ) gr.LoginButton() run_button = gr.Button("Run Evaluation & Submit All Answers") status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) run_button.click( fn=run_and_submit_all, outputs=[status_output, results_table] ) if __name__ == "__main__": logger.info("\n" + "-"*30 + " App Starting " + "-"*30) space_id = os.getenv("SPACE_ID") logger.info(f"SPACE_ID: {space_id}") logger.info("Launching Gradio Interface...") demo.launch(debug=True, share=False)