Commit
·
33d0a71
1
Parent(s):
ad8cab1
Improve interactivity
Browse filesSigned-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>
- mcp_client.py +85 -20
mcp_client.py
CHANGED
@@ -11,12 +11,10 @@ from mcp.client.stdio import stdio_client
|
|
11 |
from anthropic import Anthropic
|
12 |
from anthropic._exceptions import OverloadedError
|
13 |
from dotenv import load_dotenv
|
|
|
14 |
|
15 |
load_dotenv()
|
16 |
|
17 |
-
loop = asyncio.new_event_loop()
|
18 |
-
asyncio.set_event_loop(loop)
|
19 |
-
|
20 |
# SYSTEM_PROMPT = f"""You are a helpful assistant and today is {datetime.now().strftime("%Y-%m-%d")}.
|
21 |
|
22 |
# You do not have any knowledge of the World Development Indicators (WDI) data. However, you can use the tools provided to answer questions.
|
@@ -76,6 +74,10 @@ LLM_MODEL = "claude-3-5-haiku-20241022"
|
|
76 |
# When a tool is needed for any step, ensure to add the token `TOOL_USE`.
|
77 |
|
78 |
|
|
|
|
|
|
|
|
|
79 |
class MCPClientWrapper:
|
80 |
def __init__(self):
|
81 |
self.session = None
|
@@ -83,15 +85,55 @@ class MCPClientWrapper:
|
|
83 |
self.anthropic = Anthropic()
|
84 |
self.tools = []
|
85 |
|
86 |
-
def connect(self, server_path: str) -> str:
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
if self.exit_stack:
|
91 |
await self.exit_stack.aclose()
|
92 |
|
93 |
self.exit_stack = AsyncExitStack()
|
94 |
-
|
95 |
is_python = server_path.endswith(".py")
|
96 |
command = "python" if is_python else "node"
|
97 |
|
@@ -101,11 +143,13 @@ class MCPClientWrapper:
|
|
101 |
env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"},
|
102 |
)
|
103 |
|
|
|
104 |
stdio_transport = await self.exit_stack.enter_async_context(
|
105 |
stdio_client(server_params)
|
106 |
)
|
107 |
self.stdio, self.write = stdio_transport
|
108 |
|
|
|
109 |
self.session = await self.exit_stack.enter_async_context(
|
110 |
ClientSession(self.stdio, self.write)
|
111 |
)
|
@@ -121,14 +165,13 @@ class MCPClientWrapper:
|
|
121 |
for tool in response.tools
|
122 |
]
|
123 |
|
124 |
-
print(self.tools)
|
125 |
-
|
126 |
tool_names = [tool["name"] for tool in self.tools]
|
127 |
return f"Connected to MCP server. Available tools: {', '.join(tool_names)}"
|
128 |
|
129 |
-
def process_message(
|
130 |
self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
|
131 |
-
)
|
132 |
if not self.session:
|
133 |
messages = history + [
|
134 |
{"role": "user", "content": message},
|
@@ -137,17 +180,25 @@ class MCPClientWrapper:
|
|
137 |
"content": "Please connect to an MCP server first.",
|
138 |
},
|
139 |
]
|
|
|
140 |
else:
|
141 |
-
new_messages = loop.run_until_complete(
|
142 |
-
|
143 |
-
)
|
144 |
-
messages = history + [{"role": "user", "content": message}] + new_messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
with open("messages.log.jsonl", "a+") as fl:
|
147 |
fl.write(json.dumps(dict(time=f"{datetime.now()}", messages=messages)))
|
148 |
|
149 |
-
return messages, gr.Textbox(value="")
|
150 |
-
|
151 |
async def _process_query(
|
152 |
self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
|
153 |
):
|
@@ -173,7 +224,7 @@ class MCPClientWrapper:
|
|
173 |
tools=self.tools,
|
174 |
)
|
175 |
except OverloadedError:
|
176 |
-
|
177 |
{
|
178 |
"role": "assistant",
|
179 |
"content": "The LLM API is overloaded now, try again later...",
|
@@ -181,6 +232,7 @@ class MCPClientWrapper:
|
|
181 |
]
|
182 |
|
183 |
result_messages = []
|
|
|
184 |
|
185 |
print(response.content)
|
186 |
contents = response.content
|
@@ -194,6 +246,9 @@ class MCPClientWrapper:
|
|
194 |
if content.type == "text":
|
195 |
result_messages.append({"role": "assistant", "content": content.text})
|
196 |
claude_messages.append({"role": "assistant", "content": content.text})
|
|
|
|
|
|
|
197 |
|
198 |
# if (
|
199 |
# auto_calls < MAX_CALLS
|
@@ -242,6 +297,8 @@ class MCPClientWrapper:
|
|
242 |
},
|
243 |
}
|
244 |
)
|
|
|
|
|
245 |
|
246 |
result_messages.append(
|
247 |
{
|
@@ -256,6 +313,8 @@ class MCPClientWrapper:
|
|
256 |
},
|
257 |
}
|
258 |
)
|
|
|
|
|
259 |
|
260 |
print(f"Calling tool: {tool_name} with args: {tool_args}")
|
261 |
result = await self.session.call_tool(tool_name, tool_args)
|
@@ -274,6 +333,9 @@ class MCPClientWrapper:
|
|
274 |
},
|
275 |
}
|
276 |
)
|
|
|
|
|
|
|
277 |
|
278 |
result_content = result.content
|
279 |
print(result_content)
|
@@ -304,6 +366,9 @@ class MCPClientWrapper:
|
|
304 |
},
|
305 |
}
|
306 |
)
|
|
|
|
|
|
|
307 |
|
308 |
# claude_messages.append(
|
309 |
# {
|
@@ -337,7 +402,7 @@ class MCPClientWrapper:
|
|
337 |
)
|
338 |
auto_calls += 1
|
339 |
except OverloadedError:
|
340 |
-
|
341 |
{
|
342 |
"role": "assistant",
|
343 |
"content": "The LLM API is overloaded now, try again later...",
|
@@ -353,7 +418,7 @@ class MCPClientWrapper:
|
|
353 |
# {"role": "assistant", "content": next_response.content[0].text}
|
354 |
# )
|
355 |
|
356 |
-
|
357 |
|
358 |
|
359 |
client = MCPClientWrapper()
|
|
|
11 |
from anthropic import Anthropic
|
12 |
from anthropic._exceptions import OverloadedError
|
13 |
from dotenv import load_dotenv
|
14 |
+
import functools
|
15 |
|
16 |
load_dotenv()
|
17 |
|
|
|
|
|
|
|
18 |
# SYSTEM_PROMPT = f"""You are a helpful assistant and today is {datetime.now().strftime("%Y-%m-%d")}.
|
19 |
|
20 |
# You do not have any knowledge of the World Development Indicators (WDI) data. However, you can use the tools provided to answer questions.
|
|
|
74 |
# When a tool is needed for any step, ensure to add the token `TOOL_USE`.
|
75 |
|
76 |
|
77 |
+
loop = asyncio.new_event_loop()
|
78 |
+
asyncio.set_event_loop(loop)
|
79 |
+
|
80 |
+
|
81 |
class MCPClientWrapper:
|
82 |
def __init__(self):
|
83 |
self.session = None
|
|
|
85 |
self.anthropic = Anthropic()
|
86 |
self.tools = []
|
87 |
|
88 |
+
# def connect(self, server_path: str) -> str:
|
89 |
+
# return loop.run_until_complete(self._connect(server_path))
|
90 |
+
|
91 |
+
# async def _connect(self, server_path: str) -> str:
|
92 |
+
# if self.exit_stack:
|
93 |
+
# await self.exit_stack.aclose()
|
94 |
+
|
95 |
+
# self.exit_stack = AsyncExitStack()
|
96 |
+
|
97 |
+
# is_python = server_path.endswith(".py")
|
98 |
+
# command = "python" if is_python else "node"
|
99 |
+
|
100 |
+
# server_params = StdioServerParameters(
|
101 |
+
# command=command,
|
102 |
+
# args=[server_path],
|
103 |
+
# env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"},
|
104 |
+
# )
|
105 |
+
|
106 |
+
# stdio_transport = await self.exit_stack.enter_async_context(
|
107 |
+
# stdio_client(server_params)
|
108 |
+
# )
|
109 |
+
# self.stdio, self.write = stdio_transport
|
110 |
+
|
111 |
+
# self.session = await self.exit_stack.enter_async_context(
|
112 |
+
# ClientSession(self.stdio, self.write)
|
113 |
+
# )
|
114 |
+
# await self.session.initialize()
|
115 |
|
116 |
+
# response = await self.session.list_tools()
|
117 |
+
# self.tools = [
|
118 |
+
# {
|
119 |
+
# "name": tool.name,
|
120 |
+
# "description": tool.description,
|
121 |
+
# "input_schema": tool.inputSchema,
|
122 |
+
# }
|
123 |
+
# for tool in response.tools
|
124 |
+
# ]
|
125 |
+
|
126 |
+
# print(self.tools)
|
127 |
+
|
128 |
+
# tool_names = [tool["name"] for tool in self.tools]
|
129 |
+
# return f"Connected to MCP server. Available tools: {', '.join(tool_names)}"
|
130 |
+
|
131 |
+
async def connect(self, server_path: str) -> str:
|
132 |
+
# If there's an existing session, close it
|
133 |
if self.exit_stack:
|
134 |
await self.exit_stack.aclose()
|
135 |
|
136 |
self.exit_stack = AsyncExitStack()
|
|
|
137 |
is_python = server_path.endswith(".py")
|
138 |
command = "python" if is_python else "node"
|
139 |
|
|
|
143 |
env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"},
|
144 |
)
|
145 |
|
146 |
+
# Launch MCP subprocess and bind streams on the *current* running loop
|
147 |
stdio_transport = await self.exit_stack.enter_async_context(
|
148 |
stdio_client(server_params)
|
149 |
)
|
150 |
self.stdio, self.write = stdio_transport
|
151 |
|
152 |
+
# Create ClientSession on this same loop
|
153 |
self.session = await self.exit_stack.enter_async_context(
|
154 |
ClientSession(self.stdio, self.write)
|
155 |
)
|
|
|
165 |
for tool in response.tools
|
166 |
]
|
167 |
|
168 |
+
print("Available tools:", self.tools)
|
|
|
169 |
tool_names = [tool["name"] for tool in self.tools]
|
170 |
return f"Connected to MCP server. Available tools: {', '.join(tool_names)}"
|
171 |
|
172 |
+
async def process_message(
|
173 |
self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
|
174 |
+
):
|
175 |
if not self.session:
|
176 |
messages = history + [
|
177 |
{"role": "user", "content": message},
|
|
|
180 |
"content": "Please connect to an MCP server first.",
|
181 |
},
|
182 |
]
|
183 |
+
yield messages, gr.Textbox(value="")
|
184 |
else:
|
185 |
+
# new_messages = loop.run_until_complete(
|
186 |
+
# self._process_query(message, history)
|
187 |
+
# )
|
188 |
+
# messages = history + [{"role": "user", "content": message}] + new_messages
|
189 |
+
|
190 |
+
messages = history + [{"role": "user", "content": message}]
|
191 |
+
|
192 |
+
yield messages, gr.Textbox(value="")
|
193 |
+
|
194 |
+
async for partial in self._process_query(message, history):
|
195 |
+
messages.extend(partial)
|
196 |
+
|
197 |
+
yield messages, gr.Textbox(value="")
|
198 |
|
199 |
with open("messages.log.jsonl", "a+") as fl:
|
200 |
fl.write(json.dumps(dict(time=f"{datetime.now()}", messages=messages)))
|
201 |
|
|
|
|
|
202 |
async def _process_query(
|
203 |
self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
|
204 |
):
|
|
|
224 |
tools=self.tools,
|
225 |
)
|
226 |
except OverloadedError:
|
227 |
+
yield [
|
228 |
{
|
229 |
"role": "assistant",
|
230 |
"content": "The LLM API is overloaded now, try again later...",
|
|
|
232 |
]
|
233 |
|
234 |
result_messages = []
|
235 |
+
partial_messages = []
|
236 |
|
237 |
print(response.content)
|
238 |
contents = response.content
|
|
|
246 |
if content.type == "text":
|
247 |
result_messages.append({"role": "assistant", "content": content.text})
|
248 |
claude_messages.append({"role": "assistant", "content": content.text})
|
249 |
+
partial_messages.append(result_messages[-1])
|
250 |
+
yield [result_messages[-1]]
|
251 |
+
partial_messages = []
|
252 |
|
253 |
# if (
|
254 |
# auto_calls < MAX_CALLS
|
|
|
297 |
},
|
298 |
}
|
299 |
)
|
300 |
+
partial_messages.append(result_messages[-1])
|
301 |
+
yield [result_messages[-1]]
|
302 |
|
303 |
result_messages.append(
|
304 |
{
|
|
|
313 |
},
|
314 |
}
|
315 |
)
|
316 |
+
partial_messages.append(result_messages[-1])
|
317 |
+
yield [result_messages[-1]]
|
318 |
|
319 |
print(f"Calling tool: {tool_name} with args: {tool_args}")
|
320 |
result = await self.session.call_tool(tool_name, tool_args)
|
|
|
333 |
},
|
334 |
}
|
335 |
)
|
336 |
+
partial_messages.append(result_messages[-1])
|
337 |
+
yield [result_messages[-1]]
|
338 |
+
partial_messages = []
|
339 |
|
340 |
result_content = result.content
|
341 |
print(result_content)
|
|
|
366 |
},
|
367 |
}
|
368 |
)
|
369 |
+
partial_messages.append(result_messages[-1])
|
370 |
+
yield [result_messages[-1]]
|
371 |
+
partial_messages = []
|
372 |
|
373 |
# claude_messages.append(
|
374 |
# {
|
|
|
402 |
)
|
403 |
auto_calls += 1
|
404 |
except OverloadedError:
|
405 |
+
yield [
|
406 |
{
|
407 |
"role": "assistant",
|
408 |
"content": "The LLM API is overloaded now, try again later...",
|
|
|
418 |
# {"role": "assistant", "content": next_response.content[0].text}
|
419 |
# )
|
420 |
|
421 |
+
# yield result_messages
|
422 |
|
423 |
|
424 |
client = MCPClientWrapper()
|