test-data-mcp-server / wdi_mcp_server.py
avsolatorio's picture
Implement simple WDI MCP Server
9978e32
raw
history blame
4.82 kB
from mcp.server.fastmcp import FastMCP
import json
import sys
import io
import time
import numpy as np
import pandas as pd
import torch
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from gradio_client import Client
from pydantic import BaseModel, Field
def get_best_torch_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
device = get_best_torch_device()
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
mcp = FastMCP("huggingface_spaces_wdi_data")
# Load the basic WDI metadata and vectors.
wdi_data_vec_fpath = (
"./data/avsolatorio__GIST-small-Embedding-v0__005__indicator_embeddings.json"
)
df = pd.read_json(wdi_data_vec_fpath)
# Make it easy to index based on the idno
df.index = df["idno"]
# Change the IDS naming to metadata standard
df.rename(columns={"title": "name", "text": "definition"}, inplace=True)
# Extract the vectors into a torch.tensor
vectors = torch.Tensor(df["embedding"]).to(device)
# Load the embedding model
model_name = "/".join(wdi_data_vec_fpath.split("/")[-1].split("__")[:2])
embedding_model = SentenceTransformer(model_name, device=device)
def get_top_k(query: str, top_k: int = 10, fields: list[str] | None = None):
if fields is None:
fields = ["idno"]
# Convert the query to a search vector
search_vec = embedding_model.encode([query], convert_to_tensor=True) @ vectors.T
# Sort by descending similarity score
idx = search_vec.argsort(descending=True)[0][:top_k].tolist()
return df.iloc[idx][fields].to_dict("records")
@mcp.tool()
async def generate_image(prompt: str, width: int = 512, height: int = 512) -> str:
"""Generate an image using SanaSprint model.
Args:
prompt: Text prompt describing the image to generate
width: Image width (default: 512)
height: Image height (default: 512)
"""
client = Client("https://ysharma-sanasprint.hf.space/")
try:
result = client.predict(
prompt, "0.6B", 0, True, width, height, 4.0, 2, api_name="/infer"
)
if isinstance(result, list) and len(result) >= 1:
image_data = result[0]
if isinstance(image_data, dict) and "url" in image_data:
return json.dumps(
{
"type": "image",
"url": image_data["url"],
"message": f"Generated image for prompt: {prompt}",
}
)
return json.dumps({"type": "error", "message": "Failed to generate image"})
except Exception as e:
return json.dumps(
{"type": "error", "message": f"Error generating image: {str(e)}"}
)
class SearchOutput(BaseModel):
idno: str = Field(..., description="The unique identifier of the indicator.")
name: str = Field(..., description="The name of the indicator.")
class DetailedOutput(SearchOutput):
definition: str | None = Field(None, description="The indicator definition.")
@mcp.tool()
async def search_relevant_indicators(query: str, top_k: int = 1) -> list[SearchOutput]:
"""Search for a shortlist of relevant indicators from the World Development Indicators (WDI) given the query. The search ranking may not be optimal, so the LLM may use this as shortlist and pick the most relevant from the list (if any).
Args:
query: The search query by the user or one formulated by an LLM based on the user's prompt.
top_k: The number of shortlisted indicators that will be returned that are semantically related to the query.
Returns:
List of objects with keys indicator code/idno and name.
"""
return [
SearchOutput(**out)
for out in get_top_k(query=query, top_k=top_k, fields=["idno", "name"])
]
@mcp.tool()
async def indicator_info(indicator_ids: list[str]) -> list[DetailedOutput]:
"""Provides definition information for the given indicator id (idno).
Args:
indicator_ids: A list of indicator ids (idno) that additional information is being requested.
Returns:
List of objects with keys indicator code/idno, name, and definition.
"""
if isinstance(indicator_ids, str):
indicator_ids = [indicator_ids]
return [
DetailedOutput(**out)
for out in df.loc[indicator_ids][["idno", "name", "definition"]].to_dict(
"records"
)
]
if __name__ == "__main__":
mcp.run(transport="stdio")