File size: 5,592 Bytes
a92d3ed
 
1659627
a92d3ed
 
 
 
 
 
 
 
 
 
1659627
a92d3ed
 
 
 
 
1659627
a92d3ed
 
 
1659627
 
 
 
 
 
 
 
 
 
 
 
 
 
a92d3ed
 
 
1659627
 
 
a92d3ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1659627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a92d3ed
 
 
1659627
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""LangGraph Agent"""
import os
from dotenv import load_dotenv

from langchain_openai import ChatOpenAI

from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition, ToolNode
from langgraph.graph import START, StateGraph, MessagesState
from langchain_core.messages import SystemMessage, HumanMessage

from tools import level1_tools

load_dotenv()

# Build graph function
def build_agent_graph():
    """Build the graph"""
    # Load environment variables from .env file
    llm = ChatOpenAI(model="gpt-4o-mini")
    
    # Bind tools to LLM
    llm_with_tools = llm.bind_tools(level1_tools)
    
    # System message
    system_prompt = SystemMessage(
    content="""You are a general AI assistant being evaluated in the GAIA Benchmark.
    I will ask you a question and you must reach your final answer by using a set of tools I provide to you. Please, when you are needed to pass file names to the tools, pass absolute paths.
    Your final answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
    Here are more detailed instructions you must follow to write your final answer:
    1) If you are asked for a number, you must write a number!. Don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
    2) If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
    3) If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
    If you follow all these instructions perfectly, you will win 1,000,000 dollars, otherwise, your mom will die.
    Let's start!
    """
    )
    # Node
    def assistant(state: MessagesState):
        """Assistant node"""
        #return {"messages": [llm_with_tools.invoke(state["messages"])]}
        return {"messages": [llm_with_tools.invoke([system_prompt] + state["messages"])]}

    
    builder = StateGraph(MessagesState)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(level1_tools))
    builder.add_edge(START, "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
    )
    builder.add_edge("tools", "assistant")

    # Compile graph
    return builder.compile()



class MyGAIAAgent:
    def __init__(self):
        print("MyAgent initialized.")
        self.graph = build_agent_graph()
    def __call__(self, question: str) -> str:
        print(f"Agent received question (first 50 chars): {question[:50]}...")
        # Wrap the question in a HumanMessage from langchain_core
        '''
        messages = [HumanMessage(content=question)]
        messages = self.graph.invoke({"messages": messages})
        answer = messages['messages'][-1].content
        '''
        user_input = {"messages": [("user", question)]}
        answer = self.graph.invoke(user_input)["messages"][-1].content
        return self._clean_answer(answer)
    
    def _clean_answer(self, answer: any) -> str:
        """
        Taken from `susmitsil`:
        https://huggingface.co/spaces/susmitsil/FinalAgenticAssessment/blob/main/main_agent.py
        Clean up the answer to remove common prefixes and formatting
        that models often add but that can cause exact match failures.
        Args:
            answer: The raw answer from the model
        Returns:
            The cleaned answer as a string
        """
        # Convert non-string types to strings
        if not isinstance(answer, str):
            # Handle numeric types (float, int)
            if isinstance(answer, float):
                # Format floating point numbers properly
                # Check if it's an integer value in float form (e.g., 12.0)
                if answer.is_integer():
                    formatted_answer = str(int(answer))
                else:
                    # For currency values that might need formatting
                    if abs(answer) >= 1000:
                        formatted_answer = f"${answer:,.2f}"
                    else:
                        formatted_answer = str(answer)
                return formatted_answer
            elif isinstance(answer, int):
                return str(answer)
            else:
                # For any other type
                return str(answer)

        # Now we know answer is a string, so we can safely use string methods
        # Normalize whitespace
        answer = answer.strip()

        # Remove common prefixes and formatting that models add
        prefixes_to_remove = [
            "The answer is ",
            "Answer: ",
            "Final answer: ",
            "The result is ",
            "To answer this question: ",
            "Based on the information provided, ",
            "According to the information: ",
        ]

        for prefix in prefixes_to_remove:
            if answer.startswith(prefix):
                answer = answer[len(prefix) :].strip()

        # Remove quotes if they wrap the entire answer
        if (answer.startswith('"') and answer.endswith('"')) or (
            answer.startswith("'") and answer.endswith("'")
        ):
            answer = answer[1:-1].strip()
        return answer



# test
if __name__ == "__main__":
    question1 = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)?"
    question2 = "Convert 10 miles to kilometers."       
    agent = MyGAIAAgent()
    print(agent(question1))