avsolatorio commited on
Commit
8cd6fde
·
1 Parent(s): 11faf9a

Duplicate code

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (1) hide show
  1. mcp_openai_client.py +508 -0
mcp_openai_client.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import json
4
+ from typing import List, Dict, Any, Union
5
+ from contextlib import AsyncExitStack
6
+ from datetime import datetime
7
+ import gradio as gr
8
+ from gradio.components.chatbot import ChatMessage
9
+ from mcp import ClientSession, StdioServerParameters
10
+ from mcp.client.stdio import stdio_client
11
+ from mcp.client.sse import sse_client
12
+ from anthropic import Anthropic
13
+ from anthropic._exceptions import OverloadedError
14
+ from dotenv import load_dotenv
15
+
16
+
17
+ load_dotenv()
18
+
19
+ SYSTEM_PROMPT = f"""You are a helpful assistant. Today is {datetime.now().strftime("%Y-%m-%d")}.
20
+
21
+ You **do not** have prior knowledge of the World Development Indicators (WDI) data. Instead, you must rely entirely on the tools available to you to answer the user's questions.
22
+
23
+ When responding you must always plan the steps and enumerate all the tools that you plan to use to answer the user's query.
24
+
25
+ ### Your Instructions:
26
+
27
+ 1. **Tool Use Only**:
28
+ - You must not provide any answers based on prior knowledge or assumptions.
29
+ - You must **not** fabricate data or simulate the behavior of the `get_wdi_data` tool.
30
+ - You cannot use the `get_wdi_data` tool without using the `search_relevant_indicators` tool first.
31
+ - If the user requests WDI data, you **MUST ALWAYS** first call the `search_relevant_indicators` tool to see if there's any relevant data.
32
+ - If relevant data exists, call the `get_wdi_data` tool to get the data.
33
+
34
+ 2. **Tool Invocation**:
35
+ - Use any relevant tools provided to you to answer the user's question.
36
+ - You may call multiple tools if needed, and you should do so in a logical sequence to minimize unnecessary user interaction.
37
+ - Do not hesitate to invoke tools as soon as they are relevant.
38
+
39
+ 3. **Limitations**:
40
+ - If a user request cannot be fulfilled using the tools available, respond by clearly stating that you do not have access to that information.
41
+
42
+ 4. **Ethical Guidelines**:
43
+ - Do not make or endorse statements based on stereotypes, bias, or assumptions.
44
+ - Ensure all claims and explanations are grounded in the data or factual evidence retrieved via tools.
45
+ - Politely refuse to respond to requests that involve stereotypes or unfounded generalizations.
46
+
47
+ 5. **Communication Style**:
48
+ - Present the data in clear, user-friendly language.
49
+ - You may summarize or explain the data retrieved, but do **not** elaborate based on outside or implicit knowledge.
50
+ - You may describe the data in a way that is easy to understand but you MUST NOT elaborate based on external knowledge.
51
+
52
+ 6. **Presentation**:
53
+ - Present the data in a way that is easy to understand.
54
+ - Summarize the data in a table format with clear column names and values.
55
+ - If the data is not available, respond by clearly stating that you do not have access to that information.
56
+
57
+ Stay strictly within these boundaries while maintaining a helpful and respectful tone."""
58
+
59
+
60
+ LLM_MODEL = "claude-3-5-haiku-20241022"
61
+ # What is the military spending of bangladesh in 2014?
62
+ # When a tool is needed for any step, ensure to add the token `TOOL_USE`.
63
+
64
+
65
+ loop = asyncio.new_event_loop()
66
+ asyncio.set_event_loop(loop)
67
+
68
+
69
+ class MCPClientWrapper:
70
+ def __init__(self):
71
+ self.session = None
72
+ self.exit_stack = None
73
+ self.anthropic = Anthropic()
74
+ self.tools = []
75
+
76
+ async def connect(self, server_path_or_url: str) -> str:
77
+ try:
78
+ # If there's an existing session, close it
79
+ if self.exit_stack:
80
+ return "Already connected to an MCP server. Please disconnect first."
81
+ # await self.exit_stack.aclose()
82
+
83
+ self.exit_stack = AsyncExitStack()
84
+
85
+ if server_path_or_url.endswith(".py"):
86
+ command = "python"
87
+
88
+ server_params = StdioServerParameters(
89
+ command=command,
90
+ args=[server_path_or_url],
91
+ env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"},
92
+ )
93
+
94
+ print(
95
+ f"Starting MCP server with command: {command} {server_path_or_url}"
96
+ )
97
+ # Launch MCP subprocess and bind streams on the current running loop
98
+ stdio_transport = await self.exit_stack.enter_async_context(
99
+ stdio_client(server_params)
100
+ )
101
+ self.stdio, self.write = stdio_transport
102
+ else:
103
+ print(f"Connecting to MCP server at: {server_path_or_url}")
104
+ sse_transport = await self.exit_stack.enter_async_context(
105
+ sse_client(
106
+ server_path_or_url,
107
+ headers={"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"},
108
+ )
109
+ )
110
+ self.stdio, self.write = sse_transport
111
+
112
+ print("Creating MCP client session...")
113
+ # Create ClientSession on this same loop
114
+ self.session = await self.exit_stack.enter_async_context(
115
+ ClientSession(self.stdio, self.write)
116
+ )
117
+ await self.session.initialize()
118
+ print("MCP session initialized successfully")
119
+
120
+ response = await self.session.list_tools()
121
+ self.tools = [
122
+ {
123
+ "name": tool.name,
124
+ "description": tool.description,
125
+ "input_schema": tool.inputSchema,
126
+ }
127
+ for tool in response.tools
128
+ ]
129
+
130
+ print("Available tools:", self.tools)
131
+ tool_names = [tool["name"] for tool in self.tools]
132
+ return f"Connected to MCP server. Available tools: {', '.join(tool_names)}"
133
+ except Exception as e:
134
+ error_msg = f"Failed to connect to MCP server: {str(e)}"
135
+ print(error_msg)
136
+ # Clean up on error
137
+ if self.exit_stack:
138
+ await self.exit_stack.aclose()
139
+ self.exit_stack = None
140
+ self.session = None
141
+ return error_msg
142
+
143
+ async def disconnect(self):
144
+ if self.exit_stack:
145
+ print("Disconnecting from MCP server...")
146
+ await self.exit_stack.aclose()
147
+ self.exit_stack = None
148
+ self.session = None
149
+
150
+ async def process_message(
151
+ self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
152
+ ):
153
+ if not self.session:
154
+ messages = history + [
155
+ {"role": "user", "content": message},
156
+ {
157
+ "role": "assistant",
158
+ "content": "Please connect to an MCP server first by reloading the page.",
159
+ },
160
+ ]
161
+ yield messages, gr.Textbox(value="")
162
+ else:
163
+ messages = history + [
164
+ {"role": "user", "content": message},
165
+ {
166
+ "role": "assistant",
167
+ "content": "Ok, let me think about your query 🤔...",
168
+ },
169
+ ]
170
+
171
+ yield messages, gr.Textbox(value="")
172
+ # simulate thinking with asyncio.sleep
173
+ await asyncio.sleep(0.1)
174
+ messages.pop(-1)
175
+
176
+ async for partial in self._process_query(message, history):
177
+ messages.extend(partial)
178
+ yield messages, gr.Textbox(value="")
179
+ await asyncio.sleep(0.05)
180
+
181
+ if (
182
+ messages[-1]["role"] == "assistant"
183
+ and messages[-1]["content"]
184
+ == "The LLM API is overloaded now, try again later..."
185
+ ):
186
+ break
187
+
188
+ with open("messages.log.jsonl", "a+") as fl:
189
+ fl.write(json.dumps(dict(time=f"{datetime.now()}", messages=messages)))
190
+
191
+ async def _process_query(
192
+ self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
193
+ ):
194
+ claude_messages = []
195
+ for msg in history:
196
+ if isinstance(msg, ChatMessage):
197
+ role, content = msg.role, msg.content
198
+ else:
199
+ role, content = msg.get("role"), msg.get("content")
200
+
201
+ if role in ["user", "assistant", "system"]:
202
+ claude_messages.append({"role": role, "content": content})
203
+
204
+ claude_messages.append({"role": "user", "content": message})
205
+
206
+ try:
207
+ response = self.anthropic.messages.create(
208
+ # model="claude-3-5-sonnet-20241022",
209
+ model=LLM_MODEL,
210
+ system=SYSTEM_PROMPT,
211
+ max_tokens=1000,
212
+ messages=claude_messages,
213
+ tools=self.tools,
214
+ )
215
+ except OverloadedError:
216
+ yield [
217
+ {
218
+ "role": "assistant",
219
+ "content": "The LLM API is overloaded now, try again later...",
220
+ }
221
+ ]
222
+ # TODO: Add a retry mechanism
223
+
224
+ result_messages = []
225
+ partial_messages = []
226
+
227
+ print(response.content)
228
+ contents = response.content
229
+
230
+ MAX_CALLS = 10
231
+ auto_calls = 0
232
+
233
+ while len(contents) > 0 and auto_calls < MAX_CALLS:
234
+ content = contents.pop(0)
235
+
236
+ if content.type == "text":
237
+ result_messages.append({"role": "assistant", "content": content.text})
238
+ claude_messages.append({"role": "assistant", "content": content.text})
239
+ partial_messages.append(result_messages[-1])
240
+ yield [result_messages[-1]]
241
+ partial_messages = []
242
+
243
+ elif content.type == "tool_use":
244
+ tool_id = content.id
245
+ tool_name = content.name
246
+ tool_args = content.input
247
+
248
+ result_messages.append(
249
+ {
250
+ "role": "assistant",
251
+ "content": f"I'll use the {tool_name} tool to help answer your question.",
252
+ "metadata": {
253
+ "title": f"Using tool: {tool_name.replace('avsolatorio_test_data_mcp_server', '')}",
254
+ "log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}",
255
+ # "status": "pending",
256
+ "status": "done",
257
+ "id": f"tool_call_{tool_name}",
258
+ },
259
+ }
260
+ )
261
+ partial_messages.append(result_messages[-1])
262
+ yield [result_messages[-1]]
263
+
264
+ result_messages.append(
265
+ {
266
+ "role": "assistant",
267
+ "content": "```json\n"
268
+ + json.dumps(tool_args, indent=2, ensure_ascii=True)
269
+ + "\n```",
270
+ "metadata": {
271
+ "parent_id": f"tool_call_{tool_name}",
272
+ "id": f"params_{tool_name}",
273
+ "title": "Tool Parameters",
274
+ },
275
+ }
276
+ )
277
+ partial_messages.append(result_messages[-1])
278
+ yield [result_messages[-1]]
279
+
280
+ print(f"Calling tool: {tool_name} with args: {tool_args}")
281
+ try:
282
+ # Check if session is still valid
283
+ if not self.session or not self.stdio or not self.write:
284
+ raise Exception(
285
+ "MCP session is not connected or has been closed"
286
+ )
287
+
288
+ result = await self.session.call_tool(tool_name, tool_args)
289
+ except Exception as e:
290
+ error_msg = f"Error calling tool {tool_name}: {str(e)}"
291
+ print(error_msg)
292
+ result_messages.append(
293
+ {
294
+ "role": "assistant",
295
+ "content": f"Sorry, I encountered an error while calling the tool: {error_msg}. Please try again.",
296
+ "metadata": {
297
+ "title": f"Tool Error for {tool_name.replace('avsolatorio_test_data_mcp_server', '')}",
298
+ "status": "error",
299
+ "id": f"error_{tool_name}",
300
+ },
301
+ }
302
+ )
303
+ partial_messages.append(result_messages[-1])
304
+ yield [result_messages[-1]]
305
+ partial_messages = []
306
+ continue
307
+
308
+ if result_messages and "metadata" in result_messages[-2]:
309
+ result_messages[-2]["metadata"]["status"] = "done"
310
+
311
+ result_messages.append(
312
+ {
313
+ "role": "assistant",
314
+ "content": "Here are the results from the tool:",
315
+ "metadata": {
316
+ "title": f"Tool Result for {tool_name.replace('avsolatorio_test_data_mcp_server', '')}",
317
+ "status": "done",
318
+ "id": f"result_{tool_name}",
319
+ },
320
+ }
321
+ )
322
+ partial_messages.append(result_messages[-1])
323
+ yield [result_messages[-1]]
324
+ partial_messages = []
325
+
326
+ result_content = result.content
327
+ print(result_content)
328
+ if isinstance(result_content, list):
329
+ result_content = [r.model_dump() for r in result_content]
330
+
331
+ for r in result_content:
332
+ # Remove annotations field from each item if it exists
333
+ r.pop("annotations", None)
334
+ try:
335
+ r["text"] = json.loads(r["text"])
336
+ except:
337
+ pass
338
+
339
+ print("result_content", result_content)
340
+
341
+ result_messages.append(
342
+ {
343
+ "role": "assistant",
344
+ "content": "```\n"
345
+ + json.dumps(result_content, indent=2)
346
+ + "\n```",
347
+ "metadata": {
348
+ "parent_id": f"result_{tool_name}",
349
+ "id": f"raw_result_{tool_name}",
350
+ "title": "Raw Output",
351
+ },
352
+ }
353
+ )
354
+ partial_messages.append(result_messages[-1])
355
+ yield [result_messages[-1]]
356
+ partial_messages = []
357
+
358
+ claude_messages.append(
359
+ {"role": "assistant", "content": [content.model_dump()]}
360
+ )
361
+ claude_messages.append(
362
+ {
363
+ "role": "user",
364
+ "content": [
365
+ {
366
+ "type": "tool_result",
367
+ "tool_use_id": tool_id,
368
+ "content": json.dumps(result_content, indent=2),
369
+ }
370
+ ],
371
+ }
372
+ )
373
+
374
+ try:
375
+ next_response = self.anthropic.messages.create(
376
+ model=LLM_MODEL,
377
+ system=SYSTEM_PROMPT,
378
+ max_tokens=1000,
379
+ messages=claude_messages,
380
+ tools=self.tools,
381
+ )
382
+ auto_calls += 1
383
+ except OverloadedError:
384
+ yield [
385
+ {
386
+ "role": "assistant",
387
+ "content": "The LLM API is overloaded now, try again later...",
388
+ }
389
+ ]
390
+
391
+ print("next_response", next_response.content)
392
+
393
+ contents.extend(next_response.content)
394
+
395
+
396
+ def gradio_interface(
397
+ server_path_or_url: str = "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse",
398
+ ):
399
+ # server_path_or_url = "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse"
400
+ # server_path_or_url = "wdi_mcp_server.py"
401
+
402
+ client = MCPClientWrapper()
403
+ custom_css = """
404
+ .gradio-container {
405
+ background-color: #fff !important;
406
+ }
407
+ .message-row.panel.bot-row {
408
+ background-color: #fff !important;
409
+ }
410
+ .message-row.panel.user-row {
411
+ background-color: #fff !important;
412
+ }
413
+ .user {
414
+ background-color: #f1f6ff !important;
415
+ }
416
+ .bot {
417
+ background-color: #fff !important;
418
+ }
419
+ .role {
420
+ margin-left: 10px !important;
421
+ }
422
+ footer{display:none !important}
423
+ """
424
+
425
+ # Disable auto-dark mode by setting theme to None
426
+ with gr.Blocks(title="WDI MCP Client", css=custom_css, theme=None) as demo:
427
+ try:
428
+ gr.Markdown("# Development Data Chat")
429
+ # gr.Markdown("Connect to the WDI MCP server and chat with the assistant")
430
+
431
+ with gr.Accordion(
432
+ "Connect to the WDI MCP server and chat with the assistant",
433
+ open=False,
434
+ visible=server_path_or_url.endswith(".py"),
435
+ ):
436
+ with gr.Row(equal_height=True):
437
+ with gr.Column(scale=4):
438
+ server_path = gr.Textbox(
439
+ label="Server Script Path",
440
+ placeholder="Enter path to server script (e.g., wdi_mcp_server.py)",
441
+ value=server_path_or_url,
442
+ )
443
+ with gr.Column(scale=1):
444
+ connect_btn = gr.Button("Connect")
445
+
446
+ status = gr.Textbox(label="Connection Status", interactive=False)
447
+
448
+ chatbot = gr.Chatbot(
449
+ value=[],
450
+ height="81vh",
451
+ type="messages",
452
+ show_copy_button=False,
453
+ avatar_images=("img/small-user.png", "img/small-robot.png"),
454
+ autoscroll=True,
455
+ layout="panel",
456
+ placeholder="Ask development data questions!",
457
+ )
458
+
459
+ with gr.Row(equal_height=True):
460
+ msg = gr.Textbox(
461
+ label=None,
462
+ placeholder="Ask about what indicators are available for a specific topic (e.g., What's the definition of GDP?)",
463
+ scale=4,
464
+ show_label=False,
465
+ )
466
+ # clear_btn = gr.Button("Clear Chat", scale=1)
467
+
468
+ # connect_btn.click(client.connect, inputs=server_path, outputs=status)
469
+ # Automatically call client.connect(...) as soon as the interface loads
470
+ demo.load(
471
+ fn=client.connect,
472
+ inputs=server_path,
473
+ outputs=status,
474
+ show_progress="full",
475
+ )
476
+
477
+ msg.submit(
478
+ client.process_message,
479
+ [msg, chatbot],
480
+ [chatbot, msg],
481
+ concurrency_limit=10,
482
+ )
483
+ # clear_btn.click(lambda: [], None, chatbot)
484
+
485
+ except KeyboardInterrupt:
486
+ print("Keyboard interrupt received. Disconnecting from MCP server...")
487
+ asyncio.run(client.disconnect())
488
+ raise KeyboardInterrupt
489
+ # demo.unload(client.disconnect)
490
+
491
+ return demo
492
+
493
+
494
+ if __name__ == "__main__":
495
+ if not os.getenv("ANTHROPIC_API_KEY"):
496
+ print(
497
+ "Warning: ANTHROPIC_API_KEY not found in environment. Please set it in your .env file."
498
+ )
499
+
500
+ # interface = gradio_interface(server_path_or_url="wdi_mcp_server.py")
501
+ interface = gradio_interface(
502
+ server_path_or_url="https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse"
503
+ )
504
+ interface.launch(
505
+ server_name=os.getenv("SERVER_NAME", "127.0.0.1"),
506
+ server_port=os.getenv("SERVER_PORT", 7860),
507
+ debug=True,
508
+ )