FergusFindley commited on
Commit
8c9df2c
·
verified ·
1 Parent(s): f4b1c3a

Create gaia_agent.py

Browse files
Files changed (1) hide show
  1. gaia_agent.py +200 -0
gaia_agent.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_ollama import ChatOllama
2
+ from langchain_core.messages import SystemMessage, HumanMessage
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import ToolNode, tools_condition
5
+ from langchain_community.tools import DuckDuckGoSearchRun
6
+ from langchain_community.tools.tavily_search import TavilySearchResults
7
+ from langchain_community.document_loaders import WikipediaLoader
8
+ from langchain_community.document_loaders import ArxivLoader
9
+ from langchain_core.tools import tool
10
+ from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
11
+ from math import sqrt
12
+
13
+ ### =============== MATHEMATICAL TOOLS =============== ###
14
+
15
+
16
+ @tool
17
+ def multiply(a: float, b: float) -> float:
18
+ """
19
+ Multiplies two numbers.
20
+ Args:
21
+ a (float): the first number
22
+ b (float): the second number
23
+ """
24
+ return a * b
25
+
26
+
27
+ @tool
28
+ def add(a: float, b: float) -> float:
29
+ """
30
+ Adds two numbers.
31
+ Args:
32
+ a (float): the first number
33
+ b (float): the second number
34
+ """
35
+ return a + b
36
+
37
+
38
+ @tool
39
+ def subtract(a: float, b: float) -> int:
40
+ """
41
+ Subtracts two numbers.
42
+ Args:
43
+ a (float): the first number
44
+ b (float): the second number
45
+ """
46
+ return a - b
47
+
48
+
49
+ @tool
50
+ def divide(a: float, b: float) -> float:
51
+ """
52
+ Divides two numbers.
53
+ Args:
54
+ a (float): the first float number
55
+ b (float): the second float number
56
+ """
57
+ if b == 0:
58
+ raise ValueError("Cannot divided by zero.")
59
+ return a / b
60
+
61
+
62
+ @tool
63
+ def modulus(a: int, b: int) -> int:
64
+ """
65
+ Get the modulus of two numbers.
66
+ Args:
67
+ a (int): the first number
68
+ b (int): the second number
69
+ """
70
+ return a % b
71
+
72
+
73
+ @tool
74
+ def power(a: float, b: float) -> float:
75
+ """
76
+ Get the power of two numbers.
77
+ Args:
78
+ a (float): the first number
79
+ b (float): the second number
80
+ """
81
+ return a**b
82
+
83
+
84
+ @tool
85
+ def square_root(a: float) -> float | complex:
86
+ """
87
+ Get the square root of a number.
88
+ Args:
89
+ a (float): the number to get the square root of
90
+ """
91
+ if a >= 0:
92
+ return a**0.5
93
+ return sqrt(a)
94
+
95
+
96
+ ### =============== BROWSER TOOLS =============== ###
97
+
98
+ search_tool = DuckDuckGoSearchRun()
99
+
100
+
101
+ @tool
102
+ def wiki_search(query: str) -> str:
103
+ """Search Wikipedia for a query and return maximum 2 results.
104
+ Args:
105
+ query: The search query."""
106
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
107
+ formatted_search_docs = "\n\n---\n\n".join(
108
+ [
109
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
110
+ for doc in search_docs
111
+ ]
112
+ )
113
+ return {"wiki_results": formatted_search_docs}
114
+
115
+
116
+ @tool
117
+ def web_search(query: str) -> str:
118
+ """Search Tavily for a query and return maximum 3 results.
119
+ Args:
120
+ query: The search query."""
121
+ print(f"I'm running web_search with {query = }")
122
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
123
+ formatted_search_docs = "\n\n---\n\n".join(
124
+ [
125
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
126
+ for doc in search_docs
127
+ ]
128
+ )
129
+ return {"web_results": formatted_search_docs}
130
+
131
+
132
+ @tool
133
+ def arxiv_search(query: str) -> str:
134
+ """Search Arxiv for a query and return maximum 3 result.
135
+ Args:
136
+ query: The search query."""
137
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
138
+ formatted_search_docs = "\n\n---\n\n".join(
139
+ [
140
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
141
+ for doc in search_docs
142
+ ]
143
+ )
144
+ return {"arxiv_results": formatted_search_docs}
145
+
146
+
147
+ tools = [multiply, add, subtract, divide, modulus, power, square_root, web_search, arxiv_search, wiki_search]
148
+
149
+
150
+ GAIA_SYSTEM_PROMPT = """
151
+ You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
152
+ """
153
+ sys_msg = SystemMessage(content=GAIA_SYSTEM_PROMPT)
154
+
155
+
156
+ # Build graph function
157
+ def build_graph(provider: str = "ollama"):
158
+ """Build the graph"""
159
+ # Load environment variables from .env file
160
+ if provider == "ollama":
161
+ chat = ChatOllama(model="llama3.1")
162
+ elif provider == "huggingface":
163
+ llm = HuggingFaceEndpoint(
164
+ repo_id="Qwen/Qwen2.5-Coder-32B-Instruct"
165
+ )
166
+
167
+ chat = ChatHuggingFace(llm=llm, verbose=True)
168
+ else:
169
+ raise ValueError("Invalid provider. Choose 'ollama' or 'huggingface'.")
170
+ # Bind tools to LLM
171
+ chat_with_tools = chat.bind_tools(tools)
172
+
173
+ # Node
174
+ def assistant(state: MessagesState):
175
+ """Assistant node"""
176
+ print([sys_msg] + state["messages"])
177
+ return {"messages": [chat_with_tools.invoke([sys_msg] + state["messages"])]}
178
+
179
+ builder = StateGraph(MessagesState)
180
+ builder.add_node("assistant", assistant)
181
+ builder.add_node("tools", ToolNode(tools))
182
+ builder.add_edge(START, "assistant")
183
+ builder.add_conditional_edges(
184
+ "assistant",
185
+ tools_condition,
186
+ )
187
+ builder.add_edge("tools", "assistant")
188
+
189
+ return builder.compile()
190
+
191
+
192
+ # test
193
+ if __name__ == "__main__":
194
+ question = "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""
195
+ # fixed_answer = "extremely"
196
+ graph = build_graph(provider="huggingface")
197
+ messages = [HumanMessage(content=question)]
198
+ messages = graph.invoke({"messages": messages})
199
+ for m in messages["messages"]:
200
+ m.pretty_print()