|
import asyncio |
|
import os |
|
import json |
|
from typing import List, Dict, Any, Union |
|
from contextlib import AsyncExitStack |
|
from datetime import datetime |
|
import gradio as gr |
|
from gradio.components.chatbot import ChatMessage |
|
from mcp import ClientSession, StdioServerParameters |
|
from mcp.client.stdio import stdio_client |
|
from mcp.client.sse import sse_client |
|
from anthropic import Anthropic |
|
from anthropic._exceptions import OverloadedError |
|
from dotenv import load_dotenv |
|
from openai import OpenAI |
|
import openai |
|
from openai.types.responses import ( |
|
ResponseTextDeltaEvent, |
|
ResponseContentPartAddedEvent, |
|
ResponseContentPartDoneEvent, |
|
ResponseTextDoneEvent, |
|
ResponseMcpCallInProgressEvent, |
|
ResponseAudioDeltaEvent, |
|
ResponseMcpCallCompletedEvent, |
|
ResponseOutputItemDoneEvent, |
|
ResponseOutputItemAddedEvent, |
|
ResponseCompletedEvent, |
|
) |
|
import ast |
|
|
|
load_dotenv() |
|
|
|
|
|
LLM_PROVIDER = "openai" |
|
|
|
SYSTEM_PROMPT = f"""You are a helpful assistant. Today is {datetime.now().strftime("%Y-%m-%d")}. |
|
|
|
You **do not** have prior knowledge of the World Development Indicators (WDI) data. Instead, you must rely entirely on the tools available to you to answer the user's questions. |
|
|
|
Detect the language of the user's query and use that language for your response, unless the user specifies otherwise. |
|
|
|
When responding you must always plan the steps and enumerate all the tools that you plan to use to answer the user's query. |
|
|
|
### Your Instructions: |
|
|
|
1. **Tool Use Only**: |
|
- You must not provide any answers based on prior knowledge or assumptions. |
|
- You must **not** fabricate data or simulate the behavior of the `get_wdi_data` tool. |
|
- You cannot use the `get_wdi_data` tool without using the `search_relevant_indicators` tool first. |
|
- If the user requests WDI data, you **MUST ALWAYS** first call the `search_relevant_indicators` tool to see if there's any relevant data. |
|
- If relevant data exists, call the `get_wdi_data` tool to get the data. |
|
|
|
2. **Tool Invocation**: |
|
- Use any relevant tools provided to you to answer the user's question. |
|
- You may call multiple tools if needed, and you should do so in a logical sequence to minimize unnecessary user interaction. |
|
- Do not hesitate to invoke tools as soon as they are relevant. |
|
|
|
3. **Limitations**: |
|
- If a user request cannot be fulfilled using the tools available, respond by clearly stating that you do not have access to that information. |
|
|
|
4. **Ethical Guidelines**: |
|
- Do not make or endorse statements based on stereotypes, bias, or assumptions. |
|
- Ensure all claims and explanations are grounded in the data or factual evidence retrieved via tools. |
|
- Politely refuse to respond to requests that involve stereotypes or unfounded generalizations. |
|
|
|
5. **Communication Style**: |
|
- Present the data in clear, user-friendly language. |
|
- You may summarize or explain the data retrieved, but do **not** elaborate based on outside or implicit knowledge. |
|
- You may describe the data in a way that is easy to understand but you MUST NOT elaborate based on external knowledge. |
|
- Provide summary of the answer in the last step describing some observations and insights solely based on the data. |
|
|
|
6. **Presentation**: |
|
- Present the data in a way that is easy to understand. |
|
- Summarize the data in a table format with clear column names and values. |
|
- If the data is not available, respond by clearly stating that you do not have access to that information. |
|
|
|
7. **Tool Use**: |
|
- Fetch each indicator data using independent tool calls. |
|
- Provide some brief explanation between tool calls. |
|
|
|
Stay strictly within these boundaries while maintaining a helpful and respectful tone.""" |
|
|
|
|
|
LLM_MODEL = "claude-3-5-haiku-20241022" |
|
OPENAI_MODEL = "gpt-4.1" |
|
|
|
|
|
|
|
|
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
|
|
class MCPClientWrapper: |
|
def __init__(self): |
|
self.session = None |
|
self.exit_stack = None |
|
self.anthropic = Anthropic() |
|
self.openai = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
self.tools = [] |
|
|
|
async def connect(self, server_path_or_url: str) -> str: |
|
try: |
|
|
|
if self.exit_stack: |
|
return "Already connected to an MCP server. Please disconnect first." |
|
|
|
|
|
self.exit_stack = AsyncExitStack() |
|
|
|
if server_path_or_url.endswith(".py"): |
|
command = "python" |
|
|
|
server_params = StdioServerParameters( |
|
command=command, |
|
args=[server_path_or_url], |
|
env={"PYTHONIOENCODING": "utf-8", "PYTHONUNBUFFERED": "1"}, |
|
) |
|
|
|
print( |
|
f"Starting MCP server with command: {command} {server_path_or_url}" |
|
) |
|
|
|
stdio_transport = await self.exit_stack.enter_async_context( |
|
stdio_client(server_params) |
|
) |
|
self.stdio, self.write = stdio_transport |
|
else: |
|
print(f"Connecting to MCP server at: {server_path_or_url}") |
|
sse_transport = await self.exit_stack.enter_async_context( |
|
sse_client( |
|
server_path_or_url, |
|
headers={"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}, |
|
) |
|
) |
|
self.stdio, self.write = sse_transport |
|
|
|
print("Creating MCP client session...") |
|
|
|
self.session = await self.exit_stack.enter_async_context( |
|
ClientSession(self.stdio, self.write) |
|
) |
|
await self.session.initialize() |
|
print("MCP session initialized successfully") |
|
|
|
response = await self.session.list_tools() |
|
self.tools = [ |
|
{ |
|
"name": tool.name, |
|
"description": tool.description, |
|
"input_schema": tool.inputSchema, |
|
} |
|
for tool in response.tools |
|
] |
|
|
|
print("Available tools:", self.tools) |
|
tool_names = [tool["name"] for tool in self.tools] |
|
return f"Connected to MCP server. Available tools: {', '.join(tool_names)}" |
|
except Exception as e: |
|
error_msg = f"Failed to connect to MCP server: {str(e)}" |
|
print(error_msg) |
|
|
|
if self.exit_stack: |
|
await self.exit_stack.aclose() |
|
self.exit_stack = None |
|
self.session = None |
|
return error_msg |
|
|
|
async def disconnect(self): |
|
if self.exit_stack: |
|
print("Disconnecting from MCP server...") |
|
await self.exit_stack.aclose() |
|
self.exit_stack = None |
|
self.session = None |
|
|
|
async def process_message( |
|
self, |
|
message: str, |
|
history: List[Union[Dict[str, Any], ChatMessage]], |
|
previous_response_id: str = None, |
|
): |
|
print("previous_response_id", previous_response_id) |
|
if not self.session and LLM_PROVIDER == "anthropic": |
|
messages = history + [ |
|
{"role": "user", "content": message}, |
|
{ |
|
"role": "assistant", |
|
"content": "Please connect to an MCP server first by reloading the page.", |
|
}, |
|
] |
|
yield messages, gr.Textbox(value=""), previous_response_id |
|
else: |
|
messages = history + [ |
|
{"role": "user", "content": message}, |
|
{ |
|
"role": "assistant", |
|
"content": "Ok, let me think about your query 🤔...", |
|
}, |
|
] |
|
|
|
yield messages, gr.Textbox(value=""), previous_response_id |
|
|
|
await asyncio.sleep(0.2) |
|
messages.pop(-1) |
|
|
|
is_delta = False |
|
async for partial in self._process_query( |
|
message, history, previous_response_id |
|
): |
|
if partial[-1].get("delta"): |
|
if not is_delta: |
|
is_delta = True |
|
messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": "", |
|
} |
|
) |
|
messages[-1]["content"] += partial[-1]["delta"] |
|
if partial[-1].get("status") == "done": |
|
await asyncio.sleep(0.05) |
|
else: |
|
is_delta = False |
|
if partial[-1].get("response_id"): |
|
previous_response_id = partial[-1]["response_id"] |
|
yield ( |
|
messages, |
|
gr.Textbox(value=""), |
|
previous_response_id, |
|
) |
|
await asyncio.sleep(0.01) |
|
continue |
|
else: |
|
messages.extend(partial) |
|
print(partial) |
|
|
|
yield ( |
|
messages, |
|
gr.Textbox(value=""), |
|
previous_response_id, |
|
) |
|
await asyncio.sleep(0.01) |
|
|
|
if ( |
|
messages[-1]["role"] == "assistant" |
|
and messages[-1]["content"] |
|
== "The LLM API is overloaded now, try again later..." |
|
): |
|
break |
|
|
|
with open("messages.log.jsonl", "a+") as fl: |
|
fl.write( |
|
json.dumps( |
|
dict( |
|
time=f"{datetime.now()}", |
|
messages=messages, |
|
previous_response_id=previous_response_id, |
|
) |
|
) |
|
) |
|
|
|
async def _process_query_openai( |
|
self, |
|
message: str, |
|
history: List[Union[Dict[str, Any], ChatMessage]], |
|
previous_response_id: str = None, |
|
): |
|
response = self.openai.responses.create( |
|
model=OPENAI_MODEL, |
|
tools=[ |
|
{ |
|
"type": "mcp", |
|
"server_label": "wdi_mcp", |
|
"server_url": "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse", |
|
"require_approval": "never", |
|
"headers": {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}, |
|
|
|
}, |
|
], |
|
|
|
instructions=SYSTEM_PROMPT, |
|
|
|
input=message, |
|
parallel_tool_calls=False, |
|
stream=True, |
|
max_output_tokens=32768, |
|
temperature=0, |
|
previous_response_id=previous_response_id.strip() |
|
if previous_response_id |
|
else None, |
|
store=True, |
|
) |
|
|
|
is_tool_call = False |
|
tool_name = None |
|
tool_args = None |
|
for event in response: |
|
if isinstance(event, ResponseCompletedEvent): |
|
yield [ |
|
{ |
|
"response_id": event.response.id, |
|
} |
|
] |
|
elif ( |
|
isinstance(event, ResponseOutputItemAddedEvent) |
|
and event.item.type == "mcp_call" |
|
): |
|
is_tool_call = True |
|
tool_name = event.item.name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_tool_call: |
|
if ( |
|
isinstance(event, ResponseAudioDeltaEvent) |
|
and event.type == "response.mcp_call_arguments.done" |
|
): |
|
tool_args = event.arguments |
|
|
|
try: |
|
tool_args = json.dumps( |
|
json.loads(tool_args), ensure_ascii=True, indent=2 |
|
) |
|
except: |
|
pass |
|
|
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": f"I'll use the {tool_name} tool to help answer your question.", |
|
"metadata": { |
|
"title": f"Using tool: {tool_name.replace('avsolatorio_test_data_mcp_server', '')}", |
|
"log": f"Parameters: {tool_args}", |
|
|
|
"status": "done", |
|
"id": f"tool_call_{tool_name}", |
|
}, |
|
} |
|
] |
|
|
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": "```json\n" + tool_args + "\n```", |
|
"metadata": { |
|
"parent_id": f"tool_call_{tool_name}", |
|
"id": f"params_{tool_name}", |
|
"title": "Tool Parameters", |
|
}, |
|
} |
|
] |
|
|
|
elif isinstance(event, ResponseOutputItemDoneEvent): |
|
if event.item.type == "mcp_call": |
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": "Here are the results from the tool:", |
|
"metadata": { |
|
"title": f"Tool Result for {tool_name.replace('avsolatorio_test_data_mcp_server', '')}", |
|
"status": "done", |
|
"id": f"result_{tool_name}", |
|
}, |
|
} |
|
] |
|
|
|
result_content = event.item.output |
|
if result_content.startswith("root="): |
|
result_content = result_content[5:] |
|
try: |
|
result_content = ast.literal_eval(result_content) |
|
result_content = json.dumps(result_content, indent=2) |
|
except: |
|
pass |
|
|
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": "```\n" + result_content + "\n```", |
|
"metadata": { |
|
"parent_id": f"result_{tool_name}", |
|
"id": f"raw_result_{tool_name}", |
|
"title": "Raw Output", |
|
}, |
|
} |
|
] |
|
is_tool_call = False |
|
tool_name = None |
|
tool_args = None |
|
|
|
elif ( |
|
isinstance(event, ResponseContentPartDoneEvent) |
|
and event.type == "response.content_part.done" |
|
): |
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": "", |
|
"delta": "", |
|
"status": "done", |
|
} |
|
] |
|
elif isinstance(event, ResponseTextDeltaEvent): |
|
yield [{"role": "assistant", "content": None, "delta": event.delta}] |
|
|
|
async def _process_query_anthropic( |
|
self, message: str, history: List[Union[Dict[str, Any], ChatMessage]] |
|
): |
|
claude_messages = [] |
|
for msg in history: |
|
if isinstance(msg, ChatMessage): |
|
role, content = msg.role, msg.content |
|
else: |
|
role, content = msg.get("role"), msg.get("content") |
|
|
|
if role in ["user", "assistant", "system"]: |
|
claude_messages.append({"role": role, "content": content}) |
|
|
|
claude_messages.append({"role": "user", "content": message}) |
|
|
|
try: |
|
response = self.anthropic.messages.create( |
|
|
|
model=LLM_MODEL, |
|
system=SYSTEM_PROMPT, |
|
max_tokens=1000, |
|
messages=claude_messages, |
|
tools=self.tools, |
|
) |
|
except OverloadedError: |
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": "The LLM API is overloaded now, try again later...", |
|
} |
|
] |
|
|
|
|
|
result_messages = [] |
|
partial_messages = [] |
|
|
|
print(response.content) |
|
contents = response.content |
|
|
|
MAX_CALLS = 10 |
|
auto_calls = 0 |
|
|
|
while len(contents) > 0 and auto_calls < MAX_CALLS: |
|
content = contents.pop(0) |
|
|
|
if content.type == "text": |
|
result_messages.append({"role": "assistant", "content": content.text}) |
|
claude_messages.append({"role": "assistant", "content": content.text}) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
partial_messages = [] |
|
|
|
elif content.type == "tool_use": |
|
tool_id = content.id |
|
tool_name = content.name |
|
tool_args = content.input |
|
|
|
result_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": f"I'll use the {tool_name} tool to help answer your question.", |
|
"metadata": { |
|
"title": f"Using tool: {tool_name.replace('avsolatorio_test_data_mcp_server', '')}", |
|
"log": f"Parameters: {json.dumps(tool_args, ensure_ascii=True)}", |
|
|
|
"status": "done", |
|
"id": f"tool_call_{tool_name}", |
|
}, |
|
} |
|
) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
|
|
result_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": "```json\n" |
|
+ json.dumps(tool_args, indent=2, ensure_ascii=True) |
|
+ "\n```", |
|
"metadata": { |
|
"parent_id": f"tool_call_{tool_name}", |
|
"id": f"params_{tool_name}", |
|
"title": "Tool Parameters", |
|
}, |
|
} |
|
) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
|
|
print(f"Calling tool: {tool_name} with args: {tool_args}") |
|
try: |
|
|
|
if not self.session or not self.stdio or not self.write: |
|
raise Exception( |
|
"MCP session is not connected or has been closed" |
|
) |
|
|
|
result = await self.session.call_tool(tool_name, tool_args) |
|
except Exception as e: |
|
error_msg = f"Error calling tool {tool_name}: {str(e)}" |
|
print(error_msg) |
|
result_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": f"Sorry, I encountered an error while calling the tool: {error_msg}. Please try again or reload the page.", |
|
"metadata": { |
|
"title": f"Tool Error for {tool_name.replace('avsolatorio_test_data_mcp_server', '')}", |
|
"status": "done", |
|
"id": f"error_{tool_name}", |
|
}, |
|
} |
|
) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
partial_messages = [] |
|
continue |
|
|
|
if result_messages and "metadata" in result_messages[-2]: |
|
result_messages[-2]["metadata"]["status"] = "done" |
|
|
|
result_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": "Here are the results from the tool:", |
|
"metadata": { |
|
"title": f"Tool Result for {tool_name.replace('avsolatorio_test_data_mcp_server', '')}", |
|
"status": "done", |
|
"id": f"result_{tool_name}", |
|
}, |
|
} |
|
) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
partial_messages = [] |
|
|
|
result_content = result.content |
|
print(result_content) |
|
if isinstance(result_content, list): |
|
result_content = [r.model_dump() for r in result_content] |
|
|
|
for r in result_content: |
|
|
|
r.pop("annotations", None) |
|
try: |
|
r["text"] = json.loads(r["text"]) |
|
except: |
|
pass |
|
|
|
print("result_content", result_content) |
|
|
|
result_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": "```\n" |
|
+ json.dumps(result_content, indent=2) |
|
+ "\n```", |
|
"metadata": { |
|
"parent_id": f"result_{tool_name}", |
|
"id": f"raw_result_{tool_name}", |
|
"title": "Raw Output", |
|
}, |
|
} |
|
) |
|
partial_messages.append(result_messages[-1]) |
|
yield [result_messages[-1]] |
|
partial_messages = [] |
|
|
|
claude_messages.append( |
|
{"role": "assistant", "content": [content.model_dump()]} |
|
) |
|
claude_messages.append( |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "tool_result", |
|
"tool_use_id": tool_id, |
|
"content": json.dumps(result_content, indent=2), |
|
} |
|
], |
|
} |
|
) |
|
|
|
try: |
|
next_response = self.anthropic.messages.create( |
|
model=LLM_MODEL, |
|
system=SYSTEM_PROMPT, |
|
max_tokens=1000, |
|
messages=claude_messages, |
|
tools=self.tools, |
|
) |
|
auto_calls += 1 |
|
except OverloadedError: |
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": "The LLM API is overloaded now, try again later...", |
|
} |
|
] |
|
|
|
print("next_response", next_response.content) |
|
|
|
contents.extend(next_response.content) |
|
|
|
async def _process_query( |
|
self, |
|
message: str, |
|
history: List[Union[Dict[Any, Any], ChatMessage]], |
|
previous_response_id: str = None, |
|
): |
|
if LLM_PROVIDER == "anthropic": |
|
async for partial in self._process_query_anthropic(message, history): |
|
yield partial |
|
elif LLM_PROVIDER == "openai": |
|
try: |
|
async for partial in self._process_query_openai( |
|
message, history, previous_response_id |
|
): |
|
yield partial |
|
except openai.APIError as e: |
|
print(e) |
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": "The LLM encountered an error. Please try again or reload the page.", |
|
} |
|
] |
|
except Exception as e: |
|
print(e) |
|
yield [ |
|
{ |
|
"role": "assistant", |
|
"content": f"Sorry, I encountered an unexpected error: `{e}`. Please try again or reload the page.", |
|
} |
|
] |
|
|
|
|
|
def gradio_interface( |
|
server_path_or_url: str = "https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse", |
|
): |
|
|
|
|
|
|
|
client = MCPClientWrapper() |
|
custom_css = """ |
|
.gradio-container { |
|
background-color: #fff !important; |
|
} |
|
.message-row.panel.bot-row { |
|
background-color: #fff !important; |
|
} |
|
.message-row.panel.user-row { |
|
background-color: #fff !important; |
|
} |
|
.user { |
|
background-color: #f1f6ff !important; |
|
} |
|
.bot { |
|
background-color: #fff !important; |
|
} |
|
.role { |
|
margin-left: 10px !important; |
|
} |
|
footer{display:none !important} |
|
""" |
|
|
|
|
|
with gr.Blocks(title="WDI MCP Client", css=custom_css, theme=None) as demo: |
|
try: |
|
gr.Markdown("# Data360 Chat [Prototype]") |
|
|
|
|
|
with gr.Accordion( |
|
"Connect to the WDI MCP server and chat with the assistant", |
|
open=False, |
|
visible=server_path_or_url.endswith(".py"), |
|
): |
|
with gr.Row(equal_height=True): |
|
with gr.Column(scale=4): |
|
server_path = gr.Textbox( |
|
label="Server Script Path", |
|
placeholder="Enter path to server script (e.g., wdi_mcp_server.py)", |
|
value=server_path_or_url, |
|
) |
|
with gr.Column(scale=1): |
|
connect_btn = gr.Button("Connect") |
|
|
|
status = gr.Textbox(label="Connection Status", interactive=False) |
|
|
|
chatbot = gr.Chatbot( |
|
value=[], |
|
height="81vh", |
|
type="messages", |
|
show_copy_button=False, |
|
avatar_images=("img/small-user.png", "img/small-robot.png"), |
|
autoscroll=True, |
|
layout="panel", |
|
placeholder="Ask development data questions!", |
|
) |
|
previous_response_id = gr.State(None) |
|
|
|
with gr.Row(equal_height=True): |
|
msg = gr.Textbox( |
|
label=None, |
|
placeholder="Ask about what indicators are available for a specific topic (e.g., What's the definition of GDP?)", |
|
scale=4, |
|
show_label=False, |
|
) |
|
|
|
|
|
|
|
|
|
if LLM_PROVIDER == "anthropic": |
|
demo.load( |
|
fn=client.connect, |
|
inputs=server_path, |
|
outputs=status, |
|
show_progress="full", |
|
) |
|
|
|
msg.submit( |
|
client.process_message, |
|
[msg, chatbot, previous_response_id], |
|
[chatbot, msg, previous_response_id], |
|
concurrency_limit=10, |
|
) |
|
|
|
|
|
except KeyboardInterrupt: |
|
if LLM_PROVIDER == "anthropic": |
|
print("Keyboard interrupt received. Disconnecting from MCP server...") |
|
asyncio.run(client.disconnect()) |
|
raise KeyboardInterrupt |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
if not os.getenv("ANTHROPIC_API_KEY"): |
|
print( |
|
"Warning: ANTHROPIC_API_KEY not found in environment. Please set it in your .env file." |
|
) |
|
|
|
|
|
interface = gradio_interface( |
|
server_path_or_url="https://avsolatorio-test-data-mcp-server.hf.space/gradio_api/mcp/sse" |
|
) |
|
interface.launch( |
|
server_name=os.getenv("SERVER_NAME", "127.0.0.1"), |
|
server_port=os.getenv("SERVER_PORT", 7860), |
|
debug=True, |
|
) |
|
|