Spaces:
Sleeping
Sleeping
File size: 5,842 Bytes
9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 9c9b3ff 66f6cc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import os
import itertools
from typing import TypedDict, Annotated, Literal
from langchain_openai import ChatOpenAI
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langgraph.graph.message import add_messages
from langgraph.graph import MessagesState, StateGraph, START, END
from langchain_core.runnables import RunnableConfig # for LangSmith tracking
from langchain_community.tools import BraveSearch # web search
from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems
from tools import (
calculator_basic, datetime_tools, transcribe_audio,
transcribe_youtube, query_image, webpage_content, read_excel
)
from prompt import system_prompt
# --------------------------------------------------------------------
# 1. API Key Rotation Setup
# --------------------------------------------------------------------
api_keys = [
os.getenv("OPENROUTER_API_KEY"),
os.getenv("OPENROUTER_API_KEY_1")
]
if not any(api_keys):
raise EnvironmentError("No OpenRouter API keys found in environment variables.")
api_key_cycle = itertools.cycle([k for k in api_keys if k])
def get_next_api_key():
"""Get the next API key in rotation."""
return next(api_key_cycle)
class RotatingChatOpenAI(ChatOpenAI):
"""ChatOpenAI wrapper that automatically rotates API keys on failure."""
def invoke(self, *args, **kwargs):
# Try each key once per call
for _ in range(len(api_keys)):
self.api_key = get_next_api_key()
try:
return super().invoke(*args, **kwargs)
except Exception as e:
# Handle rate-limits or auth errors
if any(code in str(e) for code in ["429", "401", "403"]):
print(f"[API Key Rotation] Key {self.api_key[:5]}... failed, trying next key...")
continue
raise # Re-raise other unexpected errors
raise RuntimeError("All OpenRouter API keys failed or rate-limited.")
# --------------------------------------------------------------------
# 2. Initialize LLM with API Key Rotation
# --------------------------------------------------------------------
llm = RotatingChatOpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=get_next_api_key(), # Start with the first key
model="qwen/qwen3-coder:free", # Model must support function calling
temperature=1
)
# --------------------------------------------------------------------
# 3. Tools Setup
# --------------------------------------------------------------------
python_tool = PythonAstREPLTool()
search_tool = BraveSearch.from_api_key(
api_key=os.getenv("BRAVE_SEARCH_API"),
search_kwargs={"count": 4},
description="Web search using Brave"
)
community_tools = [search_tool, python_tool]
custom_tools = calculator_basic + datetime_tools + [
transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel
]
tools = community_tools + custom_tools
llm_with_tools = llm.bind_tools(tools)
tools_by_name = {tool.name: tool for tool in tools}
# --------------------------------------------------------------------
# 4. Define LangGraph State and Nodes
# --------------------------------------------------------------------
class MessagesState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# LLM Node
def llm_call(state: MessagesState):
return {
"messages": [
llm_with_tools.invoke(
[SystemMessage(content=system_prompt)] + state["messages"]
)
]
}
# Tool Node
def tool_node(state: MessagesState):
"""Executes tools requested by the LLM."""
result = []
for tool_call in state["messages"][-1].tool_calls:
tool = tools_by_name[tool_call["name"]]
observation = tool.invoke(tool_call["args"])
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
return {"messages": result}
# Conditional Routing
def should_continue(state: MessagesState) -> Literal["Action", END]:
"""Route to tools if LLM made a tool call, else end."""
last_message = state["messages"][-1]
return "Action" if last_message.tool_calls else END
# --------------------------------------------------------------------
# 5. Build LangGraph Agent
# --------------------------------------------------------------------
builder = StateGraph(MessagesState)
builder.add_node("llm_call", llm_call)
builder.add_node("environment", tool_node)
builder.add_edge(START, "llm_call")
builder.add_conditional_edges(
"llm_call",
should_continue,
{"Action": "environment", END: END}
)
builder.add_edge("environment", "llm_call")
gaia_agent = builder.compile()
# --------------------------------------------------------------------
# 6. Agent Wrapper
# --------------------------------------------------------------------
class LangGraphAgent:
def __init__(self):
print("LangGraphAgent initialized with API key rotation.")
def __call__(self, question: str) -> str:
input_state = {"messages": [HumanMessage(content=question)]}
print(f"Running LangGraphAgent with input: {question[:150]}...")
config = RunnableConfig(
config={
"run_name": "GAIA Agent",
"tags": ["gaia", "langgraph", "agent"],
"metadata": {"user_input": question},
"recursion_limit": 30,
"tracing": True
}
)
result = gaia_agent.invoke(input_state, config)
final_response = result["messages"][-1].content
try:
return final_response.split("FINAL ANSWER:")[-1].strip()
except Exception:
print("Could not split on 'FINAL ANSWER:'")
return final_response
|