Kunal Pai
commited on
Commit
·
4f96523
1
Parent(s):
b27e104
Add paper benchmarking, along with dataset for it
Browse files
bench/benchmarking_paper_reviews.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gradio_client import Client
|
2 |
+
import pandas as pd
|
3 |
+
import json
|
4 |
+
import time
|
5 |
+
import os
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
def get_last_assistant_content(resp):
|
9 |
+
"""
|
10 |
+
Return the last assistant utterance from the response object
|
11 |
+
produced by `client.predict`.
|
12 |
+
"""
|
13 |
+
if isinstance(resp, tuple):
|
14 |
+
resp = resp[0]
|
15 |
+
if not isinstance(resp, list):
|
16 |
+
return ""
|
17 |
+
for turn in reversed(resp):
|
18 |
+
if turn.get("role") != "assistant":
|
19 |
+
continue
|
20 |
+
if turn.get("content"):
|
21 |
+
return turn["content"]
|
22 |
+
fr = turn.get("function_response", {})
|
23 |
+
out = fr.get("result", {}).get("output")
|
24 |
+
if out:
|
25 |
+
return out
|
26 |
+
cont = turn.get("content")
|
27 |
+
if isinstance(cont, dict):
|
28 |
+
parts = cont.get("parts", [])
|
29 |
+
if parts and parts[0].get("text"):
|
30 |
+
return parts[0]["text"]
|
31 |
+
return ""
|
32 |
+
|
33 |
+
def benchmark_paper_reviews(
|
34 |
+
csv_path,
|
35 |
+
id_col="ID",
|
36 |
+
text_col="concatenated_text",
|
37 |
+
num_samples=None,
|
38 |
+
output_dir="results"
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
Benchmark agent performance on paper reviews.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
csv_path: path to the pipe‑separated CSV of papers + existing reviews
|
45 |
+
id_col: name of the column containing unique paper IDs
|
46 |
+
text_col: name of the column containing the full paper text
|
47 |
+
num_samples: if set, randomly sample this many papers
|
48 |
+
output_dir: where to write the JSONL results
|
49 |
+
"""
|
50 |
+
# load CSV
|
51 |
+
df = pd.read_csv(csv_path, sep="|")
|
52 |
+
if num_samples:
|
53 |
+
df = df.sample(num_samples, random_state=42).reset_index(drop=True)
|
54 |
+
|
55 |
+
# prepare output
|
56 |
+
os.makedirs(output_dir, exist_ok=True)
|
57 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
58 |
+
out_path = os.path.join(output_dir, f"paper_review_benchmark_{timestamp}.jsonl")
|
59 |
+
|
60 |
+
# init client
|
61 |
+
client = Client("http://127.0.0.1:7860/")
|
62 |
+
|
63 |
+
results = []
|
64 |
+
for idx, row in df.iterrows():
|
65 |
+
paper_id = row[id_col]
|
66 |
+
title = row["Title"]
|
67 |
+
prompt = "Create THREE agents with different personalities, expertise, and review styles. " \
|
68 |
+
"Each agent should provide a review of the paper, and recommend Accept/Reject for ICLR 2023. " \
|
69 |
+
"The review should be detailed and include strengths and weaknesses. " \
|
70 |
+
"You can use ArxivTool and WikipediaTool to get more information. " \
|
71 |
+
"The paper title is: " + title + "\n\n" + row[text_col]
|
72 |
+
print(f"[{idx+1}/{len(df)}] Paper ID: {paper_id}")
|
73 |
+
|
74 |
+
try:
|
75 |
+
start = time.time()
|
76 |
+
resp = client.predict(
|
77 |
+
messages=[{"role":"user","content": prompt}],
|
78 |
+
api_name="/run"
|
79 |
+
)
|
80 |
+
elapsed = time.time() - start
|
81 |
+
|
82 |
+
result = {
|
83 |
+
"paper_id": paper_id,
|
84 |
+
"prompt_snippet": prompt[:200],
|
85 |
+
"agent_review": resp,
|
86 |
+
"ground_truth": row["Decision"],
|
87 |
+
"response_time": elapsed
|
88 |
+
}
|
89 |
+
|
90 |
+
# write immediately
|
91 |
+
with open(out_path, "a") as f:
|
92 |
+
f.write(json.dumps(result) + "\n")
|
93 |
+
|
94 |
+
print(f" → {elapsed:.2f}s, review length {len(resp)} chars")
|
95 |
+
results.append(result)
|
96 |
+
|
97 |
+
# small delay
|
98 |
+
time.sleep(1)
|
99 |
+
except Exception as e:
|
100 |
+
print(f" Error on {paper_id}: {e}")
|
101 |
+
|
102 |
+
print(f"\nDone. Results written to {out_path}")
|
103 |
+
return results
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
# example usage: adjust path & sample count as needed
|
107 |
+
benchmark_paper_reviews(
|
108 |
+
csv_path="ICLR_2023.csv",
|
109 |
+
num_samples=1
|
110 |
+
)
|