Spaces:
Runtime error
Runtime error
from retriever import retrieve_context | |
from tools import tools | |
from langchain_core.runnables import RunnableParallel, RunnablePassthrough | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnableLambda | |
from langchain_core.messages import AIMessage, HumanMessage | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain_core.runnables import RunnableBranch | |
from langchain_community.chat_models import ChatOllama | |
from langgraph.graph import END, StateGraph | |
from typing import Annotated, TypedDict, List | |
import operator | |
model = ChatOllama(model="qwen:1.8b") | |
tools_with_names = {tool.name: tool for tool in tools} | |
class AgentState(TypedDict): | |
messages: Annotated[List], [] | |
next: str | |
tool_chain = ( | |
RunnableParallel({ | |
"message": lambda x: x["messages"][-1].content, | |
"tool": lambda x: x["next"] | |
}) | |
| (lambda x: tools_with_names[x["tool"]].invoke(x["message"])) | |
| (lambda x: {"messages": [AIMessage(content=str(x))], "next": "end"}) | |
) | |
system = """ | |
You are a helpful assistant. Use tools if needed. Keep responses short. | |
""" | |
prompt = PromptTemplate.from_template("""{context} | |
{question} | |
""") | |
context_chain = ( | |
{ | |
"context": RunnableLambda(retrieve_context), | |
"question": lambda x: x["messages"][-1].content, | |
} | |
| prompt | |
) | |
agent = context_chain | model | StrOutputParser() | (lambda x: {"messages": [AIMessage(content=x), HumanMessage(content="Do you want to use a tool?")], "next": "tool_picker"}) | |
conditional_agent = RunnableBranch( | |
(lambda x: "tool" in x["next"], tool_chain), | |
agent | |
) | |
def create_graph(): | |
graph_builder = StateGraph(AgentState) | |
graph_builder.add_node("agent", conditional_agent) | |
graph_builder.set_entry_point("agent") | |
graph_builder.add_node("tool_chain", tool_chain) | |
graph_builder.add_conditional_edges( | |
"agent", lambda x: x["next"], { | |
"tool": "tool_chain", | |
"end": END | |
} | |
) | |
graph_builder.add_edge("tool_chain", "agent") | |
return graph_builder.compile() | |
app = create_graph() | |
chain = RunnableWithMessageHistory( | |
app, | |
lambda session_id: {}, | |
input_messages_key="messages", | |
history_messages_key="messages", | |
) | |