Spaces:
Starting
Starting
import os | |
import json | |
import logging | |
import asyncio | |
import aiohttp | |
import ssl | |
import nest_asyncio | |
import requests | |
import pandas as pd | |
from typing import Dict, Any, List | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langgraph.graph import StateGraph, END | |
import torch | |
from sentence_transformers import SentenceTransformer | |
import gradio as gr | |
from dotenv import load_dotenv | |
from huggingface_hub import InferenceClient | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import together | |
from state import JARVISState, validate_state, reset_state | |
from tools.answer_generator import generate_answer, preprocess_question | |
from tools.file_fetcher import fetch_task_file | |
from tools.search import search_tool, multi_hop_search_tool | |
from tools.file_parser import file_parser_tool | |
from tools.image_parser import image_parser_tool | |
from tools.calculator import calculator_tool | |
from tools.document_retriever import document_retriever_tool | |
from tools.duckduckgo_search import duckduckgo_search_tool | |
from tools.weather_info import weather_info_tool | |
from tools.hub_stats import hub_stats_tool | |
from tools.guest_info import guest_info_retriever_tool | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Apply nest_asyncio | |
nest_asyncio.apply() | |
# Load environment variables | |
load_dotenv() | |
SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent") | |
GAIA_API_URL = "https://agents-course-unit4-api-1.hf.space/api" | |
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") | |
HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
OPENWEATHERMAP_API_KEY = os.getenv("OPENWEATHERMAP_API_KEY") | |
# Verify environment variables | |
if not SPACE_ID: | |
raise ValueError("SPACE_ID not set") | |
if not HF_API_TOKEN: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN not set") | |
if not TOGETHER_API_KEY: | |
raise ValueError("TOGETHER_API_KEY not set") | |
if not OPENWEATHERMAP_API_KEY: | |
logger.warning("OPENWEATHERMAP_API_KEY not set; weather_info_tool may fail") | |
logger.info(f"SPACE_ID: {SPACE_ID}") | |
# Model configuration | |
TOGETHER_MODELS = [ | |
"meta-llama/Llama-3.3-70B-Instruct-Turbo-Free", | |
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free", | |
] | |
HF_MODEL = "meta-llama/Llama-3.2-1B-Instruct" | |
# Initialize LLM clients | |
def initialize_llm(): | |
for model in TOGETHER_MODELS: | |
try: | |
together.api_key = TOGETHER_API_KEY | |
client = together.Together() | |
response = client.chat.completions.create( | |
model=model, | |
messages=[{"role": "user", "content": "Test"}], | |
max_tokens=10 | |
) | |
logger.info(f"Initialized Together AI model: {model}") | |
return client, "together", model | |
except Exception as e: | |
logger.warning(f"Failed to initialize Together AI model {model}: {e}") | |
try: | |
client = InferenceClient( | |
model=HF_MODEL, | |
token=HF_API_TOKEN, | |
timeout=30 | |
) | |
logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}") | |
return client, "hf_api", HF_MODEL | |
except Exception as e: | |
logger.warning(f"Failed to initialize HF Inference API: {e}") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto") | |
logger.info(f"Initialized local Hugging Face model: {HF_MODEL}") | |
return (model, tokenizer), "hf_local", HF_MODEL | |
except Exception as e: | |
logger.error(f"Failed to initialize local HF model: {e}") | |
raise Exception("No LLM could be initialized") | |
llm_client, llm_type, llm_model = initialize_llm() | |
# Initialize embedder | |
_embedder = None | |
def get_embedder(): | |
global _embedder | |
if _embedder is None: | |
try: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
_embedder = SentenceTransformer( | |
"all-MiniLM-L6-v2", | |
device=device, | |
cache_folder="./cache" | |
) | |
logger.info(f"SentenceTransformer initialized on {device.upper()}") | |
except Exception as e: | |
logger.error(f"Failed to initialize SentenceTransformer: {e}") | |
raise RuntimeError(f"Embedder initialization failed: {e}") | |
return _embedder | |
try: | |
embedder = get_embedder() | |
except Exception as e: | |
logger.error(f"Failed to initialize embedder: {e}") | |
embedder = None | |
# Log device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
# HTTP session with SSL handling | |
async def create_http_session(): | |
ssl_context = ssl.create_default_context() | |
ssl_context.check_hostname = False | |
ssl_context.verify_mode = ssl.CERT_NONE | |
return aiohttp.ClientSession( | |
connector=aiohttp.TCPConnector(ssl=ssl_context), | |
timeout=aiohttp.ClientTimeout(total=30) | |
) | |
# Tool registration | |
tools = { | |
"search_tool": search_tool, | |
"multi_hop_search_tool": multi_hop_search_tool, | |
"file_parser_tool": file_parser_tool, | |
"image_parser_tool": image_parser_tool, | |
"calculator_tool": calculator_tool, | |
"document_retriever_tool": document_retriever_tool, | |
"duckduckgo_search_tool": duckduckgo_search_tool, | |
"weather_info_tool": weather_info_tool, | |
"hub_stats_tool": hub_stats_tool, | |
"guest_info_retriever_tool": guest_info_retriever_tool, | |
} | |
# Parse question to select tools | |
async def parse_question(state: JARVISState) -> JARVISState: | |
""" | |
Parse the question to select appropriate tools using LLM with retries, preprocess the question, and integrate file-based tools. | |
Args: | |
state (JARVISState): The input state containing task_id, question. | |
Returns: | |
JARVISState: Updated state with selected tools_needed and metadata. | |
""" | |
state = validate_state(state) | |
task_id = state["task_id"] | |
question = state["question"] | |
logger.info(f"Task {task_id} Parsing question: {question}") | |
try: | |
# Preprocess question | |
processed_question = await preprocess_question(question) | |
if processed_question != question: | |
logger.info(f"Task {task_id} Preprocessed question: {processed_question}") | |
state["question"] = processed_question | |
question = processed_question | |
# Default to search_tool | |
tools_needed = ["search_tool"] | |
# LLM-based tool selection | |
if llm_client: | |
prompt = ChatPromptTemplate.from_messages([ | |
SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool']. | |
Return a JSON list of all relevant tools, e.g., ["search_tool", "duckduckgo_search_tool"]. | |
Rules: | |
- Include "search_tool" for web-based questions unless purely computational or file-based. | |
- Include "multi_hop_search_tool" for questions with >20 words or requiring multiple steps. | |
- Include "file_parser_tool" for 'data', 'table', 'excel', 'csv', 'txt', 'mp3', or file extensions. | |
- Include "image_parser_tool" for 'image', 'video', 'picture', or 'painting'. | |
- Include "calculator_tool" for 'calculate', 'math', 'sum', 'average', 'total', or numerical operations. | |
- Include "document_retriever_tool" for 'document', 'pdf', 'report', or 'paper'. | |
- Include "duckduckgo_search_tool" for 'search', 'wikipedia', 'online', or general knowledge. | |
- Include "weather_info_tool" for 'weather', 'temperature', or 'forecast'. | |
- Include "hub_stats_tool" for 'model', 'huggingface', or 'dataset'. | |
- Include "guest_info_retriever_tool" for 'guest', 'name', 'relation', or 'person'. | |
- Select multiple tools if the question spans multiple domains (e.g., web and file). | |
- Output ONLY valid JSON."""), | |
HumanMessage(content=f"Query: {question}") | |
]) | |
messages = prompt.format_messages() | |
for attempt in range(3): # Retry up to 3 times | |
try: | |
formatted_messages = [ | |
{"role": "system" if isinstance(m, SystemMessage) else "user", "content": m.content} | |
for m in messages | |
] | |
if llm_type == "hf_local": | |
model, tokenizer = llm_client | |
inputs = tokenizer.apply_chat_template( | |
formatted_messages, | |
return_tensors="pt" | |
).to(model.device) | |
outputs = model.generate(inputs, max_new_tokens=100, temperature=0.5) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
elif llm_type == "together": | |
response = llm_client.chat.completions.create( | |
model=llm_model, | |
messages=formatted_messages, | |
max_tokens=100, | |
temperature=0.5 | |
) | |
response = response.choices[0].message.content.strip() | |
else: # hf_api | |
response = llm_client.chat.completions.create( | |
messages=formatted_messages, | |
max_tokens=100, | |
temperature=0.5 | |
) | |
response = response.choices[0].message.content.strip() | |
logger.info(f"Task {task_id} LLM tool selection response: {response}") | |
try: | |
tools_needed = json.loads(response) | |
if isinstance(tools_needed, list) and all(isinstance(t, str) and t in tools for t in tools_needed): | |
break # Valid response, exit retry loop | |
else: | |
raise ValueError("Invalid tool list format") | |
except json.JSONDecodeError as e: | |
logger.warning(f"Task {task_id}: Invalid JSON (attempt {attempt + 1}): {e}") | |
if attempt == 2: | |
tools_needed = ["search_tool"] # Fallback after retries | |
except Exception as e: | |
logger.warning(f"Task {task_id} Tool selection failed (attempt {attempt + 1}): {e}") | |
if attempt == 2: | |
tools_needed = ["search_tool"] # Fallback after retries | |
# Fallback to keyword-based selection if LLM fails | |
if tools_needed == ["search_tool"] and not any(kw in question.lower() for kw in ["calculate", "math", "image", "document", "file", "weather", "guest", "model"]): | |
question_lower = question.lower() | |
if any(kw in question_lower for kw in ["excel", "csv", "mp3", "data", "table", "xlsx"]): | |
tools_needed.append("file_parser_tool") | |
if any(kw in question_lower for kw in ["image", "video", "picture", "painting"]): | |
tools_needed.append("image_parser_tool") | |
if any(kw in question_lower for kw in ["calculate", "math", "sum", "average", "total"]): | |
tools_needed.append("calculator_tool") | |
if any(kw in question_lower for kw in ["document", "pdf", "report", "paper"]): | |
tools_needed.append("document_retriever_tool") | |
if any(kw in question_lower for kw in ["search", "wikipedia", "online"]): | |
tools_needed.append("duckduckgo_search_tool") | |
if any(kw in question_lower for kw in ["weather", "temperature", "forecast"]): | |
tools_needed.append("weather_info_tool") | |
if any(kw in question_lower for kw in ["model", "huggingface", "dataset"]): | |
tools_needed.append("hub_stats_tool") | |
if any(kw in question_lower for kw in ["guest", "name", "relation", "person"]): | |
tools_needed.append("guest_info_retriever_tool") | |
if len(question.split()) > 20 or "multiple" in question_lower: | |
tools_needed.append("multi_hop_search_tool") | |
# Integrate file-based tools | |
file_results = await fetch_task_file(task_id, question) | |
for ext, content in file_results.items(): | |
if content: | |
os.makedirs("temp", exist_ok=True) | |
file_path = f"temp/{task_id}.{ext}" | |
with open(file_path, "wb") as f: | |
f.write(content) | |
state["metadata"] = state.get("metadata", {}) | {"file_ext": ext, "file_path": file_path} | |
if ext in ["txt", "csv", "xlsx", "mp3"] and "file_parser_tool" not in tools_needed: | |
tools_needed.append("file_parser_tool") | |
elif ext in ["jpg", "png"] and "image_parser_tool" not in tools_needed: | |
tools_needed.append("image_parser_tool") | |
elif ext == "pdf" and "document_retriever_tool" not in tools_needed: | |
tools_needed.append("document_retriever_tool") | |
state["tools_needed"] = list(set(tools_needed)) # Remove duplicates | |
logger.info(f"Task {task_id} Selected tools: {state['tools_needed']}") | |
return state | |
except Exception as e: | |
logger.error(f"Task {task_id} Tool selection failed: {e}") | |
state["error"] = f"Parse question failed: {str(e)}" | |
state["tools_needed"] = ["search_tool"] | |
return state | |
# Tool dispatcher | |
async def tool_dispatcher(state: JARVISState) -> JARVISState: | |
state = validate_state(state) | |
try: | |
task_id = state["task_id"] | |
question = state["question"] | |
tools_needed = state["tools_needed"] | |
for tool_name in tools_needed: | |
try: | |
if tool_name == "search_tool": | |
result = await tools["search_tool"].ainvoke({"query": question}) | |
state["web_results"].extend([str(r) for r in result] if result else ["No results from search_tool"]) | |
elif tool_name == "multi_hop_search_tool": | |
result = await tools["multi_hop_search_tool"].ainvoke({ | |
"query": question, | |
"steps": 3, | |
"llm_client": llm_client, | |
"llm_type": llm_type, | |
"llm_model": llm_model | |
}) | |
state["multi_hop_results"].extend([r["content"] if isinstance(r, dict) else str(r) for r in result] if result else ["No results from multi_hop_search_tool"]) | |
elif tool_name == "file_parser_tool": | |
file_path = state["metadata"].get("file_path") | |
file_ext = state["metadata"].get("file_ext") | |
if file_path and os.path.exists(file_path) and file_ext: | |
result = await tools["file_parser_tool"].ainvoke({ | |
"task_id": task_id, | |
"file_type": file_ext, | |
"file_path": file_path, | |
"query": question | |
}) | |
state["file_results"] = str(result) if result else "No file results" | |
else: | |
state["file_results"] = "No file available" | |
elif tool_name == "image_parser_tool": | |
file_path = state["metadata"].get("file_path") | |
if file_path and os.path.exists(file_path) and file_path.split('.')[-1] in ["jpg", "png"]: | |
result = await tools["image_parser_tool"].ainvoke({"task_id": task_id, "file_path": file_path}) | |
state["image_results"] = str(result) if result else "No image results" | |
else: | |
state["image_results"] = "No image available" | |
elif tool_name == "calculator_tool": | |
result = await tools["calculator_tool"].ainvoke({"expression": question}) | |
state["calculation_results"] = str(result) if result else "No calculation results" | |
elif tool_name == "document_retriever_tool": | |
file_path = state["metadata"].get("file_path") | |
if file_path and os.path.exists(file_path) and file_path.split('.')[-1] == "pdf": | |
result = await tools["document_retriever_tool"].ainvoke({ | |
"task_id": task_id, | |
"query": question, | |
"file_path": file_path | |
}) | |
state["document_results"] = str(result) if result else "No document results" | |
else: | |
state["document_results"] = "No document available" | |
elif tool_name == "duckduckgo_search_tool": | |
result = await tools["duckduckgo_search_tool"].ainvoke({ | |
"query": question, | |
"original_query": question, | |
"embedder": embedder | |
}) | |
state["web_results"].extend(result if isinstance(result, list) else [str(result)] if result else ["No results from duckduckgo_search_tool"]) | |
elif tool_name == "weather_info_tool": | |
location = question.split()[-1] if "weather" in question.lower() else "Unknown" | |
result = await tools["weather_info_tool"].ainvoke({"location": location}) | |
state["web_results"].append(str(result) if result else "No weather results") | |
elif tool_name == "hub_stats_tool": | |
author = question.split("by ")[1].split()[0] if "by" in question.lower() else "Unknown" | |
result = await tools["hub_stats_tool"].ainvoke({"author": author}) | |
state["web_results"].append(str(result) if result else "No hub stats results") | |
elif tool_name == "guest_info_retriever_tool": | |
result = await tools["guest_info_retriever_tool"].ainvoke({"query": question}) | |
state["web_results"].append(str(result) if result else "No guest info results") | |
state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_executed": True} | |
logger.info(f"Task {task_id}: Executed {tool_name}") | |
except Exception as e: | |
logger.warning(f"Tool {tool_name} failed for task {task_id}: {e}") | |
state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_error": str(e)} | |
# Ensure results are populated | |
state["web_results"] = state.get("web_results", ["No web results found"]) | |
state["file_results"] = state.get("file_results", "No file results found") | |
state["image_results"] = state.get("image_results", "No image results found") | |
state["document_results"] = state.get("document_results", "No document results found") | |
state["calculation_results"] = state.get("calculation_results", "No calculation results found") | |
state["answer"] = await generate_answer( | |
task_id=task_id, | |
question=question, | |
search_results=state.get("web_results", []) + [ | |
r["content"] if isinstance(r, dict) else str(r) for r in state.get("multi_hop_results", []) | |
], | |
file_results=state.get("file_results", "") + state.get("document_results", "") + state.get("image_results", "") + state.get("calculation_results", ""), | |
llm_client=llm_client | |
) | |
logger.info(f"Task {task_id}: Generated answer: {state['answer']}") | |
return state | |
except Exception as e: | |
logger.error(f"Tool dispatch failed: {e}") | |
state["error"] = f"Tool dispatch failed: {e}" | |
return state | |
# Define StateGraph | |
workflow = StateGraph(JARVISState) | |
workflow.add_node("parse_question", parse_question) | |
workflow.add_node("tool_dispatcher", tool_dispatcher) | |
workflow.set_entry_point("parse_question") | |
workflow.add_edge("parse_question", "tool_dispatcher") | |
workflow.add_edge("tool_dispatcher", END) | |
graph = workflow.compile() | |
# Agent class | |
class JARVISAgent: | |
def __init__(self): | |
self.state = reset_state(task_id="init", question="Agent initialized") | |
self.state["results_table"] = [] # Initialize as empty list | |
logger.info("JARVISAgent initialized.") | |
async def process_question(self, task_id: str, question: str) -> str: | |
state = reset_state(task_id=task_id, question=question) | |
try: | |
result = await graph.ainvoke(state) | |
answer = result.get("answer", "Unknown") | |
logger.info(f"Task {task_id} Final answer: {answer}") | |
self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": answer}) | |
self.state["metadata"] = {"last_task_id": task_id, "answer": answer} | |
return answer | |
except Exception as e: | |
logger.error(f"Error processing task {task_id}: {e}") | |
self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": f"Error: {e}"}) | |
self.state["error"] = f"Task {task_id} failed: {str(e)}" | |
return f"Error: {str(e)}" | |
finally: | |
for ext in ["txt", "csv", "xlsx", "mp3", "jpg", "png", "pdf"]: | |
file_path = f"temp/{task_id}.{ext}" | |
if os.path.exists(file_path): | |
try: | |
os.remove(file_path) | |
logger.info(f"Removed temp file: {file_path}") | |
except Exception as e: | |
logger.error(f"Error removing file {file_path}: {e}") | |
async def process_all_questions(self, profile: gr.OAuthProfile | None): | |
if not profile: | |
logger.error("User not logged in.") | |
self.state["status_output"] = "Please Login to Hugging Face." | |
return pd.DataFrame(self.state["results_table"]), self.state["status_output"] | |
username = profile.username | |
logger.info(f"User logged in: {username}") | |
questions_url = f"{GAIA_API_URL}/questions" | |
submit_url = f"{GAIA_API_URL}/submit" | |
agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main" | |
try: | |
async with await create_http_session() as session: | |
async with session.get(questions_url) as response: | |
response.raise_for_status() | |
questions = await response.json() | |
logger.info(f"Fetched {len(questions)} questions.") | |
except Exception as e: | |
logger.error(f"Error fetching questions: {e}") | |
self.state["status_output"] = f"Error fetching questions: {e}" | |
self.state["error"] = f"Fetch questions failed: {str(e)}" | |
return pd.DataFrame(self.state["results_table"]), self.state["status_output"] | |
answers_payload = [] | |
for item in questions: | |
task_id = item.get("task_id") | |
question = item.get("question") | |
if not task_id or not question: | |
logger.warning(f"Skipping invalid item: {item}") | |
continue | |
answer = await self.process_question(task_id, question) | |
answers_payload.append({"task_id": task_id, "submitted_answer": answer}) | |
if not answers_payload: | |
logger.error("No answers generated.") | |
self.state["status_output"] = "No answers to submit." | |
self.state["error"] = "No answers generated" | |
return pd.DataFrame(self.state["results_table"]), self.state["status_output"] | |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} | |
try: | |
async with await create_http_session() as session: | |
async with session.post(submit_url, json=submission_data) as response: | |
response.raise_for_status() | |
result_data = await response.json() | |
self.state["status_output"] = ( | |
f"Submission Successful!\n" | |
f"User: {result_data.get('username')}\n" | |
f"Overall Score: {result_data.get('score', 'N/A')}% " | |
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" | |
f"Message: {result_data.get('message', 'No message received.')}" | |
) | |
self.state["metadata"] = self.state.get("metadata", {}) | {"submission_score": result_data.get('score', 'N/A')} | |
except Exception as e: | |
logger.error(f"Submission failed: {e}") | |
self.state["status_output"] = f"Submission Failed: {e}" | |
self.state["error"] = f"Submission failed: {str(e)}" | |
return pd.DataFrame(self.state["results_table"] if self.state["results_table"] else [], columns=["Task ID", "Question", "Answer"]), self.state["status_output"] | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# JARVIS GAIA Agent") | |
gr.Markdown( | |
""" | |
**Instructions:** | |
1. Log in to Hugging Face using the button below. | |
2. Click 'Run Evaluation & Submit All Answers' to process GAIA questions and submit. | |
--- | |
**Disclaimers:** | |
Uses Hugging Face Inference, Together AI, SERPAPI, and OpenWeatherMap for GAIA benchmark. | |
""" | |
) | |
with gr.Row(): | |
gr.LoginButton(value="Login to Hugging Face") | |
run_button = gr.Button("Run Evaluation & Submit All Answers") | |
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) | |
results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"]) | |
agent = JARVISAgent() | |
run_button.click( | |
fn=agent.process_all_questions, | |
outputs=[results_table, status_output] | |
) | |
if __name__ == "__main__": | |
logger.info("\n" + "-"*30 + " App Starting " + "-"*30) | |
logger.info(f"SPACE_ID: {SPACE_ID}") | |
logger.info("Launching Gradio Interface...") | |
demo.launch(debug=True, share=False) |