hashiruAI / bench /benchmarking_connections.py
helloparthshah's picture
QOL updates and refactoring. Also fixed the tool/agent budgeting
6900003
raw
history blame
3.47 kB
from gradio_client import Client
from datasets import load_dataset
import json
import time
import random
import os
from datetime import datetime
import re
def get_last_assistant_content(resp):
"""
Return the last assistant utterance from the response object
produced by `client.predict`.
"""
# ❶ If the server wraps things in a (messages, meta) tuple
if isinstance(resp, tuple):
resp = resp[0]
# ❷ At this point `resp` must be the list of message dicts
if not isinstance(resp, list):
return ""
for turn in reversed(resp):
if turn.get("role") != "assistant":
continue
# a) plain messages
if turn.get("content"):
return turn["content"]
# b) tool / function_response wrapper
fr = turn.get("function_response", {})
out = fr.get("result", {}).get("output")
if out:
return out
# c) messages stored as Part objects inside `content`
cont = turn.get("content")
if isinstance(cont, dict):
parts = cont.get("parts", [])
if parts and parts[0].get("text"):
return parts[0]["text"]
return ""
def benchmark_nyt_connections(num_samples=20, categories=None):
"""
Benchmark agent performance on NYT connections dataset
Args:
num_samples: Number of samples to test
categories: List of categories to include (None for all)
"""
# Load NYT connections dataset
print("Loading NYT connections dataset...")
dataset = load_dataset("tm21cy/NYT-Connections")
# Initialize client
client = Client("http://127.0.0.1:7860/")
# Prepare output directory
output_dir = "results"
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = os.path.join(output_dir, f"nyt_connections_benchmark_{timestamp}.jsonl")
print(f"Results will be saved to {out_path}")
results = []
num_samples = min(num_samples, len(dataset["train"])) if num_samples else len(dataset["train"])
print(f"Sampling {num_samples} samples from the dataset.")
indices = random.sample(range(len(dataset["train"])), num_samples)
for i in indices:
sample = dataset["train"][i]
if categories and sample["category"] not in categories:
continue
print(f"Sample {i}: {sample['contest']}")
prompt = f"Given the following words, group them into 4 categories of 4 words each:\n{' '.join(sample['words'])}\n\n Once you've solved it, final output should be in the following format Group 1: word1, word2, word3, word4\nGroup 2: ..."
start_time = time.time()
response = client.predict(messages=[{"role": "user", "content": prompt}], api_name="/run")
end_time = time.time()
elapsed_time = end_time - start_time
assistant_content = get_last_assistant_content(response)
results.append({
"input": sample["words"],
"date": sample["contest"],
"output": assistant_content,
"expected": sample["answers"],
"elapsed_time": elapsed_time,
})
# Save intermediate results
with open(out_path, "a") as f:
for result in results:
f.write(json.dumps(result) + "\n")
print(f"Results saved to {out_path}")
return results
if __name__ == "__main__":
benchmark_nyt_connections(num_samples=1)