avsolatorio commited on
Commit
c812099
·
1 Parent(s): 0c92aa0

Bootstrap openai client

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (1) hide show
  1. mcp_openai_client.py +100 -14
mcp_openai_client.py CHANGED
@@ -12,10 +12,23 @@ from mcp.client.sse import sse_client
12
  from anthropic import Anthropic
13
  from anthropic._exceptions import OverloadedError
14
  from dotenv import load_dotenv
15
-
 
 
 
 
 
 
 
 
 
 
16
 
17
  load_dotenv()
18
 
 
 
 
19
  SYSTEM_PROMPT = f"""You are a helpful assistant. Today is {datetime.now().strftime("%Y-%m-%d")}.
20
 
21
  You **do not** have prior knowledge of the World Development Indicators (WDI) data. Instead, you must rely entirely on the tools available to you to answer the user's questions.
@@ -71,6 +84,7 @@ class MCPClientWrapper:
71
  self.session = None
72
  self.exit_stack = None
73
  self.anthropic = Anthropic()
 
74
  self.tools = []
75
 
76
  async def connect(self, server_path_or_url: str) -> str:
@@ -150,7 +164,7 @@ class MCPClientWrapper:
150
  async def process_message(
151
  self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
152
  ):
153
- if not self.session:
154
  messages = history + [
155
  {"role": "user", "content": message},
156
  {
@@ -173,8 +187,22 @@ class MCPClientWrapper:
173
  await asyncio.sleep(0.1)
174
  messages.pop(-1)
175
 
 
176
  async for partial in self._process_query(message, history):
177
- messages.extend(partial)
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  yield messages, gr.Textbox(value="")
179
  await asyncio.sleep(0.05)
180
 
@@ -188,7 +216,53 @@ class MCPClientWrapper:
188
  with open("messages.log.jsonl", "a+") as fl:
189
  fl.write(json.dumps(dict(time=f"{datetime.now()}", messages=messages)))
190
 
191
- async def _process_query(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
193
  ):
194
  claude_messages = []
@@ -292,10 +366,10 @@ class MCPClientWrapper:
292
  result_messages.append(
293
  {
294
  "role": "assistant",
295
- "content": f"Sorry, I encountered an error while calling the tool: {error_msg}. Please try again.",
296
  "metadata": {
297
  "title": f"Tool Error for {tool_name.replace('avsolatorio_test_data_mcp_server', '')}",
298
- "status": "error",
299
  "id": f"error_{tool_name}",
300
  },
301
  }
@@ -392,6 +466,16 @@ class MCPClientWrapper:
392
 
393
  contents.extend(next_response.content)
394
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  def gradio_interface(
397
  server_path_or_url: str = "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse",
@@ -467,12 +551,13 @@ def gradio_interface(
467
 
468
  # connect_btn.click(client.connect, inputs=server_path, outputs=status)
469
  # Automatically call client.connect(...) as soon as the interface loads
470
- demo.load(
471
- fn=client.connect,
472
- inputs=server_path,
473
- outputs=status,
474
- show_progress="full",
475
- )
 
476
 
477
  msg.submit(
478
  client.process_message,
@@ -483,8 +568,9 @@ def gradio_interface(
483
  # clear_btn.click(lambda: [], None, chatbot)
484
 
485
  except KeyboardInterrupt:
486
- print("Keyboard interrupt received. Disconnecting from MCP server...")
487
- asyncio.run(client.disconnect())
 
488
  raise KeyboardInterrupt
489
  # demo.unload(client.disconnect)
490
 
 
12
  from anthropic import Anthropic
13
  from anthropic._exceptions import OverloadedError
14
  from dotenv import load_dotenv
15
+ from openai import OpenAI
16
+ from openai.types.responses import (
17
+ ResponseTextDeltaEvent,
18
+ ResponseContentPartAddedEvent,
19
+ ResponseContentPartDoneEvent,
20
+ ResponseTextDoneEvent,
21
+ ResponseMcpCallInProgressEvent,
22
+ ResponseAudioDeltaEvent,
23
+ ResponseMcpCallCompletedEvent,
24
+ ResponseOutputItemDoneEvent,
25
+ )
26
 
27
  load_dotenv()
28
 
29
+ # LLM_PROVIDER = "anthropic"
30
+ LLM_PROVIDER = "openai"
31
+
32
  SYSTEM_PROMPT = f"""You are a helpful assistant. Today is {datetime.now().strftime("%Y-%m-%d")}.
33
 
34
  You **do not** have prior knowledge of the World Development Indicators (WDI) data. Instead, you must rely entirely on the tools available to you to answer the user's questions.
 
84
  self.session = None
85
  self.exit_stack = None
86
  self.anthropic = Anthropic()
87
+ self.openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
88
  self.tools = []
89
 
90
  async def connect(self, server_path_or_url: str) -> str:
 
164
  async def process_message(
165
  self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
166
  ):
167
+ if not self.session and LLM_PROVIDER == "anthropic":
168
  messages = history + [
169
  {"role": "user", "content": message},
170
  {
 
187
  await asyncio.sleep(0.1)
188
  messages.pop(-1)
189
 
190
+ is_delta = False
191
  async for partial in self._process_query(message, history):
192
+ if partial[-1].get("delta"):
193
+ if not is_delta:
194
+ is_delta = True
195
+ messages.append(
196
+ {
197
+ "role": "assistant",
198
+ "content": "",
199
+ }
200
+ )
201
+ messages[-1]["content"] += partial[-1]["delta"]
202
+ else:
203
+ is_delta = False
204
+ messages.extend(partial)
205
+
206
  yield messages, gr.Textbox(value="")
207
  await asyncio.sleep(0.05)
208
 
 
216
  with open("messages.log.jsonl", "a+") as fl:
217
  fl.write(json.dumps(dict(time=f"{datetime.now()}", messages=messages)))
218
 
219
+ async def _process_query_openai(
220
+ self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
221
+ ):
222
+ response = self.openai.responses.create(
223
+ model="gpt-4.1-mini",
224
+ tools=[
225
+ {
226
+ "type": "mcp",
227
+ "server_label": "wdi_mcp",
228
+ "server_url": "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse",
229
+ "require_approval": "never",
230
+ "headers": {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"},
231
+ # "server_token": userdata.get('MCP_HF_TOKEN'),
232
+ },
233
+ ],
234
+ # input="What transport protocols are supported in the 2025-03-26 version of the MCP spec?",
235
+ instructions=SYSTEM_PROMPT,
236
+ # input="What is the gdp of india in 2020?",
237
+ input=message,
238
+ parallel_tool_calls=False,
239
+ stream=True,
240
+ temperature=0,
241
+ )
242
+
243
+ is_tool_call = False
244
+ for event in response:
245
+ if isinstance(event, ResponseMcpCallInProgressEvent):
246
+ is_tool_call = True
247
+ yield [
248
+ {
249
+ "role": "assistant",
250
+ "content": "I'll use the tool to help answer your question.",
251
+ }
252
+ ]
253
+ elif isinstance(event, ResponseOutputItemDoneEvent):
254
+ if is_tool_call:
255
+ yield [
256
+ {
257
+ "role": "assistant",
258
+ "content": "I've used the tool to help answer your question.",
259
+ }
260
+ ]
261
+ is_tool_call = False
262
+ elif isinstance(event, ResponseTextDeltaEvent):
263
+ yield [{"role": "assistant", "content": None, "delta": event.delta}]
264
+
265
+ async def _process_query_anthropic(
266
  self, message: str, history: List[Union[Dict[str, Any], ChatMessage]]
267
  ):
268
  claude_messages = []
 
366
  result_messages.append(
367
  {
368
  "role": "assistant",
369
+ "content": f"Sorry, I encountered an error while calling the tool: {error_msg}. Please try again or reload the page.",
370
  "metadata": {
371
  "title": f"Tool Error for {tool_name.replace('avsolatorio_test_data_mcp_server', '')}",
372
+ "status": "done",
373
  "id": f"error_{tool_name}",
374
  },
375
  }
 
466
 
467
  contents.extend(next_response.content)
468
 
469
+ async def _process_query(
470
+ self, message: str, history: List[Union[Dict[Any, Any], ChatMessage]]
471
+ ):
472
+ if LLM_PROVIDER == "anthropic":
473
+ async for partial in self._process_query_anthropic(message, history):
474
+ yield partial
475
+ elif LLM_PROVIDER == "openai":
476
+ async for partial in self._process_query_openai(message, history):
477
+ yield partial
478
+
479
 
480
  def gradio_interface(
481
  server_path_or_url: str = "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse",
 
551
 
552
  # connect_btn.click(client.connect, inputs=server_path, outputs=status)
553
  # Automatically call client.connect(...) as soon as the interface loads
554
+ if LLM_PROVIDER == "anthropic":
555
+ demo.load(
556
+ fn=client.connect,
557
+ inputs=server_path,
558
+ outputs=status,
559
+ show_progress="full",
560
+ )
561
 
562
  msg.submit(
563
  client.process_message,
 
568
  # clear_btn.click(lambda: [], None, chatbot)
569
 
570
  except KeyboardInterrupt:
571
+ if LLM_PROVIDER == "anthropic":
572
+ print("Keyboard interrupt received. Disconnecting from MCP server...")
573
+ asyncio.run(client.disconnect())
574
  raise KeyboardInterrupt
575
  # demo.unload(client.disconnect)
576