wishwakankanamg commited on
Commit
9b58257
·
1 Parent(s): d271024

almost there

Browse files
Files changed (5) hide show
  1. __pycache__/graph.cpython-310.pyc +0 -0
  2. app.log +0 -0
  3. graph.png +0 -0
  4. graph.py +75 -20
  5. oldgraph.py +0 -208
__pycache__/graph.cpython-310.pyc CHANGED
Binary files a/__pycache__/graph.cpython-310.pyc and b/__pycache__/graph.cpython-310.pyc differ
 
app.log CHANGED
The diff for this file is too large to render. See raw diff
 
graph.png CHANGED
graph.py CHANGED
@@ -101,7 +101,15 @@ class GraphProcessingState(BaseModel):
101
  prompt: str = Field(default_factory=str, description="The prompt to be used for the model")
102
  tools_enabled: dict = Field(default_factory=dict, description="The tools enabled for the assistant")
103
  search_enabled: bool = Field(default=True, description="Whether to enable search tools")
 
 
 
104
  idea_complete: bool = Field(default=False)
 
 
 
 
 
105
 
106
 
107
  async def planning_node(state: GraphProcessingState, config=None):
@@ -166,38 +174,63 @@ async def assistant_node(state: GraphProcessingState, config=None):
166
  "idea_complete": idea_complete
167
  }
168
 
169
- async def chatbot(state: GraphProcessingState, config=None):
170
- assistant_tools = []
171
- if state.tools_enabled.get("download_website_text", True):
172
- assistant_tools.append(download_website_text)
173
- if search_enabled and state.tools_enabled.get("tavily_search_results_json", True):
174
- assistant_tools.append(tavily_search_tool)
175
- assistant_model = model.bind_tools(assistant_tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  if state.prompt:
177
- final_prompt = "\n".join([state.prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
178
  else:
179
- final_prompt = ASSISTANT_SYSTEM_PROMPT_BASE
180
 
 
181
  prompt = ChatPromptTemplate.from_messages(
182
  [
183
  ("system", final_prompt),
184
  MessagesPlaceholder(variable_name="messages"),
185
  ]
186
  )
 
 
 
 
 
187
  chain = prompt | assistant_model
188
 
189
- response = await chain.ainvoke({"messages": state.messages}, config=config)
190
 
191
- # message = llm_with_tools.invoke(state["messages"])
192
- # Because we will be interrupting during tool execution,
193
- # we disable parallel tool calling to avoid repeating any
194
- # tool invocations when we resume.
195
- assert len(response.tool_calls) <= 1
196
- idea_complete = evaluate_idea_completion(response)
197
 
198
  return {
199
  "messages": response,
200
- "idea_complete": idea_complete
201
  }
202
 
203
  # def assistant_cond_edge(state: GraphProcessingState):
@@ -218,6 +251,22 @@ def assistant_routing(state: GraphProcessingState) -> str:
218
  logger.info("Idea is incomplete. Routing back to 'assistant_node'.")
219
  return "assistant_node"
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
 
223
  def define_workflow() -> CompiledStateGraph:
@@ -229,11 +278,16 @@ def define_workflow() -> CompiledStateGraph:
229
  workflow.add_node("assistant_node", assistant_node)
230
  workflow.add_node("tools", ToolNode(tools))
231
  workflow.add_node("planning_node", planning_node)
 
 
232
  # workflow.add_node("chatbot", chatbot)
233
 
234
  # Edges
235
- workflow.add_edge("tools", "assistant_node")
236
- workflow.add_edge("planning_node", "assistant_node")
 
 
 
237
 
238
 
239
  # workflow.add_conditional_edges(
@@ -245,7 +299,8 @@ def define_workflow() -> CompiledStateGraph:
245
  # workflow.add_edge(START, "chatbot")
246
 
247
  # Conditional routing
248
- workflow.add_conditional_edges("assistant_node", assistant_routing)
 
249
 
250
  # # Set end nodes
251
  workflow.set_entry_point("assistant_node")
 
101
  prompt: str = Field(default_factory=str, description="The prompt to be used for the model")
102
  tools_enabled: dict = Field(default_factory=dict, description="The tools enabled for the assistant")
103
  search_enabled: bool = Field(default=True, description="Whether to enable search tools")
104
+ next_stage: str = Field(default="", description="The next stage to execute, decided by the guidance node.")
105
+
106
+ # Completion flags for each stage
107
  idea_complete: bool = Field(default=False)
108
+ brainstorming_complete: bool = Field(default=False)
109
+ planning_complete: bool = Field(default=False)
110
+ drawing_complete: bool = Field(default=False)
111
+ product_searching_complete: bool = Field(default=False)
112
+ purchasing_complete: bool = Field(default=False)
113
 
114
 
115
  async def planning_node(state: GraphProcessingState, config=None):
 
174
  "idea_complete": idea_complete
175
  }
176
 
177
+ # message = llm_with_tools.invoke(state["messages"])
178
+ # Because we will be interrupting during tool execution,
179
+ # we disable parallel tool calling to avoid repeating any
180
+ # tool invocations when we resume.
181
+ assert len(response.tool_calls) <= 1
182
+ idea_complete = evaluate_idea_completion(response)
183
+
184
+ return {
185
+ "messages": response,
186
+ "idea_complete": idea_complete
187
+ }
188
+
189
+ async def guidance_node(state: GraphProcessingState, config=None):
190
+ # Define the guiding system prompt
191
+
192
+ # Prepare context: stage completion statuses
193
+ stage_order = ["brainstorming", "planning", "drawing", "product_searching", "purchasing"]
194
+ completed = [stage for stage in stage_order if getattr(state, f"{stage}_complete", False)]
195
+ incomplete = [stage for stage in stage_order if not getattr(state, f"{stage}_complete", False)]
196
+
197
+ status_summary = f"Completed stages: {completed}\nIncomplete stages: {incomplete}"
198
+
199
+ guidance_prompt = (
200
+ "You are the guiding assistant. Based on the user's input and the current project status, "
201
+ "help decide the next stage in the DIY process. Ask the user which stage to continue next if needed.\n\n"
202
+ f"CURRENT STATUS:\n{status_summary}\n\n" # <-- The information was moved here
203
+ "Available stages: brainstorming, planning, drawing, product_searching, purchasing."
204
+ )
205
+
206
+ # Build final prompt with base and current prompt
207
  if state.prompt:
208
+ final_prompt = "\n".join([guidance_prompt, state.prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
209
  else:
210
+ final_prompt = "\n".join([guidance_prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
211
 
212
+ # Create prompt template
213
  prompt = ChatPromptTemplate.from_messages(
214
  [
215
  ("system", final_prompt),
216
  MessagesPlaceholder(variable_name="messages"),
217
  ]
218
  )
219
+
220
+ # Bind tools if needed (minimal here, mostly guidance)
221
+ assistant_model = model.bind_tools([])
222
+
223
+ # Create the chain
224
  chain = prompt | assistant_model
225
 
226
+
227
 
228
+ # Get response from assistant
229
+ response = await chain.ainvoke({"messages": state.messages}, config=config)
 
 
 
 
230
 
231
  return {
232
  "messages": response,
233
+ "next_stage": incomplete[0] if incomplete else None
234
  }
235
 
236
  # def assistant_cond_edge(state: GraphProcessingState):
 
251
  logger.info("Idea is incomplete. Routing back to 'assistant_node'.")
252
  return "assistant_node"
253
 
254
+ def guidance_routing(state: GraphProcessingState) -> str:
255
+ next_stage = state.get("next_stage")
256
+
257
+ if next_stage == "brainstorming":
258
+ return "assistant_node"
259
+ elif next_stage == "planning":
260
+ return "planning_node"
261
+ elif next_stage == "drawing":
262
+ return "tools"
263
+ elif next_stage == "product_searching":
264
+ return "assistant_node"
265
+ elif next_stage == "purchasing":
266
+ return "assistant_node"
267
+
268
+ return END
269
+
270
 
271
 
272
  def define_workflow() -> CompiledStateGraph:
 
278
  workflow.add_node("assistant_node", assistant_node)
279
  workflow.add_node("tools", ToolNode(tools))
280
  workflow.add_node("planning_node", planning_node)
281
+ workflow.add_node("guidance_node", guidance_node)
282
+
283
  # workflow.add_node("chatbot", chatbot)
284
 
285
  # Edges
286
+ workflow.add_edge("tools", "guidance_node")
287
+ # workflow.add_edge("planning_node", "assistant_node")
288
+ workflow.add_edge("planning_node", "guidance_node")
289
+ workflow.add_edge("assistant_node", "guidance_node")
290
+
291
 
292
 
293
  # workflow.add_conditional_edges(
 
299
  # workflow.add_edge(START, "chatbot")
300
 
301
  # Conditional routing
302
+ # workflow.add_conditional_edges("assistant_node", assistant_routing)
303
+ workflow.add_conditional_edges("guidance_node", guidance_routing)
304
 
305
  # # Set end nodes
306
  workflow.set_entry_point("assistant_node")
oldgraph.py DELETED
@@ -1,208 +0,0 @@
1
- import logging
2
- import os
3
- from typing import Annotated
4
- from typing_extensions import TypedDict
5
- from pydantic import model_validator
6
-
7
- import aiohttp
8
- from langchain_core.messages import AnyMessage
9
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
10
- from langchain_core.tools import tool
11
- from langgraph.graph.state import CompiledStateGraph
12
- from langgraph.prebuilt import ToolNode, tools_condition
13
- from langgraph.graph import StateGraph, START, END, add_messages
14
- from langchain_community.tools import TavilySearchResults
15
- from pydantic import BaseModel, Field
16
- from trafilatura import extract
17
- from langchain_anthropic import ChatAnthropic
18
- from langgraph.prebuilt import ToolNode
19
- from langgraph.checkpoint.memory import MemorySaver
20
- from langgraph.types import Command, interrupt
21
-
22
-
23
-
24
-
25
- logger = logging.getLogger(__name__)
26
- ASSISTANT_SYSTEM_PROMPT_BASE = """"""
27
- search_enabled = bool(os.environ.get("TAVILY_API_KEY"))
28
-
29
- def evaluate_idea_completion(response) -> bool:
30
- """
31
- Evaluates whether the assistant's response indicates a complete DIY project idea.
32
- You can customize the logic based on your specific criteria.
33
- """
34
- # Example logic: Check if the response contains certain keywords
35
- required_keywords = ["materials", "dimensions", "tools", "steps"]
36
-
37
- # Determine the type of response and extract text accordingly
38
- if isinstance(response, dict):
39
- # If response is a dictionary, extract values and join them into a single string
40
- response_text = ' '.join(str(value).lower() for value in response.values())
41
- elif isinstance(response, str):
42
- # If response is a string, convert it to lowercase
43
- response_text = response.lower()
44
- else:
45
- # If response is of an unexpected type, convert it to string and lowercase
46
- response_text = str(response).lower()
47
-
48
- return all(keyword in response_text for keyword in required_keywords)
49
-
50
- @tool
51
- def human_assistance(query: str) -> str:
52
- """Request assistance from a human."""
53
- human_response = interrupt({"query": query})
54
- return human_response["data"]
55
-
56
- @tool
57
- async def download_website_text(url: str) -> str:
58
- """Download the text from a website"""
59
- try:
60
- async with aiohttp.ClientSession() as session:
61
- async with session.get(url) as response:
62
- response.raise_for_status()
63
- downloaded = await response.text()
64
- result = extract(downloaded, include_formatting=True, include_links=True, output_format='json', with_metadata=True)
65
- return result or "No text found on the website"
66
- except Exception as e:
67
- logger.error(f"Failed to download {url}: {str(e)}")
68
- return f"Error retrieving website content: {str(e)}"
69
-
70
- tools = [download_website_text, human_assistance]
71
- memory = MemorySaver()
72
-
73
-
74
- if search_enabled:
75
- tavily_search_tool = TavilySearchResults(
76
- max_results=5,
77
- search_depth="advanced",
78
- include_answer=True,
79
- include_raw_content=True,
80
- )
81
- tools.append(tavily_search_tool)
82
- else:
83
- print("TAVILY_API_KEY environment variable not found. Websearch disabled")
84
-
85
- weak_model = ChatAnthropic(
86
- model="claude-3-5-sonnet-20240620",
87
- temperature=0,
88
- max_tokens=1024,
89
- timeout=None,
90
- max_retries=2,
91
- # other params...
92
- )
93
- model = weak_model
94
- assistant_model = weak_model
95
-
96
- class GraphProcessingState(BaseModel):
97
- # user_input: str = Field(default_factory=str, description="The original user input")
98
- messages: Annotated[list[AnyMessage], add_messages] = Field(default_factory=list)
99
- prompt: str = Field(default_factory=str, description="The prompt to be used for the model")
100
- tools_enabled: dict = Field(default_factory=dict, description="The tools enabled for the assistant")
101
- search_enabled: bool = Field(default=True, description="Whether to enable search tools")
102
- idea_complete: bool = Field(default=False)
103
-
104
- @model_validator(mode="after")
105
- def remove_empty_messages(self):
106
- """
107
- Filters out messages with empty content from the messages list.
108
- """
109
- if self.messages:
110
- filtered_messages = []
111
- for msg in self.messages:
112
- # Keep the message if its content is not empty.
113
- # This handles str content, and list content (for tool calls)
114
- if msg.content:
115
- filtered_messages.append(msg)
116
- self.messages = filtered_messages
117
- return self
118
-
119
-
120
-
121
-
122
-
123
- async def chatbot(state: GraphProcessingState, config=None):
124
- assistant_tools = []
125
- if state.tools_enabled.get("download_website_text", True):
126
- assistant_tools.append(download_website_text)
127
- if search_enabled and state.tools_enabled.get("tavily_search_results_json", True):
128
- assistant_tools.append(tavily_search_tool)
129
- assistant_model = model.bind_tools(assistant_tools)
130
- if state.prompt:
131
- final_prompt = "\n".join([state.prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
132
- else:
133
- final_prompt = ASSISTANT_SYSTEM_PROMPT_BASE
134
-
135
- prompt = ChatPromptTemplate.from_messages(
136
- [
137
- ("system", final_prompt),
138
- MessagesPlaceholder(variable_name="messages"),
139
- ]
140
- )
141
- chain = prompt | assistant_model
142
-
143
- valid_messages = [
144
- msg for msg in state.messages if (isinstance(msg.content, str) and msg.content.strip()) or (isinstance(msg.content, list) and msg.content)
145
- ]
146
-
147
- # Crucial Debugging Step: If the error persists, print the messages
148
- # to see exactly what is being sent.
149
- # print("Messages being sent to model:", valid_messages)
150
-
151
- if not valid_messages:
152
- # Handle the case where all messages are filtered out
153
- # You might want to return a default response or raise an error
154
- # For now, we can create a dummy response or just return.
155
- return {"messages": [AIMessage(content="I have no valid input to respond to.")]}
156
-
157
-
158
- response = await chain.ainvoke({"messages": valid_messages}, config=config)
159
-
160
- # message = llm_with_tools.invoke(state["messages"])
161
- # Because we will be interrupting during tool execution,
162
- # we disable parallel tool calling to avoid repeating any
163
- # tool invocations when we resume.
164
- assert len(response.tool_calls) <= 1
165
- idea_complete = evaluate_idea_completion(response)
166
-
167
- return {
168
- "messages": response,
169
- "idea_complete": idea_complete
170
- }
171
-
172
-
173
-
174
-
175
- def define_workflow() -> CompiledStateGraph:
176
- """Defines the workflow graph"""
177
- # Initialize the graph
178
- workflow = StateGraph(GraphProcessingState)
179
-
180
- # Add nodes
181
- workflow.add_node("tools", ToolNode(tools))
182
- workflow.add_node("chatbot", chatbot)
183
-
184
-
185
-
186
- workflow.add_conditional_edges(
187
- "chatbot",
188
- tools_condition,
189
- )
190
-
191
- workflow.add_edge("tools", "chatbot")
192
- workflow.add_edge(START, "chatbot")
193
-
194
- compiled_graph = workflow.compile(checkpointer=memory)
195
- try:
196
- img_bytes = compiled_graph.get_graph().draw_mermaid_png()
197
- with open("graph.png", "wb") as f:
198
- f.write(img_bytes)
199
- print("Graph image saved as graph.png")
200
- except Exception as e:
201
- print("Can't print the graph:")
202
- print(e)
203
-
204
-
205
- return compiled_graph
206
-
207
- graph = define_workflow()
208
-