from mcp.server.fastmcp import FastMCP import json import sys import io import time import numpy as np import pandas as pd import torch from sklearn.metrics.pairwise import cosine_similarity from sentence_transformers import SentenceTransformer from gradio_client import Client from pydantic import BaseModel, Field def get_best_torch_device(): if torch.cuda.is_available(): return torch.device("cuda") elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") else: return torch.device("cpu") device = get_best_torch_device() sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace") sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace") mcp = FastMCP("huggingface_spaces_wdi_data") # Load the basic WDI metadata and vectors. wdi_data_vec_fpath = ( "./data/avsolatorio__GIST-small-Embedding-v0__005__indicator_embeddings.json" ) df = pd.read_json(wdi_data_vec_fpath) # Make it easy to index based on the idno df.index = df["idno"] # Change the IDS naming to metadata standard df.rename(columns={"title": "name", "text": "definition"}, inplace=True) # Extract the vectors into a torch.tensor vectors = torch.Tensor(df["embedding"]).to(device) # Load the embedding model model_name = "/".join(wdi_data_vec_fpath.split("/")[-1].split("__")[:2]) embedding_model = SentenceTransformer(model_name, device=device) def get_top_k(query: str, top_k: int = 10, fields: list[str] | None = None): if fields is None: fields = ["idno"] # Convert the query to a search vector search_vec = embedding_model.encode([query], convert_to_tensor=True) @ vectors.T # Sort by descending similarity score idx = search_vec.argsort(descending=True)[0][:top_k].tolist() return df.iloc[idx][fields].to_dict("records") @mcp.tool() async def generate_image(prompt: str, width: int = 512, height: int = 512) -> str: """Generate an image using SanaSprint model. Args: prompt: Text prompt describing the image to generate width: Image width (default: 512) height: Image height (default: 512) """ client = Client("https://ysharma-sanasprint.hf.space/") try: result = client.predict( prompt, "0.6B", 0, True, width, height, 4.0, 2, api_name="/infer" ) if isinstance(result, list) and len(result) >= 1: image_data = result[0] if isinstance(image_data, dict) and "url" in image_data: return json.dumps( { "type": "image", "url": image_data["url"], "message": f"Generated image for prompt: {prompt}", } ) return json.dumps({"type": "error", "message": "Failed to generate image"}) except Exception as e: return json.dumps( {"type": "error", "message": f"Error generating image: {str(e)}"} ) class SearchOutput(BaseModel): idno: str = Field(..., description="The unique identifier of the indicator.") name: str = Field(..., description="The name of the indicator.") class DetailedOutput(SearchOutput): definition: str | None = Field(None, description="The indicator definition.") @mcp.tool() async def search_relevant_indicators(query: str, top_k: int = 1) -> list[SearchOutput]: """Search for a shortlist of relevant indicators from the World Development Indicators (WDI) given the query. The search ranking may not be optimal, so the LLM may use this as shortlist and pick the most relevant from the list (if any). Args: query: The search query by the user or one formulated by an LLM based on the user's prompt. top_k: The number of shortlisted indicators that will be returned that are semantically related to the query. Returns: List of objects with keys indicator code/idno and name. """ return [ SearchOutput(**out) for out in get_top_k(query=query, top_k=top_k, fields=["idno", "name"]) ] @mcp.tool() async def indicator_info(indicator_ids: list[str]) -> list[DetailedOutput]: """Provides definition information for the given indicator id (idno). Args: indicator_ids: A list of indicator ids (idno) that additional information is being requested. Returns: List of objects with keys indicator code/idno, name, and definition. """ if isinstance(indicator_ids, str): indicator_ids = [indicator_ids] return [ DetailedOutput(**out) for out in df.loc[indicator_ids][["idno", "name", "definition"]].to_dict( "records" ) ] if __name__ == "__main__": mcp.run(transport="stdio")