Kunal Pai commited on
Commit
0577af4
·
1 Parent(s): 164e70c

Add benchmarking functionality for NYT Connections dataset

Browse files
Files changed (1) hide show
  1. bench/benchmarking_connections.py +97 -0
bench/benchmarking_connections.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+ from datasets import load_dataset
3
+ import json
4
+ import time
5
+ import random
6
+ import os
7
+ from datetime import datetime
8
+ import re
9
+
10
+ def get_last_assistant_content(resp):
11
+ """
12
+ Return the last assistant utterance from the response object
13
+ produced by `client.predict`.
14
+ """
15
+ # ❶ If the server wraps things in a (messages, meta) tuple
16
+ if isinstance(resp, tuple):
17
+ resp = resp[0]
18
+
19
+ # ❷ At this point `resp` must be the list of message dicts
20
+ if not isinstance(resp, list):
21
+ return ""
22
+
23
+ for turn in reversed(resp):
24
+ if turn.get("role") != "assistant":
25
+ continue
26
+
27
+ # a) plain messages
28
+ if turn.get("content"):
29
+ return turn["content"]
30
+
31
+ # b) tool / function_response wrapper
32
+ fr = turn.get("function_response", {})
33
+ out = fr.get("result", {}).get("output")
34
+ if out:
35
+ return out
36
+
37
+ # c) messages stored as Part objects inside `content`
38
+ cont = turn.get("content")
39
+ if isinstance(cont, dict):
40
+ parts = cont.get("parts", [])
41
+ if parts and parts[0].get("text"):
42
+ return parts[0]["text"]
43
+
44
+ return ""
45
+
46
+ def benchmark_nyt_connections(num_samples=20, categories=None):
47
+ """
48
+ Benchmark agent performance on NYT connections dataset
49
+ Args:
50
+ num_samples: Number of samples to test
51
+ categories: List of categories to include (None for all)
52
+ """
53
+ # Load NYT connections dataset
54
+ print("Loading NYT connections dataset...")
55
+ dataset = load_dataset("tm21cy/NYT-Connections")
56
+
57
+ # Initialize client
58
+ client = Client("http://127.0.0.1:7860/")
59
+
60
+ # Prepare output directory
61
+ output_dir = "results"
62
+ os.makedirs(output_dir, exist_ok=True)
63
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
64
+ out_path = os.path.join(output_dir, f"nyt_connections_benchmark_{timestamp}.jsonl")
65
+ print(f"Results will be saved to {out_path}")
66
+ results = []
67
+ num_samples = min(num_samples, len(dataset["train"])) if num_samples else len(dataset["train"])
68
+ print(f"Sampling {num_samples} samples from the dataset.")
69
+ indices = random.sample(range(len(dataset["train"])), num_samples)
70
+ for i in indices:
71
+ sample = dataset["train"][i]
72
+ if categories and sample["category"] not in categories:
73
+ continue
74
+ print(f"Sample {i}: {sample['contest']}")
75
+ prompt = f"Given the following words, group them into 4 categories of 4 words each:\n{' '.join(sample['words'])}\n\n"
76
+ start_time = time.time()
77
+ response = client.predict(messages=[{"role": "user", "content": prompt}], api_name="/run")
78
+ end_time = time.time()
79
+ elapsed_time = end_time - start_time
80
+ assistant_content = get_last_assistant_content(response)
81
+ results.append({
82
+ "input": sample["words"],
83
+ "date": sample["contest"],
84
+ "output": assistant_content,
85
+ "expected": sample["answers"],
86
+ "elapsed_time": elapsed_time,
87
+ })
88
+
89
+ # Save intermediate results
90
+ with open(out_path, "a") as f:
91
+ for result in results:
92
+ f.write(json.dumps(result) + "\n")
93
+ print(f"Results saved to {out_path}")
94
+ return results
95
+
96
+ if __name__ == "__main__":
97
+ benchmark_nyt_connections(num_samples=1)