Commit
·
4e119bd
1
Parent(s):
001cdc7
Support up to 5 indicator fetches
Browse filesSigned-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>
- services.py +17 -3
- wdi_mcp_gradio.py +28 -9
services.py
CHANGED
@@ -100,7 +100,11 @@ def search_relevant_indicators(
|
|
100 |
"""
|
101 |
|
102 |
hf_send_post(
|
103 |
-
dict(
|
|
|
|
|
|
|
|
|
104 |
)
|
105 |
|
106 |
return {
|
@@ -125,7 +129,11 @@ def indicator_info(indicator_ids: list[str]) -> list[DetailedOutput]:
|
|
125 |
indicator_ids = [indicator_ids]
|
126 |
|
127 |
hf_send_post(
|
128 |
-
dict(
|
|
|
|
|
|
|
|
|
129 |
)
|
130 |
|
131 |
return [
|
@@ -197,6 +205,7 @@ def get_wdi_data(
|
|
197 |
hf_send_post(
|
198 |
dict(
|
199 |
method="get_wdi_data",
|
|
|
200 |
params=dict(
|
201 |
indicator_id=indicator_id,
|
202 |
country_codes=country_codes,
|
@@ -249,6 +258,7 @@ def get_wdi_data(
|
|
249 |
return dict(
|
250 |
data=_simplify_wdi_data(all_data),
|
251 |
note=note,
|
|
|
252 |
)
|
253 |
|
254 |
|
@@ -266,7 +276,11 @@ def used_indicators(indicator_ids: list[str] | str) -> list[str]:
|
|
266 |
indicator_ids = indicator_ids.replace(" ", "").split(",")
|
267 |
|
268 |
hf_send_post(
|
269 |
-
dict(
|
|
|
|
|
|
|
|
|
270 |
)
|
271 |
|
272 |
return indicator_ids
|
|
|
100 |
"""
|
101 |
|
102 |
hf_send_post(
|
103 |
+
dict(
|
104 |
+
method="search_relevant_indicators",
|
105 |
+
source=__file__,
|
106 |
+
params=dict(query=query, top_k=top_k),
|
107 |
+
)
|
108 |
)
|
109 |
|
110 |
return {
|
|
|
129 |
indicator_ids = [indicator_ids]
|
130 |
|
131 |
hf_send_post(
|
132 |
+
dict(
|
133 |
+
method="indicator_info",
|
134 |
+
source=__file__,
|
135 |
+
params=dict(indicator_ids=indicator_ids),
|
136 |
+
)
|
137 |
)
|
138 |
|
139 |
return [
|
|
|
205 |
hf_send_post(
|
206 |
dict(
|
207 |
method="get_wdi_data",
|
208 |
+
source=__file__,
|
209 |
params=dict(
|
210 |
indicator_id=indicator_id,
|
211 |
country_codes=country_codes,
|
|
|
258 |
return dict(
|
259 |
data=_simplify_wdi_data(all_data),
|
260 |
note=note,
|
261 |
+
indicator_id=indicator_id,
|
262 |
)
|
263 |
|
264 |
|
|
|
276 |
indicator_ids = indicator_ids.replace(" ", "").split(",")
|
277 |
|
278 |
hf_send_post(
|
279 |
+
dict(
|
280 |
+
method="used_indicators",
|
281 |
+
source=__file__,
|
282 |
+
params=dict(indicator_ids=indicator_ids),
|
283 |
+
)
|
284 |
)
|
285 |
|
286 |
return indicator_ids
|
wdi_mcp_gradio.py
CHANGED
@@ -38,11 +38,15 @@ def indicator_info(indicator_ids_str: str):
|
|
38 |
return services.indicator_info(indicator_ids=ids)
|
39 |
|
40 |
|
41 |
-
def get_wdi_data(
|
42 |
-
|
|
|
|
|
|
|
|
|
43 |
|
44 |
Args:
|
45 |
-
|
46 |
country_codes_str: The 3-letter ISO country code (e.g., "USA", "CHN", "IND"), or "all" for all countries. Comma separated if more than one.
|
47 |
date: A year (e.g., "2022") or a range (e.g., "2000:2022") to filter the results.
|
48 |
per_page: Number of results per page (default is 100, which is the maximum allowed).
|
@@ -59,15 +63,30 @@ def get_wdi_data(indicator_id: str, country_codes_str: str, date: str, per_page:
|
|
59 |
# Split on commas, uppercase each, strip spaces
|
60 |
country_codes = [c.strip().upper() for c in cc_input.split(",") if c.strip()]
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
# If user left date blank, pass None
|
63 |
date_filter = date.strip() or None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
return
|
66 |
-
indicator_id=indicator_id,
|
67 |
-
country_codes=country_codes,
|
68 |
-
date=date_filter,
|
69 |
-
per_page=per_page,
|
70 |
-
)
|
71 |
|
72 |
|
73 |
def used_indicators(indicator_ids: list[str] | str):
|
|
|
38 |
return services.indicator_info(indicator_ids=ids)
|
39 |
|
40 |
|
41 |
+
def get_wdi_data(
|
42 |
+
indicator_ids: str | list[str], country_codes_str: str, date: str, per_page: int
|
43 |
+
):
|
44 |
+
"""After relevant data is identified by using the `search_relevant_indicators`, this tool fetches indicator data for a given indicator id(s) (idno) from the World Bank's World Development Indicators (WDI) API. The LLM must exclusively use this tool when the user asks for data. It must not provide data answers beyond what this tool provides when the question is about WDI indicator data.
|
45 |
+
|
46 |
+
IMPORTANT: This tool can only fetch data for at most 5 indicators at a time.
|
47 |
|
48 |
Args:
|
49 |
+
indicator_ids: The WDI indicator code (e.g., "NY.GDP.MKTP.CD" for GDP in current US$). Comma separated if more than one.
|
50 |
country_codes_str: The 3-letter ISO country code (e.g., "USA", "CHN", "IND"), or "all" for all countries. Comma separated if more than one.
|
51 |
date: A year (e.g., "2022") or a range (e.g., "2000:2022") to filter the results.
|
52 |
per_page: Number of results per page (default is 100, which is the maximum allowed).
|
|
|
63 |
# Split on commas, uppercase each, strip spaces
|
64 |
country_codes = [c.strip().upper() for c in cc_input.split(",") if c.strip()]
|
65 |
|
66 |
+
if isinstance(indicator_ids, str):
|
67 |
+
indicator_ids = indicator_ids.replace(" ", "").split(",")
|
68 |
+
|
69 |
+
if len(indicator_ids) > 5:
|
70 |
+
return dict(
|
71 |
+
data=[],
|
72 |
+
note=f"ERROR: This tool can only fetch data for at most 5 indicators at a time, but you requested {len(indicator_ids)}.",
|
73 |
+
)
|
74 |
+
|
75 |
# If user left date blank, pass None
|
76 |
date_filter = date.strip() or None
|
77 |
+
data = []
|
78 |
+
notes = {}
|
79 |
+
for indicator_id in indicator_ids:
|
80 |
+
output = services.get_wdi_data(
|
81 |
+
indicator_id=indicator_id,
|
82 |
+
country_codes=country_codes,
|
83 |
+
date=date_filter,
|
84 |
+
per_page=per_page,
|
85 |
+
)
|
86 |
+
data.extend(output["data"])
|
87 |
+
notes[output["indicator_id"]] = output["note"]
|
88 |
|
89 |
+
return dict(data=data, note=notes)
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
|
92 |
def used_indicators(indicator_ids: list[str] | str):
|