onisj's picture
feat(tools): add more tool to extend the functionaily of jarvis
751d628
import os
import json
import logging
import asyncio
import aiohttp
import ssl
import nest_asyncio
import requests
import pandas as pd
from typing import Dict, Any, List
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.graph import StateGraph, END
import torch
from sentence_transformers import SentenceTransformer
import gradio as gr
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
import together
from state import JARVISState, validate_state, reset_state
from tools.answer_generator import generate_answer, preprocess_question
from tools.file_fetcher import fetch_task_file
from tools.search import search_tool, multi_hop_search_tool
from tools.file_parser import file_parser_tool
from tools.image_parser import image_parser_tool
from tools.calculator import calculator_tool
from tools.document_retriever import document_retriever_tool
from tools.duckduckgo_search import duckduckgo_search_tool
from tools.weather_info import weather_info_tool
from tools.hub_stats import hub_stats_tool
from tools.guest_info import guest_info_retriever_tool
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Apply nest_asyncio
nest_asyncio.apply()
# Load environment variables
load_dotenv()
SPACE_ID = os.getenv("SPACE_ID", "onisj/jarvis_gaia_agent")
GAIA_API_URL = "https://agents-course-unit4-api-1.hf.space/api"
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
HF_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
OPENWEATHERMAP_API_KEY = os.getenv("OPENWEATHERMAP_API_KEY")
# Verify environment variables
if not SPACE_ID:
raise ValueError("SPACE_ID not set")
if not HF_API_TOKEN:
raise ValueError("HUGGINGFACEHUB_API_TOKEN not set")
if not TOGETHER_API_KEY:
raise ValueError("TOGETHER_API_KEY not set")
if not OPENWEATHERMAP_API_KEY:
logger.warning("OPENWEATHERMAP_API_KEY not set; weather_info_tool may fail")
logger.info(f"SPACE_ID: {SPACE_ID}")
# Model configuration
TOGETHER_MODELS = [
"meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B-free",
]
HF_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
# Initialize LLM clients
def initialize_llm():
for model in TOGETHER_MODELS:
try:
together.api_key = TOGETHER_API_KEY
client = together.Together()
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "Test"}],
max_tokens=10
)
logger.info(f"Initialized Together AI model: {model}")
return client, "together", model
except Exception as e:
logger.warning(f"Failed to initialize Together AI model {model}: {e}")
try:
client = InferenceClient(
model=HF_MODEL,
token=HF_API_TOKEN,
timeout=30
)
logger.info(f"Initialized Hugging Face Inference API model: {HF_MODEL}")
return client, "hf_api", HF_MODEL
except Exception as e:
logger.warning(f"Failed to initialize HF Inference API: {e}")
try:
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto")
logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
return (model, tokenizer), "hf_local", HF_MODEL
except Exception as e:
logger.error(f"Failed to initialize local HF model: {e}")
raise Exception("No LLM could be initialized")
llm_client, llm_type, llm_model = initialize_llm()
# Initialize embedder
_embedder = None
def get_embedder():
global _embedder
if _embedder is None:
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
_embedder = SentenceTransformer(
"all-MiniLM-L6-v2",
device=device,
cache_folder="./cache"
)
logger.info(f"SentenceTransformer initialized on {device.upper()}")
except Exception as e:
logger.error(f"Failed to initialize SentenceTransformer: {e}")
raise RuntimeError(f"Embedder initialization failed: {e}")
return _embedder
try:
embedder = get_embedder()
except Exception as e:
logger.error(f"Failed to initialize embedder: {e}")
embedder = None
# Log device
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# HTTP session with SSL handling
async def create_http_session():
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
return aiohttp.ClientSession(
connector=aiohttp.TCPConnector(ssl=ssl_context),
timeout=aiohttp.ClientTimeout(total=30)
)
# Tool registration
tools = {
"search_tool": search_tool,
"multi_hop_search_tool": multi_hop_search_tool,
"file_parser_tool": file_parser_tool,
"image_parser_tool": image_parser_tool,
"calculator_tool": calculator_tool,
"document_retriever_tool": document_retriever_tool,
"duckduckgo_search_tool": duckduckgo_search_tool,
"weather_info_tool": weather_info_tool,
"hub_stats_tool": hub_stats_tool,
"guest_info_retriever_tool": guest_info_retriever_tool,
}
# Parse question to select tools
async def parse_question(state: JARVISState) -> JARVISState:
"""
Parse the question to select appropriate tools using LLM with retries, preprocess the question, and integrate file-based tools.
Args:
state (JARVISState): The input state containing task_id, question.
Returns:
JARVISState: Updated state with selected tools_needed and metadata.
"""
state = validate_state(state)
task_id = state["task_id"]
question = state["question"]
logger.info(f"Task {task_id} Parsing question: {question}")
try:
# Preprocess question
processed_question = await preprocess_question(question)
if processed_question != question:
logger.info(f"Task {task_id} Preprocessed question: {processed_question}")
state["question"] = processed_question
question = processed_question
# Default to search_tool
tools_needed = ["search_tool"]
# LLM-based tool selection
if llm_client:
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content="""Select tools from: ['search_tool', 'multi_hop_search_tool', 'file_parser_tool', 'image_parser_tool', 'calculator_tool', 'document_retriever_tool', 'duckduckgo_search_tool', 'weather_info_tool', 'hub_stats_tool', 'guest_info_retriever_tool'].
Return a JSON list of all relevant tools, e.g., ["search_tool", "duckduckgo_search_tool"].
Rules:
- Include "search_tool" for web-based questions unless purely computational or file-based.
- Include "multi_hop_search_tool" for questions with >20 words or requiring multiple steps.
- Include "file_parser_tool" for 'data', 'table', 'excel', 'csv', 'txt', 'mp3', or file extensions.
- Include "image_parser_tool" for 'image', 'video', 'picture', or 'painting'.
- Include "calculator_tool" for 'calculate', 'math', 'sum', 'average', 'total', or numerical operations.
- Include "document_retriever_tool" for 'document', 'pdf', 'report', or 'paper'.
- Include "duckduckgo_search_tool" for 'search', 'wikipedia', 'online', or general knowledge.
- Include "weather_info_tool" for 'weather', 'temperature', or 'forecast'.
- Include "hub_stats_tool" for 'model', 'huggingface', or 'dataset'.
- Include "guest_info_retriever_tool" for 'guest', 'name', 'relation', or 'person'.
- Select multiple tools if the question spans multiple domains (e.g., web and file).
- Output ONLY valid JSON."""),
HumanMessage(content=f"Query: {question}")
])
messages = prompt.format_messages()
for attempt in range(3): # Retry up to 3 times
try:
formatted_messages = [
{"role": "system" if isinstance(m, SystemMessage) else "user", "content": m.content}
for m in messages
]
if llm_type == "hf_local":
model, tokenizer = llm_client
inputs = tokenizer.apply_chat_template(
formatted_messages,
return_tensors="pt"
).to(model.device)
outputs = model.generate(inputs, max_new_tokens=100, temperature=0.5)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
elif llm_type == "together":
response = llm_client.chat.completions.create(
model=llm_model,
messages=formatted_messages,
max_tokens=100,
temperature=0.5
)
response = response.choices[0].message.content.strip()
else: # hf_api
response = llm_client.chat.completions.create(
messages=formatted_messages,
max_tokens=100,
temperature=0.5
)
response = response.choices[0].message.content.strip()
logger.info(f"Task {task_id} LLM tool selection response: {response}")
try:
tools_needed = json.loads(response)
if isinstance(tools_needed, list) and all(isinstance(t, str) and t in tools for t in tools_needed):
break # Valid response, exit retry loop
else:
raise ValueError("Invalid tool list format")
except json.JSONDecodeError as e:
logger.warning(f"Task {task_id}: Invalid JSON (attempt {attempt + 1}): {e}")
if attempt == 2:
tools_needed = ["search_tool"] # Fallback after retries
except Exception as e:
logger.warning(f"Task {task_id} Tool selection failed (attempt {attempt + 1}): {e}")
if attempt == 2:
tools_needed = ["search_tool"] # Fallback after retries
# Fallback to keyword-based selection if LLM fails
if tools_needed == ["search_tool"] and not any(kw in question.lower() for kw in ["calculate", "math", "image", "document", "file", "weather", "guest", "model"]):
question_lower = question.lower()
if any(kw in question_lower for kw in ["excel", "csv", "mp3", "data", "table", "xlsx"]):
tools_needed.append("file_parser_tool")
if any(kw in question_lower for kw in ["image", "video", "picture", "painting"]):
tools_needed.append("image_parser_tool")
if any(kw in question_lower for kw in ["calculate", "math", "sum", "average", "total"]):
tools_needed.append("calculator_tool")
if any(kw in question_lower for kw in ["document", "pdf", "report", "paper"]):
tools_needed.append("document_retriever_tool")
if any(kw in question_lower for kw in ["search", "wikipedia", "online"]):
tools_needed.append("duckduckgo_search_tool")
if any(kw in question_lower for kw in ["weather", "temperature", "forecast"]):
tools_needed.append("weather_info_tool")
if any(kw in question_lower for kw in ["model", "huggingface", "dataset"]):
tools_needed.append("hub_stats_tool")
if any(kw in question_lower for kw in ["guest", "name", "relation", "person"]):
tools_needed.append("guest_info_retriever_tool")
if len(question.split()) > 20 or "multiple" in question_lower:
tools_needed.append("multi_hop_search_tool")
# Integrate file-based tools
file_results = await fetch_task_file(task_id, question)
for ext, content in file_results.items():
if content:
os.makedirs("temp", exist_ok=True)
file_path = f"temp/{task_id}.{ext}"
with open(file_path, "wb") as f:
f.write(content)
state["metadata"] = state.get("metadata", {}) | {"file_ext": ext, "file_path": file_path}
if ext in ["txt", "csv", "xlsx", "mp3"] and "file_parser_tool" not in tools_needed:
tools_needed.append("file_parser_tool")
elif ext in ["jpg", "png"] and "image_parser_tool" not in tools_needed:
tools_needed.append("image_parser_tool")
elif ext == "pdf" and "document_retriever_tool" not in tools_needed:
tools_needed.append("document_retriever_tool")
state["tools_needed"] = list(set(tools_needed)) # Remove duplicates
logger.info(f"Task {task_id} Selected tools: {state['tools_needed']}")
return state
except Exception as e:
logger.error(f"Task {task_id} Tool selection failed: {e}")
state["error"] = f"Parse question failed: {str(e)}"
state["tools_needed"] = ["search_tool"]
return state
# Tool dispatcher
async def tool_dispatcher(state: JARVISState) -> JARVISState:
state = validate_state(state)
try:
task_id = state["task_id"]
question = state["question"]
tools_needed = state["tools_needed"]
for tool_name in tools_needed:
try:
if tool_name == "search_tool":
result = await tools["search_tool"].ainvoke({"query": question})
state["web_results"].extend([str(r) for r in result] if result else ["No results from search_tool"])
elif tool_name == "multi_hop_search_tool":
result = await tools["multi_hop_search_tool"].ainvoke({
"query": question,
"steps": 3,
"llm_client": llm_client,
"llm_type": llm_type,
"llm_model": llm_model
})
state["multi_hop_results"].extend([r["content"] if isinstance(r, dict) else str(r) for r in result] if result else ["No results from multi_hop_search_tool"])
elif tool_name == "file_parser_tool":
file_path = state["metadata"].get("file_path")
file_ext = state["metadata"].get("file_ext")
if file_path and os.path.exists(file_path) and file_ext:
result = await tools["file_parser_tool"].ainvoke({
"task_id": task_id,
"file_type": file_ext,
"file_path": file_path,
"query": question
})
state["file_results"] = str(result) if result else "No file results"
else:
state["file_results"] = "No file available"
elif tool_name == "image_parser_tool":
file_path = state["metadata"].get("file_path")
if file_path and os.path.exists(file_path) and file_path.split('.')[-1] in ["jpg", "png"]:
result = await tools["image_parser_tool"].ainvoke({"task_id": task_id, "file_path": file_path})
state["image_results"] = str(result) if result else "No image results"
else:
state["image_results"] = "No image available"
elif tool_name == "calculator_tool":
result = await tools["calculator_tool"].ainvoke({"expression": question})
state["calculation_results"] = str(result) if result else "No calculation results"
elif tool_name == "document_retriever_tool":
file_path = state["metadata"].get("file_path")
if file_path and os.path.exists(file_path) and file_path.split('.')[-1] == "pdf":
result = await tools["document_retriever_tool"].ainvoke({
"task_id": task_id,
"query": question,
"file_path": file_path
})
state["document_results"] = str(result) if result else "No document results"
else:
state["document_results"] = "No document available"
elif tool_name == "duckduckgo_search_tool":
result = await tools["duckduckgo_search_tool"].ainvoke({
"query": question,
"original_query": question,
"embedder": embedder
})
state["web_results"].extend(result if isinstance(result, list) else [str(result)] if result else ["No results from duckduckgo_search_tool"])
elif tool_name == "weather_info_tool":
location = question.split()[-1] if "weather" in question.lower() else "Unknown"
result = await tools["weather_info_tool"].ainvoke({"location": location})
state["web_results"].append(str(result) if result else "No weather results")
elif tool_name == "hub_stats_tool":
author = question.split("by ")[1].split()[0] if "by" in question.lower() else "Unknown"
result = await tools["hub_stats_tool"].ainvoke({"author": author})
state["web_results"].append(str(result) if result else "No hub stats results")
elif tool_name == "guest_info_retriever_tool":
result = await tools["guest_info_retriever_tool"].ainvoke({"query": question})
state["web_results"].append(str(result) if result else "No guest info results")
state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_executed": True}
logger.info(f"Task {task_id}: Executed {tool_name}")
except Exception as e:
logger.warning(f"Tool {tool_name} failed for task {task_id}: {e}")
state["metadata"] = state.get("metadata", {}) | {f"{tool_name}_error": str(e)}
# Ensure results are populated
state["web_results"] = state.get("web_results", ["No web results found"])
state["file_results"] = state.get("file_results", "No file results found")
state["image_results"] = state.get("image_results", "No image results found")
state["document_results"] = state.get("document_results", "No document results found")
state["calculation_results"] = state.get("calculation_results", "No calculation results found")
state["answer"] = await generate_answer(
task_id=task_id,
question=question,
search_results=state.get("web_results", []) + [
r["content"] if isinstance(r, dict) else str(r) for r in state.get("multi_hop_results", [])
],
file_results=state.get("file_results", "") + state.get("document_results", "") + state.get("image_results", "") + state.get("calculation_results", ""),
llm_client=llm_client
)
logger.info(f"Task {task_id}: Generated answer: {state['answer']}")
return state
except Exception as e:
logger.error(f"Tool dispatch failed: {e}")
state["error"] = f"Tool dispatch failed: {e}"
return state
# Define StateGraph
workflow = StateGraph(JARVISState)
workflow.add_node("parse_question", parse_question)
workflow.add_node("tool_dispatcher", tool_dispatcher)
workflow.set_entry_point("parse_question")
workflow.add_edge("parse_question", "tool_dispatcher")
workflow.add_edge("tool_dispatcher", END)
graph = workflow.compile()
# Agent class
class JARVISAgent:
def __init__(self):
self.state = reset_state(task_id="init", question="Agent initialized")
self.state["results_table"] = [] # Initialize as empty list
logger.info("JARVISAgent initialized.")
async def process_question(self, task_id: str, question: str) -> str:
state = reset_state(task_id=task_id, question=question)
try:
result = await graph.ainvoke(state)
answer = result.get("answer", "Unknown")
logger.info(f"Task {task_id} Final answer: {answer}")
self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": answer})
self.state["metadata"] = {"last_task_id": task_id, "answer": answer}
return answer
except Exception as e:
logger.error(f"Error processing task {task_id}: {e}")
self.state["results_table"].append({"Task ID": task_id, "Question": question, "Answer": f"Error: {e}"})
self.state["error"] = f"Task {task_id} failed: {str(e)}"
return f"Error: {str(e)}"
finally:
for ext in ["txt", "csv", "xlsx", "mp3", "jpg", "png", "pdf"]:
file_path = f"temp/{task_id}.{ext}"
if os.path.exists(file_path):
try:
os.remove(file_path)
logger.info(f"Removed temp file: {file_path}")
except Exception as e:
logger.error(f"Error removing file {file_path}: {e}")
async def process_all_questions(self, profile: gr.OAuthProfile | None):
if not profile:
logger.error("User not logged in.")
self.state["status_output"] = "Please Login to Hugging Face."
return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
username = profile.username
logger.info(f"User logged in: {username}")
questions_url = f"{GAIA_API_URL}/questions"
submit_url = f"{GAIA_API_URL}/submit"
agent_code = f"https://huggingface.co/spaces/{SPACE_ID}/tree/main"
try:
async with await create_http_session() as session:
async with session.get(questions_url) as response:
response.raise_for_status()
questions = await response.json()
logger.info(f"Fetched {len(questions)} questions.")
except Exception as e:
logger.error(f"Error fetching questions: {e}")
self.state["status_output"] = f"Error fetching questions: {e}"
self.state["error"] = f"Fetch questions failed: {str(e)}"
return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
answers_payload = []
for item in questions:
task_id = item.get("task_id")
question = item.get("question")
if not task_id or not question:
logger.warning(f"Skipping invalid item: {item}")
continue
answer = await self.process_question(task_id, question)
answers_payload.append({"task_id": task_id, "submitted_answer": answer})
if not answers_payload:
logger.error("No answers generated.")
self.state["status_output"] = "No answers to submit."
self.state["error"] = "No answers generated"
return pd.DataFrame(self.state["results_table"]), self.state["status_output"]
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
try:
async with await create_http_session() as session:
async with session.post(submit_url, json=submission_data) as response:
response.raise_for_status()
result_data = await response.json()
self.state["status_output"] = (
f"Submission Successful!\n"
f"User: {result_data.get('username')}\n"
f"Overall Score: {result_data.get('score', 'N/A')}% "
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
f"Message: {result_data.get('message', 'No message received.')}"
)
self.state["metadata"] = self.state.get("metadata", {}) | {"submission_score": result_data.get('score', 'N/A')}
except Exception as e:
logger.error(f"Submission failed: {e}")
self.state["status_output"] = f"Submission Failed: {e}"
self.state["error"] = f"Submission failed: {str(e)}"
return pd.DataFrame(self.state["results_table"] if self.state["results_table"] else [], columns=["Task ID", "Question", "Answer"]), self.state["status_output"]
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# JARVIS GAIA Agent")
gr.Markdown(
"""
**Instructions:**
1. Log in to Hugging Face using the button below.
2. Click 'Run Evaluation & Submit All Answers' to process GAIA questions and submit.
---
**Disclaimers:**
Uses Hugging Face Inference, Together AI, SERPAPI, and OpenWeatherMap for GAIA benchmark.
"""
)
with gr.Row():
gr.LoginButton(value="Login to Hugging Face")
run_button = gr.Button("Run Evaluation & Submit All Answers")
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
results_table = gr.DataFrame(label="Questions and Answers", wrap=True, headers=["Task ID", "Question", "Answer"])
agent = JARVISAgent()
run_button.click(
fn=agent.process_all_questions,
outputs=[results_table, status_output]
)
if __name__ == "__main__":
logger.info("\n" + "-"*30 + " App Starting " + "-"*30)
logger.info(f"SPACE_ID: {SPACE_ID}")
logger.info("Launching Gradio Interface...")
demo.launch(debug=True, share=False)