Ahmud's picture
Update agent.py
aac43d6 verified
raw
history blame
6.41 kB
import os
import itertools
from typing import TypedDict, Annotated, Literal
from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langgraph.graph.message import add_messages
from langgraph.graph import MessagesState, StateGraph, START, END
from langchain_core.runnables import RunnableConfig # for LangSmith tracking
from langchain_community.tools import BraveSearch # web search
from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems
from tools import (
calculator_basic, datetime_tools, transcribe_audio,
transcribe_youtube, query_image, webpage_content, read_excel
)
from prompt import system_prompt
# --------------------------------------------------------------------
# 1. API Key Rotation Setup
# --------------------------------------------------------------------
api_keys = [k for k in [
os.getenv("OPENROUTER_API_KEY"),
os.getenv("OPENROUTER_API_KEY_1")
] if k]
if not api_keys:
raise EnvironmentError("No OpenRouter API keys found in environment variables.")
api_key_cycle = itertools.cycle(api_keys)
def get_next_api_key() -> str:
"""Get the next API key in rotation."""
return next(api_key_cycle)
class RotatingChatOpenAI(ChatOpenAI):
"""ChatOpenAI wrapper that automatically rotates OpenRouter API keys on failure."""
def invoke(self, *args, **kwargs):
for _ in range(len(api_keys)): # try each key once per call
current_key = get_next_api_key()
self.openai_api_key = current_key # ✅ Correct for ChatOpenAI
try:
return super().invoke(*args, **kwargs)
except Exception as e:
# Handle rate-limits or auth errors
if any(code in str(e) for code in ["429", "401", "403"]):
print(f"[API Key Rotation] Key {current_key[:5]}... failed, trying next key...")
continue
raise # Re-raise unexpected errors
raise RuntimeError("All OpenRouter API keys failed or were rate-limited.")
# --------------------------------------------------------------------
# 2. Initialize LLM with API Key Rotation
# --------------------------------------------------------------------
llm = RotatingChatOpenAI(
base_url="https://openrouter.ai/api/v1",
openai_api_key=get_next_api_key(), # ✅ start with the first key
model="qwen/qwen3-coder:free", # must support tool/function calling
temperature=1
)
# --------------------------------------------------------------------
# 3. Tools Setup
# --------------------------------------------------------------------
python_tool = PythonAstREPLTool()
search_tool = BraveSearch.from_api_key(
api_key=os.getenv("BRAVE_SEARCH_API"),
search_kwargs={"count": 4},
description="Web search using Brave"
)
community_tools = [search_tool, python_tool]
custom_tools = calculator_basic + datetime_tools + [
transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel
]
tools = community_tools + custom_tools
llm_with_tools = llm.bind_tools(tools)
tools_by_name = {tool.name: tool for tool in tools}
# --------------------------------------------------------------------
# 4. Define LangGraph State and Nodes
# --------------------------------------------------------------------
class MessagesState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# LLM Node
def llm_call(state: MessagesState):
return {
"messages": [
llm_with_tools.invoke(
[SystemMessage(content=system_prompt)] + state["messages"]
)
]
}
# Tool Node
def tool_node(state: MessagesState):
"""Executes tools requested by the LLM."""
results = []
last_message = state["messages"][-1]
for tool_call in getattr(last_message, "tool_calls", []) or []:
tool = tools_by_name.get(tool_call["name"])
if not tool:
results.append(ToolMessage(content=f"Unknown tool: {tool_call['name']}", tool_call_id=tool_call["id"]))
continue
args = tool_call["args"]
# Handle dict vs positional args safely
try:
observation = tool.invoke(**args) if isinstance(args, dict) else tool.invoke(args)
except Exception as e:
observation = f"[Tool Error] {str(e)}"
results.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": results}
# Conditional Routing
def should_continue(state: MessagesState) -> Literal["Action", END]:
"""Route to tools if LLM made a tool call, else end."""
last_message = state["messages"][-1]
return "Action" if getattr(last_message, "tool_calls", None) else END
# --------------------------------------------------------------------
# 5. Build LangGraph Agent
# --------------------------------------------------------------------
builder = StateGraph(MessagesState)
builder.add_node("llm_call", llm_call)
builder.add_node("environment", tool_node)
builder.add_edge(START, "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,
{"Action": "environment", END: END}
)
builder.add_edge("environment", "llm_call")
gaia_agent = builder.compile()
# --------------------------------------------------------------------
# 6. Agent Wrapper
# --------------------------------------------------------------------
class LangGraphAgent:
def __init__(self):
print("LangGraphAgent initialized with API key rotation.")
def __call__(self, question: str) -> str:
input_state = {"messages": [HumanMessage(content=question)]}
print(f"Running LangGraphAgent with input: {question[:150]}...")
config = RunnableConfig(
config={
"run_name": "GAIA Agent",
"tags": ["gaia", "langgraph", "agent"],
"metadata": {"user_input": question},
"recursion_limit": 30,
"tracing": True
}
)
result = gaia_agent.invoke(input_state, config)
final_response = result["messages"][-1].content
# Extract "FINAL ANSWER" if present
if isinstance(final_response, str):
return final_response.split("FINAL ANSWER:")[-1].strip()
else:
return str(final_response)