Spaces:
Sleeping
Sleeping
from typing import Annotated, Sequence, TypedDict | |
import functools | |
import operator | |
from bertopic import BERTopic | |
from langchain.agents import AgentExecutor, create_openai_tools_agent | |
from langchain_core.messages import BaseMessage, HumanMessage | |
from langchain_openai import ChatOpenAI | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser | |
from langgraph.graph import END, StateGraph | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_experimental.tools import PythonREPLTool | |
# Initialize tools | |
tavily_tool = TavilySearchResults(max_results=5) | |
python_repl_tool = PythonREPLTool() | |
# Load BERTopic model | |
topic_model = BERTopic.load("topic_model") | |
# Function to create an agent | |
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str): | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_prompt), | |
MessagesPlaceholder(variable_name="messages"), | |
MessagesPlaceholder(variable_name="agent_scratchpad"), | |
] | |
) | |
agent = create_openai_tools_agent(llm, tools, prompt) | |
executor = AgentExecutor(agent=agent, tools=tools) | |
return executor | |
# Function to define an agent node | |
def agent_node(state, agent, name): | |
result = agent.invoke(state) | |
return {"messages": [HumanMessage(content=result["output"], name=name)]} | |
# Define the Viewer agent using the BERTopic model | |
def viewer_agent(state): | |
review = state["messages"][-1].content | |
index = topic_model.transform([review])[0][0] | |
answer = topic_model.topic_labels_.get(index) | |
return {"messages": [HumanMessage(content=answer, name="Viewer")]} | |
# Define AgentState | |
class AgentState(TypedDict): | |
messages: Annotated[Sequence[BaseMessage], operator.add] | |
next: str | |
# Create LLM for the supervisor | |
llm = ChatOpenAI(model="gpt-4-1106-preview") | |
# Define the system prompt for the supervisor | |
system_prompt = ( | |
"You are a supervisor tasked with managing a conversation between the following workers: Researcher, Coder, Viewer. " | |
"Given the following user request, respond with the worker to act next. Each worker will perform a task and respond with their results and status. " | |
"When finished, respond with FINISH. If the request seems like a product review or sentiment analysis, route it to the Viewer." | |
) | |
# Define options | |
options = ["FINISH", "Researcher", "Coder", "Viewer"] | |
# Define the function for routing | |
function_def = { | |
"name": "route", | |
"description": "Select the next role.", | |
"parameters": { | |
"title": "routeSchema", | |
"type": "object", | |
"properties": { | |
"next": { | |
"title": "Next", | |
"anyOf": [ | |
{"enum": options}, | |
], | |
} | |
}, | |
"required": ["next"], | |
}, | |
} | |
# Define the prompt for the supervisor | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system_prompt), | |
MessagesPlaceholder(variable_name="messages"), | |
( | |
"system", | |
"Given the conversation above, who should act next? Or should we FINISH? Select one of: {options}", | |
), | |
] | |
).partial(options=str(options), members="Researcher, Coder, Viewer") | |
# Create the supervisor chain | |
supervisor_chain = ( | |
prompt | |
| llm.bind_functions(functions=[function_def], function_call="route") | |
| JsonOutputFunctionsParser() | |
) | |
# Define agents | |
research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.") | |
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher") | |
code_agent = create_agent( | |
llm, | |
[python_repl_tool], | |
"You may generate safe python code to analyze data and generate charts using matplotlib.", | |
) | |
code_node = functools.partial(agent_node, agent=code_agent, name="Coder") | |
viewer_node = functools.partial(viewer_agent) | |
# Create the workflow | |
workflow = StateGraph(AgentState) | |
workflow.add_node("Researcher", research_node) | |
workflow.add_node("Coder", code_node) | |
workflow.add_node("Viewer", viewer_node) | |
workflow.add_node("supervisor", supervisor_chain) | |
# Add edges for each agent to report back to the supervisor | |
members = ["Researcher", "Coder", "Viewer"] | |
for member in members: | |
workflow.add_edge(member, "supervisor") | |
# Add conditional edges | |
conditional_map = {k: k for k in members} | |
conditional_map["FINISH"] = END | |
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map) | |
# Set the entry point | |
workflow.set_entry_point("supervisor") | |
# Compile the workflow | |
graph = workflow.compile() | |
# Testing the workflow | |
for s in graph.stream( | |
{ | |
"messages": [ | |
HumanMessage(content="write a report of gopal who worked in 3 k technologies") | |
] | |
} | |
): | |
if "__end__" not in s: | |
print(s) | |
print("----") | |