
Improve call to action when disconnected to the mcp server. Try closing the client when terminating the app
471d861
import asyncio | |
import os | |
import json | |
from typing import List, Dict, Any, Union | |
from contextlib import AsyncExitStack | |
from datetime import datetime | |
import gradio as gr | |
from gradio.components.chatbot import ChatMessage | |
from mcp import ClientSession, StdioServerParameters | |
from mcp.client.stdio import stdio_client | |
from mcp.client.sse import sse_client | |
from anthropic import Anthropic | |
from anthropic._exceptions import OverloadedError | |
from dotenv import load_dotenv | |
load_dotenv() | |
SYSTEM_PROMPT = f"""You are a helpful assistant. Today is {datetime.now().strftime("%Y-%m-%d")}. | |
You **do not** have prior knowledge of the World Development Indicators (WDI) data. Instead, you must rely entirely on the tools available to you to answer the user's questions. | |
When responding you must always plan the steps and enumerate all the tools that you plan to use to answer the user's query. | |
### Your Instructions: | |
1. **Tool Use Only**: | |
- You must not provide any answers based on prior knowledge or assumptions. | |
- You must **not** fabricate data or simulate the behavior of the `get_wdi_data` tool. | |
- You cannot use the `get_wdi_data` tool without using the `search_relevant_indicators` tool first. | |
- If the user requests WDI data, you **MUST ALWAYS** first call the `search_relevant_indicators` tool to see if there's any relevant data. | |
- If relevant data exists, call the `get_wdi_data` tool to get the data. | |
2. **Tool Invocation**: | |
- Use any relevant tools provided to you to answer the user's question. | |
- You may call multiple tools if needed, and you should do so in a logical sequence to minimize unnecessary user interaction. | |
- Do not hesitate to invoke tools as soon as they are relevant. | |
3. **Limitations**: | |
- If a user request cannot be fulfilled using the tools available, respond by clearly stating that you do not have access to that information. | |
4. **Ethical Guidelines**: | |
- Do not make or endorse statements based on stereotypes, bias, or assumptions. | |
- Ensure all claims and explanations are grounded in the data or factual evidence retrieved via tools. | |
- Politely refuse to respond to requests that involve stereotypes or unfounded generalizations. | |
5. **Communication Style**: | |
- Present the data in clear, user-friendly language. | |
- You may summarize or explain the data retrieved, but do **not** elaborate based on outside or implicit knowledge. | |
- You may describe the data in a way that is easy to understand but you MUST NOT elaborate based on external knowledge. | |
6. **Presentation**: | |
- Present the data in a way that is easy to understand. | |
- Summarize the data in a table format with clear column names and values. | |
- If the data is not available, respond by clearly stating that you do not have access to that information. | |
Stay strictly within these boundaries while maintaining a helpful and respectful tone.""" | |
LLM_MODEL = "claude-3-5-haiku-20241022" | |
# What is the military spending of bangladesh in 2014? | |
# When a tool is needed for any step, ensure to add the token `TOOL_USE`. | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
class MCPClientWrapper: | |
def __init__(self): | |
self.session = None | |
self.exit_stack = None | |
self.anthropic = Anthropic() | |
self.tools = [] | |
async def connect(self, server_path_or_url: str) -> str: | |
try: | |
# If there's an existing session, close it | |
if self.exit_stack: | |
return "Already connected to an MCP server. Please disconnect first." | |
# await self.exit_stack.aclose() | |
self.exit_stack = AsyncExitStack() | |
if server_path_or_url.endswith(".py"): | |
command = "python" | |
server_params = StdioServerParameters( | |
command=command, | |
args=[server_path_or_url], | |
env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"}, | |
) | |
print( | |
f"Starting MCP server with command: {command} {server_path_or_url}" | |
) | |
# Launch MCP subprocess and bind streams on the current running loop | |
stdio_transport = await self.exit_stack.enter_async_context( | |
stdio_client(server_params) | |
) | |
self.stdio, self.write = stdio_transport | |
else: | |
print(f"Connecting to MCP server at: {server_path_or_url}") | |
sse_transport = await self.exit_stack.enter_async_context( | |
sse_client( | |
server_path_or_url, | |
headers={"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}, | |
) | |
) | |
self.stdio, self.write = sse_transport | |
print("Creating MCP client session...") | |
# Create ClientSession on this same loop | |
self.session = await self.exit_stack.enter_async_context( | |
ClientSession(self.stdio, self.write) | |
) | |
await self.session.initialize() | |
print("MCP session initialized successfully") | |
response = await self.session.list_tools() | |
self.tools = [ | |
{ | |
"name": tool.name, | |
"description": tool.description, | |
"input_schema": tool.inputSchema, | |
} | |
for tool in response.tools | |
] | |
print("Available tools:", self.tools) | |
tool_names = [tool["name"] for tool in self.tools] | |
return f"Connected to MCP server. Available tools: {', '.join(tool_names)}" | |
except Exception as e: | |
error_msg = f"Failed to connect to MCP server: {str(e)}" | |
print(error_msg) | |
# Clean up on error | |
if self.exit_stack: | |
await self.exit_stack.aclose() | |
self.exit_stack = None | |
self.session = None | |
return error_msg | |
async def disconnect(self): | |
if self.exit_stack: | |
print("Disconnecting from MCP server...") | |
await self.exit_stack.aclose() | |
self.exit_stack = None | |
self.session = None | |
async def process_message( | |
self, message: str, history: List[Union[Dict[str, Any], ChatMessage]] | |
): | |
if not self.session: | |
messages = history + [ | |
{"role": "user", "content": message}, | |
{ | |
"role": "assistant", | |
"content": "Please connect to an MCP server first by reloading the page.", | |
}, | |
] | |
yield messages, gr.Textbox(value="") | |
else: | |
messages = history + [ | |
{"role": "user", "content": message}, | |
{ | |
"role": "assistant", | |
"content": "Ok, let me think about your query 🤔...", | |
}, | |
] | |
yield messages, gr.Textbox(value="") | |
# simulate thinking with asyncio.sleep | |
await asyncio.sleep(0.1) | |
messages.pop(-1) | |
async for partial in self._process_query(message, history): | |
messages.extend(partial) | |
yield messages, gr.Textbox(value="") | |
await asyncio.sleep(0.05) | |
if ( | |
messages[-1]["role"] == "assistant" | |
and messages[-1]["content"] | |
== "The LLM API is overloaded now, try again later..." | |
): | |
break | |
with open("messages.log.jsonl", "a+") as fl: | |
fl.write(json.dumps(dict(time=f"{datetime.now()}", messages=messages))) | |
async def _process_query( | |
self, message: str, history: List[Union[Dict[str, Any], ChatMessage]] | |
): | |
claude_messages = [] | |
for msg in history: | |
if isinstance(msg, ChatMessage): | |
role, content = msg.role, msg.content | |
else: | |
role, content = msg.get("role"), msg.get("content") | |
if role in ["user", "assistant", "system"]: | |
claude_messages.append({"role": role, "content": content}) | |
claude_messages.append({"role": "user", "content": message}) | |
try: | |
response = self.anthropic.messages.create( | |
# model="claude-3-5-sonnet-20241022", | |
model=LLM_MODEL, | |
system=SYSTEM_PROMPT, | |
max_tokens=1000, | |
messages=claude_messages, | |
tools=self.tools, | |
) | |
except OverloadedError: | |
yield [ | |
{ | |
"role": "assistant", | |
"content": "The LLM API is overloaded now, try again later...", | |
} | |
] | |
# TODO: Add a retry mechanism | |
result_messages = [] | |
partial_messages = [] | |
print(response.content) | |
contents = response.content | |
MAX_CALLS = 10 | |
auto_calls = 0 | |
while len(contents) > 0 and auto_calls < MAX_CALLS: | |
content = contents.pop(0) | |
if content.type == "text": | |
result_messages.append({"role": "assistant", "content": content.text}) | |
claude_messages.append({"role": "assistant", "content": content.text}) | |
partial_messages.append(result_messages[-1]) | |
yield [result_messages[-1]] | |
partial_messages = [] | |
elif content.type == "tool_use": | |
tool_id = content.id | |
tool_name = content.name | |
tool_args = content.input | |
result_messages.append( | |
{ | |
"role": "assistant", | |
"content": f"I'll use the {tool_name} tool to help answer your question.", | |
"metadata": { | |
"title": f"Using tool: {tool_name.replace('avsolatorio_test_data_mcp_server', '')}", | |
"log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}", | |
# "status": "pending", | |
"status": "done", | |
"id": f"tool_call_{tool_name}", | |
}, | |
} | |
) | |
partial_messages.append(result_messages[-1]) | |
yield [result_messages[-1]] | |
result_messages.append( | |
{ | |
"role": "assistant", | |
"content": "```json\n" | |
+ json.dumps(tool_args, indent=2, ensure_ascii=True) | |
+ "\n```", | |
"metadata": { | |
"parent_id": f"tool_call_{tool_name}", | |
"id": f"params_{tool_name}", | |
"title": "Tool Parameters", | |
}, | |
} | |
) | |
partial_messages.append(result_messages[-1]) | |
yield [result_messages[-1]] | |
print(f"Calling tool: {tool_name} with args: {tool_args}") | |
try: | |
# Check if session is still valid | |
if not self.session or not self.stdio or not self.write: | |
raise Exception( | |
"MCP session is not connected or has been closed" | |
) | |
result = await self.session.call_tool(tool_name, tool_args) | |
except Exception as e: | |
error_msg = f"Error calling tool {tool_name}: {str(e)}" | |
print(error_msg) | |
result_messages.append( | |
{ | |
"role": "assistant", | |
"content": f"Sorry, I encountered an error while calling the tool: {error_msg}. Please try again.", | |
"metadata": { | |
"title": f"Tool Error for {tool_name.replace('avsolatorio_test_data_mcp_server', '')}", | |
"status": "error", | |
"id": f"error_{tool_name}", | |
}, | |
} | |
) | |
partial_messages.append(result_messages[-1]) | |
yield [result_messages[-1]] | |
partial_messages = [] | |
continue | |
if result_messages and "metadata" in result_messages[-2]: | |
result_messages[-2]["metadata"]["status"] = "done" | |
result_messages.append( | |
{ | |
"role": "assistant", | |
"content": "Here are the results from the tool:", | |
"metadata": { | |
"title": f"Tool Result for {tool_name.replace('avsolatorio_test_data_mcp_server', '')}", | |
"status": "done", | |
"id": f"result_{tool_name}", | |
}, | |
} | |
) | |
partial_messages.append(result_messages[-1]) | |
yield [result_messages[-1]] | |
partial_messages = [] | |
result_content = result.content | |
print(result_content) | |
if isinstance(result_content, list): | |
result_content = [r.model_dump() for r in result_content] | |
for r in result_content: | |
# Remove annotations field from each item if it exists | |
r.pop("annotations", None) | |
try: | |
r["text"] = json.loads(r["text"]) | |
except: | |
pass | |
print("result_content", result_content) | |
result_messages.append( | |
{ | |
"role": "assistant", | |
"content": "```\n" | |
+ json.dumps(result_content, indent=2) | |
+ "\n```", | |
"metadata": { | |
"parent_id": f"result_{tool_name}", | |
"id": f"raw_result_{tool_name}", | |
"title": "Raw Output", | |
}, | |
} | |
) | |
partial_messages.append(result_messages[-1]) | |
yield [result_messages[-1]] | |
partial_messages = [] | |
claude_messages.append( | |
{"role": "assistant", "content": [content.model_dump()]} | |
) | |
claude_messages.append( | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "tool_result", | |
"tool_use_id": tool_id, | |
"content": json.dumps(result_content, indent=2), | |
} | |
], | |
} | |
) | |
try: | |
next_response = self.anthropic.messages.create( | |
model=LLM_MODEL, | |
system=SYSTEM_PROMPT, | |
max_tokens=1000, | |
messages=claude_messages, | |
tools=self.tools, | |
) | |
auto_calls += 1 | |
except OverloadedError: | |
yield [ | |
{ | |
"role": "assistant", | |
"content": "The LLM API is overloaded now, try again later...", | |
} | |
] | |
print("next_response", next_response.content) | |
contents.extend(next_response.content) | |
def gradio_interface( | |
server_path_or_url: str = "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse", | |
): | |
# server_path_or_url = "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse" | |
# server_path_or_url = "wdi_mcp_server.py" | |
client = MCPClientWrapper() | |
custom_css = """ | |
.gradio-container { | |
background-color: #fff !important; | |
} | |
.message-row.panel.bot-row { | |
background-color: #fff !important; | |
} | |
.message-row.panel.user-row { | |
background-color: #fff !important; | |
} | |
.user { | |
background-color: #f1f6ff !important; | |
} | |
.bot { | |
background-color: #fff !important; | |
} | |
.role { | |
margin-left: 10px !important; | |
} | |
footer{display:none !important} | |
""" | |
# Disable auto-dark mode by setting theme to None | |
with gr.Blocks(title="WDI MCP Client", css=custom_css, theme=None) as demo: | |
try: | |
gr.Markdown("# Development Data Chat") | |
# gr.Markdown("Connect to the WDI MCP server and chat with the assistant") | |
with gr.Accordion( | |
"Connect to the WDI MCP server and chat with the assistant", | |
open=False, | |
visible=server_path_or_url.endswith(".py"), | |
): | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4): | |
server_path = gr.Textbox( | |
label="Server Script Path", | |
placeholder="Enter path to server script (e.g., wdi_mcp_server.py)", | |
value=server_path_or_url, | |
) | |
with gr.Column(scale=1): | |
connect_btn = gr.Button("Connect") | |
status = gr.Textbox(label="Connection Status", interactive=False) | |
chatbot = gr.Chatbot( | |
value=[], | |
height="81vh", | |
type="messages", | |
show_copy_button=False, | |
avatar_images=("img/small-user.png", "img/small-robot.png"), | |
autoscroll=True, | |
layout="panel", | |
placeholder="Ask development data questions!", | |
) | |
with gr.Row(equal_height=True): | |
msg = gr.Textbox( | |
label=None, | |
placeholder="Ask about what indicators are available for a specific topic (e.g., What's the definition of GDP?)", | |
scale=4, | |
show_label=False, | |
) | |
# clear_btn = gr.Button("Clear Chat", scale=1) | |
# connect_btn.click(client.connect, inputs=server_path, outputs=status) | |
# Automatically call client.connect(...) as soon as the interface loads | |
demo.load( | |
fn=client.connect, | |
inputs=server_path, | |
outputs=status, | |
show_progress="full", | |
) | |
msg.submit( | |
client.process_message, | |
[msg, chatbot], | |
[chatbot, msg], | |
concurrency_limit=10, | |
) | |
# clear_btn.click(lambda: [], None, chatbot) | |
except KeyboardInterrupt: | |
print("Keyboard interrupt received. Disconnecting from MCP server...") | |
asyncio.run(client.disconnect()) | |
raise KeyboardInterrupt | |
# demo.unload(client.disconnect) | |
return demo | |
if __name__ == "__main__": | |
if not os.getenv("ANTHROPIC_API_KEY"): | |
print( | |
"Warning: ANTHROPIC_API_KEY not found in environment. Please set it in your .env file." | |
) | |
# interface = gradio_interface(server_path_or_url="wdi_mcp_server.py") | |
interface = gradio_interface( | |
server_path_or_url="https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse" | |
) | |
interface.launch( | |
server_name=os.getenv("SERVER_NAME", "127.0.0.1"), | |
server_port=os.getenv("SERVER_PORT", 7860), | |
debug=True, | |
) | |