avsolatorio commited on
Commit
78faeae
·
1 Parent(s): 9978e32

Add data access via api

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (1) hide show
  1. wdi_mcp_server.py +121 -41
wdi_mcp_server.py CHANGED
@@ -1,14 +1,19 @@
1
  from mcp.server.fastmcp import FastMCP
2
  import json
3
- import sys
4
- import io
5
- import time
6
- import numpy as np
 
7
  import pandas as pd
8
  import torch
9
- from sklearn.metrics.pairwise import cosine_similarity
 
 
 
10
  from sentence_transformers import SentenceTransformer
11
- from gradio_client import Client
 
12
  from pydantic import BaseModel, Field
13
 
14
 
@@ -23,8 +28,8 @@ def get_best_torch_device():
23
 
24
  device = get_best_torch_device()
25
 
26
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
27
- sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
28
 
29
 
30
  mcp = FastMCP("huggingface_spaces_wdi_data")
@@ -64,39 +69,39 @@ def get_top_k(query: str, top_k: int = 10, fields: list[str] | None = None):
64
  return df.iloc[idx][fields].to_dict("records")
65
 
66
 
67
- @mcp.tool()
68
- async def generate_image(prompt: str, width: int = 512, height: int = 512) -> str:
69
- """Generate an image using SanaSprint model.
70
 
71
- Args:
72
- prompt: Text prompt describing the image to generate
73
- width: Image width (default: 512)
74
- height: Image height (default: 512)
75
- """
76
- client = Client("https://ysharma-sanasprint.hf.space/")
77
 
78
- try:
79
- result = client.predict(
80
- prompt, "0.6B", 0, True, width, height, 4.0, 2, api_name="/infer"
81
- )
82
 
83
- if isinstance(result, list) and len(result) >= 1:
84
- image_data = result[0]
85
- if isinstance(image_data, dict) and "url" in image_data:
86
- return json.dumps(
87
- {
88
- "type": "image",
89
- "url": image_data["url"],
90
- "message": f"Generated image for prompt: {prompt}",
91
- }
92
- )
93
-
94
- return json.dumps({"type": "error", "message": "Failed to generate image"})
95
-
96
- except Exception as e:
97
- return json.dumps(
98
- {"type": "error", "message": f"Error generating image: {str(e)}"}
99
- )
100
 
101
 
102
  class SearchOutput(BaseModel):
@@ -141,11 +146,86 @@ async def indicator_info(indicator_ids: list[str]) -> list[DetailedOutput]:
141
 
142
  return [
143
  DetailedOutput(**out)
144
- for out in df.loc[indicator_ids][["idno", "name", "definition"]].to_dict(
145
- "records"
146
- )
147
  ]
148
 
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  if __name__ == "__main__":
 
 
 
 
 
151
  mcp.run(transport="stdio")
 
 
1
  from mcp.server.fastmcp import FastMCP
2
  import json
3
+
4
+ # import sys
5
+ # import io
6
+ # import time
7
+ # import numpy as np
8
  import pandas as pd
9
  import torch
10
+ import httpx
11
+
12
+
13
+ from typing import Optional, Any
14
  from sentence_transformers import SentenceTransformer
15
+
16
+ # from gradio_client import Client
17
  from pydantic import BaseModel, Field
18
 
19
 
 
28
 
29
  device = get_best_torch_device()
30
 
31
+ # sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
32
+ # sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
33
 
34
 
35
  mcp = FastMCP("huggingface_spaces_wdi_data")
 
69
  return df.iloc[idx][fields].to_dict("records")
70
 
71
 
72
+ # @mcp.tool()
73
+ # async def generate_image(prompt: str, width: int = 512, height: int = 512) -> str:
74
+ # """Generate an image using SanaSprint model.
75
 
76
+ # Args:
77
+ # prompt: Text prompt describing the image to generate
78
+ # width: Image width (default: 512)
79
+ # height: Image height (default: 512)
80
+ # """
81
+ # client = Client("https://ysharma-sanasprint.hf.space/")
82
 
83
+ # try:
84
+ # result = client.predict(
85
+ # prompt, "0.6B", 0, True, width, height, 4.0, 2, api_name="/infer"
86
+ # )
87
 
88
+ # if isinstance(result, list) and len(result) >= 1:
89
+ # image_data = result[0]
90
+ # if isinstance(image_data, dict) and "url" in image_data:
91
+ # return json.dumps(
92
+ # {
93
+ # "type": "image",
94
+ # "url": image_data["url"],
95
+ # "message": f"Generated image for prompt: {prompt}",
96
+ # }
97
+ # )
98
+
99
+ # return json.dumps({"type": "error", "message": "Failed to generate image"})
100
+
101
+ # except Exception as e:
102
+ # return json.dumps(
103
+ # {"type": "error", "message": f"Error generating image: {str(e)}"}
104
+ # )
105
 
106
 
107
  class SearchOutput(BaseModel):
 
146
 
147
  return [
148
  DetailedOutput(**out)
149
+ for out in df.loc[indicator_ids][
150
+ ["idno", "name", "definition", "time_coverage", "geographic_coverage"]
151
+ ].to_dict("records")
152
  ]
153
 
154
 
155
+ @mcp.tool()
156
+ async def get_wdi_data(
157
+ indicator_id: str,
158
+ country_codes: str | list[str],
159
+ date: Optional[str] = None,
160
+ per_page: Optional[int] = 100,
161
+ ) -> dict[str, list[dict[str, Any]] | str]:
162
+ """Fetches indicator data for a given indicator id (idno) from the World Bank's World Development Indicators (WDI) API. The LLM must exclusively use this tool when the user asks for data. It must not provide data answers beyond what this tool provides when the question is about WDI indicator data.
163
+
164
+ Args:
165
+ indicator_id: The WDI indicator code (e.g., "NY.GDP.MKTP.CD" for GDP in current US$).
166
+ country_codes: The 3-letter ISO country code (e.g., "USA", "CHN", "IND"), or "all" for all countries.
167
+ date: A year (e.g., "2022") or a range (e.g., "2000:2022") to filter the results.
168
+ per_page: Number of results per page (default is 100, which is the maximum allowed).
169
+
170
+ Returns:
171
+ A dictionary with keys `data` and `note`. The `data` key contains a list of indicator data entries requested. The `note` key contains a note about the data returned.
172
+ """
173
+ print("Hello...")
174
+ MAX_INFO = 100
175
+ note = ""
176
+
177
+ if isinstance(country_codes, str):
178
+ country_codes = [country_codes]
179
+
180
+ country_code = ";".join(country_codes)
181
+ base_url = (
182
+ f"https://api.worldbank.org/v2/country/{country_code}/indicator/{indicator_id}"
183
+ )
184
+ params = {"format": "json", "date": date, "per_page": per_page or 100, "page": 1}
185
+
186
+ with open("mcp_server.log", "a+") as log:
187
+ log.write(json.dumps(dict(base_url=base_url, params=params)) + "\n")
188
+
189
+ with httpx.Client(timeout=30.0) as client:
190
+ all_data = []
191
+ while True:
192
+ response = client.get(base_url, params=params)
193
+ if response.status_code != 200:
194
+ note = f"ERROR: Failed to fetch data: HTTP {response.status_code}"
195
+ break
196
+
197
+ json_response = response.json()
198
+
199
+ if not isinstance(json_response, list) or len(json_response) < 2:
200
+ note = "ERROR: The API response is invalid or empty."
201
+ break
202
+
203
+ metadata, data_page = json_response
204
+ all_data.extend(data_page)
205
+
206
+ if len(all_data) >= MAX_INFO:
207
+ note = f"IMPORTANT: Let the user know that the data is truncated to the first {MAX_INFO} entries."
208
+ break
209
+
210
+ if params["page"] >= metadata.get("pages", 1):
211
+ break
212
+
213
+ params["page"] += 1
214
+
215
+ with open("mcp_server.log", "a+") as log:
216
+ log.write(json.dumps(dict(all_data=all_data)) + "\n")
217
+
218
+ return dict(
219
+ data=all_data,
220
+ note=note,
221
+ )
222
+
223
+
224
  if __name__ == "__main__":
225
+ """
226
+ Run the MCP server.
227
+
228
+ uv run mcp dev wdi_mcp_server.py
229
+ """
230
  mcp.run(transport="stdio")
231
+ # mcp.run()