wishwakankanamg commited on
Commit
5992e45
·
1 Parent(s): 86a0d42
Files changed (3) hide show
  1. app.py +528 -90
  2. graph.py +116 -0
  3. notapp.py +120 -0
app.py CHANGED
@@ -1,120 +1,558 @@
1
  import os
 
 
 
 
 
 
2
  import gradio as gr
3
- from huggingface_hub import InferenceClient
4
  from langchain_core.messages import HumanMessage
5
- from langchain.agents import AgentExecutor
6
- from agent import build_graph
7
- from PIL import Image
8
 
9
- """
10
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
11
- """
12
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
13
 
 
 
14
 
15
- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
 
16
 
17
- class BasicAgent:
18
- """A langgraph agent."""
19
- def __init__(self):
20
- print("BasicAgent initialized.")
21
- self.graph = build_graph()
22
 
23
- def __call__(self, question: str) -> str:
24
- print(f"Agent received question (first 50 chars): {question[:50]}...")
25
- # Wrap the question in a HumanMessage from langchain_core
26
- messages = [HumanMessage(content=question)]
27
- config = {"recursion_limit": 27}
28
 
29
- messages = self.graph.invoke({"messages": messages}, config=config)
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- answer = messages['messages'][-1].content
32
- return answer[14:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- try:
35
- agent = BasicAgent()
36
- except Exception as e:
37
- print(f"Error instantiating agent: {e}")
38
 
39
- def show_graph():
40
- if not os.path.exists("graph.png"):
41
- return None
42
- return Image.open("graph.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- config = {"configurable": {"thread_id": "1"}}
 
45
 
 
 
 
46
 
47
- def run_langgraph_agent(user_input: str):
48
- graph = build_graph()
49
- result = graph.stream(
50
- {"messages": [HumanMessage(content=user_input)]},
51
- config,
52
- stream_mode="values",
 
 
 
 
 
53
  )
54
- return result["messages"][-1].content if "messages" in result else result
 
 
 
 
 
 
 
 
 
 
 
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- demo = gr.Interface(
58
- fn=run_langgraph_agent,
59
- inputs=gr.Textbox(lines=2, placeholder="Enter your message..."),
60
- outputs="text",
61
- title="LangGraph Agent Chat",
62
- )
 
 
 
 
 
 
 
 
 
 
63
 
64
- if __name__ == "__main__":
65
- demo.launch()
 
66
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
 
 
 
 
 
 
 
 
69
 
70
- # def respond(
71
- # message,
72
- # history: list[tuple[str, str]],
73
- # system_message,
74
- # max_tokens,
75
- # temperature,
76
- # top_p,
77
- # ):
78
- # messages = [{"role": "system", "content": system_message}]
79
 
80
- # for val in history:
81
- # if val[0]:
82
- # messages.append({"role": "user", "content": val[0]})
83
- # if val[1]:
84
- # messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
 
 
85
 
86
- # messages.append({"role": "user", "content": message})
 
 
 
 
 
87
 
88
- # response = ""
 
 
89
 
90
- # for message in client.chat_completion(
91
- # messages,
92
- # max_tokens=max_tokens,
93
- # stream=True,
94
- # temperature=temperature,
95
- # top_p=top_p,
96
- # ):
97
- # token = message.choices[0].delta.content
98
 
99
- # response += token
100
- # yield response
 
 
101
 
 
 
102
 
103
- """
104
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
105
- """
106
- # demo = gr.ChatInterface(
107
- # respond,
108
- # additional_inputs=[
109
- # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
110
- # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
111
- # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
112
- # gr.Slider(
113
- # minimum=0.1,
114
- # maximum=1.0,
115
- # value=0.95,
116
- # step=0.05,
117
- # label="Top-p (nucleus sampling)",
118
- # ),
119
- # ],
120
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import logging
3
+ import logging.config
4
+ from typing import Any
5
+ from uuid import uuid4, UUID
6
+ import json
7
+
8
  import gradio as gr
9
+ from dotenv import load_dotenv
10
  from langchain_core.messages import HumanMessage
11
+ from langgraph.types import RunnableConfig
12
+ from pydantic import BaseModel
 
13
 
14
+ load_dotenv()
 
 
 
15
 
16
+ # There are tools set here dependent on environment variables
17
+ from graph import graph, weak_model, search_enabled # noqa
18
 
19
+ FOLLOWUP_QUESTION_NUMBER = 3
20
+ TRIM_MESSAGE_LENGTH = 16 # Includes tool messages
21
+ USER_INPUT_MAX_LENGTH = 10000 # Characters
22
 
23
+ # We need the same secret for data persistance
24
+ # If you store sensitive data, you should store your secret in .env
25
+ BROWSER_STORAGE_SECRET = "itsnosecret"
 
 
26
 
27
+ with open('logging-config.json', 'r') as fh:
28
+ config = json.load(fh)
29
+ logging.config.dictConfig(config)
30
+ logger = logging.getLogger(__name__)
 
31
 
32
+ async def chat_fn(user_input: str, history: dict, input_graph_state: dict, uuid: UUID, prompt: str, search_enabled: bool, download_website_text_enabled: bool):
33
+ """
34
+ Args:
35
+ user_input (str): The user's input message
36
+ history (dict): The history of the conversation in gradio
37
+ input_graph_state (dict): The current state of the graph. This includes tool call history
38
+ uuid (UUID): The unique identifier for the current conversation. This can be used in conjunction with langgraph or for memory
39
+ prompt (str): The system prompt
40
+ Yields:
41
+ str: The output message
42
+ dict|Any: The final state of the graph
43
+ bool|Any: Whether to trigger follow up questions
44
 
45
+ We do not use gradio history in the graph since we want the ToolMessage in the history
46
+ ordered properly. GraphProcessingState.messages is used as history instead
47
+ """
48
+ try:
49
+ logger.info(f"Prompt: {prompt}")
50
+ input_graph_state["tools_enabled"] = {
51
+ "download_website_text": download_website_text_enabled,
52
+ "tavily_search_results_json": search_enabled,
53
+ }
54
+ if prompt:
55
+ input_graph_state["prompt"] = prompt
56
+ if "messages" not in input_graph_state:
57
+ input_graph_state["messages"] = []
58
+ input_graph_state["messages"].append(
59
+ HumanMessage(user_input[:USER_INPUT_MAX_LENGTH])
60
+ )
61
+ input_graph_state["messages"] = input_graph_state["messages"][-TRIM_MESSAGE_LENGTH:]
62
+ config = RunnableConfig(
63
+ recursion_limit=20,
64
+ run_name="user_chat",
65
+ configurable={"thread_id": uuid}
66
+ )
67
 
68
+ output: str = ""
69
+ final_state: dict | Any = {}
70
+ waiting_output_seq: list[str] = []
 
71
 
72
+ async for stream_mode, chunk in graph.astream(
73
+ input_graph_state,
74
+ config=config,
75
+ stream_mode=["values", "messages"],
76
+ ):
77
+ if stream_mode == "values":
78
+ final_state = chunk
79
+ last_message = chunk["messages"][-1]
80
+ if hasattr(last_message, "tool_calls"):
81
+ for msg_tool_call in last_message.tool_calls:
82
+ tool_name: str = msg_tool_call['name']
83
+ if tool_name == "tavily_search_results_json":
84
+ query = msg_tool_call['args']['query']
85
+ waiting_output_seq.append(f"Searching for '{query}'...")
86
+ yield "\n".join(waiting_output_seq), gr.skip(), gr.skip()
87
+ # download_website_text is the name of the function defined in graph.py
88
+ elif tool_name == "download_website_text":
89
+ url = msg_tool_call['args']['url']
90
+ waiting_output_seq.append(f"Downloading text from '{url}'...")
91
+ yield "\n".join(waiting_output_seq), gr.skip(), gr.skip()
92
+ else:
93
+ waiting_output_seq.append(f"Running {tool_name}...")
94
+ yield "\n".join(waiting_output_seq), gr.skip(), gr.skip()
95
+ elif stream_mode == "messages":
96
+ msg, metadata = chunk
97
+ # print("output: ", msg, metadata)
98
+ # assistant_node is the name we defined in the langgraph graph
99
+ if metadata['langgraph_node'] == "assistant_node" and msg.content:
100
+ output += msg.content
101
+ yield output, gr.skip(), gr.skip()
102
+ # Trigger for asking follow up questions
103
+ # + store the graph state for next iteration
104
+ # yield output, dict(final_state), gr.skip()
105
+ yield output + " ", dict(final_state), True
106
+ except Exception:
107
+ logger.exception("Exception occurred")
108
+ user_error_message = "There was an error processing your request. Please try again."
109
+ yield user_error_message, gr.skip(), False
110
 
111
+ def clear():
112
+ return dict(), uuid4()
113
 
114
+ class FollowupQuestions(BaseModel):
115
+ """Model for langchain to use for structured output for followup questions"""
116
+ questions: list[str]
117
 
118
+ async def populate_followup_questions(end_of_chat_response: bool, messages: dict[str, str], uuid: UUID):
119
+ """
120
+ This function gets called a lot due to the asynchronous nature of streaming
121
+
122
+ Only populate followup questions if streaming has completed and the message is coming from the assistant
123
+ """
124
+ if not end_of_chat_response or not messages or messages[-1]["role"] != "assistant":
125
+ return *[gr.skip() for _ in range(FOLLOWUP_QUESTION_NUMBER)], False
126
+ config = RunnableConfig(
127
+ run_name="populate_followup_questions",
128
+ configurable={"thread_id": uuid}
129
  )
130
+ weak_model_with_config = weak_model.with_config(config)
131
+ follow_up_questions = await weak_model_with_config.with_structured_output(FollowupQuestions).ainvoke([
132
+ ("system", f"suggest {FOLLOWUP_QUESTION_NUMBER} followup questions for the user to ask the assistant. Refrain from asking personal questions."),
133
+ *messages,
134
+ ])
135
+ if len(follow_up_questions.questions) != FOLLOWUP_QUESTION_NUMBER:
136
+ raise ValueError("Invalid value of followup questions")
137
+ buttons = []
138
+ for i in range(FOLLOWUP_QUESTION_NUMBER):
139
+ buttons.append(
140
+ gr.Button(follow_up_questions.questions[i], visible=True, elem_classes="chat-tab"),
141
+ )
142
+ return *buttons, False
143
 
144
+ async def summarize_chat(end_of_chat_response: bool, messages: dict, sidebar_summaries: dict, uuid: UUID):
145
+ """Summarize chat for tab names"""
146
+ # print("\n------------------------")
147
+ # print("not end_of_chat_response", not end_of_chat_response)
148
+ # print("not messages", not messages)
149
+ # if messages:
150
+ # print("messages[-1][role] != assistant", messages[-1]["role"] != "assistant")
151
+ # print("isinstance(sidebar_summaries, type(lambda x: x))", isinstance(sidebar_summaries, type(lambda x: x)))
152
+ # print("uuid in sidebar_summaries", uuid in sidebar_summaries)
153
+ should_return = (
154
+ not end_of_chat_response or
155
+ not messages or
156
+ messages[-1]["role"] != "assistant" or
157
+ # This is a bug with gradio
158
+ isinstance(sidebar_summaries, type(lambda x: x)) or
159
+ # Already created summary
160
+ uuid in sidebar_summaries
161
+ )
162
+ if should_return:
163
+ return gr.skip(), gr.skip()
164
+ config = RunnableConfig(
165
+ run_name="summarize_chat",
166
+ configurable={"thread_id": uuid}
167
+ )
168
+ weak_model_with_config = weak_model.with_config(config)
169
+ summary_response = await weak_model_with_config.ainvoke([
170
+ ("system", "summarize this chat in 7 tokens or less. Refrain from using periods"),
171
+ *messages,
172
+ ])
173
+ if uuid not in sidebar_summaries:
174
+ sidebar_summaries[uuid] = summary_response.content
175
+ return sidebar_summaries, False
176
 
177
+ async def new_tab(uuid, gradio_graph, messages, tabs, prompt, sidebar_summaries):
178
+ new_uuid = uuid4()
179
+ new_graph = {}
180
+ if uuid not in sidebar_summaries:
181
+ sidebar_summaries, _ = await summarize_chat(True, messages, sidebar_summaries, uuid)
182
+ tabs[uuid] = {
183
+ "graph": gradio_graph,
184
+ "messages": messages,
185
+ "prompt": prompt,
186
+ }
187
+ suggestion_buttons = []
188
+ for _ in range(FOLLOWUP_QUESTION_NUMBER):
189
+ suggestion_buttons.append(gr.Button(visible=False))
190
+ new_messages = {}
191
+ new_prompt = "You are a helpful assistant."
192
+ return new_uuid, new_graph, new_messages, tabs, new_prompt, sidebar_summaries, *suggestion_buttons
193
 
194
+ def switch_tab(selected_uuid, tabs, gradio_graph, uuid, messages, prompt):
195
+ # I don't know of another way to lookup uuid other than
196
+ # by the button value
197
 
198
+ # Save current state
199
+ if messages:
200
+ tabs[uuid] = {
201
+ "graph": gradio_graph,
202
+ "messages": messages,
203
+ "prompt": prompt
204
+ }
205
 
206
+ if selected_uuid not in tabs:
207
+ logger.error(f"Could not find the selected tab in offloaded_tabs_data_storage {selected_uuid}")
208
+ return gr.skip(), gr.skip(), gr.skip(), gr.skip()
209
+ selected_tab_state = tabs[selected_uuid]
210
+ selected_graph = selected_tab_state["graph"]
211
+ selected_messages = selected_tab_state["messages"]
212
+ selected_prompt = selected_tab_state.get("prompt", "")
213
+ suggestion_buttons = []
214
+ for _ in range(FOLLOWUP_QUESTION_NUMBER):
215
+ suggestion_buttons.append(gr.Button(visible=False))
216
+ return selected_graph, selected_uuid, selected_messages, tabs, selected_prompt, *suggestion_buttons
217
 
218
+ def delete_tab(current_chat_uuid, selected_uuid, sidebar_summaries, tabs):
219
+ output_messages = gr.skip()
220
+ if current_chat_uuid == selected_uuid:
221
+ output_messages = dict()
222
+ if selected_uuid in tabs:
223
+ del tabs[selected_uuid]
224
+ if selected_uuid in sidebar_summaries:
225
+ del sidebar_summaries[selected_uuid]
226
+ return sidebar_summaries, tabs, output_messages
227
 
228
+ def submit_edit_tab(selected_uuid, sidebar_summaries, text):
229
+ sidebar_summaries[selected_uuid] = text
230
+ return sidebar_summaries, ""
 
 
 
 
 
 
231
 
232
+ CSS = """
233
+ footer {visibility: hidden}
234
+ .followup-question-button {font-size: 12px }
235
+ .chat-tab {
236
+ font-size: 12px;
237
+ padding-inline: 0;
238
+ }
239
+ .chat-tab.active {
240
+ background-color: #654343;
241
+ }
242
+ #new-chat-button { background-color: #0f0f11; color: white; }
243
 
244
+ .tab-button-control {
245
+ min-width: 0;
246
+ padding-left: 0;
247
+ padding-right: 0;
248
+ }
249
+ """
250
 
251
+ # We set the ChatInterface textbox id to chat-textbox for this to work
252
+ TRIGGER_CHATINTERFACE_BUTTON = """
253
+ function triggerChatButtonClick() {
254
 
255
+ // Find the div with id "chat-textbox"
256
+ const chatTextbox = document.getElementById("chat-textbox");
 
 
 
 
 
 
257
 
258
+ if (!chatTextbox) {
259
+ console.error("Error: Could not find element with id 'chat-textbox'");
260
+ return;
261
+ }
262
 
263
+ // Find the button that is a descendant of the div
264
+ const button = chatTextbox.querySelector("button");
265
 
266
+ if (!button) {
267
+ console.error("Error: No button found inside the chat-textbox element");
268
+ return;
269
+ }
270
+
271
+ // Trigger the click event
272
+ button.click();
273
+ }"""
274
+
275
+ if __name__ == "__main__":
276
+ logger.info("Starting the interface")
277
+ with gr.Blocks(title="Langgraph Template", fill_height=True, css=CSS) as app:
278
+ current_prompt_state = gr.BrowserState(
279
+ storage_key="current_prompt_state",
280
+ secret=BROWSER_STORAGE_SECRET,
281
+ )
282
+ current_uuid_state = gr.BrowserState(
283
+ uuid4,
284
+ storage_key="current_uuid_state",
285
+ secret=BROWSER_STORAGE_SECRET,
286
+ )
287
+ current_langgraph_state = gr.BrowserState(
288
+ dict(),
289
+ storage_key="current_langgraph_state",
290
+ secret=BROWSER_STORAGE_SECRET,
291
+ )
292
+ end_of_assistant_response_state = gr.State(
293
+ bool(),
294
+ )
295
+ # [uuid] -> summary of chat
296
+ sidebar_names_state = gr.BrowserState(
297
+ dict(),
298
+ storage_key="sidebar_names_state",
299
+ secret=BROWSER_STORAGE_SECRET,
300
+ )
301
+ # [uuid] -> {"graph": gradio_graph, "messages": messages}
302
+ offloaded_tabs_data_storage = gr.BrowserState(
303
+ dict(),
304
+ storage_key="offloaded_tabs_data_storage",
305
+ secret=BROWSER_STORAGE_SECRET,
306
+ )
307
+
308
+ chatbot_message_storage = gr.BrowserState(
309
+ [],
310
+ storage_key="chatbot_message_storage",
311
+ secret=BROWSER_STORAGE_SECRET,
312
+ )
313
+ with gr.Column():
314
+ prompt_textbox = gr.Textbox(show_label=False, interactive=True)
315
+ with gr.Row():
316
+ checkbox_search_enabled = gr.Checkbox(
317
+ value=True,
318
+ label="Enable search",
319
+ show_label=True,
320
+ visible=search_enabled,
321
+ scale=1,
322
+ )
323
+ checkbox_download_website_text = gr.Checkbox(
324
+ value=True,
325
+ show_label=True,
326
+ label="Enable downloading text from urls",
327
+ scale=1,
328
+ )
329
+ chatbot = gr.Chatbot(
330
+ type="messages",
331
+ scale=1,
332
+ show_copy_button=True,
333
+ height=600,
334
+ editable="all",
335
+ )
336
+ tab_edit_uuid_state = gr.State(
337
+ str()
338
+ )
339
+ prompt_textbox.change(lambda prompt: prompt, inputs=[prompt_textbox], outputs=[current_prompt_state])
340
+ with gr.Sidebar() as sidebar:
341
+ @gr.render(inputs=[tab_edit_uuid_state, end_of_assistant_response_state, sidebar_names_state, current_uuid_state, chatbot, offloaded_tabs_data_storage])
342
+ def render_chats(tab_uuid_edit, end_of_chat_response, sidebar_summaries, active_uuid, messages, tabs):
343
+ current_tab_button_text = ""
344
+ if active_uuid not in sidebar_summaries:
345
+ current_tab_button_text = "Current Chat"
346
+ elif active_uuid not in tabs:
347
+ current_tab_button_text = sidebar_summaries[active_uuid]
348
+ if current_tab_button_text:
349
+ gr.Button(current_tab_button_text, elem_classes=["chat-tab", "active"])
350
+ for chat_uuid, tab in reversed(tabs.items()):
351
+ elem_classes = ["chat-tab"]
352
+ if chat_uuid == active_uuid:
353
+ elem_classes.append("active")
354
+ button_uuid_state = gr.State(chat_uuid)
355
+ with gr.Row():
356
+ clear_tab_button = gr.Button(
357
+ "🗑",
358
+ scale=0,
359
+ elem_classes=["tab-button-control"]
360
+ )
361
+ clear_tab_button.click(
362
+ fn=delete_tab,
363
+ inputs=[
364
+ current_uuid_state,
365
+ button_uuid_state,
366
+ sidebar_names_state,
367
+ offloaded_tabs_data_storage
368
+ ],
369
+ outputs=[
370
+ sidebar_names_state,
371
+ offloaded_tabs_data_storage,
372
+ chat_interface.chatbot_value
373
+ ]
374
+ )
375
+ chat_button_text = sidebar_summaries.get(chat_uuid)
376
+ if not chat_button_text:
377
+ chat_button_text = str(chat_uuid)
378
+ if chat_uuid != tab_uuid_edit:
379
+ set_edit_tab_button = gr.Button(
380
+ "✎",
381
+ scale=0,
382
+ elem_classes=["tab-button-control"]
383
+ )
384
+ set_edit_tab_button.click(
385
+ fn=lambda x: x,
386
+ inputs=[button_uuid_state],
387
+ outputs=[tab_edit_uuid_state]
388
+ )
389
+ chat_tab_button = gr.Button(
390
+ chat_button_text,
391
+ elem_id=f"chat-{chat_uuid}-button",
392
+ elem_classes=elem_classes,
393
+ scale=2
394
+ )
395
+ chat_tab_button.click(
396
+ fn=switch_tab,
397
+ inputs=[
398
+ button_uuid_state,
399
+ offloaded_tabs_data_storage,
400
+ current_langgraph_state,
401
+ current_uuid_state,
402
+ chatbot,
403
+ prompt_textbox
404
+ ],
405
+ outputs=[
406
+ current_langgraph_state,
407
+ current_uuid_state,
408
+ chat_interface.chatbot_value,
409
+ offloaded_tabs_data_storage,
410
+ prompt_textbox,
411
+ *followup_question_buttons
412
+ ]
413
+ )
414
+ else:
415
+ chat_tab_text = gr.Textbox(
416
+ chat_button_text,
417
+ scale=2,
418
+ interactive=True,
419
+ show_label=False
420
+ )
421
+ chat_tab_text.submit(
422
+ fn=submit_edit_tab,
423
+ inputs=[
424
+ button_uuid_state,
425
+ sidebar_names_state,
426
+ chat_tab_text
427
+ ],
428
+ outputs=[
429
+ sidebar_names_state,
430
+ tab_edit_uuid_state
431
+ ]
432
+ )
433
+ # )
434
+ # return chat_tabs, sidebar_summaries
435
+ new_chat_button = gr.Button("New Chat", elem_id="new-chat-button")
436
+ chatbot.clear(fn=clear, outputs=[current_langgraph_state, current_uuid_state])
437
+ with gr.Row():
438
+ followup_question_buttons = []
439
+ for i in range(FOLLOWUP_QUESTION_NUMBER):
440
+ btn = gr.Button(f"Button {i+1}", visible=False)
441
+ followup_question_buttons.append(btn)
442
+
443
+ multimodal = False
444
+ textbox_component = (
445
+ gr.MultimodalTextbox if multimodal else gr.Textbox
446
+ )
447
+ with gr.Column():
448
+ textbox = textbox_component(
449
+ show_label=False,
450
+ label="Message",
451
+ placeholder="Type a message...",
452
+ scale=7,
453
+ autofocus=True,
454
+ submit_btn=True,
455
+ stop_btn=True,
456
+ elem_id="chat-textbox",
457
+ lines=1,
458
+ )
459
+ chat_interface = gr.ChatInterface(
460
+ chatbot=chatbot,
461
+ fn=chat_fn,
462
+ additional_inputs=[
463
+ current_langgraph_state,
464
+ current_uuid_state,
465
+ prompt_textbox,
466
+ checkbox_search_enabled,
467
+ checkbox_download_website_text,
468
+ ],
469
+ additional_outputs=[
470
+ current_langgraph_state,
471
+ end_of_assistant_response_state
472
+ ],
473
+ type="messages",
474
+ multimodal=multimodal,
475
+ textbox=textbox,
476
+ )
477
+
478
+ new_chat_button.click(
479
+ new_tab,
480
+ inputs=[
481
+ current_uuid_state,
482
+ current_langgraph_state,
483
+ chatbot,
484
+ offloaded_tabs_data_storage,
485
+ prompt_textbox,
486
+ sidebar_names_state,
487
+ ],
488
+ outputs=[
489
+ current_uuid_state,
490
+ current_langgraph_state,
491
+ chat_interface.chatbot_value,
492
+ offloaded_tabs_data_storage,
493
+ prompt_textbox,
494
+ sidebar_names_state,
495
+ *followup_question_buttons,
496
+ ]
497
+ )
498
+
499
+ def click_followup_button(btn):
500
+ buttons = [gr.Button(visible=False) for _ in range(len(followup_question_buttons))]
501
+ return btn, *buttons
502
+ for btn in followup_question_buttons:
503
+ btn.click(
504
+ fn=click_followup_button,
505
+ inputs=[btn],
506
+ outputs=[
507
+ chat_interface.textbox,
508
+ *followup_question_buttons
509
+ ]
510
+ ).success(lambda: None, js=TRIGGER_CHATINTERFACE_BUTTON)
511
+
512
+ chatbot.change(
513
+ fn=populate_followup_questions,
514
+ inputs=[
515
+ end_of_assistant_response_state,
516
+ chatbot,
517
+ current_uuid_state
518
+ ],
519
+ outputs=[
520
+ *followup_question_buttons,
521
+ end_of_assistant_response_state
522
+ ],
523
+ trigger_mode="multiple"
524
+ )
525
+ chatbot.change(
526
+ fn=summarize_chat,
527
+ inputs=[
528
+ end_of_assistant_response_state,
529
+ chatbot,
530
+ sidebar_names_state,
531
+ current_uuid_state
532
+ ],
533
+ outputs=[
534
+ sidebar_names_state,
535
+ end_of_assistant_response_state
536
+ ],
537
+ trigger_mode="multiple"
538
+ )
539
+ chatbot.change(
540
+ fn=lambda x: x,
541
+ inputs=[chatbot],
542
+ outputs=[chatbot_message_storage],
543
+ trigger_mode="always_last"
544
+ )
545
+
546
+ @app.load(inputs=[chatbot_message_storage], outputs=[chat_interface.chatbot_value])
547
+ def load_messages(messages):
548
+ return messages
549
+
550
+ @app.load(inputs=[current_prompt_state], outputs=[prompt_textbox])
551
+ def load_prompt(current_prompt):
552
+ return current_prompt
553
+
554
+ app.launch(
555
+ server_name="127.0.0.1",
556
+ server_port=int(os.getenv("GRADIO_SERVER_PORT", 7860)),
557
+ # favicon_path="assets/favicon.ico"
558
+ )
graph.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Annotated
4
+
5
+ import aiohttp
6
+ from langchain_core.messages import AnyMessage
7
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
8
+ from langchain_core.tools import tool
9
+ from langchain_openai import ChatOpenAI
10
+ from langgraph.graph.state import CompiledStateGraph
11
+ from langgraph.prebuilt import ToolNode
12
+ from langgraph.graph import StateGraph, END, add_messages
13
+ from langchain_community.tools import TavilySearchResults
14
+ from pydantic import BaseModel, Field
15
+ from trafilatura import extract
16
+
17
+ logger = logging.getLogger(__name__)
18
+ ASSISTANT_SYSTEM_PROMPT_BASE = """"""
19
+ search_enabled = bool(os.environ.get("TAVILY_API_KEY"))
20
+
21
+ @tool
22
+ async def download_website_text(url: str) -> str:
23
+ """Download the text from a website"""
24
+ try:
25
+ async with aiohttp.ClientSession() as session:
26
+ async with session.get(url) as response:
27
+ response.raise_for_status()
28
+ downloaded = await response.text()
29
+ result = extract(downloaded, include_formatting=True, include_links=True, output_format='json', with_metadata=True)
30
+ return result or "No text found on the website"
31
+ except Exception as e:
32
+ logger.error(f"Failed to download {url}: {str(e)}")
33
+ return f"Error retrieving website content: {str(e)}"
34
+
35
+ tools = [download_website_text]
36
+
37
+ if search_enabled:
38
+ tavily_search_tool = TavilySearchResults(
39
+ max_results=5,
40
+ search_depth="advanced",
41
+ include_answer=True,
42
+ include_raw_content=True,
43
+ )
44
+ tools.append(tavily_search_tool)
45
+ else:
46
+ print("TAVILY_API_KEY environment variable not found. Websearch disabled")
47
+
48
+ weak_model = ChatOpenAI(model="gpt-4o-mini", tags=["assistant"])
49
+ model = weak_model
50
+ assistant_model = weak_model
51
+
52
+ class GraphProcessingState(BaseModel):
53
+ # user_input: str = Field(default_factory=str, description="The original user input")
54
+ messages: Annotated[list[AnyMessage], add_messages] = Field(default_factory=list)
55
+ prompt: str = Field(default_factory=str, description="The prompt to be used for the model")
56
+ tools_enabled: dict = Field(default_factory=dict, description="The tools enabled for the assistant")
57
+ search_enabled: bool = Field(default=True, description="Whether to enable search tools")
58
+
59
+ async def assistant_node(state: GraphProcessingState, config=None):
60
+ assistant_tools = []
61
+ if state.tools_enabled.get("download_website_text", True):
62
+ assistant_tools.append(download_website_text)
63
+ if search_enabled and state.tools_enabled.get("tavily_search_results_json", True):
64
+ assistant_tools.append(tavily_search_tool)
65
+ assistant_model = model.bind_tools(assistant_tools)
66
+ if state.prompt:
67
+ final_prompt = "\n".join([state.prompt, ASSISTANT_SYSTEM_PROMPT_BASE])
68
+ else:
69
+ final_prompt = ASSISTANT_SYSTEM_PROMPT_BASE
70
+
71
+ prompt = ChatPromptTemplate.from_messages(
72
+ [
73
+ ("system", final_prompt),
74
+ MessagesPlaceholder(variable_name="messages"),
75
+ ]
76
+ )
77
+ chain = prompt | assistant_model
78
+ response = await chain.ainvoke({"messages": state.messages}, config=config)
79
+
80
+ return {
81
+ "messages": response
82
+ }
83
+
84
+ def assistant_cond_edge(state: GraphProcessingState):
85
+ last_message = state.messages[-1]
86
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
87
+ logger.info(f"Tool call detected: {last_message.tool_calls}")
88
+ return "tools"
89
+ return END
90
+
91
+ def define_workflow() -> CompiledStateGraph:
92
+ """Defines the workflow graph"""
93
+ # Initialize the graph
94
+ workflow = StateGraph(GraphProcessingState)
95
+
96
+ # Add nodes
97
+ workflow.add_node("assistant_node", assistant_node)
98
+ workflow.add_node("tools", ToolNode(tools))
99
+
100
+ # Edges
101
+ workflow.add_edge("tools", "assistant_node")
102
+
103
+ # Conditional routing
104
+ workflow.add_conditional_edges(
105
+ "assistant_node",
106
+ # If the latest message (result) from assistant is a tool call -> assistant_cond_edge routes to tools
107
+ # If the latest message (result) from assistant is a not a tool call -> assistant_cond_edge routes to END
108
+ assistant_cond_edge,
109
+ )
110
+ # Set end nodes
111
+ workflow.set_entry_point("assistant_node")
112
+ # workflow.set_finish_point("assistant_node")
113
+
114
+ return workflow.compile()
115
+
116
+ graph = define_workflow()
notapp.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from huggingface_hub import InferenceClient
4
+ from langchain_core.messages import HumanMessage
5
+ from langchain.agents import AgentExecutor
6
+ from agent import build_graph
7
+ from PIL import Image
8
+
9
+ """
10
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
11
+ """
12
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
13
+
14
+
15
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
16
+
17
+ class BasicAgent:
18
+ """A langgraph agent."""
19
+ def __init__(self):
20
+ print("BasicAgent initialized.")
21
+ self.graph = build_graph()
22
+
23
+ def __call__(self, question: str) -> str:
24
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
25
+ # Wrap the question in a HumanMessage from langchain_core
26
+ messages = [HumanMessage(content=question)]
27
+ config = {"recursion_limit": 27}
28
+
29
+ messages = self.graph.invoke({"messages": messages}, config=config)
30
+
31
+ answer = messages['messages'][-1].content
32
+ return answer[14:]
33
+
34
+ try:
35
+ agent = BasicAgent()
36
+ except Exception as e:
37
+ print(f"Error instantiating agent: {e}")
38
+
39
+ def show_graph():
40
+ if not os.path.exists("graph.png"):
41
+ return None
42
+ return Image.open("graph.png")
43
+
44
+ config = {"configurable": {"thread_id": "1"}}
45
+
46
+
47
+ def run_langgraph_agent(user_input: str):
48
+ graph = build_graph()
49
+ result = graph.stream(
50
+ {"messages": [HumanMessage(content=user_input)]},
51
+ config,
52
+ stream_mode="values",
53
+ )
54
+ return result["messages"][-1].content if "messages" in result else result
55
+
56
+
57
+ demo = gr.Interface(
58
+ fn=run_langgraph_agent,
59
+ inputs=gr.Textbox(lines=2, placeholder="Enter your message..."),
60
+ outputs="text",
61
+ title="LangGraph Agent Chat",
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ demo.launch()
66
+
67
+
68
+
69
+
70
+ # def respond(
71
+ # message,
72
+ # history: list[tuple[str, str]],
73
+ # system_message,
74
+ # max_tokens,
75
+ # temperature,
76
+ # top_p,
77
+ # ):
78
+ # messages = [{"role": "system", "content": system_message}]
79
+
80
+ # for val in history:
81
+ # if val[0]:
82
+ # messages.append({"role": "user", "content": val[0]})
83
+ # if val[1]:
84
+ # messages.append({"role": "assistant", "content": val[1]})
85
+
86
+ # messages.append({"role": "user", "content": message})
87
+
88
+ # response = ""
89
+
90
+ # for message in client.chat_completion(
91
+ # messages,
92
+ # max_tokens=max_tokens,
93
+ # stream=True,
94
+ # temperature=temperature,
95
+ # top_p=top_p,
96
+ # ):
97
+ # token = message.choices[0].delta.content
98
+
99
+ # response += token
100
+ # yield response
101
+
102
+
103
+ """
104
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
105
+ """
106
+ # demo = gr.ChatInterface(
107
+ # respond,
108
+ # additional_inputs=[
109
+ # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
110
+ # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
111
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
112
+ # gr.Slider(
113
+ # minimum=0.1,
114
+ # maximum=1.0,
115
+ # value=0.95,
116
+ # step=0.05,
117
+ # label="Top-p (nucleus sampling)",
118
+ # ),
119
+ # ],
120
+ # )