Update gaia_agent.py
Browse files- gaia_agent.py +7 -1
gaia_agent.py
CHANGED
@@ -8,6 +8,7 @@ from langchain_community.document_loaders import WikipediaLoader
|
|
8 |
from langchain_community.document_loaders import ArxivLoader
|
9 |
from langchain_core.tools import tool
|
10 |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
|
|
11 |
from math import sqrt
|
12 |
|
13 |
### =============== MATHEMATICAL TOOLS =============== ###
|
@@ -165,6 +166,11 @@ def build_graph(provider: str = "huggingface"):
|
|
165 |
)
|
166 |
|
167 |
chat = ChatHuggingFace(llm=llm, verbose=True)
|
|
|
|
|
|
|
|
|
|
|
168 |
else:
|
169 |
raise ValueError("Invalid provider. Choose 'ollama' or 'huggingface'.")
|
170 |
# Bind tools to LLM
|
@@ -193,7 +199,7 @@ def build_graph(provider: str = "huggingface"):
|
|
193 |
if __name__ == "__main__":
|
194 |
question = "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""
|
195 |
# fixed_answer = "extremely"
|
196 |
-
graph = build_graph(provider="
|
197 |
messages = [HumanMessage(content=question)]
|
198 |
messages = graph.invoke({"messages": messages})
|
199 |
for m in messages["messages"]:
|
|
|
8 |
from langchain_community.document_loaders import ArxivLoader
|
9 |
from langchain_core.tools import tool
|
10 |
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
|
11 |
+
from langchain_xai import ChatXAI
|
12 |
from math import sqrt
|
13 |
|
14 |
### =============== MATHEMATICAL TOOLS =============== ###
|
|
|
166 |
)
|
167 |
|
168 |
chat = ChatHuggingFace(llm=llm, verbose=True)
|
169 |
+
elif provider == "xai":
|
170 |
+
chat = ChatXAI(
|
171 |
+
model="grok-3-latest",
|
172 |
+
temperature=0,
|
173 |
+
)
|
174 |
else:
|
175 |
raise ValueError("Invalid provider. Choose 'ollama' or 'huggingface'.")
|
176 |
# Bind tools to LLM
|
|
|
199 |
if __name__ == "__main__":
|
200 |
question = "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""
|
201 |
# fixed_answer = "extremely"
|
202 |
+
graph = build_graph(provider="xai")
|
203 |
messages = [HumanMessage(content=question)]
|
204 |
messages = graph.invoke({"messages": messages})
|
205 |
for m in messages["messages"]:
|