File size: 3,469 Bytes
0577af4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6900003
0577af4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from gradio_client import Client
from datasets import load_dataset
import json
import time
import random
import os
from datetime import datetime
import re

def get_last_assistant_content(resp):
    """
    Return the last assistant utterance from the response object
    produced by `client.predict`.
    """
    # ❶ If the server wraps things in a (messages, meta) tuple
    if isinstance(resp, tuple):
        resp = resp[0]

    # ❷ At this point `resp` must be the list of message dicts
    if not isinstance(resp, list):
        return ""

    for turn in reversed(resp):
        if turn.get("role") != "assistant":
            continue

        # a) plain messages
        if turn.get("content"):
            return turn["content"]

        # b) tool / function_response wrapper
        fr = turn.get("function_response", {})
        out = fr.get("result", {}).get("output")
        if out:
            return out

        # c) messages stored as Part objects inside `content`
        cont = turn.get("content")
        if isinstance(cont, dict):
            parts = cont.get("parts", [])
            if parts and parts[0].get("text"):
                return parts[0]["text"]

    return ""

def benchmark_nyt_connections(num_samples=20, categories=None):
    """
    Benchmark agent performance on NYT connections dataset
    Args:
        num_samples: Number of samples to test
        categories: List of categories to include (None for all)
    """
    # Load NYT connections dataset
    print("Loading NYT connections dataset...")
    dataset = load_dataset("tm21cy/NYT-Connections")
    
    # Initialize client
    client = Client("http://127.0.0.1:7860/")

    # Prepare output directory
    output_dir = "results"
    os.makedirs(output_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_path = os.path.join(output_dir, f"nyt_connections_benchmark_{timestamp}.jsonl")
    print(f"Results will be saved to {out_path}")
    results = []
    num_samples = min(num_samples, len(dataset["train"])) if num_samples else len(dataset["train"])
    print(f"Sampling {num_samples} samples from the dataset.")
    indices = random.sample(range(len(dataset["train"])), num_samples)
    for i in indices:
        sample = dataset["train"][i]
        if categories and sample["category"] not in categories:
            continue
        print(f"Sample {i}: {sample['contest']}")
        prompt = f"Given the following words, group them into 4 categories of 4 words each:\n{' '.join(sample['words'])}\n\n Once you've solved it, final output should be in the following format Group 1: word1, word2, word3, word4\nGroup 2: ..."
        start_time = time.time()
        response = client.predict(messages=[{"role": "user", "content": prompt}], api_name="/run")
        end_time = time.time()
        elapsed_time = end_time - start_time
        assistant_content = get_last_assistant_content(response)
        results.append({
            "input": sample["words"],
            "date": sample["contest"],
            "output": assistant_content,
            "expected": sample["answers"],
            "elapsed_time": elapsed_time,
        })

        # Save intermediate results
        with open(out_path, "a") as f:
            for result in results:
                f.write(json.dumps(result) + "\n")
    print(f"Results saved to {out_path}")
    return results

if __name__ == "__main__":
    benchmark_nyt_connections(num_samples=1)