Kunal Pai commited on
Commit
97e9ed5
·
1 Parent(s): a997e09

Add benchmarking script for GlobleDistanceTool via Gradio API

Browse files
Files changed (1) hide show
  1. bench/benchmarking_hle.py +129 -0
bench/benchmarking_hle.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+ import geopandas as gpd
3
+ import json
4
+ import time
5
+ import random
6
+ import os
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+
10
+
11
+ def load_countries(geo_path):
12
+ """
13
+ Load country centroids list from a GeoJSON/Shapefile via GeoPandas.
14
+ Returns a list of country names.
15
+ """
16
+ gdf = gpd.read_file(geo_path)
17
+ # pick a name field
18
+ name_field = next((c for c in ["ADMIN","NAME","NAME_EN","NAME_LONG","SOVEREIGN","COUNTRY"] if c in gdf.columns), None)
19
+ if not name_field:
20
+ # fallback to first non-geometry
21
+ non_geom = [c for c in gdf.columns if c.lower() != "geometry"]
22
+ name_field = non_geom[0]
23
+
24
+ return gdf[name_field].dropna().unique().tolist()
25
+
26
+
27
+ def benchmark_globle_api(
28
+ server_url: str = "http://127.0.0.1:7860/",
29
+ geo_path: str = "./tools/util/countries.geojson",
30
+ num_trials: int = 10,
31
+ results_dir: str = "results"
32
+ ):
33
+ """
34
+ Benchmark a GlobleDistanceTool deployed behind a Gradio API.
35
+
36
+ Each trial resets the game, reads the hidden target, then issues random guesses until correct or gave up.
37
+ Results are written to JSONL, one line per trial.
38
+ """
39
+ # prepare client and results
40
+ client = Client(server_url)
41
+ os.makedirs(results_dir, exist_ok=True)
42
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
43
+ results_file = Path(results_dir) / f"globle_api_benchmark_{timestamp}.jsonl"
44
+
45
+ # load country list locally
46
+ country_list = load_countries(geo_path)
47
+
48
+ all_trials = []
49
+ for trial in range(1, num_trials + 1):
50
+ # reset game
51
+ try:
52
+ client.predict(
53
+ {"geo_path": geo_path, "guess": "", "reset": True},
54
+ api_name="/run"
55
+ )
56
+ except Exception:
57
+ # reset call may error due to missing guess—but state is reset on server
58
+ pass
59
+
60
+ # read target from the shared state file
61
+ state = json.loads(Path.home().joinpath(".Globle_distance_state.json").read_text())
62
+ target = state["target"]
63
+
64
+ trial_record = {"trial": trial, "target": target, "guesses": []}
65
+ guess_count = 0
66
+
67
+ # loop until terminal
68
+ while True:
69
+ guess = random.choice([c for c in country_list if c != target])
70
+ guess_count += 1
71
+
72
+ start = time.time()
73
+ out = client.predict(
74
+ {"geo_path": geo_path, "guess": guess, "reset": False},
75
+ api_name="/run"
76
+ )
77
+ latency = time.time() - start
78
+
79
+ # parse output structure
80
+ status = out.get("status")
81
+ msg = out.get("message")
82
+ output = out.get("output", {})
83
+ result = output.get("result")
84
+
85
+ trial_record["guesses"].append({
86
+ "guess": guess,
87
+ "time_s": latency,
88
+ "status": status,
89
+ "message": msg,
90
+ "output": output
91
+ })
92
+
93
+ if result == "gave_up":
94
+ trial_record["failed"] = True
95
+ break
96
+ if result == "correct":
97
+ trial_record["failed"] = False
98
+ trial_record["guess_count"] = guess_count
99
+ break
100
+
101
+ # write JSONL line
102
+ with open(results_file, "a") as f:
103
+ f.write(json.dumps(trial_record) + "\n")
104
+ all_trials.append(trial_record)
105
+
106
+ # summary
107
+ latencies = [g["time_s"] for t in all_trials for g in t["guesses"]]
108
+ avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
109
+ print(f"Completed {num_trials} trials. Avg latency: {avg_latency:.3f}s over {len(latencies)} calls.")
110
+ print(f"Results saved to {results_file}")
111
+
112
+ return all_trials
113
+
114
+
115
+ if __name__ == "__main__":
116
+ import argparse
117
+ p = argparse.ArgumentParser(description="Benchmark GlobleDistanceTool via Gradio API")
118
+ p.add_argument("--server", type=str, default="http://127.0.0.1:7860/", help="Gradio server URL")
119
+ p.add_argument("--geo", type=str, default="./tools/util/countries.geojson", help="Path to geojson file")
120
+ p.add_argument("--trials", type=int, default=10, help="Number of games to run")
121
+ p.add_argument("--outdir", type=str, default="results", help="Output directory for JSONL results")
122
+ args = p.parse_args()
123
+
124
+ benchmark_globle_api(
125
+ server_url=args.server,
126
+ geo_path=args.geo,
127
+ num_trials=args.trials,
128
+ results_dir=args.outdir
129
+ )