Kunal Pai commited on
Commit
aa7e221
·
1 Parent(s): 97e9ed5

Refactor benchmarking script to implement HLE dataset performance evaluation and improve response handling

Browse files
Files changed (1) hide show
  1. bench/benchmarking_hle.py +138 -111
bench/benchmarking_hle.py CHANGED
@@ -1,129 +1,156 @@
1
  from gradio_client import Client
2
- import geopandas as gpd
3
  import json
4
  import time
5
  import random
6
  import os
7
  from datetime import datetime
8
- from pathlib import Path
9
 
10
-
11
- def load_countries(geo_path):
12
  """
13
- Load country centroids list from a GeoJSON/Shapefile via GeoPandas.
14
- Returns a list of country names.
 
15
  """
16
- gdf = gpd.read_file(geo_path)
17
- # pick a name field
18
- name_field = next((c for c in ["ADMIN","NAME","NAME_EN","NAME_LONG","SOVEREIGN","COUNTRY"] if c in gdf.columns), None)
19
- if not name_field:
20
- # fallback to first non-geometry
21
- non_geom = [c for c in gdf.columns if c.lower() != "geometry"]
22
- name_field = non_geom[0]
23
-
24
- return gdf[name_field].dropna().unique().tolist()
25
-
26
-
27
- def benchmark_globle_api(
28
- server_url: str = "http://127.0.0.1:7860/",
29
- geo_path: str = "./tools/util/countries.geojson",
30
- num_trials: int = 10,
31
- results_dir: str = "results"
32
- ):
33
  """
34
- Benchmark a GlobleDistanceTool deployed behind a Gradio API.
35
-
36
- Each trial resets the game, reads the hidden target, then issues random guesses until correct or gave up.
37
- Results are written to JSONL, one line per trial.
 
38
  """
39
- # prepare client and results
40
- client = Client(server_url)
41
- os.makedirs(results_dir, exist_ok=True)
 
 
 
 
 
 
42
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
43
- results_file = Path(results_dir) / f"globle_api_benchmark_{timestamp}.jsonl"
44
-
45
- # load country list locally
46
- country_list = load_countries(geo_path)
47
-
48
- all_trials = []
49
- for trial in range(1, num_trials + 1):
50
- # reset game
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  try:
52
- client.predict(
53
- {"geo_path": geo_path, "guess": "", "reset": True},
 
54
  api_name="/run"
55
  )
56
- except Exception:
57
- # reset call may error due to missing guess—but state is reset on server
58
- pass
59
-
60
- # read target from the shared state file
61
- state = json.loads(Path.home().joinpath(".Globle_distance_state.json").read_text())
62
- target = state["target"]
63
-
64
- trial_record = {"trial": trial, "target": target, "guesses": []}
65
- guess_count = 0
66
-
67
- # loop until terminal
68
- while True:
69
- guess = random.choice([c for c in country_list if c != target])
70
- guess_count += 1
71
-
72
- start = time.time()
73
- out = client.predict(
74
- {"geo_path": geo_path, "guess": guess, "reset": False},
75
- api_name="/run"
76
- )
77
- latency = time.time() - start
78
-
79
- # parse output structure
80
- status = out.get("status")
81
- msg = out.get("message")
82
- output = out.get("output", {})
83
- result = output.get("result")
84
-
85
- trial_record["guesses"].append({
86
- "guess": guess,
87
- "time_s": latency,
88
- "status": status,
89
- "message": msg,
90
- "output": output
91
- })
92
-
93
- if result == "gave_up":
94
- trial_record["failed"] = True
95
- break
96
- if result == "correct":
97
- trial_record["failed"] = False
98
- trial_record["guess_count"] = guess_count
99
- break
100
-
101
- # write JSONL line
102
- with open(results_file, "a") as f:
103
- f.write(json.dumps(trial_record) + "\n")
104
- all_trials.append(trial_record)
105
-
106
- # summary
107
- latencies = [g["time_s"] for t in all_trials for g in t["guesses"]]
108
- avg_latency = sum(latencies) / len(latencies) if latencies else 0.0
109
- print(f"Completed {num_trials} trials. Avg latency: {avg_latency:.3f}s over {len(latencies)} calls.")
110
- print(f"Results saved to {results_file}")
111
-
112
- return all_trials
113
-
 
 
 
114
 
115
  if __name__ == "__main__":
116
- import argparse
117
- p = argparse.ArgumentParser(description="Benchmark GlobleDistanceTool via Gradio API")
118
- p.add_argument("--server", type=str, default="http://127.0.0.1:7860/", help="Gradio server URL")
119
- p.add_argument("--geo", type=str, default="./tools/util/countries.geojson", help="Path to geojson file")
120
- p.add_argument("--trials", type=int, default=10, help="Number of games to run")
121
- p.add_argument("--outdir", type=str, default="results", help="Output directory for JSONL results")
122
- args = p.parse_args()
123
-
124
- benchmark_globle_api(
125
- server_url=args.server,
126
- geo_path=args.geo,
127
- num_trials=args.trials,
128
- results_dir=args.outdir
129
  )
 
1
  from gradio_client import Client
2
+ from datasets import load_dataset
3
  import json
4
  import time
5
  import random
6
  import os
7
  from datetime import datetime
8
+ import re
9
 
10
+ def get_last_assistant_content(agent_response_json):
 
11
  """
12
+ Parses the agent's full response JSON to find the content of the last
13
+ turn with the 'assistant' role that contains content.
14
+ Returns the content string if found, otherwise an empty string.
15
  """
16
+ content = ""
17
+ # Find the content of the last turn with the 'assistant' role
18
+ if agent_response_json and 'agent_response' in agent_response_json and isinstance(agent_response_json['agent_response'], list):
19
+ for turn in reversed(agent_response_json['agent_response']):
20
+ # Check for 'assistant' role and if the turn has content
21
+ turn_content = turn.get('content')
22
+ if turn.get('role') == 'assistant' and turn_content is not None and turn_content != "":
23
+ content = turn_content
24
+ break # Found the last assistant turn with non-empty content
25
+
26
+ return content
27
+
28
+ def benchmark_hle(num_samples=20, categories=None):
 
 
 
 
29
  """
30
+ Benchmark agent performance on HLE dataset
31
+
32
+ Args:
33
+ num_samples: Number of samples to test
34
+ categories: List of categories to include (None for all)
35
  """
36
+ # Load HLE dataset
37
+ print("Loading HLE dataset...")
38
+ dataset = load_dataset("cais/hle")
39
+
40
+ # Initialize client
41
+ client = Client("http://127.0.0.1:7860/")
42
+
43
+ # Create results directory if it doesn't exist
44
+ os.makedirs("results", exist_ok=True)
45
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
46
+ results_file = f"results/hle_benchmark_{timestamp}.jsonl"
47
+
48
+ # Select samples
49
+ all_samples = []
50
+ for split in ['validation', 'test']: # Using validation and test splits
51
+ if split in dataset:
52
+ all_samples.extend(dataset[split])
53
+
54
+ # Filter by category if specified
55
+ if categories:
56
+ all_samples = [s for s in all_samples if s.get('category') in categories]
57
+
58
+ # Filter out prompts mentioning images (text-substring only)
59
+ filtered_samples = [s for s in all_samples if 'image' not in s.get('input', '').lower()]
60
+ removed = len(all_samples) - len(filtered_samples)
61
+ if removed > 0:
62
+ print(f"Filtered out {removed} samples containing 'image'.")
63
+ all_samples = filtered_samples
64
+
65
+ # Select random samples
66
+ if len(all_samples) > num_samples:
67
+ samples = random.sample(all_samples, num_samples)
68
+ else:
69
+ samples = all_samples
70
+ print(f"Warning: Only found {len(samples)} samples after filtering.")
71
+
72
+ print(f"Running benchmark on {len(samples)} samples...")
73
+
74
+ # Run benchmarks
75
+ results = []
76
+ for i, sample in enumerate(samples):
77
+ print(f"\nProcessing sample {i+1}/{len(samples)}")
78
+ category = sample.get('category', 'Unknown')
79
+ prompt = sample.get('question', '')
80
+ print(f"Category: {category}")
81
+ print(f"Question: {prompt[:100]}...")
82
+
83
+ # Send query to agent
84
  try:
85
+ start_time = time.time()
86
+ response = client.predict(
87
+ messages=[{"role": "user", "content": prompt}],
88
  api_name="/run"
89
  )
90
+ end_time = time.time()
91
+
92
+ target_answer_phrase = sample.get('answer', '').strip()
93
+
94
+ agent_final_response_content = get_last_assistant_content(response)
95
+
96
+ is_correct = False
97
+
98
+ # Only attempt the check if both the target phrase and the agent content are non-empty
99
+ if target_answer_phrase and agent_final_response_content:
100
+ # Perform the simple case-insensitive substring check
101
+ if target_answer_phrase.lower() in agent_final_response_content.lower():
102
+ is_correct = True
103
+
104
+ # Record result
105
+ result = {
106
+ "sample_id": sample.get('id', f'sample_{i}'),
107
+ "category": category,
108
+ "input": prompt,
109
+ "target_output": sample.get('answer', ''),
110
+ "agent_full_response": response,
111
+ "agent_final_response": agent_final_response_content,
112
+ "response_time": end_time - start_time,
113
+ "is_correct": is_correct
114
+ }
115
+
116
+ results.append(result)
117
+
118
+ # Write to file immediately to preserve progress
119
+ with open(results_file, 'a') as f:
120
+ f.write(json.dumps(result) + '\n')
121
+
122
+ print(f"Response received in {end_time - start_time:.2f} seconds")
123
+ print(f"Response: {response[:100]}...")
124
+
125
+ # Add a delay to avoid overwhelming the server
126
+ time.sleep(1)
127
+
128
+ except Exception as e:
129
+ print(f"Error processing sample: {e}")
130
+ continue
131
+
132
+ # Print summary statistics
133
+ print("\n===== HLE BENCHMARK SUMMARY =====")
134
+ print(f"Samples processed: {len(results)}")
135
+
136
+ # Categorize by categories
137
+ by_category = {}
138
+ for result in results:
139
+ category = result.get('category', 'Unknown')
140
+ by_category.setdefault(category, []).append(result)
141
+
142
+ print("\nSamples by category:")
143
+ for category, items in by_category.items():
144
+ print(f" {category}: {len(items)} samples")
145
+
146
+ avg_time = sum(r.get('response_time', 0) for r in results) / len(results) if results else 0
147
+ print(f"\nAverage response time: {avg_time:.2f} seconds")
148
+ print(f"Results saved to: {results_file}")
149
+
150
+ return results
151
 
152
  if __name__ == "__main__":
153
+ benchmark_hle(
154
+ num_samples=1,
155
+ categories=None
 
 
 
 
 
 
 
 
 
 
156
  )