Commit
·
34225b1
1
Parent(s):
089cc3a
Start building MCP for WDI
Browse filesSigned-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>
- wdi_mcp_server.py +78 -0
wdi_mcp_server.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mcp.server.fastmcp import FastMCP
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import io
|
5 |
+
import time
|
6 |
+
from gradio_client import Client
|
7 |
+
from pydantic import BaseModel, Field
|
8 |
+
|
9 |
+
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8", errors="replace")
|
10 |
+
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding="utf-8", errors="replace")
|
11 |
+
|
12 |
+
mcp = FastMCP("huggingface_spaces_wdi_data")
|
13 |
+
|
14 |
+
|
15 |
+
@mcp.tool()
|
16 |
+
async def generate_image(prompt: str, width: int = 512, height: int = 512) -> str:
|
17 |
+
"""Generate an image using SanaSprint model.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
prompt: Text prompt describing the image to generate
|
21 |
+
width: Image width (default: 512)
|
22 |
+
height: Image height (default: 512)
|
23 |
+
"""
|
24 |
+
client = Client("https://ysharma-sanasprint.hf.space/")
|
25 |
+
|
26 |
+
try:
|
27 |
+
result = client.predict(
|
28 |
+
prompt, "0.6B", 0, True, width, height, 4.0, 2, api_name="/infer"
|
29 |
+
)
|
30 |
+
|
31 |
+
if isinstance(result, list) and len(result) >= 1:
|
32 |
+
image_data = result[0]
|
33 |
+
if isinstance(image_data, dict) and "url" in image_data:
|
34 |
+
return json.dumps(
|
35 |
+
{
|
36 |
+
"type": "image",
|
37 |
+
"url": image_data["url"],
|
38 |
+
"message": f"Generated image for prompt: {prompt}",
|
39 |
+
}
|
40 |
+
)
|
41 |
+
|
42 |
+
return json.dumps({"type": "error", "message": "Failed to generate image"})
|
43 |
+
|
44 |
+
except Exception as e:
|
45 |
+
return json.dumps(
|
46 |
+
{"type": "error", "message": f"Error generating image: {str(e)}"}
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
class SearchOutput(BaseModel):
|
51 |
+
idno: str = Field(..., description="The unique identifier of the indicator.")
|
52 |
+
name: str = Field(..., description="The name of the indicator.")
|
53 |
+
definition: str | None = Field(None, description="The indicator definition.")
|
54 |
+
|
55 |
+
|
56 |
+
@mcp.tool()
|
57 |
+
async def search_relevant_indicators(query: str, top_k: int = 1) -> list[SearchOutput]:
|
58 |
+
"""Search for a shortlist of relevant indicators from the World Development Indicators (WDI) given the query. The search ranking may not be optimal, so the LLM may use this as shortlist and pick the most relevant from the list (if any).
|
59 |
+
|
60 |
+
Args:
|
61 |
+
query: The search query by the user or one formulated by an LLM based on the user's prompt.
|
62 |
+
top_k: The number of shortlisted indicators that will be returned that are semantically related to the query.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
List of objects with keys indicator name, code
|
66 |
+
"""
|
67 |
+
|
68 |
+
return [
|
69 |
+
SearchOutput(
|
70 |
+
idno="NY.GDP.MKTP.KD",
|
71 |
+
name="GDP (constant 2015 US$)",
|
72 |
+
definition="GDP at purchaser's prices is the sum of gross value added by all resident producers in the economy plus any product taxes and minus any subsidies not included in the value of the products. It is calculated without making deductions for depreciation of fabricated assets or for depletion and degradation of natural resources. Data are in constant 2015 prices, expressed in U.S. dollars. Dollar figures for GDP are converted from domestic currencies using 2015 official exchange rates. For a few countries where the official exchange rate does not reflect the rate effectively applied to actual foreign exchange transactions, an alternative conversion factor is used.",
|
73 |
+
)
|
74 |
+
]
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
mcp.run(transport="stdio")
|