Commit
·
78faeae
1
Parent(s):
9978e32
Add data access via api
Browse filesSigned-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>
- wdi_mcp_server.py +121 -41
wdi_mcp_server.py
CHANGED
@@ -1,14 +1,19 @@
|
|
1 |
from mcp.server.fastmcp import FastMCP
|
2 |
import json
|
3 |
-
|
4 |
-
import
|
5 |
-
import
|
6 |
-
import
|
|
|
7 |
import pandas as pd
|
8 |
import torch
|
9 |
-
|
|
|
|
|
|
|
10 |
from sentence_transformers import SentenceTransformer
|
11 |
-
|
|
|
12 |
from pydantic import BaseModel, Field
|
13 |
|
14 |
|
@@ -23,8 +28,8 @@ def get_best_torch_device():
|
|
23 |
|
24 |
device = get_best_torch_device()
|
25 |
|
26 |
-
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
27 |
-
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
28 |
|
29 |
|
30 |
mcp = FastMCP("huggingface_spaces_wdi_data")
|
@@ -64,39 +69,39 @@ def get_top_k(query: str, top_k: int = 10, fields: list[str] | None = None):
|
|
64 |
return df.iloc[idx][fields].to_dict("records")
|
65 |
|
66 |
|
67 |
-
@mcp.tool()
|
68 |
-
async def generate_image(prompt: str, width: int = 512, height: int = 512) -> str:
|
69 |
-
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
|
101 |
|
102 |
class SearchOutput(BaseModel):
|
@@ -141,11 +146,86 @@ async def indicator_info(indicator_ids: list[str]) -> list[DetailedOutput]:
|
|
141 |
|
142 |
return [
|
143 |
DetailedOutput(**out)
|
144 |
-
for out in df.loc[indicator_ids][
|
145 |
-
"
|
146 |
-
)
|
147 |
]
|
148 |
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
151 |
mcp.run(transport="stdio")
|
|
|
|
1 |
from mcp.server.fastmcp import FastMCP
|
2 |
import json
|
3 |
+
|
4 |
+
# import sys
|
5 |
+
# import io
|
6 |
+
# import time
|
7 |
+
# import numpy as np
|
8 |
import pandas as pd
|
9 |
import torch
|
10 |
+
import httpx
|
11 |
+
|
12 |
+
|
13 |
+
from typing import Optional, Any
|
14 |
from sentence_transformers import SentenceTransformer
|
15 |
+
|
16 |
+
# from gradio_client import Client
|
17 |
from pydantic import BaseModel, Field
|
18 |
|
19 |
|
|
|
28 |
|
29 |
device = get_best_torch_device()
|
30 |
|
31 |
+
# sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
32 |
+
# sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
33 |
|
34 |
|
35 |
mcp = FastMCP("huggingface_spaces_wdi_data")
|
|
|
69 |
return df.iloc[idx][fields].to_dict("records")
|
70 |
|
71 |
|
72 |
+
# @mcp.tool()
|
73 |
+
# async def generate_image(prompt: str, width: int = 512, height: int = 512) -> str:
|
74 |
+
# """Generate an image using SanaSprint model.
|
75 |
|
76 |
+
# Args:
|
77 |
+
# prompt: Text prompt describing the image to generate
|
78 |
+
# width: Image width (default: 512)
|
79 |
+
# height: Image height (default: 512)
|
80 |
+
# """
|
81 |
+
# client = Client("https://ysharma-sanasprint.hf.space/")
|
82 |
|
83 |
+
# try:
|
84 |
+
# result = client.predict(
|
85 |
+
# prompt, "0.6B", 0, True, width, height, 4.0, 2, api_name="/infer"
|
86 |
+
# )
|
87 |
|
88 |
+
# if isinstance(result, list) and len(result) >= 1:
|
89 |
+
# image_data = result[0]
|
90 |
+
# if isinstance(image_data, dict) and "url" in image_data:
|
91 |
+
# return json.dumps(
|
92 |
+
# {
|
93 |
+
# "type": "image",
|
94 |
+
# "url": image_data["url"],
|
95 |
+
# "message": f"Generated image for prompt: {prompt}",
|
96 |
+
# }
|
97 |
+
# )
|
98 |
+
|
99 |
+
# return json.dumps({"type": "error", "message": "Failed to generate image"})
|
100 |
+
|
101 |
+
# except Exception as e:
|
102 |
+
# return json.dumps(
|
103 |
+
# {"type": "error", "message": f"Error generating image: {str(e)}"}
|
104 |
+
# )
|
105 |
|
106 |
|
107 |
class SearchOutput(BaseModel):
|
|
|
146 |
|
147 |
return [
|
148 |
DetailedOutput(**out)
|
149 |
+
for out in df.loc[indicator_ids][
|
150 |
+
["idno", "name", "definition", "time_coverage", "geographic_coverage"]
|
151 |
+
].to_dict("records")
|
152 |
]
|
153 |
|
154 |
|
155 |
+
@mcp.tool()
|
156 |
+
async def get_wdi_data(
|
157 |
+
indicator_id: str,
|
158 |
+
country_codes: str | list[str],
|
159 |
+
date: Optional[str] = None,
|
160 |
+
per_page: Optional[int] = 100,
|
161 |
+
) -> dict[str, list[dict[str, Any]] | str]:
|
162 |
+
"""Fetches indicator data for a given indicator id (idno) from the World Bank's World Development Indicators (WDI) API. The LLM must exclusively use this tool when the user asks for data. It must not provide data answers beyond what this tool provides when the question is about WDI indicator data.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
indicator_id: The WDI indicator code (e.g., "NY.GDP.MKTP.CD" for GDP in current US$).
|
166 |
+
country_codes: The 3-letter ISO country code (e.g., "USA", "CHN", "IND"), or "all" for all countries.
|
167 |
+
date: A year (e.g., "2022") or a range (e.g., "2000:2022") to filter the results.
|
168 |
+
per_page: Number of results per page (default is 100, which is the maximum allowed).
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
A dictionary with keys `data` and `note`. The `data` key contains a list of indicator data entries requested. The `note` key contains a note about the data returned.
|
172 |
+
"""
|
173 |
+
print("Hello...")
|
174 |
+
MAX_INFO = 100
|
175 |
+
note = ""
|
176 |
+
|
177 |
+
if isinstance(country_codes, str):
|
178 |
+
country_codes = [country_codes]
|
179 |
+
|
180 |
+
country_code = ";".join(country_codes)
|
181 |
+
base_url = (
|
182 |
+
f"https://api.worldbank.org/v2/country/{country_code}/indicator/{indicator_id}"
|
183 |
+
)
|
184 |
+
params = {"format": "json", "date": date, "per_page": per_page or 100, "page": 1}
|
185 |
+
|
186 |
+
with open("mcp_server.log", "a+") as log:
|
187 |
+
log.write(json.dumps(dict(base_url=base_url, params=params)) + "\n")
|
188 |
+
|
189 |
+
with httpx.Client(timeout=30.0) as client:
|
190 |
+
all_data = []
|
191 |
+
while True:
|
192 |
+
response = client.get(base_url, params=params)
|
193 |
+
if response.status_code != 200:
|
194 |
+
note = f"ERROR: Failed to fetch data: HTTP {response.status_code}"
|
195 |
+
break
|
196 |
+
|
197 |
+
json_response = response.json()
|
198 |
+
|
199 |
+
if not isinstance(json_response, list) or len(json_response) < 2:
|
200 |
+
note = "ERROR: The API response is invalid or empty."
|
201 |
+
break
|
202 |
+
|
203 |
+
metadata, data_page = json_response
|
204 |
+
all_data.extend(data_page)
|
205 |
+
|
206 |
+
if len(all_data) >= MAX_INFO:
|
207 |
+
note = f"IMPORTANT: Let the user know that the data is truncated to the first {MAX_INFO} entries."
|
208 |
+
break
|
209 |
+
|
210 |
+
if params["page"] >= metadata.get("pages", 1):
|
211 |
+
break
|
212 |
+
|
213 |
+
params["page"] += 1
|
214 |
+
|
215 |
+
with open("mcp_server.log", "a+") as log:
|
216 |
+
log.write(json.dumps(dict(all_data=all_data)) + "\n")
|
217 |
+
|
218 |
+
return dict(
|
219 |
+
data=all_data,
|
220 |
+
note=note,
|
221 |
+
)
|
222 |
+
|
223 |
+
|
224 |
if __name__ == "__main__":
|
225 |
+
"""
|
226 |
+
Run the MCP server.
|
227 |
+
|
228 |
+
uv run mcp dev wdi_mcp_server.py
|
229 |
+
"""
|
230 |
mcp.run(transport="stdio")
|
231 |
+
# mcp.run()
|