from langgraph.graph import StateGraph, END from langgraph.checkpoint.memory import MemorySaver from state import JARVISState from langchain_openai import ChatOpenAI from langchain_core.messages import SystemMessage, HumanMessage from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool from langfuse.callback import LangfuseCallbackHandler import json import os from dotenv import load_dotenv # Load environment variables load_dotenv() # Debug: Verify environment variables print(f"OPENAI_API_KEY loaded in graph.py: {'set' if os.getenv('OPENAI_API_KEY') else 'not set'}") print(f"LANGFUSE_PUBLIC_KEY loaded in graph.py: {'set' if os.getenv('LANGFUSE_PUBLIC_KEY') else 'not set'}") # Initialize LLM and Langfuse api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY environment variable not set") llm = ChatOpenAI(model="gpt-4o", api_key=api_key) langfuse = LangfuseCallbackHandler( public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), secret_key=os.getenv("LANGFUSE_SECRET_KEY"), host=os.getenv("LANGFUSE_HOST") ) memory = MemorySaver() # Question Parser Node async def parse_question(state: JARVISState) -> JARVISState: question = state["question"] prompt = f"""Analyze this GAIA question: {question} Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever). Return a JSON list of tool names.""" response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]}) tools_needed = json.loads(response.content) return {"messages": state["messages"] + [response], "tools_needed": tools_needed} # Web Search Agent Node async def web_search_agent(state: JARVISState) -> JARVISState: results = [] if "web_search" in state["tools_needed"]: result = await search_tool.arun(state["question"]) results.append(result) if "multi_hop_search" in state["tools_needed"]: result = await multi_hop_search_tool.aparse(state["question"], steps=3) results.append(result) return {"web_results": results} # File Parser Agent Node async def file_parser_agent(state: JARVISState) -> JARVISState: if "file_parser" in state["tools_needed"]: result = await file_parser_tool.aparse(state["task_id"]) return {"file_results": result} return {"file_results": ""} # Image Parser Agent Node async def image_parser_agent(state: JARVISState) -> JARVISState: if "image_parser" in state["tools_needed"]: task = "match" if "fruits" in state["question"].lower() else "describe" match_query = "fruits" if task == "match" else "" result = await image_parser_tool.aparse( f"temp_{state['task_id']}.jpg", task=task, match_query=match_query ) return {"image_results": result} return {"image_results": ""} # Calculator Agent Node async def calculator_agent(state: JARVISState) -> JARVISState: if "calculator" in state["tools_needed"]: prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}" response = await llm.ainvoke(prompt, config={"callbacks": [langfuse]}) expression = response.content result = await calculator_tool.aparse(expression) return {"calculation_results": result} return {"calculation_results": ""} # Document Retriever Agent Node async def document_retriever_agent(state: JARVISState) -> JARVISState: if "document_retriever" in state["tools_needed"]: file_type = "txt" if "menu" in state["question"].lower() else "csv" if "report" in state["question"].lower() or "document" in state["question"].lower(): file_type = "pdf" result = await document_retriever_tool.aparse( state["task_id"], state["question"], file_type=file_type ) return {"document_results": result} return {"document_results": ""} # Reasoning Agent Node async def reasoning_agent(state: JARVISState) -> JARVISState: prompt = f"""Question: {state['question']} Web Results: {state['web_results']} File Results: {state['file_results']} Image Results: {state['image_results']} Calculation Results: {state['calculation_results']} Document Results: {state['document_results']} Synthesize an exact-match answer for the GAIA benchmark. Output only the answer (e.g., '90', 'White;5876').""" response = await llm.ainvoke( [ SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."), HumanMessage(content=prompt) ], config={"callbacks": [langfuse]} ) return {"answer": response.content, "messages": state["messages"] + [response]} # Conditional Edge Router def router(state: JARVISState) -> str: if state["tools_needed"]: return "tools" return "reasoning" # Build Graph workflow = StateGraph(JARVISState) workflow.add_node("parse", parse_question) workflow.add_node("web_search", web_search_agent) workflow.add_node("file_parser", file_parser_agent) workflow.add_node("image_parser", image_parser_agent) workflow.add_node("calculator", calculator_agent) workflow.add_node("document_retriever", document_retriever_agent) workflow.add_node("reasoning", reasoning_agent) workflow.set_entry_point("parse") workflow.add_conditional_edges( "parse", router, { "tools": ["web_search", "file_parser", "image_parser", "calculator", "document_retriever"], "reasoning": "reasoning" } ) workflow.add_edge("web_search", "reasoning") workflow.add_edge("file_parser", "reasoning") workflow.add_edge("image_parser", "reasoning") workflow.add_edge("calculator", "reasoning") workflow.add_edge("document_retriever", "reasoning") workflow.add_edge("reasoning", END) graph = workflow.compile(checkpointer=memory)