avsolatorio commited on
Commit
37530e6
·
1 Parent(s): 8822f57

Add support for remote sse mcp server

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (1) hide show
  1. mcp_remote_client.py +387 -0
mcp_remote_client.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import json
4
+ from typing import List, Dict, Any, Union
5
+ from contextlib import AsyncExitStack
6
+ from datetime import datetime
7
+ import gradio as gr
8
+ from gradio.components.chatbot import ChatMessage
9
+ from mcp import ClientSession, StdioServerParameters
10
+ from mcp.client.stdio import stdio_client
11
+ from mcp.client.sse import sse_client
12
+ from anthropic import Anthropic
13
+ from anthropic._exceptions import OverloadedError
14
+ from dotenv import load_dotenv
15
+
16
+
17
+ load_dotenv()
18
+
19
+ SYSTEM_PROMPT = f"""You are a helpful assistant. Today is {datetime.now().strftime("%Y-%m-%d")}.
20
+
21
+ You **do not** have prior knowledge of the World Development Indicators (WDI) data. Instead, you must rely entirely on the tools available to you to answer the user's questions.
22
+
23
+ When responding you must always plan the steps and enumerate all the tools that you plan to use to answer the user's query.
24
+
25
+ ### Your Instructions:
26
+
27
+ 1. **Tool Use Only**:
28
+ - You must not provide any answers based on prior knowledge or assumptions.
29
+ - You must **not** fabricate data or simulate the behavior of the `get_wdi_data` tool.
30
+ - You cannot use the `get_wdi_data` tool without using the `search_relevant_indicators` tool first.
31
+ - If the user requests WDI data, you **MUST ALWAYS** first call the `search_relevant_indicators` tool to see if there's any relevant data.
32
+ - If relevant data exists, call the `get_wdi_data` tool to get the data.
33
+
34
+ 2. **Tool Invocation**:
35
+ - Use any relevant tools provided to you to answer the user's question.
36
+ - You may call multiple tools if needed, and you should do so in a logical sequence to minimize unnecessary user interaction.
37
+ - Do not hesitate to invoke tools as soon as they are relevant.
38
+
39
+ 3. **Limitations**:
40
+ - If a user request cannot be fulfilled using the tools available, respond by clearly stating that you do not have access to that information.
41
+
42
+ 4. **Ethical Guidelines**:
43
+ - Do not make or endorse statements based on stereotypes, bias, or assumptions.
44
+ - Ensure all claims and explanations are grounded in the data or factual evidence retrieved via tools.
45
+ - Politely refuse to respond to requests that involve stereotypes or unfounded generalizations.
46
+
47
+ 5. **Communication Style**:
48
+ - Present the data in clear, user-friendly language.
49
+ - You may summarize or explain the data retrieved, but do **not** elaborate based on outside or implicit knowledge.
50
+ - You may describe the data in a way that is easy to understand but you MUST NOT elaborate based on external knowledge.
51
+
52
+ Stay strictly within these boundaries while maintaining a helpful and respectful tone."""
53
+
54
+
55
+ LLM_MODEL = "claude-3-5-haiku-20241022"
56
+ # What is the military spending of bangladesh in 2014?
57
+ # When a tool is needed for any step, ensure to add the token `TOOL_USE`.
58
+
59
+
60
+ loop = asyncio.new_event_loop()
61
+ asyncio.set_event_loop(loop)
62
+
63
+
64
+ class MCPClientWrapper:
65
+ def __init__(self):
66
+ self.session = None
67
+ self.exit_stack = None
68
+ self.anthropic = Anthropic()
69
+ self.tools = []
70
+
71
+ async def connect(self, server_path_or_url: str) -> str:
72
+ # If there's an existing session, close it
73
+ if self.exit_stack:
74
+ await self.exit_stack.aclose()
75
+
76
+ self.exit_stack = AsyncExitStack()
77
+
78
+ if server_path_or_url.endswith(".py"):
79
+ command = "python"
80
+
81
+ server_params = StdioServerParameters(
82
+ command=command,
83
+ args=[server_path_or_url],
84
+ env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"},
85
+ )
86
+
87
+ # Launch MCP subprocess and bind streams on the *current* running loop
88
+ stdio_transport = await self.exit_stack.enter_async_context(
89
+ stdio_client(server_params)
90
+ )
91
+ self.stdio, self.write = stdio_transport
92
+ else:
93
+ sse_transport = await self.exit_stack.enter_async_context(
94
+ sse_client(
95
+ server_path_or_url,
96
+ headers={"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"},
97
+ )
98
+ )
99
+ self.stdio, self.write = sse_transport
100
+
101
+ # Create ClientSession on this same loop
102
+ self.session = await self.exit_stack.enter_async_context(
103
+ ClientSession(self.stdio, self.write)
104
+ )
105
+ await self.session.initialize()
106
+
107
+ response = await self.session.list_tools()
108
+ self.tools = [
109
+ {
110
+ "name": tool.name,
111
+ "description": tool.description,
112
+ "input_schema": tool.inputSchema,
113
+ }
114
+ for tool in response.tools
115
+ ]
116
+
117
+ print("Available tools:", self.tools)
118
+ tool_names = [tool["name"] for tool in self.tools]
119
+ return f"Connected to MCP server. Available tools: {', '.join(tool_names)}"
120
+
121
+ async def process_message(
122
+ self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
123
+ ):
124
+ if not self.session:
125
+ messages = history + [
126
+ {"role": "user", "content": message},
127
+ {
128
+ "role": "assistant",
129
+ "content": "Please connect to an MCP server first.",
130
+ },
131
+ ]
132
+ yield messages, gr.Textbox(value="")
133
+ else:
134
+ messages = history + [{"role": "user", "content": message}]
135
+
136
+ yield messages, gr.Textbox(value="")
137
+
138
+ async for partial in self._process_query(message, history):
139
+ messages.extend(partial)
140
+
141
+ yield messages, gr.Textbox(value="")
142
+
143
+ with open("messages.log.jsonl", "a+") as fl:
144
+ fl.write(json.dumps(dict(time=f"{datetime.now()}", messages=messages)))
145
+
146
+ async def _process_query(
147
+ self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
148
+ ):
149
+ claude_messages = []
150
+ for msg in history:
151
+ if isinstance(msg, ChatMessage):
152
+ role, content = msg.role, msg.content
153
+ else:
154
+ role, content = msg.get("role"), msg.get("content")
155
+
156
+ if role in ["user", "assistant", "system"]:
157
+ claude_messages.append({"role": role, "content": content})
158
+
159
+ claude_messages.append({"role": "user", "content": message})
160
+
161
+ try:
162
+ response = self.anthropic.messages.create(
163
+ # model="claude-3-5-sonnet-20241022",
164
+ model=LLM_MODEL,
165
+ system=SYSTEM_PROMPT,
166
+ max_tokens=1000,
167
+ messages=claude_messages,
168
+ tools=self.tools,
169
+ )
170
+ except OverloadedError:
171
+ yield [
172
+ {
173
+ "role": "assistant",
174
+ "content": "The LLM API is overloaded now, try again later...",
175
+ }
176
+ ]
177
+
178
+ result_messages = []
179
+ partial_messages = []
180
+
181
+ print(response.content)
182
+ contents = response.content
183
+
184
+ MAX_CALLS = 10
185
+ auto_calls = 0
186
+
187
+ while len(contents) > 0 and auto_calls < MAX_CALLS:
188
+ content = contents.pop(0)
189
+
190
+ if content.type == "text":
191
+ result_messages.append({"role": "assistant", "content": content.text})
192
+ claude_messages.append({"role": "assistant", "content": content.text})
193
+ partial_messages.append(result_messages[-1])
194
+ yield [result_messages[-1]]
195
+ partial_messages = []
196
+
197
+ elif content.type == "tool_use":
198
+ tool_id = content.id
199
+ tool_name = content.name
200
+ tool_args = content.input
201
+
202
+ result_messages.append(
203
+ {
204
+ "role": "assistant",
205
+ "content": f"I'll use the {tool_name} tool to help answer your question.",
206
+ "metadata": {
207
+ "title": f"Using tool: {tool_name}",
208
+ "log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}",
209
+ "status": "pending",
210
+ "id": f"tool_call_{tool_name}",
211
+ },
212
+ }
213
+ )
214
+ partial_messages.append(result_messages[-1])
215
+ yield [result_messages[-1]]
216
+
217
+ result_messages.append(
218
+ {
219
+ "role": "assistant",
220
+ "content": "```json\n"
221
+ + json.dumps(tool_args, indent=2, ensure_ascii=True)
222
+ + "\n```",
223
+ "metadata": {
224
+ "parent_id": f"tool_call_{tool_name}",
225
+ "id": f"params_{tool_name}",
226
+ "title": "Tool Parameters",
227
+ },
228
+ }
229
+ )
230
+ partial_messages.append(result_messages[-1])
231
+ yield [result_messages[-1]]
232
+
233
+ print(f"Calling tool: {tool_name} with args: {tool_args}")
234
+ result = await self.session.call_tool(tool_name, tool_args)
235
+
236
+ if result_messages and "metadata" in result_messages[-2]:
237
+ result_messages[-2]["metadata"]["status"] = "done"
238
+
239
+ result_messages.append(
240
+ {
241
+ "role": "assistant",
242
+ "content": "Here are the results from the tool:",
243
+ "metadata": {
244
+ "title": f"Tool Result for {tool_name}",
245
+ "status": "done",
246
+ "id": f"result_{tool_name}",
247
+ },
248
+ }
249
+ )
250
+ partial_messages.append(result_messages[-1])
251
+ yield [result_messages[-1]]
252
+ partial_messages = []
253
+
254
+ result_content = result.content
255
+ print(result_content)
256
+ if isinstance(result_content, list):
257
+ result_content = [r.model_dump() for r in result_content]
258
+
259
+ for r in result_content:
260
+ # Remove annotations field from each item if it exists
261
+ r.pop("annotations", None)
262
+ try:
263
+ r["text"] = json.loads(r["text"])
264
+ except:
265
+ pass
266
+
267
+ print("result_content", result_content)
268
+
269
+ result_messages.append(
270
+ {
271
+ "role": "assistant",
272
+ "content": "```\n"
273
+ + json.dumps(result_content, indent=2)
274
+ + "\n```",
275
+ "metadata": {
276
+ "parent_id": f"result_{tool_name}",
277
+ "id": f"raw_result_{tool_name}",
278
+ "title": "Raw Output",
279
+ },
280
+ }
281
+ )
282
+ partial_messages.append(result_messages[-1])
283
+ yield [result_messages[-1]]
284
+ partial_messages = []
285
+
286
+ claude_messages.append(
287
+ {"role": "assistant", "content": [content.model_dump()]}
288
+ )
289
+ claude_messages.append(
290
+ {
291
+ "role": "user",
292
+ "content": [
293
+ {
294
+ "type": "tool_result",
295
+ "tool_use_id": tool_id,
296
+ "content": json.dumps(result_content, indent=2),
297
+ }
298
+ ],
299
+ }
300
+ )
301
+
302
+ try:
303
+ next_response = self.anthropic.messages.create(
304
+ model=LLM_MODEL,
305
+ system=SYSTEM_PROMPT,
306
+ max_tokens=1000,
307
+ messages=claude_messages,
308
+ tools=self.tools,
309
+ )
310
+ auto_calls += 1
311
+ except OverloadedError:
312
+ yield [
313
+ {
314
+ "role": "assistant",
315
+ "content": "The LLM API is overloaded now, try again later...",
316
+ }
317
+ ]
318
+
319
+ print("next_response", next_response.content)
320
+
321
+ contents.extend(next_response.content)
322
+
323
+
324
+ def gradio_interface(
325
+ server_path_or_url: str = "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse",
326
+ ):
327
+ # server_path_or_url = "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse"
328
+ # server_path_or_url = "wdi_mcp_server.py"
329
+
330
+ client = MCPClientWrapper()
331
+
332
+ with gr.Blocks(title="WDI MCP Client") as demo:
333
+ gr.Markdown("## Ask about the World Development Indicators (WDI) data")
334
+ # gr.Markdown("Connect to the WDI MCP server and chat with the assistant")
335
+
336
+ with gr.Accordion(
337
+ "Connect to the WDI MCP server and chat with the assistant",
338
+ open=False,
339
+ visible=server_path_or_url.endswith(".py"),
340
+ ):
341
+ with gr.Row(equal_height=True):
342
+ with gr.Column(scale=4):
343
+ server_path = gr.Textbox(
344
+ label="Server Script Path",
345
+ placeholder="Enter path to server script (e.g., wdi_mcp_server.py)",
346
+ value=server_path_or_url,
347
+ )
348
+ with gr.Column(scale=1):
349
+ connect_btn = gr.Button("Connect")
350
+
351
+ status = gr.Textbox(label="Connection Status", interactive=False)
352
+
353
+ chatbot = gr.Chatbot(
354
+ value=[],
355
+ height=600,
356
+ type="messages",
357
+ show_copy_button=True,
358
+ avatar_images=("img/small-user.png", "img/small-robot.png"),
359
+ autoscroll=True,
360
+ )
361
+
362
+ with gr.Row(equal_height=True):
363
+ msg = gr.Textbox(
364
+ label="Your Question",
365
+ placeholder="Ask about what indicators are available for a specific topic (e.g., What's the definition of GDP?)",
366
+ scale=4,
367
+ )
368
+ clear_btn = gr.Button("Clear Chat", scale=1)
369
+
370
+ connect_btn.click(client.connect, inputs=server_path, outputs=status)
371
+ # Automatically call client.connect(...) as soon as the interface loads
372
+ demo.load(fn=client.connect, inputs=server_path, outputs=status)
373
+
374
+ msg.submit(client.process_message, [msg, chatbot], [chatbot, msg])
375
+ clear_btn.click(lambda: [], None, chatbot)
376
+
377
+ return demo
378
+
379
+
380
+ if __name__ == "__main__":
381
+ if not os.getenv("ANTHROPIC_API_KEY"):
382
+ print(
383
+ "Warning: ANTHROPIC_API_KEY not found in environment. Please set it in your .env file."
384
+ )
385
+
386
+ interface = gradio_interface()
387
+ interface.launch(server_name=os.getenv("SERVER_NAME", "127.0.0.1"), debug=True)