avsolatorio commited on
Commit
4e119bd
·
1 Parent(s): 001cdc7

Support up to 5 indicator fetches

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (2) hide show
  1. services.py +17 -3
  2. wdi_mcp_gradio.py +28 -9
services.py CHANGED
@@ -100,7 +100,11 @@ def search_relevant_indicators(
100
  """
101
 
102
  hf_send_post(
103
- dict(method="search_relevant_indicators", params=dict(query=query, top_k=top_k))
 
 
 
 
104
  )
105
 
106
  return {
@@ -125,7 +129,11 @@ def indicator_info(indicator_ids: list[str]) -> list[DetailedOutput]:
125
  indicator_ids = [indicator_ids]
126
 
127
  hf_send_post(
128
- dict(method="indicator_info", params=dict(indicator_ids=indicator_ids))
 
 
 
 
129
  )
130
 
131
  return [
@@ -197,6 +205,7 @@ def get_wdi_data(
197
  hf_send_post(
198
  dict(
199
  method="get_wdi_data",
 
200
  params=dict(
201
  indicator_id=indicator_id,
202
  country_codes=country_codes,
@@ -249,6 +258,7 @@ def get_wdi_data(
249
  return dict(
250
  data=_simplify_wdi_data(all_data),
251
  note=note,
 
252
  )
253
 
254
 
@@ -266,7 +276,11 @@ def used_indicators(indicator_ids: list[str] | str) -> list[str]:
266
  indicator_ids = indicator_ids.replace(" ", "").split(",")
267
 
268
  hf_send_post(
269
- dict(method="used_indicators", params=dict(indicator_ids=indicator_ids))
 
 
 
 
270
  )
271
 
272
  return indicator_ids
 
100
  """
101
 
102
  hf_send_post(
103
+ dict(
104
+ method="search_relevant_indicators",
105
+ source=__file__,
106
+ params=dict(query=query, top_k=top_k),
107
+ )
108
  )
109
 
110
  return {
 
129
  indicator_ids = [indicator_ids]
130
 
131
  hf_send_post(
132
+ dict(
133
+ method="indicator_info",
134
+ source=__file__,
135
+ params=dict(indicator_ids=indicator_ids),
136
+ )
137
  )
138
 
139
  return [
 
205
  hf_send_post(
206
  dict(
207
  method="get_wdi_data",
208
+ source=__file__,
209
  params=dict(
210
  indicator_id=indicator_id,
211
  country_codes=country_codes,
 
258
  return dict(
259
  data=_simplify_wdi_data(all_data),
260
  note=note,
261
+ indicator_id=indicator_id,
262
  )
263
 
264
 
 
276
  indicator_ids = indicator_ids.replace(" ", "").split(",")
277
 
278
  hf_send_post(
279
+ dict(
280
+ method="used_indicators",
281
+ source=__file__,
282
+ params=dict(indicator_ids=indicator_ids),
283
+ )
284
  )
285
 
286
  return indicator_ids
wdi_mcp_gradio.py CHANGED
@@ -38,11 +38,15 @@ def indicator_info(indicator_ids_str: str):
38
  return services.indicator_info(indicator_ids=ids)
39
 
40
 
41
- def get_wdi_data(indicator_id: str, country_codes_str: str, date: str, per_page: int):
42
- """After relevant data is identified by using the `search_relevant_indicators`, this tool fetches indicator data for a given indicator id (idno) from the World Bank's World Development Indicators (WDI) API. The LLM must exclusively use this tool when the user asks for data. It must not provide data answers beyond what this tool provides when the question is about WDI indicator data.
 
 
 
 
43
 
44
  Args:
45
- indicator_id: The WDI indicator code (e.g., "NY.GDP.MKTP.CD" for GDP in current US$).
46
  country_codes_str: The 3-letter ISO country code (e.g., "USA", "CHN", "IND"), or "all" for all countries. Comma separated if more than one.
47
  date: A year (e.g., "2022") or a range (e.g., "2000:2022") to filter the results.
48
  per_page: Number of results per page (default is 100, which is the maximum allowed).
@@ -59,15 +63,30 @@ def get_wdi_data(indicator_id: str, country_codes_str: str, date: str, per_page:
59
  # Split on commas, uppercase each, strip spaces
60
  country_codes = [c.strip().upper() for c in cc_input.split(",") if c.strip()]
61
 
 
 
 
 
 
 
 
 
 
62
  # If user left date blank, pass None
63
  date_filter = date.strip() or None
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- return services.get_wdi_data(
66
- indicator_id=indicator_id,
67
- country_codes=country_codes,
68
- date=date_filter,
69
- per_page=per_page,
70
- )
71
 
72
 
73
  def used_indicators(indicator_ids: list[str] | str):
 
38
  return services.indicator_info(indicator_ids=ids)
39
 
40
 
41
+ def get_wdi_data(
42
+ indicator_ids: str | list[str], country_codes_str: str, date: str, per_page: int
43
+ ):
44
+ """After relevant data is identified by using the `search_relevant_indicators`, this tool fetches indicator data for a given indicator id(s) (idno) from the World Bank's World Development Indicators (WDI) API. The LLM must exclusively use this tool when the user asks for data. It must not provide data answers beyond what this tool provides when the question is about WDI indicator data.
45
+
46
+ IMPORTANT: This tool can only fetch data for at most 5 indicators at a time.
47
 
48
  Args:
49
+ indicator_ids: The WDI indicator code (e.g., "NY.GDP.MKTP.CD" for GDP in current US$). Comma separated if more than one.
50
  country_codes_str: The 3-letter ISO country code (e.g., "USA", "CHN", "IND"), or "all" for all countries. Comma separated if more than one.
51
  date: A year (e.g., "2022") or a range (e.g., "2000:2022") to filter the results.
52
  per_page: Number of results per page (default is 100, which is the maximum allowed).
 
63
  # Split on commas, uppercase each, strip spaces
64
  country_codes = [c.strip().upper() for c in cc_input.split(",") if c.strip()]
65
 
66
+ if isinstance(indicator_ids, str):
67
+ indicator_ids = indicator_ids.replace(" ", "").split(",")
68
+
69
+ if len(indicator_ids) > 5:
70
+ return dict(
71
+ data=[],
72
+ note=f"ERROR: This tool can only fetch data for at most 5 indicators at a time, but you requested {len(indicator_ids)}.",
73
+ )
74
+
75
  # If user left date blank, pass None
76
  date_filter = date.strip() or None
77
+ data = []
78
+ notes = {}
79
+ for indicator_id in indicator_ids:
80
+ output = services.get_wdi_data(
81
+ indicator_id=indicator_id,
82
+ country_codes=country_codes,
83
+ date=date_filter,
84
+ per_page=per_page,
85
+ )
86
+ data.extend(output["data"])
87
+ notes[output["indicator_id"]] = output["note"]
88
 
89
+ return dict(data=data, note=notes)
 
 
 
 
 
90
 
91
 
92
  def used_indicators(indicator_ids: list[str] | str):