hashiruAI / bench /benchmarking_hle.py
Kunal Pai
Refactor benchmarking script to implement HLE dataset performance evaluation and improve response handling
aa7e221
raw
history blame
5.72 kB
from gradio_client import Client
from datasets import load_dataset
import json
import time
import random
import os
from datetime import datetime
import re
def get_last_assistant_content(agent_response_json):
"""
Parses the agent's full response JSON to find the content of the last
turn with the 'assistant' role that contains content.
Returns the content string if found, otherwise an empty string.
"""
content = ""
# Find the content of the last turn with the 'assistant' role
if agent_response_json and 'agent_response' in agent_response_json and isinstance(agent_response_json['agent_response'], list):
for turn in reversed(agent_response_json['agent_response']):
# Check for 'assistant' role and if the turn has content
turn_content = turn.get('content')
if turn.get('role') == 'assistant' and turn_content is not None and turn_content != "":
content = turn_content
break # Found the last assistant turn with non-empty content
return content
def benchmark_hle(num_samples=20, categories=None):
"""
Benchmark agent performance on HLE dataset
Args:
num_samples: Number of samples to test
categories: List of categories to include (None for all)
"""
# Load HLE dataset
print("Loading HLE dataset...")
dataset = load_dataset("cais/hle")
# Initialize client
client = Client("http://127.0.0.1:7860/")
# Create results directory if it doesn't exist
os.makedirs("results", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = f"results/hle_benchmark_{timestamp}.jsonl"
# Select samples
all_samples = []
for split in ['validation', 'test']: # Using validation and test splits
if split in dataset:
all_samples.extend(dataset[split])
# Filter by category if specified
if categories:
all_samples = [s for s in all_samples if s.get('category') in categories]
# Filter out prompts mentioning images (text-substring only)
filtered_samples = [s for s in all_samples if 'image' not in s.get('input', '').lower()]
removed = len(all_samples) - len(filtered_samples)
if removed > 0:
print(f"Filtered out {removed} samples containing 'image'.")
all_samples = filtered_samples
# Select random samples
if len(all_samples) > num_samples:
samples = random.sample(all_samples, num_samples)
else:
samples = all_samples
print(f"Warning: Only found {len(samples)} samples after filtering.")
print(f"Running benchmark on {len(samples)} samples...")
# Run benchmarks
results = []
for i, sample in enumerate(samples):
print(f"\nProcessing sample {i+1}/{len(samples)}")
category = sample.get('category', 'Unknown')
prompt = sample.get('question', '')
print(f"Category: {category}")
print(f"Question: {prompt[:100]}...")
# Send query to agent
try:
start_time = time.time()
response = client.predict(
messages=[{"role": "user", "content": prompt}],
api_name="/run"
)
end_time = time.time()
target_answer_phrase = sample.get('answer', '').strip()
agent_final_response_content = get_last_assistant_content(response)
is_correct = False
# Only attempt the check if both the target phrase and the agent content are non-empty
if target_answer_phrase and agent_final_response_content:
# Perform the simple case-insensitive substring check
if target_answer_phrase.lower() in agent_final_response_content.lower():
is_correct = True
# Record result
result = {
"sample_id": sample.get('id', f'sample_{i}'),
"category": category,
"input": prompt,
"target_output": sample.get('answer', ''),
"agent_full_response": response,
"agent_final_response": agent_final_response_content,
"response_time": end_time - start_time,
"is_correct": is_correct
}
results.append(result)
# Write to file immediately to preserve progress
with open(results_file, 'a') as f:
f.write(json.dumps(result) + '\n')
print(f"Response received in {end_time - start_time:.2f} seconds")
print(f"Response: {response[:100]}...")
# Add a delay to avoid overwhelming the server
time.sleep(1)
except Exception as e:
print(f"Error processing sample: {e}")
continue
# Print summary statistics
print("\n===== HLE BENCHMARK SUMMARY =====")
print(f"Samples processed: {len(results)}")
# Categorize by categories
by_category = {}
for result in results:
category = result.get('category', 'Unknown')
by_category.setdefault(category, []).append(result)
print("\nSamples by category:")
for category, items in by_category.items():
print(f" {category}: {len(items)} samples")
avg_time = sum(r.get('response_time', 0) for r in results) / len(results) if results else 0
print(f"\nAverage response time: {avg_time:.2f} seconds")
print(f"Results saved to: {results_file}")
return results
if __name__ == "__main__":
benchmark_hle(
num_samples=1,
categories=None
)