import os import itertools from typing import TypedDict, Annotated, Literal from langchain_openai import ChatOpenAI from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage from langgraph.graph.message import add_messages from langgraph.graph import MessagesState, StateGraph, START, END from langchain_core.runnables import RunnableConfig # for LangSmith tracking from langchain_community.tools import BraveSearch # web search from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems from tools import ( calculator_basic, datetime_tools, transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel ) from prompt import system_prompt # -------------------------------------------------------------------- # 1. API Key Rotation Setup # -------------------------------------------------------------------- api_keys = [k for k in [ os.getenv("OPENROUTER_API_KEY"), os.getenv("OPENROUTER_API_KEY_1") ] if k] if not api_keys: raise EnvironmentError("No OpenRouter API keys found in environment variables.") api_key_cycle = itertools.cycle(api_keys) def get_next_api_key() -> str: """Get the next API key in rotation.""" return next(api_key_cycle) class RotatingChatOpenAI(ChatOpenAI): """ChatOpenAI wrapper that automatically rotates OpenRouter API keys on failure.""" def invoke(self, *args, **kwargs): for _ in range(len(api_keys)): # try each key once per call current_key = get_next_api_key() self.openai_api_key = current_key # ✅ Correct for ChatOpenAI try: return super().invoke(*args, **kwargs) except Exception as e: # Handle rate-limits or auth errors if any(code in str(e) for code in ["429", "401", "403"]): print(f"[API Key Rotation] Key {current_key[:5]}... failed, trying next key...") continue raise # Re-raise unexpected errors raise RuntimeError("All OpenRouter API keys failed or were rate-limited.") # -------------------------------------------------------------------- # 2. Initialize LLM with API Key Rotation # -------------------------------------------------------------------- llm = RotatingChatOpenAI( base_url="https://openrouter.ai/api/v1", openai_api_key=get_next_api_key(), # ✅ start with the first key model="qwen/qwen3-coder:free", # must support tool/function calling temperature=1 ) # -------------------------------------------------------------------- # 3. Tools Setup # -------------------------------------------------------------------- python_tool = PythonAstREPLTool() search_tool = BraveSearch.from_api_key( api_key=os.getenv("BRAVE_SEARCH_API"), search_kwargs={"count": 4}, description="Web search using Brave" ) community_tools = [search_tool, python_tool] custom_tools = calculator_basic + datetime_tools + [ transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel ] tools = community_tools + custom_tools llm_with_tools = llm.bind_tools(tools) tools_by_name = {tool.name: tool for tool in tools} # -------------------------------------------------------------------- # 4. Define LangGraph State and Nodes # -------------------------------------------------------------------- class MessagesState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] # LLM Node def llm_call(state: MessagesState): return { "messages": [ llm_with_tools.invoke( [SystemMessage(content=system_prompt)] + state["messages"] ) ] } # Tool Node def tool_node(state: MessagesState): """Executes tools requested by the LLM.""" results = [] last_message = state["messages"][-1] for tool_call in getattr(last_message, "tool_calls", []) or []: tool = tools_by_name.get(tool_call["name"]) if not tool: results.append(ToolMessage(content=f"Unknown tool: {tool_call['name']}", tool_call_id=tool_call["id"])) continue args = tool_call["args"] # Handle dict vs positional args safely try: observation = tool.invoke(**args) if isinstance(args, dict) else tool.invoke(args) except Exception as e: observation = f"[Tool Error] {str(e)}" results.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) return {"messages": results} # Conditional Routing def should_continue(state: MessagesState) -> Literal["Action", END]: """Route to tools if LLM made a tool call, else end.""" last_message = state["messages"][-1] return "Action" if getattr(last_message, "tool_calls", None) else END # -------------------------------------------------------------------- # 5. Build LangGraph Agent # -------------------------------------------------------------------- builder = StateGraph(MessagesState) builder.add_node("llm_call", llm_call) builder.add_node("environment", tool_node) builder.add_edge(START, "llm_call") builder.add_conditional_edges( "llm_call", should_continue, {"Action": "environment", END: END} ) builder.add_edge("environment", "llm_call") gaia_agent = builder.compile() # -------------------------------------------------------------------- # 6. Agent Wrapper # -------------------------------------------------------------------- class LangGraphAgent: def __init__(self): print("LangGraphAgent initialized with API key rotation.") def __call__(self, question: str) -> str: input_state = {"messages": [HumanMessage(content=question)]} print(f"Running LangGraphAgent with input: {question[:150]}...") config = RunnableConfig( config={ "run_name": "GAIA Agent", "tags": ["gaia", "langgraph", "agent"], "metadata": {"user_input": question}, "recursion_limit": 30, "tracing": True } ) result = gaia_agent.invoke(input_state, config) final_response = result["messages"][-1].content # Extract "FINAL ANSWER" if present if isinstance(final_response, str): return final_response.split("FINAL ANSWER:")[-1].strip() else: return str(final_response)