Spaces:
Running
Running
Kunal Pai
Refactor get_last_assistant_content function to improve response handling and support various response formats
81fafc1
| from gradio_client import Client | |
| from datasets import load_dataset | |
| import json | |
| import time | |
| import random | |
| import os | |
| from datetime import datetime | |
| import re | |
| def get_last_assistant_content(resp): | |
| """ | |
| Return the last assistant utterance from the response object | |
| produced by `client.predict`. | |
| """ | |
| # ❶ If the server wraps things in a (messages, meta) tuple | |
| if isinstance(resp, tuple): | |
| resp = resp[0] | |
| # ❷ At this point `resp` must be the list of message dicts | |
| if not isinstance(resp, list): | |
| return "" | |
| for turn in reversed(resp): | |
| if turn.get("role") != "assistant": | |
| continue | |
| # a) plain messages | |
| if turn.get("content"): | |
| return turn["content"] | |
| # b) tool / function_response wrapper | |
| fr = turn.get("function_response", {}) | |
| out = fr.get("result", {}).get("output") | |
| if out: | |
| return out | |
| # c) messages stored as Part objects inside `content` | |
| cont = turn.get("content") | |
| if isinstance(cont, dict): | |
| parts = cont.get("parts", []) | |
| if parts and parts[0].get("text"): | |
| return parts[0]["text"] | |
| return "" | |
| def benchmark_hle(num_samples=20, categories=None): | |
| """ | |
| Benchmark agent performance on HLE dataset | |
| Args: | |
| num_samples: Number of samples to test | |
| categories: List of categories to include (None for all) | |
| """ | |
| # Load HLE dataset | |
| print("Loading HLE dataset...") | |
| dataset = load_dataset("cais/hle") | |
| # Initialize client | |
| client = Client("http://127.0.0.1:7860/") | |
| # Create results directory if it doesn't exist | |
| os.makedirs("results", exist_ok=True) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| results_file = f"results/hle_benchmark_{timestamp}.jsonl" | |
| # Select samples | |
| all_samples = [] | |
| for split in ['validation', 'test']: # Using validation and test splits | |
| if split in dataset: | |
| all_samples.extend(dataset[split]) | |
| # Filter by category if specified | |
| if categories: | |
| all_samples = [s for s in all_samples if s.get('category') in categories] | |
| # Filter out prompts mentioning images (text-substring only) | |
| filtered_samples = [s for s in all_samples if 'image' not in s.get('input', '').lower()] | |
| removed = len(all_samples) - len(filtered_samples) | |
| if removed > 0: | |
| print(f"Filtered out {removed} samples containing 'image'.") | |
| all_samples = filtered_samples | |
| # Select random samples | |
| if len(all_samples) > num_samples: | |
| samples = random.sample(all_samples, num_samples) | |
| else: | |
| samples = all_samples | |
| print(f"Warning: Only found {len(samples)} samples after filtering.") | |
| print(f"Running benchmark on {len(samples)} samples...") | |
| # Run benchmarks | |
| results = [] | |
| for i, sample in enumerate(samples): | |
| print(f"\nProcessing sample {i+1}/{len(samples)}") | |
| category = sample.get('category', 'Unknown') | |
| prompt = sample.get('question', '') | |
| print(f"Category: {category}") | |
| print(f"Question: {prompt[:100]}...") | |
| # Send query to agent | |
| try: | |
| start_time = time.time() | |
| response = client.predict( | |
| messages=[{"role": "user", "content": prompt}], | |
| api_name="/run" | |
| ) | |
| end_time = time.time() | |
| target_answer_phrase = sample.get('answer', '').strip() | |
| agent_final_response_content = get_last_assistant_content(response) | |
| is_correct = False | |
| # Only attempt the check if both the target phrase and the agent content are non-empty | |
| if target_answer_phrase and agent_final_response_content: | |
| # Perform the simple case-insensitive substring check | |
| if target_answer_phrase.lower() in agent_final_response_content.lower(): | |
| is_correct = True | |
| # Record result | |
| result = { | |
| "sample_id": sample.get('id', f'sample_{i}'), | |
| "category": category, | |
| "input": prompt, | |
| "target_output": sample.get('answer', ''), | |
| "agent_full_response": response, | |
| "agent_final_response": agent_final_response_content, | |
| "response_time": end_time - start_time, | |
| "is_correct": is_correct | |
| } | |
| results.append(result) | |
| # Write to file immediately to preserve progress | |
| with open(results_file, 'a') as f: | |
| f.write(json.dumps(result) + '\n') | |
| print(f"Response received in {end_time - start_time:.2f} seconds") | |
| print(f"Response: {response[:100]}...") | |
| # Add a delay to avoid overwhelming the server | |
| time.sleep(1) | |
| except Exception as e: | |
| print(f"Error processing sample: {e}") | |
| continue | |
| # Print summary statistics | |
| print("\n===== HLE BENCHMARK SUMMARY =====") | |
| print(f"Samples processed: {len(results)}") | |
| # Categorize by categories | |
| by_category = {} | |
| for result in results: | |
| category = result.get('category', 'Unknown') | |
| by_category.setdefault(category, []).append(result) | |
| print("\nSamples by category:") | |
| for category, items in by_category.items(): | |
| print(f" {category}: {len(items)} samples") | |
| avg_time = sum(r.get('response_time', 0) for r in results) / len(results) if results else 0 | |
| print(f"\nAverage response time: {avg_time:.2f} seconds") | |
| print(f"Results saved to: {results_file}") | |
| return results | |
| if __name__ == "__main__": | |
| benchmark_hle( | |
| num_samples=1, | |
| categories=None | |
| ) | |