RAG / app.py
Rohit108's picture
Upload app.py
6d35820 verified
raw
history blame
5.05 kB
from typing import Annotated, Sequence, TypedDict
import functools
import operator
from bertopic import BERTopic
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
from langgraph.graph import END, StateGraph
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_experimental.tools import PythonREPLTool
# Initialize tools
tavily_tool = TavilySearchResults(max_results=5)
python_repl_tool = PythonREPLTool()
# Load BERTopic model
topic_model = BERTopic.load("topic_model")
# Function to create an agent
def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str):
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
agent = create_openai_tools_agent(llm, tools, prompt)
executor = AgentExecutor(agent=agent, tools=tools)
return executor
# Function to define an agent node
def agent_node(state, agent, name):
result = agent.invoke(state)
return {"messages": [HumanMessage(content=result["output"], name=name)]}
# Define the Viewer agent using the BERTopic model
def viewer_agent(state):
review = state["messages"][-1].content
index = topic_model.transform([review])[0][0]
answer = topic_model.topic_labels_.get(index)
return {"messages": [HumanMessage(content=answer, name="Viewer")]}
# Define AgentState
class AgentState(TypedDict):
messages: Annotated[Sequence[BaseMessage], operator.add]
next: str
# Create LLM for the supervisor
llm = ChatOpenAI(model="gpt-4-1106-preview")
# Define the system prompt for the supervisor
system_prompt = (
"You are a supervisor tasked with managing a conversation between the following workers: Researcher, Coder, Viewer. "
"Given the following user request, respond with the worker to act next. Each worker will perform a task and respond with their results and status. "
"When finished, respond with FINISH. If the request seems like a product review or sentiment analysis, route it to the Viewer."
)
# Define options
options = ["FINISH", "Researcher", "Coder", "Viewer"]
# Define the function for routing
function_def = {
"name": "route",
"description": "Select the next role.",
"parameters": {
"title": "routeSchema",
"type": "object",
"properties": {
"next": {
"title": "Next",
"anyOf": [
{"enum": options},
],
}
},
"required": ["next"],
},
}
# Define the prompt for the supervisor
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="messages"),
(
"system",
"Given the conversation above, who should act next? Or should we FINISH? Select one of: {options}",
),
]
).partial(options=str(options), members="Researcher, Coder, Viewer")
# Create the supervisor chain
supervisor_chain = (
prompt
| llm.bind_functions(functions=[function_def], function_call="route")
| JsonOutputFunctionsParser()
)
# Define agents
research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.")
research_node = functools.partial(agent_node, agent=research_agent, name="Researcher")
code_agent = create_agent(
llm,
[python_repl_tool],
"You may generate safe python code to analyze data and generate charts using matplotlib.",
)
code_node = functools.partial(agent_node, agent=code_agent, name="Coder")
viewer_node = functools.partial(viewer_agent)
# Create the workflow
workflow = StateGraph(AgentState)
workflow.add_node("Researcher", research_node)
workflow.add_node("Coder", code_node)
workflow.add_node("Viewer", viewer_node)
workflow.add_node("supervisor", supervisor_chain)
# Add edges for each agent to report back to the supervisor
members = ["Researcher", "Coder", "Viewer"]
for member in members:
workflow.add_edge(member, "supervisor")
# Add conditional edges
conditional_map = {k: k for k in members}
conditional_map["FINISH"] = END
workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map)
# Set the entry point
workflow.set_entry_point("supervisor")
# Compile the workflow
graph = workflow.compile()
# Testing the workflow
for s in graph.stream(
{
"messages": [
HumanMessage(content="write a report of gopal who worked in 3 k technologies")
]
}
):
if "__end__" not in s:
print(s)
print("----")