gaia_agent_code / agent.py
shrutikaP8497's picture
Update agent.py
0375d91 verified
from retriever import retrieve_context
from tools import tools
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables import RunnableBranch
from langchain_community.chat_models import ChatOllama
from langgraph.graph import END, StateGraph
from typing import Annotated, TypedDict, List
import operator
model = ChatOllama(model="qwen:1.8b")
tools_with_names = {tool.name: tool for tool in tools}
class AgentState(TypedDict):
messages: Annotated[List], []
next: str
tool_chain = (
RunnableParallel({
"message": lambda x: x["messages"][-1].content,
"tool": lambda x: x["next"]
})
| (lambda x: tools_with_names[x["tool"]].invoke(x["message"]))
| (lambda x: {"messages": [AIMessage(content=str(x))], "next": "end"})
)
system = """
You are a helpful assistant. Use tools if needed. Keep responses short.
"""
prompt = PromptTemplate.from_template("""{context}
{question}
""")
context_chain = (
{
"context": RunnableLambda(retrieve_context),
"question": lambda x: x["messages"][-1].content,
}
| prompt
)
agent = context_chain | model | StrOutputParser() | (lambda x: {"messages": [AIMessage(content=x), HumanMessage(content="Do you want to use a tool?")], "next": "tool_picker"})
conditional_agent = RunnableBranch(
(lambda x: "tool" in x["next"], tool_chain),
agent
)
def create_graph():
graph_builder = StateGraph(AgentState)
graph_builder.add_node("agent", conditional_agent)
graph_builder.set_entry_point("agent")
graph_builder.add_node("tool_chain", tool_chain)
graph_builder.add_conditional_edges(
"agent", lambda x: x["next"], {
"tool": "tool_chain",
"end": END
}
)
graph_builder.add_edge("tool_chain", "agent")
return graph_builder.compile()
app = create_graph()
chain = RunnableWithMessageHistory(
app,
lambda session_id: {},
input_messages_key="messages",
history_messages_key="messages",
)