Spaces:
Starting
Starting
from typing import TypedDict, List, Dict, Optional, Any, Union | |
from langchain_core.messages import BaseMessage | |
import logging | |
logger = logging.getLogger(__name__) | |
class JARVISState(TypedDict): | |
""" | |
State dictionary for the JARVIS GAIA Agent, used with LangGraph to manage task processing. | |
Attributes: | |
task_id: Unique identifier for the GAIA task. | |
question: The question text to be answered. | |
tools_needed: List of tool names to be used for the task. | |
web_results: List of web search results (e.g., from SERPAPI, DuckDuckGo). | |
file_results: Parsed content from text, CSV, Excel, or audio files. | |
image_results: OCR or description results from image files. | |
calculation_results: Results from mathematical calculations. | |
document_results: Extracted content from PDF or text documents. | |
multi_hop_results: Results from iterative multi-hop searches (supports strings or dicts). | |
messages: List of messages for LLM context (e.g., user prompts, system instructions). | |
answer: Final answer for the task, formatted for GAIA submission. | |
results_table: List of task results for Gradio display (Task ID, Question, Answer). | |
status_output: Status message for Gradio output (e.g., submission result). | |
error: Optional error message if task processing fails. | |
metadata: Optional metadata (e.g., timestamps, tool execution status). | |
""" | |
task_id: str | |
question: str | |
tools_needed: List[str] | |
web_results: List[str] | |
file_results: str | |
image_results: str | |
calculation_results: str | |
document_results: str | |
multi_hop_results: List[Union[str, Dict[str, Any]]] | |
messages: List[BaseMessage] | |
answer: str | |
results_table: List[Dict[str, str]] | |
status_output: str | |
error: Optional[str] | |
metadata: Optional[Dict[str, Any]] | |
def validate_state(state: JARVISState) -> JARVISState: | |
""" | |
Validate and initialize JARVISState fields. | |
Args: | |
state: Input state dictionary. | |
Returns: | |
Validated and initialized state. | |
""" | |
try: | |
if not state.get("task_id"): | |
logger.error("task_id is required") | |
raise ValueError("task_id is required") | |
if not state.get("question"): | |
logger.error("question is required") | |
raise ValueError("question is required") | |
# Initialize default values if missing | |
defaults = { | |
"tools_needed": ["search_tool"], | |
"web_results": [], | |
"file_results": "", | |
"image_results": "", | |
"calculation_results": "", | |
"document_results": "", | |
"multi_hop_results": [], | |
"messages": [], | |
"answer": "", | |
"results_table": [], | |
"status_output": "", | |
"error": None, | |
"metadata": {} | |
} | |
for key, default in defaults.items(): | |
if key not in state or state[key] is None: | |
state[key] = default | |
logger.debug(f"Validated state for task {state['task_id']}") | |
return state | |
except Exception as e: | |
logger.error(f"State validation failed: {e}") | |
raise | |
def reset_state(task_id: str, question: str) -> JARVISState: | |
""" | |
Create a fresh JARVISState for a new task. | |
Args: | |
task_id: Task identifier. | |
question: Question text. | |
Returns: | |
Initialized JARVISState. | |
""" | |
state = JARVISState( | |
task_id=task_id, | |
question=question, | |
tools_needed=["search_tool"], | |
web_results=[], | |
file_results="", | |
image_results="", | |
calculation_results="", | |
document_results="", | |
multi_hop_results=[], | |
messages=[], | |
answer="", | |
results_table=[], | |
status_output="", | |
error=None, | |
metadata={} | |
) | |
return validate_state(state) |