mchinea
update answer and agent class
1659627
raw
history blame
5.59 kB
"""LangGraph Agent"""
import os
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.graph import START, StateGraph, MessagesState
from langchain_core.messages import SystemMessage, HumanMessage
from tools import level1_tools
load_dotenv()
# Build graph function
def build_agent_graph():
"""Build the graph"""
# Load environment variables from .env file
llm = ChatOpenAI(model="gpt-4o-mini")
# Bind tools to LLM
llm_with_tools = llm.bind_tools(level1_tools)
# System message
system_prompt = SystemMessage(
content="""You are a general AI assistant being evaluated in the GAIA Benchmark.
I will ask you a question and you must reach your final answer by using a set of tools I provide to you. Please, when you are needed to pass file names to the tools, pass absolute paths.
Your final answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
Here are more detailed instructions you must follow to write your final answer:
1) If you are asked for a number, you must write a number!. Don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
2) If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
3) If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
If you follow all these instructions perfectly, you will win 1,000,000 dollars, otherwise, your mom will die.
Let's start!
"""
)
# Node
def assistant(state: MessagesState):
"""Assistant node"""
#return {"messages": [llm_with_tools.invoke(state["messages"])]}
return {"messages": [llm_with_tools.invoke([system_prompt] + state["messages"])]}
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(level1_tools))
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition,
)
builder.add_edge("tools", "assistant")
# Compile graph
return builder.compile()
class MyGAIAAgent:
def __init__(self):
print("MyAgent initialized.")
self.graph = build_agent_graph()
def __call__(self, question: str) -> str:
print(f"Agent received question (first 50 chars): {question[:50]}...")
# Wrap the question in a HumanMessage from langchain_core
'''
messages = [HumanMessage(content=question)]
messages = self.graph.invoke({"messages": messages})
answer = messages['messages'][-1].content
'''
user_input = {"messages": [("user", question)]}
answer = self.graph.invoke(user_input)["messages"][-1].content
return self._clean_answer(answer)
def _clean_answer(self, answer: any) -> str:
"""
Taken from `susmitsil`:
https://huggingface.co/spaces/susmitsil/FinalAgenticAssessment/blob/main/main_agent.py
Clean up the answer to remove common prefixes and formatting
that models often add but that can cause exact match failures.
Args:
answer: The raw answer from the model
Returns:
The cleaned answer as a string
"""
# Convert non-string types to strings
if not isinstance(answer, str):
# Handle numeric types (float, int)
if isinstance(answer, float):
# Format floating point numbers properly
# Check if it's an integer value in float form (e.g., 12.0)
if answer.is_integer():
formatted_answer = str(int(answer))
else:
# For currency values that might need formatting
if abs(answer) >= 1000:
formatted_answer = f"${answer:,.2f}"
else:
formatted_answer = str(answer)
return formatted_answer
elif isinstance(answer, int):
return str(answer)
else:
# For any other type
return str(answer)
# Now we know answer is a string, so we can safely use string methods
# Normalize whitespace
answer = answer.strip()
# Remove common prefixes and formatting that models add
prefixes_to_remove = [
"The answer is ",
"Answer: ",
"Final answer: ",
"The result is ",
"To answer this question: ",
"Based on the information provided, ",
"According to the information: ",
]
for prefix in prefixes_to_remove:
if answer.startswith(prefix):
answer = answer[len(prefix) :].strip()
# Remove quotes if they wrap the entire answer
if (answer.startswith('"') and answer.endswith('"')) or (
answer.startswith("'") and answer.endswith("'")
):
answer = answer[1:-1].strip()
return answer
# test
if __name__ == "__main__":
question1 = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?"
question2 = "Convert 10 miles to kilometers."
agent = MyGAIAAgent()
print(agent(question1))