File size: 5,850 Bytes
97e9ed5
aa7e221
97e9ed5
 
 
 
 
aa7e221
97e9ed5
81fafc1
97e9ed5
81fafc1
 
97e9ed5
81fafc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa7e221
 
97e9ed5
aa7e221
 
 
 
 
97e9ed5
aa7e221
 
 
 
 
 
 
 
 
97e9ed5
aa7e221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97e9ed5
aa7e221
0581f43
 
 
97e9ed5
aa7e221
 
 
 
0581f43
aa7e221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0581f43
aa7e221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97e9ed5
 
aa7e221
 
 
97e9ed5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from gradio_client import Client
from datasets import load_dataset
import json
import time
import random
import os
from datetime import datetime
import re

def get_last_assistant_content(resp):
    """
    Return the last assistant utterance from the response object
    produced by `client.predict`.
    """
    # ❶ If the server wraps things in a (messages, meta) tuple
    if isinstance(resp, tuple):
        resp = resp[0]

    # ❷ At this point `resp` must be the list of message dicts
    if not isinstance(resp, list):
        return ""

    for turn in reversed(resp):
        if turn.get("role") != "assistant":
            continue

        # a) plain messages
        if turn.get("content"):
            return turn["content"]

        # b) tool / function_response wrapper
        fr = turn.get("function_response", {})
        out = fr.get("result", {}).get("output")
        if out:
            return out

        # c) messages stored as Part objects inside `content`
        cont = turn.get("content")
        if isinstance(cont, dict):
            parts = cont.get("parts", [])
            if parts and parts[0].get("text"):
                return parts[0]["text"]

    return ""

def benchmark_hle(num_samples=20, categories=None):
    """
    Benchmark agent performance on HLE dataset
    
    Args:
        num_samples: Number of samples to test
        categories: List of categories to include (None for all)
    """
    # Load HLE dataset
    print("Loading HLE dataset...")
    dataset = load_dataset("cais/hle")
    
    # Initialize client
    client = Client("http://127.0.0.1:7860/")
    
    # Create results directory if it doesn't exist
    os.makedirs("results", exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_file = f"results/hle_benchmark_{timestamp}.jsonl"
    
    # Select samples
    all_samples = []
    for split in ['validation', 'test']:  # Using validation and test splits
        if split in dataset:
            all_samples.extend(dataset[split])
    
    # Filter by category if specified
    if categories:
        all_samples = [s for s in all_samples if s.get('category') in categories]
    
    # Filter out prompts mentioning images (text-substring only)
    filtered_samples = [s for s in all_samples if 'image' not in s.get('input', '').lower()]
    removed = len(all_samples) - len(filtered_samples)
    if removed > 0:
        print(f"Filtered out {removed} samples containing 'image'.")
    all_samples = filtered_samples
    
    # Select random samples
    if len(all_samples) > num_samples:
        samples = random.sample(all_samples, num_samples)
    else:
        samples = all_samples
        print(f"Warning: Only found {len(samples)} samples after filtering.")
    
    print(f"Running benchmark on {len(samples)} samples...")
    
    # Run benchmarks
    results = []
    for i, sample in enumerate(samples):
        print(f"\nProcessing sample {i+1}/{len(samples)}")
        category = sample.get('category', 'Unknown')
        prompt = sample.get('question', '')
        print(f"Category: {category}")
        print(f"Question: {prompt[:100]}...")
        
        # Send query to agent
        try:
            start_time = time.time()
            response, history = client.predict(
                message={"text": prompt, "files": []},
                api_name="/chat"
            )
            end_time = time.time()

            target_answer_phrase = sample.get('answer', '').strip()

            agent_final_response_content = get_last_assistant_content(history)

            is_correct = False

            # Only attempt the check if both the target phrase and the agent content are non-empty
            if target_answer_phrase and agent_final_response_content:
                # Perform the simple case-insensitive substring check
                if target_answer_phrase.lower() in agent_final_response_content.lower():
                    is_correct = True
            
            # Record result
            result = {
                "sample_id": sample.get('id', f'sample_{i}'),
                "category": category,
                "input": prompt,
                "target_output": sample.get('answer', ''),
                "agent_full_response": history,
                "agent_final_response": agent_final_response_content,
                "response_time": end_time - start_time,
                "is_correct": is_correct
            }
            
            results.append(result)
            
            # Write to file immediately to preserve progress
            with open(results_file, 'a') as f:
                f.write(json.dumps(result) + '\n')
            
            print(f"Response received in {end_time - start_time:.2f} seconds")
            print(f"Response: {response[:100]}...")
            
            # Add a delay to avoid overwhelming the server
            time.sleep(1)
            
        except Exception as e:
            print(f"Error processing sample: {e}")
            continue
    
    # Print summary statistics
    print("\n===== HLE BENCHMARK SUMMARY =====")
    print(f"Samples processed: {len(results)}")
    
    # Categorize by categories
    by_category = {}
    for result in results:
        category = result.get('category', 'Unknown')
        by_category.setdefault(category, []).append(result)
    
    print("\nSamples by category:")
    for category, items in by_category.items():
        print(f"  {category}: {len(items)} samples")
    
    avg_time = sum(r.get('response_time', 0) for r in results) / len(results) if results else 0
    print(f"\nAverage response time: {avg_time:.2f} seconds")
    print(f"Results saved to: {results_file}")
    
    return results

if __name__ == "__main__":
    benchmark_hle(
        num_samples=1,
        categories=None
    )