Spaces:
Sleeping
Sleeping
import os | |
import itertools | |
from typing import TypedDict, Annotated, Literal | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage | |
from langgraph.graph.message import add_messages | |
from langgraph.graph import MessagesState, StateGraph, START, END | |
from langchain_core.runnables import RunnableConfig # for LangSmith tracking | |
from langchain_community.tools import BraveSearch # web search | |
from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems | |
from tools import ( | |
calculator_basic, datetime_tools, transcribe_audio, | |
transcribe_youtube, query_image, webpage_content, read_excel | |
) | |
from prompt import system_prompt | |
# -------------------------------------------------------------------- | |
# 1. API Key Rotation Setup | |
# -------------------------------------------------------------------- | |
api_keys = [ | |
os.getenv("OPENROUTER_API_KEY"), | |
os.getenv("OPENROUTER_API_KEY_1") | |
] | |
if not any(api_keys): | |
raise EnvironmentError("No OpenRouter API keys found in environment variables.") | |
api_key_cycle = itertools.cycle([k for k in api_keys if k]) | |
def get_next_api_key(): | |
"""Get the next API key in rotation.""" | |
return next(api_key_cycle) | |
class RotatingChatOpenAI(ChatOpenAI): | |
"""ChatOpenAI wrapper that automatically rotates API keys on failure.""" | |
def invoke(self, *args, **kwargs): | |
# Try each key once per call | |
for _ in range(len(api_keys)): | |
self.api_key = get_next_api_key() | |
try: | |
return super().invoke(*args, **kwargs) | |
except Exception as e: | |
# Handle rate-limits or auth errors | |
if any(code in str(e) for code in ["429", "401", "403"]): | |
print(f"[API Key Rotation] Key {self.api_key[:5]}... failed, trying next key...") | |
continue | |
raise # Re-raise other unexpected errors | |
raise RuntimeError("All OpenRouter API keys failed or rate-limited.") | |
# -------------------------------------------------------------------- | |
# 2. Initialize LLM with API Key Rotation | |
# -------------------------------------------------------------------- | |
llm = RotatingChatOpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=get_next_api_key(), # Start with the first key | |
model="qwen/qwen3-coder:free", # Model must support function calling | |
temperature=1 | |
) | |
# -------------------------------------------------------------------- | |
# 3. Tools Setup | |
# -------------------------------------------------------------------- | |
python_tool = PythonAstREPLTool() | |
search_tool = BraveSearch.from_api_key( | |
api_key=os.getenv("BRAVE_SEARCH_API"), | |
search_kwargs={"count": 4}, | |
description="Web search using Brave" | |
) | |
community_tools = [search_tool, python_tool] | |
custom_tools = calculator_basic + datetime_tools + [ | |
transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel | |
] | |
tools = community_tools + custom_tools | |
llm_with_tools = llm.bind_tools(tools) | |
tools_by_name = {tool.name: tool for tool in tools} | |
# -------------------------------------------------------------------- | |
# 4. Define LangGraph State and Nodes | |
# -------------------------------------------------------------------- | |
class MessagesState(TypedDict): | |
messages: Annotated[list[AnyMessage], add_messages] | |
# LLM Node | |
def llm_call(state: MessagesState): | |
return { | |
"messages": [ | |
llm_with_tools.invoke( | |
[SystemMessage(content=system_prompt)] + state["messages"] | |
) | |
] | |
} | |
# Tool Node | |
def tool_node(state: MessagesState): | |
"""Executes tools requested by the LLM.""" | |
result = [] | |
for tool_call in state["messages"][-1].tool_calls: | |
tool = tools_by_name[tool_call["name"]] | |
observation = tool.invoke(tool_call["args"]) | |
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) | |
return {"messages": result} | |
# Conditional Routing | |
def should_continue(state: MessagesState) -> Literal["Action", END]: | |
"""Route to tools if LLM made a tool call, else end.""" | |
last_message = state["messages"][-1] | |
return "Action" if last_message.tool_calls else END | |
# -------------------------------------------------------------------- | |
# 5. Build LangGraph Agent | |
# -------------------------------------------------------------------- | |
builder = StateGraph(MessagesState) | |
builder.add_node("llm_call", llm_call) | |
builder.add_node("environment", tool_node) | |
builder.add_edge(START, "llm_call") | |
builder.add_conditional_edges( | |
"llm_call", | |
should_continue, | |
{"Action": "environment", END: END} | |
) | |
builder.add_edge("environment", "llm_call") | |
gaia_agent = builder.compile() | |
# -------------------------------------------------------------------- | |
# 6. Agent Wrapper | |
# -------------------------------------------------------------------- | |
class LangGraphAgent: | |
def __init__(self): | |
print("LangGraphAgent initialized with API key rotation.") | |
def __call__(self, question: str) -> str: | |
input_state = {"messages": [HumanMessage(content=question)]} | |
print(f"Running LangGraphAgent with input: {question[:150]}...") | |
config = RunnableConfig( | |
config={ | |
"run_name": "GAIA Agent", | |
"tags": ["gaia", "langgraph", "agent"], | |
"metadata": {"user_input": question}, | |
"recursion_limit": 30, | |
"tracing": True | |
} | |
) | |
result = gaia_agent.invoke(input_state, config) | |
final_response = result["messages"][-1].content | |
try: | |
return final_response.split("FINAL ANSWER:")[-1].strip() | |
except Exception: | |
print("Could not split on 'FINAL ANSWER:'") | |
return final_response | |