Spaces:
Running
Running
Kunal Pai
commited on
Commit
·
aa7e221
1
Parent(s):
97e9ed5
Refactor benchmarking script to implement HLE dataset performance evaluation and improve response handling
Browse files- bench/benchmarking_hle.py +138 -111
bench/benchmarking_hle.py
CHANGED
@@ -1,129 +1,156 @@
|
|
1 |
from gradio_client import Client
|
2 |
-
|
3 |
import json
|
4 |
import time
|
5 |
import random
|
6 |
import os
|
7 |
from datetime import datetime
|
8 |
-
|
9 |
|
10 |
-
|
11 |
-
def load_countries(geo_path):
|
12 |
"""
|
13 |
-
|
14 |
-
|
|
|
15 |
"""
|
16 |
-
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
geo_path: str = "./tools/util/countries.geojson",
|
30 |
-
num_trials: int = 10,
|
31 |
-
results_dir: str = "results"
|
32 |
-
):
|
33 |
"""
|
34 |
-
Benchmark
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
38 |
"""
|
39 |
-
#
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
43 |
-
results_file =
|
44 |
-
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
try:
|
52 |
-
|
53 |
-
|
|
|
54 |
api_name="/run"
|
55 |
)
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
114 |
|
115 |
if __name__ == "__main__":
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
p.add_argument("--geo", type=str, default="./tools/util/countries.geojson", help="Path to geojson file")
|
120 |
-
p.add_argument("--trials", type=int, default=10, help="Number of games to run")
|
121 |
-
p.add_argument("--outdir", type=str, default="results", help="Output directory for JSONL results")
|
122 |
-
args = p.parse_args()
|
123 |
-
|
124 |
-
benchmark_globle_api(
|
125 |
-
server_url=args.server,
|
126 |
-
geo_path=args.geo,
|
127 |
-
num_trials=args.trials,
|
128 |
-
results_dir=args.outdir
|
129 |
)
|
|
|
1 |
from gradio_client import Client
|
2 |
+
from datasets import load_dataset
|
3 |
import json
|
4 |
import time
|
5 |
import random
|
6 |
import os
|
7 |
from datetime import datetime
|
8 |
+
import re
|
9 |
|
10 |
+
def get_last_assistant_content(agent_response_json):
|
|
|
11 |
"""
|
12 |
+
Parses the agent's full response JSON to find the content of the last
|
13 |
+
turn with the 'assistant' role that contains content.
|
14 |
+
Returns the content string if found, otherwise an empty string.
|
15 |
"""
|
16 |
+
content = ""
|
17 |
+
# Find the content of the last turn with the 'assistant' role
|
18 |
+
if agent_response_json and 'agent_response' in agent_response_json and isinstance(agent_response_json['agent_response'], list):
|
19 |
+
for turn in reversed(agent_response_json['agent_response']):
|
20 |
+
# Check for 'assistant' role and if the turn has content
|
21 |
+
turn_content = turn.get('content')
|
22 |
+
if turn.get('role') == 'assistant' and turn_content is not None and turn_content != "":
|
23 |
+
content = turn_content
|
24 |
+
break # Found the last assistant turn with non-empty content
|
25 |
+
|
26 |
+
return content
|
27 |
+
|
28 |
+
def benchmark_hle(num_samples=20, categories=None):
|
|
|
|
|
|
|
|
|
29 |
"""
|
30 |
+
Benchmark agent performance on HLE dataset
|
31 |
+
|
32 |
+
Args:
|
33 |
+
num_samples: Number of samples to test
|
34 |
+
categories: List of categories to include (None for all)
|
35 |
"""
|
36 |
+
# Load HLE dataset
|
37 |
+
print("Loading HLE dataset...")
|
38 |
+
dataset = load_dataset("cais/hle")
|
39 |
+
|
40 |
+
# Initialize client
|
41 |
+
client = Client("http://127.0.0.1:7860/")
|
42 |
+
|
43 |
+
# Create results directory if it doesn't exist
|
44 |
+
os.makedirs("results", exist_ok=True)
|
45 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
46 |
+
results_file = f"results/hle_benchmark_{timestamp}.jsonl"
|
47 |
+
|
48 |
+
# Select samples
|
49 |
+
all_samples = []
|
50 |
+
for split in ['validation', 'test']: # Using validation and test splits
|
51 |
+
if split in dataset:
|
52 |
+
all_samples.extend(dataset[split])
|
53 |
+
|
54 |
+
# Filter by category if specified
|
55 |
+
if categories:
|
56 |
+
all_samples = [s for s in all_samples if s.get('category') in categories]
|
57 |
+
|
58 |
+
# Filter out prompts mentioning images (text-substring only)
|
59 |
+
filtered_samples = [s for s in all_samples if 'image' not in s.get('input', '').lower()]
|
60 |
+
removed = len(all_samples) - len(filtered_samples)
|
61 |
+
if removed > 0:
|
62 |
+
print(f"Filtered out {removed} samples containing 'image'.")
|
63 |
+
all_samples = filtered_samples
|
64 |
+
|
65 |
+
# Select random samples
|
66 |
+
if len(all_samples) > num_samples:
|
67 |
+
samples = random.sample(all_samples, num_samples)
|
68 |
+
else:
|
69 |
+
samples = all_samples
|
70 |
+
print(f"Warning: Only found {len(samples)} samples after filtering.")
|
71 |
+
|
72 |
+
print(f"Running benchmark on {len(samples)} samples...")
|
73 |
+
|
74 |
+
# Run benchmarks
|
75 |
+
results = []
|
76 |
+
for i, sample in enumerate(samples):
|
77 |
+
print(f"\nProcessing sample {i+1}/{len(samples)}")
|
78 |
+
category = sample.get('category', 'Unknown')
|
79 |
+
prompt = sample.get('question', '')
|
80 |
+
print(f"Category: {category}")
|
81 |
+
print(f"Question: {prompt[:100]}...")
|
82 |
+
|
83 |
+
# Send query to agent
|
84 |
try:
|
85 |
+
start_time = time.time()
|
86 |
+
response = client.predict(
|
87 |
+
messages=[{"role": "user", "content": prompt}],
|
88 |
api_name="/run"
|
89 |
)
|
90 |
+
end_time = time.time()
|
91 |
+
|
92 |
+
target_answer_phrase = sample.get('answer', '').strip()
|
93 |
+
|
94 |
+
agent_final_response_content = get_last_assistant_content(response)
|
95 |
+
|
96 |
+
is_correct = False
|
97 |
+
|
98 |
+
# Only attempt the check if both the target phrase and the agent content are non-empty
|
99 |
+
if target_answer_phrase and agent_final_response_content:
|
100 |
+
# Perform the simple case-insensitive substring check
|
101 |
+
if target_answer_phrase.lower() in agent_final_response_content.lower():
|
102 |
+
is_correct = True
|
103 |
+
|
104 |
+
# Record result
|
105 |
+
result = {
|
106 |
+
"sample_id": sample.get('id', f'sample_{i}'),
|
107 |
+
"category": category,
|
108 |
+
"input": prompt,
|
109 |
+
"target_output": sample.get('answer', ''),
|
110 |
+
"agent_full_response": response,
|
111 |
+
"agent_final_response": agent_final_response_content,
|
112 |
+
"response_time": end_time - start_time,
|
113 |
+
"is_correct": is_correct
|
114 |
+
}
|
115 |
+
|
116 |
+
results.append(result)
|
117 |
+
|
118 |
+
# Write to file immediately to preserve progress
|
119 |
+
with open(results_file, 'a') as f:
|
120 |
+
f.write(json.dumps(result) + '\n')
|
121 |
+
|
122 |
+
print(f"Response received in {end_time - start_time:.2f} seconds")
|
123 |
+
print(f"Response: {response[:100]}...")
|
124 |
+
|
125 |
+
# Add a delay to avoid overwhelming the server
|
126 |
+
time.sleep(1)
|
127 |
+
|
128 |
+
except Exception as e:
|
129 |
+
print(f"Error processing sample: {e}")
|
130 |
+
continue
|
131 |
+
|
132 |
+
# Print summary statistics
|
133 |
+
print("\n===== HLE BENCHMARK SUMMARY =====")
|
134 |
+
print(f"Samples processed: {len(results)}")
|
135 |
+
|
136 |
+
# Categorize by categories
|
137 |
+
by_category = {}
|
138 |
+
for result in results:
|
139 |
+
category = result.get('category', 'Unknown')
|
140 |
+
by_category.setdefault(category, []).append(result)
|
141 |
+
|
142 |
+
print("\nSamples by category:")
|
143 |
+
for category, items in by_category.items():
|
144 |
+
print(f" {category}: {len(items)} samples")
|
145 |
+
|
146 |
+
avg_time = sum(r.get('response_time', 0) for r in results) / len(results) if results else 0
|
147 |
+
print(f"\nAverage response time: {avg_time:.2f} seconds")
|
148 |
+
print(f"Results saved to: {results_file}")
|
149 |
+
|
150 |
+
return results
|
151 |
|
152 |
if __name__ == "__main__":
|
153 |
+
benchmark_hle(
|
154 |
+
num_samples=1,
|
155 |
+
categories=None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
)
|