avsolatorio commited on
Commit
33d0a71
·
1 Parent(s): ad8cab1

Improve interactivity

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (1) hide show
  1. mcp_client.py +85 -20
mcp_client.py CHANGED
@@ -11,12 +11,10 @@ from mcp.client.stdio import stdio_client
11
  from anthropic import Anthropic
12
  from anthropic._exceptions import OverloadedError
13
  from dotenv import load_dotenv
 
14
 
15
  load_dotenv()
16
 
17
- loop = asyncio.new_event_loop()
18
- asyncio.set_event_loop(loop)
19
-
20
  # SYSTEM_PROMPT = f"""You are a helpful assistant and today is {datetime.now().strftime("%Y-%m-%d")}.
21
 
22
  # You do not have any knowledge of the World Development Indicators (WDI) data. However, you can use the tools provided to answer questions.
@@ -76,6 +74,10 @@ LLM_MODEL = "claude-3-5-haiku-20241022"
76
  # When a tool is needed for any step, ensure to add the token `TOOL_USE`.
77
 
78
 
 
 
 
 
79
  class MCPClientWrapper:
80
  def __init__(self):
81
  self.session = None
@@ -83,15 +85,55 @@ class MCPClientWrapper:
83
  self.anthropic = Anthropic()
84
  self.tools = []
85
 
86
- def connect(self, server_path: str) -> str:
87
- return loop.run_until_complete(self._connect(server_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- async def _connect(self, server_path: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  if self.exit_stack:
91
  await self.exit_stack.aclose()
92
 
93
  self.exit_stack = AsyncExitStack()
94
-
95
  is_python = server_path.endswith(".py")
96
  command = "python" if is_python else "node"
97
 
@@ -101,11 +143,13 @@ class MCPClientWrapper:
101
  env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"},
102
  )
103
 
 
104
  stdio_transport = await self.exit_stack.enter_async_context(
105
  stdio_client(server_params)
106
  )
107
  self.stdio, self.write = stdio_transport
108
 
 
109
  self.session = await self.exit_stack.enter_async_context(
110
  ClientSession(self.stdio, self.write)
111
  )
@@ -121,14 +165,13 @@ class MCPClientWrapper:
121
  for tool in response.tools
122
  ]
123
 
124
- print(self.tools)
125
-
126
  tool_names = [tool["name"] for tool in self.tools]
127
  return f"Connected to MCP server. Available tools: {', '.join(tool_names)}"
128
 
129
- def process_message(
130
  self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
131
- ) -> tuple:
132
  if not self.session:
133
  messages = history + [
134
  {"role": "user", "content": message},
@@ -137,17 +180,25 @@ class MCPClientWrapper:
137
  "content": "Please connect to an MCP server first.",
138
  },
139
  ]
 
140
  else:
141
- new_messages = loop.run_until_complete(
142
- self._process_query(message, history)
143
- )
144
- messages = history + [{"role": "user", "content": message}] + new_messages
 
 
 
 
 
 
 
 
 
145
 
146
  with open("messages.log.jsonl", "a+") as fl:
147
  fl.write(json.dumps(dict(time=f"{datetime.now()}", messages=messages)))
148
 
149
- return messages, gr.Textbox(value="")
150
-
151
  async def _process_query(
152
  self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
153
  ):
@@ -173,7 +224,7 @@ class MCPClientWrapper:
173
  tools=self.tools,
174
  )
175
  except OverloadedError:
176
- return [
177
  {
178
  "role": "assistant",
179
  "content": "The LLM API is overloaded now, try again later...",
@@ -181,6 +232,7 @@ class MCPClientWrapper:
181
  ]
182
 
183
  result_messages = []
 
184
 
185
  print(response.content)
186
  contents = response.content
@@ -194,6 +246,9 @@ class MCPClientWrapper:
194
  if content.type == "text":
195
  result_messages.append({"role": "assistant", "content": content.text})
196
  claude_messages.append({"role": "assistant", "content": content.text})
 
 
 
197
 
198
  # if (
199
  # auto_calls < MAX_CALLS
@@ -242,6 +297,8 @@ class MCPClientWrapper:
242
  },
243
  }
244
  )
 
 
245
 
246
  result_messages.append(
247
  {
@@ -256,6 +313,8 @@ class MCPClientWrapper:
256
  },
257
  }
258
  )
 
 
259
 
260
  print(f"Calling tool: {tool_name} with args: {tool_args}")
261
  result = await self.session.call_tool(tool_name, tool_args)
@@ -274,6 +333,9 @@ class MCPClientWrapper:
274
  },
275
  }
276
  )
 
 
 
277
 
278
  result_content = result.content
279
  print(result_content)
@@ -304,6 +366,9 @@ class MCPClientWrapper:
304
  },
305
  }
306
  )
 
 
 
307
 
308
  # claude_messages.append(
309
  # {
@@ -337,7 +402,7 @@ class MCPClientWrapper:
337
  )
338
  auto_calls += 1
339
  except OverloadedError:
340
- return [
341
  {
342
  "role": "assistant",
343
  "content": "The LLM API is overloaded now, try again later...",
@@ -353,7 +418,7 @@ class MCPClientWrapper:
353
  # {"role": "assistant", "content": next_response.content[0].text}
354
  # )
355
 
356
- return result_messages
357
 
358
 
359
  client = MCPClientWrapper()
 
11
  from anthropic import Anthropic
12
  from anthropic._exceptions import OverloadedError
13
  from dotenv import load_dotenv
14
+ import functools
15
 
16
  load_dotenv()
17
 
 
 
 
18
  # SYSTEM_PROMPT = f"""You are a helpful assistant and today is {datetime.now().strftime("%Y-%m-%d")}.
19
 
20
  # You do not have any knowledge of the World Development Indicators (WDI) data. However, you can use the tools provided to answer questions.
 
74
  # When a tool is needed for any step, ensure to add the token `TOOL_USE`.
75
 
76
 
77
+ loop = asyncio.new_event_loop()
78
+ asyncio.set_event_loop(loop)
79
+
80
+
81
  class MCPClientWrapper:
82
  def __init__(self):
83
  self.session = None
 
85
  self.anthropic = Anthropic()
86
  self.tools = []
87
 
88
+ # def connect(self, server_path: str) -> str:
89
+ # return loop.run_until_complete(self._connect(server_path))
90
+
91
+ # async def _connect(self, server_path: str) -> str:
92
+ # if self.exit_stack:
93
+ # await self.exit_stack.aclose()
94
+
95
+ # self.exit_stack = AsyncExitStack()
96
+
97
+ # is_python = server_path.endswith(".py")
98
+ # command = "python" if is_python else "node"
99
+
100
+ # server_params = StdioServerParameters(
101
+ # command=command,
102
+ # args=[server_path],
103
+ # env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"},
104
+ # )
105
+
106
+ # stdio_transport = await self.exit_stack.enter_async_context(
107
+ # stdio_client(server_params)
108
+ # )
109
+ # self.stdio, self.write = stdio_transport
110
+
111
+ # self.session = await self.exit_stack.enter_async_context(
112
+ # ClientSession(self.stdio, self.write)
113
+ # )
114
+ # await self.session.initialize()
115
 
116
+ # response = await self.session.list_tools()
117
+ # self.tools = [
118
+ # {
119
+ # "name": tool.name,
120
+ # "description": tool.description,
121
+ # "input_schema": tool.inputSchema,
122
+ # }
123
+ # for tool in response.tools
124
+ # ]
125
+
126
+ # print(self.tools)
127
+
128
+ # tool_names = [tool["name"] for tool in self.tools]
129
+ # return f"Connected to MCP server. Available tools: {', '.join(tool_names)}"
130
+
131
+ async def connect(self, server_path: str) -> str:
132
+ # If there's an existing session, close it
133
  if self.exit_stack:
134
  await self.exit_stack.aclose()
135
 
136
  self.exit_stack = AsyncExitStack()
 
137
  is_python = server_path.endswith(".py")
138
  command = "python" if is_python else "node"
139
 
 
143
  env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"},
144
  )
145
 
146
+ # Launch MCP subprocess and bind streams on the *current* running loop
147
  stdio_transport = await self.exit_stack.enter_async_context(
148
  stdio_client(server_params)
149
  )
150
  self.stdio, self.write = stdio_transport
151
 
152
+ # Create ClientSession on this same loop
153
  self.session = await self.exit_stack.enter_async_context(
154
  ClientSession(self.stdio, self.write)
155
  )
 
165
  for tool in response.tools
166
  ]
167
 
168
+ print("Available tools:", self.tools)
 
169
  tool_names = [tool["name"] for tool in self.tools]
170
  return f"Connected to MCP server. Available tools: {', '.join(tool_names)}"
171
 
172
+ async def process_message(
173
  self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
174
+ ):
175
  if not self.session:
176
  messages = history + [
177
  {"role": "user", "content": message},
 
180
  "content": "Please connect to an MCP server first.",
181
  },
182
  ]
183
+ yield messages, gr.Textbox(value="")
184
  else:
185
+ # new_messages = loop.run_until_complete(
186
+ # self._process_query(message, history)
187
+ # )
188
+ # messages = history + [{"role": "user", "content": message}] + new_messages
189
+
190
+ messages = history + [{"role": "user", "content": message}]
191
+
192
+ yield messages, gr.Textbox(value="")
193
+
194
+ async for partial in self._process_query(message, history):
195
+ messages.extend(partial)
196
+
197
+ yield messages, gr.Textbox(value="")
198
 
199
  with open("messages.log.jsonl", "a+") as fl:
200
  fl.write(json.dumps(dict(time=f"{datetime.now()}", messages=messages)))
201
 
 
 
202
  async def _process_query(
203
  self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
204
  ):
 
224
  tools=self.tools,
225
  )
226
  except OverloadedError:
227
+ yield [
228
  {
229
  "role": "assistant",
230
  "content": "The LLM API is overloaded now, try again later...",
 
232
  ]
233
 
234
  result_messages = []
235
+ partial_messages = []
236
 
237
  print(response.content)
238
  contents = response.content
 
246
  if content.type == "text":
247
  result_messages.append({"role": "assistant", "content": content.text})
248
  claude_messages.append({"role": "assistant", "content": content.text})
249
+ partial_messages.append(result_messages[-1])
250
+ yield [result_messages[-1]]
251
+ partial_messages = []
252
 
253
  # if (
254
  # auto_calls < MAX_CALLS
 
297
  },
298
  }
299
  )
300
+ partial_messages.append(result_messages[-1])
301
+ yield [result_messages[-1]]
302
 
303
  result_messages.append(
304
  {
 
313
  },
314
  }
315
  )
316
+ partial_messages.append(result_messages[-1])
317
+ yield [result_messages[-1]]
318
 
319
  print(f"Calling tool: {tool_name} with args: {tool_args}")
320
  result = await self.session.call_tool(tool_name, tool_args)
 
333
  },
334
  }
335
  )
336
+ partial_messages.append(result_messages[-1])
337
+ yield [result_messages[-1]]
338
+ partial_messages = []
339
 
340
  result_content = result.content
341
  print(result_content)
 
366
  },
367
  }
368
  )
369
+ partial_messages.append(result_messages[-1])
370
+ yield [result_messages[-1]]
371
+ partial_messages = []
372
 
373
  # claude_messages.append(
374
  # {
 
402
  )
403
  auto_calls += 1
404
  except OverloadedError:
405
+ yield [
406
  {
407
  "role": "assistant",
408
  "content": "The LLM API is overloaded now, try again later...",
 
418
  # {"role": "assistant", "content": next_response.content[0].text}
419
  # )
420
 
421
+ # yield result_messages
422
 
423
 
424
  client = MCPClientWrapper()