wishwakankanamg commited on
Commit
c3defd0
·
1 Parent(s): c9f0eb5

validator for messages

Browse files
Files changed (5) hide show
  1. __pycache__/graph.cpython-310.pyc +0 -0
  2. app.log +0 -0
  3. graph.png +0 -0
  4. graph.py +40 -62
  5. oldgraph.py +264 -0
__pycache__/graph.cpython-310.pyc CHANGED
Binary files a/__pycache__/graph.cpython-310.pyc and b/__pycache__/graph.cpython-310.pyc differ
 
app.log CHANGED
The diff for this file is too large to render. See raw diff
 
graph.png CHANGED
graph.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  import os
3
  from typing import Annotated
 
4
 
5
  import aiohttp
6
  from langchain_core.messages import AnyMessage
@@ -8,13 +9,15 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
8
  from langchain_core.tools import tool
9
  from langgraph.graph.state import CompiledStateGraph
10
  from langgraph.prebuilt import ToolNode, tools_condition
11
- from langgraph.graph import StateGraph, END, add_messages
12
  from langchain_community.tools import TavilySearchResults
13
  from pydantic import BaseModel, Field
14
  from trafilatura import extract
15
  from langchain_anthropic import ChatAnthropic
16
  from langgraph.prebuilt import ToolNode
17
  from langgraph.checkpoint.memory import MemorySaver
 
 
18
 
19
 
20
 
@@ -43,6 +46,11 @@ def evaluate_idea_completion(response) -> bool:
43
 
44
  return all(keyword in response_text for keyword in required_keywords)
45
 
 
 
 
 
 
46
 
47
  @tool
48
  async def download_website_text(url: str) -> str:
@@ -58,7 +66,7 @@ async def download_website_text(url: str) -> str:
58
  logger.error(f"Failed to download {url}: {str(e)}")
59
  return f"Error retrieving website content: {str(e)}"
60
 
61
- tools = [download_website_text]
62
  memory = MemorySaver()
63
 
64
 
@@ -92,42 +100,19 @@ class GraphProcessingState(BaseModel):
92
  search_enabled: bool = Field(default=True, description="Whether to enable search tools")
93
  idea_complete: bool = Field(default=False)
94
 
95
-
96
- async def planning_node(state: GraphProcessingState, config=None):
97
- # Define the system prompt for planning
98
- planning_prompt = "Based on the user's idea, create a detailed step-by-step plan to build the DIY product."
99
-
100
- # Combine the planning prompt with any existing prompts
101
- if state.prompt:
102
- final_prompt = "\n".join([planning_prompt, state.prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
103
- else:
104
- final_prompt = "\n".join([planning_prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
105
-
106
- # Create the prompt template
107
- prompt = ChatPromptTemplate.from_messages(
108
- [
109
- ("system", final_prompt),
110
- MessagesPlaceholder(variable_name="messages"),
111
  ]
112
- )
113
 
114
- # Bind tools if necessary
115
- assistant_tools = []
116
- if state.tools_enabled.get("download_website_text", True):
117
- assistant_tools.append(download_website_text)
118
- if search_enabled and state.tools_enabled.get("tavily_search_results_json", True):
119
- assistant_tools.append(tavily_search_tool)
120
- assistant_model = model.bind_tools(assistant_tools)
121
 
122
- # Create the chain and invoke it
123
- chain = prompt | assistant_model
124
- response = await chain.ainvoke({"messages": state.messages}, config=config)
125
 
126
- return {
127
- "messages": response
128
- }
129
 
130
- async def assistant_node(state: GraphProcessingState, config=None):
 
131
  assistant_tools = []
132
  if state.tools_enabled.get("download_website_text", True):
133
  assistant_tools.append(download_website_text)
@@ -146,8 +131,16 @@ async def assistant_node(state: GraphProcessingState, config=None):
146
  ]
147
  )
148
  chain = prompt | assistant_model
149
- response = await chain.ainvoke({"messages": state.messages}, config=config)
150
 
 
 
 
 
 
 
 
 
 
151
  idea_complete = evaluate_idea_completion(response)
152
 
153
  return {
@@ -155,23 +148,6 @@ async def assistant_node(state: GraphProcessingState, config=None):
155
  "idea_complete": idea_complete
156
  }
157
 
158
- # def assistant_cond_edge(state: GraphProcessingState):
159
- # last_message = state.messages[-1]
160
- # if hasattr(last_message, "tool_calls") and last_message.tool_calls:
161
- # logger.info(f"Tool call detected: {last_message.tool_calls}")
162
- # return "tools"
163
- # return END
164
- def assistant_routing(state: GraphProcessingState) -> str:
165
- last_message = state.messages[-1]
166
- if hasattr(last_message, "tool_calls") and last_message.tool_calls:
167
- logger.info("Tool call detected. Routing to 'tools' node.")
168
- return "tools"
169
- elif state.idea_complete:
170
- logger.info("Idea is complete. Routing to 'planning_node'.")
171
- return "planning_node"
172
- else:
173
- logger.info("Idea is incomplete. Routing back to 'assistant_node'.")
174
- return "assistant_node"
175
 
176
 
177
 
@@ -181,20 +157,19 @@ def define_workflow() -> CompiledStateGraph:
181
  workflow = StateGraph(GraphProcessingState)
182
 
183
  # Add nodes
184
- workflow.add_node("assistant_node", assistant_node)
185
  workflow.add_node("tools", ToolNode(tools))
186
- workflow.add_node("planning_node", planning_node)
187
 
188
- # Edges
189
- workflow.add_edge("tools", "assistant_node")
190
- workflow.add_edge("planning_node", "assistant_node")
 
 
 
191
 
192
- # Conditional routing
193
- workflow.add_conditional_edges("assistant_node", assistant_routing)
194
 
195
- # Set end nodes
196
- workflow.set_entry_point("assistant_node")
197
- # workflow.set_finish_point("assistant_node")
198
  compiled_graph = workflow.compile(checkpointer=memory)
199
  try:
200
  img_bytes = compiled_graph.get_graph().draw_mermaid_png()
@@ -208,4 +183,7 @@ def define_workflow() -> CompiledStateGraph:
208
 
209
  return compiled_graph
210
 
211
- graph = define_workflow()
 
 
 
 
1
  import logging
2
  import os
3
  from typing import Annotated
4
+ from typing_extensions import TypedDict
5
 
6
  import aiohttp
7
  from langchain_core.messages import AnyMessage
 
9
  from langchain_core.tools import tool
10
  from langgraph.graph.state import CompiledStateGraph
11
  from langgraph.prebuilt import ToolNode, tools_condition
12
+ from langgraph.graph import StateGraph, START, END, add_messages
13
  from langchain_community.tools import TavilySearchResults
14
  from pydantic import BaseModel, Field
15
  from trafilatura import extract
16
  from langchain_anthropic import ChatAnthropic
17
  from langgraph.prebuilt import ToolNode
18
  from langgraph.checkpoint.memory import MemorySaver
19
+ from langgraph.types import Command, interrupt
20
+
21
 
22
 
23
 
 
46
 
47
  return all(keyword in response_text for keyword in required_keywords)
48
 
49
+ @tool
50
+ def human_assistance(query: str) -> str:
51
+ """Request assistance from a human."""
52
+ human_response = interrupt({"query": query})
53
+ return human_response["data"]
54
 
55
  @tool
56
  async def download_website_text(url: str) -> str:
 
66
  logger.error(f"Failed to download {url}: {str(e)}")
67
  return f"Error retrieving website content: {str(e)}"
68
 
69
+ tools = [download_website_text, human_assistance]
70
  memory = MemorySaver()
71
 
72
 
 
100
  search_enabled: bool = Field(default=True, description="Whether to enable search tools")
101
  idea_complete: bool = Field(default=False)
102
 
103
+ @root_validator
104
+ def remove_empty_messages(cls, values):
105
+ messages = values.get("messages", [])
106
+ values["messages"] = [
107
+ msg for msg in messages if getattr(msg, "content", "").strip()
 
 
 
 
 
 
 
 
 
 
 
108
  ]
109
+ return values
110
 
 
 
 
 
 
 
 
111
 
 
 
 
112
 
 
 
 
113
 
114
+
115
+ async def chatbot(state: GraphProcessingState, config=None):
116
  assistant_tools = []
117
  if state.tools_enabled.get("download_website_text", True):
118
  assistant_tools.append(download_website_text)
 
131
  ]
132
  )
133
  chain = prompt | assistant_model
 
134
 
135
+ valid_messages = [msg for msg in state.messages if getattr(msg, "content", "").strip()]
136
+
137
+ response = await chain.ainvoke({"messages": valid_messages}, config=config)
138
+
139
+ # message = llm_with_tools.invoke(state["messages"])
140
+ # Because we will be interrupting during tool execution,
141
+ # we disable parallel tool calling to avoid repeating any
142
+ # tool invocations when we resume.
143
+ assert len(response.tool_calls) <= 1
144
  idea_complete = evaluate_idea_completion(response)
145
 
146
  return {
 
148
  "idea_complete": idea_complete
149
  }
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
 
 
157
  workflow = StateGraph(GraphProcessingState)
158
 
159
  # Add nodes
 
160
  workflow.add_node("tools", ToolNode(tools))
161
+ workflow.add_node("chatbot", chatbot)
162
 
163
+
164
+
165
+ workflow.add_conditional_edges(
166
+ "chatbot",
167
+ tools_condition,
168
+ )
169
 
170
+ workflow.add_edge("tools", "chatbot")
171
+ workflow.add_edge(START, "chatbot")
172
 
 
 
 
173
  compiled_graph = workflow.compile(checkpointer=memory)
174
  try:
175
  img_bytes = compiled_graph.get_graph().draw_mermaid_png()
 
183
 
184
  return compiled_graph
185
 
186
+ graph = define_workflow()
187
+
188
+
189
+
oldgraph.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Annotated
4
+ from typing_extensions import TypedDict
5
+
6
+ import aiohttp
7
+ from langchain_core.messages import AnyMessage
8
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
9
+ from langchain_core.tools import tool
10
+ from langgraph.graph.state import CompiledStateGraph
11
+ from langgraph.prebuilt import ToolNode, tools_condition
12
+ from langgraph.graph import StateGraph, START, END, add_messages
13
+ from langchain_community.tools import TavilySearchResults
14
+ from pydantic import BaseModel, Field
15
+ from trafilatura import extract
16
+ from langchain_anthropic import ChatAnthropic
17
+ from langgraph.prebuilt import ToolNode
18
+ from langgraph.checkpoint.memory import MemorySaver
19
+ from langgraph.types import Command, interrupt
20
+
21
+ class State(TypedDict):
22
+ messages: Annotated[list, add_messages]
23
+
24
+
25
+ logger = logging.getLogger(__name__)
26
+ ASSISTANT_SYSTEM_PROMPT_BASE = """"""
27
+ search_enabled = bool(os.environ.get("TAVILY_API_KEY"))
28
+
29
+ def evaluate_idea_completion(response) -> bool:
30
+ """
31
+ Evaluates whether the assistant's response indicates a complete DIY project idea.
32
+ You can customize the logic based on your specific criteria.
33
+ """
34
+ # Example logic: Check if the response contains certain keywords
35
+ required_keywords = ["materials", "dimensions", "tools", "steps"]
36
+
37
+ # Determine the type of response and extract text accordingly
38
+ if isinstance(response, dict):
39
+ # If response is a dictionary, extract values and join them into a single string
40
+ response_text = ' '.join(str(value).lower() for value in response.values())
41
+ elif isinstance(response, str):
42
+ # If response is a string, convert it to lowercase
43
+ response_text = response.lower()
44
+ else:
45
+ # If response is of an unexpected type, convert it to string and lowercase
46
+ response_text = str(response).lower()
47
+
48
+ return all(keyword in response_text for keyword in required_keywords)
49
+
50
+ @tool
51
+ def human_assistance(query: str) -> str:
52
+ """Request assistance from a human."""
53
+ human_response = interrupt({"query": query})
54
+ return human_response["data"]
55
+
56
+ @tool
57
+ async def download_website_text(url: str) -> str:
58
+ """Download the text from a website"""
59
+ try:
60
+ async with aiohttp.ClientSession() as session:
61
+ async with session.get(url) as response:
62
+ response.raise_for_status()
63
+ downloaded = await response.text()
64
+ result = extract(downloaded, include_formatting=True, include_links=True, output_format='json', with_metadata=True)
65
+ return result or "No text found on the website"
66
+ except Exception as e:
67
+ logger.error(f"Failed to download {url}: {str(e)}")
68
+ return f"Error retrieving website content: {str(e)}"
69
+
70
+ tools = [download_website_text, human_assistance]
71
+ memory = MemorySaver()
72
+
73
+
74
+ if search_enabled:
75
+ tavily_search_tool = TavilySearchResults(
76
+ max_results=5,
77
+ search_depth="advanced",
78
+ include_answer=True,
79
+ include_raw_content=True,
80
+ )
81
+ tools.append(tavily_search_tool)
82
+ else:
83
+ print("TAVILY_API_KEY environment variable not found. Websearch disabled")
84
+
85
+ weak_model = ChatAnthropic(
86
+ model="claude-3-5-sonnet-20240620",
87
+ temperature=0,
88
+ max_tokens=1024,
89
+ timeout=None,
90
+ max_retries=2,
91
+ # other params...
92
+ )
93
+ model = weak_model
94
+ assistant_model = weak_model
95
+
96
+ class GraphProcessingState(BaseModel):
97
+ # user_input: str = Field(default_factory=str, description="The original user input")
98
+ messages: Annotated[list[AnyMessage], add_messages] = Field(default_factory=list)
99
+ prompt: str = Field(default_factory=str, description="The prompt to be used for the model")
100
+ tools_enabled: dict = Field(default_factory=dict, description="The tools enabled for the assistant")
101
+ search_enabled: bool = Field(default=True, description="Whether to enable search tools")
102
+ idea_complete: bool = Field(default=False)
103
+
104
+
105
+ async def planning_node(state: GraphProcessingState, config=None):
106
+ # Define the system prompt for planning
107
+ planning_prompt = "Based on the user's idea, create a detailed step-by-step plan to build the DIY product."
108
+
109
+ # Combine the planning prompt with any existing prompts
110
+ if state.prompt:
111
+ final_prompt = "\n".join([planning_prompt, state.prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
112
+ else:
113
+ final_prompt = "\n".join([planning_prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
114
+
115
+ # Create the prompt template
116
+ prompt = ChatPromptTemplate.from_messages(
117
+ [
118
+ ("system", final_prompt),
119
+ MessagesPlaceholder(variable_name="messages"),
120
+ ]
121
+ )
122
+
123
+ # Bind tools if necessary
124
+ assistant_tools = []
125
+ if state.tools_enabled.get("download_website_text", True):
126
+ assistant_tools.append(download_website_text)
127
+ if search_enabled and state.tools_enabled.get("tavily_search_results_json", True):
128
+ assistant_tools.append(tavily_search_tool)
129
+ assistant_model = model.bind_tools(assistant_tools)
130
+
131
+ # Create the chain and invoke it
132
+ chain = prompt | assistant_model
133
+ response = await chain.ainvoke({"messages": state.messages}, config=config)
134
+
135
+ return {
136
+ "messages": response
137
+ }
138
+
139
+ async def assistant_node(state: GraphProcessingState, config=None):
140
+ assistant_tools = []
141
+ if state.tools_enabled.get("download_website_text", True):
142
+ assistant_tools.append(download_website_text)
143
+ if search_enabled and state.tools_enabled.get("tavily_search_results_json", True):
144
+ assistant_tools.append(tavily_search_tool)
145
+ assistant_model = model.bind_tools(assistant_tools)
146
+ if state.prompt:
147
+ final_prompt = "\n".join([state.prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
148
+ else:
149
+ final_prompt = ASSISTANT_SYSTEM_PROMPT_BASE
150
+
151
+ prompt = ChatPromptTemplate.from_messages(
152
+ [
153
+ ("system", final_prompt),
154
+ MessagesPlaceholder(variable_name="messages"),
155
+ ]
156
+ )
157
+ chain = prompt | assistant_model
158
+ response = await chain.ainvoke({"messages": state.messages}, config=config)
159
+
160
+ idea_complete = evaluate_idea_completion(response)
161
+
162
+ return {
163
+ "messages": response,
164
+ "idea_complete": idea_complete
165
+ }
166
+
167
+ async def chatbot(state: GraphProcessingState, config=None):
168
+ assistant_tools = []
169
+ if state.tools_enabled.get("download_website_text", True):
170
+ assistant_tools.append(download_website_text)
171
+ if search_enabled and state.tools_enabled.get("tavily_search_results_json", True):
172
+ assistant_tools.append(tavily_search_tool)
173
+ assistant_model = model.bind_tools(assistant_tools)
174
+ if state.prompt:
175
+ final_prompt = "\n".join([state.prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
176
+ else:
177
+ final_prompt = ASSISTANT_SYSTEM_PROMPT_BASE
178
+
179
+ prompt = ChatPromptTemplate.from_messages(
180
+ [
181
+ ("system", final_prompt),
182
+ MessagesPlaceholder(variable_name="messages"),
183
+ ]
184
+ )
185
+ chain = prompt | assistant_model
186
+
187
+ response = await chain.ainvoke({"messages": state.messages}, config=config)
188
+
189
+ # message = llm_with_tools.invoke(state["messages"])
190
+ # Because we will be interrupting during tool execution,
191
+ # we disable parallel tool calling to avoid repeating any
192
+ # tool invocations when we resume.
193
+ assert len(response.tool_calls) <= 1
194
+ idea_complete = evaluate_idea_completion(response)
195
+
196
+ return {
197
+ "messages": response,
198
+ "idea_complete": idea_complete
199
+ }
200
+
201
+ # def assistant_cond_edge(state: GraphProcessingState):
202
+ # last_message = state.messages[-1]
203
+ # if hasattr(last_message, "tool_calls") and last_message.tool_calls:
204
+ # logger.info(f"Tool call detected: {last_message.tool_calls}")
205
+ # return "tools"
206
+ # return END
207
+ def assistant_routing(state: GraphProcessingState) -> str:
208
+ last_message = state.messages[-1]
209
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
210
+ logger.info("Tool call detected. Routing to 'tools' node.")
211
+ return "tools"
212
+ elif state.idea_complete:
213
+ logger.info("Idea is complete. Routing to 'planning_node'.")
214
+ return "planning_node"
215
+ else:
216
+ logger.info("Idea is incomplete. Routing back to 'assistant_node'.")
217
+ return "assistant_node"
218
+
219
+
220
+
221
+ def define_workflow() -> CompiledStateGraph:
222
+ """Defines the workflow graph"""
223
+ # Initialize the graph
224
+ workflow = StateGraph(GraphProcessingState)
225
+
226
+ # Add nodes
227
+ # workflow.add_node("assistant_node", assistant_node)
228
+ workflow.add_node("tools", ToolNode(tools))
229
+ # workflow.add_node("planning_node", planning_node)
230
+ workflow.add_node("chatbot", chatbot)
231
+
232
+ # Edges
233
+ # workflow.add_edge("tools", "assistant_node")
234
+ # workflow.add_edge("planning_node", "assistant_node")
235
+
236
+
237
+ workflow.add_conditional_edges(
238
+ "chatbot",
239
+ tools_condition,
240
+ )
241
+
242
+ workflow.add_edge("tools", "chatbot")
243
+ workflow.add_edge(START, "chatbot")
244
+
245
+ # Conditional routing
246
+ # workflow.add_conditional_edges("assistant_node", assistant_routing)
247
+
248
+ # # Set end nodes
249
+ # workflow.set_entry_point("assistant_node")
250
+ # workflow.set_finish_point("assistant_node")
251
+ compiled_graph = workflow.compile(checkpointer=memory)
252
+ try:
253
+ img_bytes = compiled_graph.get_graph().draw_mermaid_png()
254
+ with open("graph.png", "wb") as f:
255
+ f.write(img_bytes)
256
+ print("Graph image saved as graph.png")
257
+ except Exception as e:
258
+ print("Can't print the graph:")
259
+ print(e)
260
+
261
+
262
+ return compiled_graph
263
+
264
+ graph = define_workflow()