File size: 4,815 Bytes
34225b1
 
 
 
 
9978e32
 
 
 
 
34225b1
 
 
9978e32
 
 
 
 
 
 
 
 
 
 
 
34225b1
 
 
9978e32
34225b1
 
 
9978e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34225b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9978e32
 
 
34225b1
 
 
 
 
 
 
 
 
 
 
 
9978e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34225b1
9978e32
 
34225b1
 
9978e32
 
 
34225b1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from mcp.server.fastmcp import FastMCP
import json
import sys
import io
import time
import numpy as np
import pandas as pd
import torch
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from gradio_client import Client
from pydantic import BaseModel, Field


def get_best_torch_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


device = get_best_torch_device()

sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")


mcp = FastMCP("huggingface_spaces_wdi_data")


# Load the basic WDI metadata and vectors.
wdi_data_vec_fpath = (
    "./data/avsolatorio__GIST-small-Embedding-v0__005__indicator_embeddings.json"
)
df = pd.read_json(wdi_data_vec_fpath)

# Make it easy to index based on the idno
df.index = df["idno"]

# Change the IDS naming to metadata standard
df.rename(columns={"title": "name", "text": "definition"}, inplace=True)

# Extract the vectors into a torch.tensor
vectors = torch.Tensor(df["embedding"]).to(device)


# Load the embedding model
model_name = "/".join(wdi_data_vec_fpath.split("/")[-1].split("__")[:2])
embedding_model = SentenceTransformer(model_name, device=device)


def get_top_k(query: str, top_k: int = 10, fields: list[str] | None = None):
    if fields is None:
        fields = ["idno"]

    # Convert the query to a search vector
    search_vec = embedding_model.encode([query], convert_to_tensor=True) @ vectors.T

    # Sort by descending similarity score
    idx = search_vec.argsort(descending=True)[0][:top_k].tolist()

    return df.iloc[idx][fields].to_dict("records")


@mcp.tool()
async def generate_image(prompt: str, width: int = 512, height: int = 512) -> str:
    """Generate an image using SanaSprint model.

    Args:
        prompt: Text prompt describing the image to generate
        width: Image width (default: 512)
        height: Image height (default: 512)
    """
    client = Client("https://ysharma-sanasprint.hf.space/")

    try:
        result = client.predict(
            prompt, "0.6B", 0, True, width, height, 4.0, 2, api_name="/infer"
        )

        if isinstance(result, list) and len(result) >= 1:
            image_data = result[0]
            if isinstance(image_data, dict) and "url" in image_data:
                return json.dumps(
                    {
                        "type": "image",
                        "url": image_data["url"],
                        "message": f"Generated image for prompt: {prompt}",
                    }
                )

        return json.dumps({"type": "error", "message": "Failed to generate image"})

    except Exception as e:
        return json.dumps(
            {"type": "error", "message": f"Error generating image: {str(e)}"}
        )


class SearchOutput(BaseModel):
    idno: str = Field(..., description="The unique identifier of the indicator.")
    name: str = Field(..., description="The name of the indicator.")


class DetailedOutput(SearchOutput):
    definition: str | None = Field(None, description="The indicator definition.")


@mcp.tool()
async def search_relevant_indicators(query: str, top_k: int = 1) -> list[SearchOutput]:
    """Search for a shortlist of relevant indicators from the World Development Indicators (WDI) given the query. The search ranking may not be optimal, so the LLM may use this as shortlist and pick the most relevant from the list (if any).

    Args:
        query: The search query by the user or one formulated by an LLM based on the user's prompt.
        top_k: The number of shortlisted indicators that will be returned that are semantically related to the query.

    Returns:
        List of objects with keys indicator code/idno and name.
    """

    return [
        SearchOutput(**out)
        for out in get_top_k(query=query, top_k=top_k, fields=["idno", "name"])
    ]


@mcp.tool()
async def indicator_info(indicator_ids: list[str]) -> list[DetailedOutput]:
    """Provides definition information for the given indicator id (idno).

    Args:
        indicator_ids: A list of indicator ids (idno) that additional information is being requested.

    Returns:
        List of objects with keys indicator code/idno, name, and definition.
    """
    if isinstance(indicator_ids, str):
        indicator_ids = [indicator_ids]

    return [
        DetailedOutput(**out)
        for out in df.loc[indicator_ids][["idno", "name", "definition"]].to_dict(
            "records"
        )
    ]


if __name__ == "__main__":
    mcp.run(transport="stdio")