Kunal Pai commited on
Commit
4f96523
·
1 Parent(s): b27e104

Add paper benchmarking, along with dataset for it

Browse files
Files changed (1) hide show
  1. bench/benchmarking_paper_reviews.py +110 -0
bench/benchmarking_paper_reviews.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+ import pandas as pd
3
+ import json
4
+ import time
5
+ import os
6
+ from datetime import datetime
7
+
8
+ def get_last_assistant_content(resp):
9
+ """
10
+ Return the last assistant utterance from the response object
11
+ produced by `client.predict`.
12
+ """
13
+ if isinstance(resp, tuple):
14
+ resp = resp[0]
15
+ if not isinstance(resp, list):
16
+ return ""
17
+ for turn in reversed(resp):
18
+ if turn.get("role") != "assistant":
19
+ continue
20
+ if turn.get("content"):
21
+ return turn["content"]
22
+ fr = turn.get("function_response", {})
23
+ out = fr.get("result", {}).get("output")
24
+ if out:
25
+ return out
26
+ cont = turn.get("content")
27
+ if isinstance(cont, dict):
28
+ parts = cont.get("parts", [])
29
+ if parts and parts[0].get("text"):
30
+ return parts[0]["text"]
31
+ return ""
32
+
33
+ def benchmark_paper_reviews(
34
+ csv_path,
35
+ id_col="ID",
36
+ text_col="concatenated_text",
37
+ num_samples=None,
38
+ output_dir="results"
39
+ ):
40
+ """
41
+ Benchmark agent performance on paper reviews.
42
+
43
+ Args:
44
+ csv_path: path to the pipe‑separated CSV of papers + existing reviews
45
+ id_col: name of the column containing unique paper IDs
46
+ text_col: name of the column containing the full paper text
47
+ num_samples: if set, randomly sample this many papers
48
+ output_dir: where to write the JSONL results
49
+ """
50
+ # load CSV
51
+ df = pd.read_csv(csv_path, sep="|")
52
+ if num_samples:
53
+ df = df.sample(num_samples, random_state=42).reset_index(drop=True)
54
+
55
+ # prepare output
56
+ os.makedirs(output_dir, exist_ok=True)
57
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
58
+ out_path = os.path.join(output_dir, f"paper_review_benchmark_{timestamp}.jsonl")
59
+
60
+ # init client
61
+ client = Client("http://127.0.0.1:7860/")
62
+
63
+ results = []
64
+ for idx, row in df.iterrows():
65
+ paper_id = row[id_col]
66
+ title = row["Title"]
67
+ prompt = "Create THREE agents with different personalities, expertise, and review styles. " \
68
+ "Each agent should provide a review of the paper, and recommend Accept/Reject for ICLR 2023. " \
69
+ "The review should be detailed and include strengths and weaknesses. " \
70
+ "You can use ArxivTool and WikipediaTool to get more information. " \
71
+ "The paper title is: " + title + "\n\n" + row[text_col]
72
+ print(f"[{idx+1}/{len(df)}] Paper ID: {paper_id}")
73
+
74
+ try:
75
+ start = time.time()
76
+ resp = client.predict(
77
+ messages=[{"role":"user","content": prompt}],
78
+ api_name="/run"
79
+ )
80
+ elapsed = time.time() - start
81
+
82
+ result = {
83
+ "paper_id": paper_id,
84
+ "prompt_snippet": prompt[:200],
85
+ "agent_review": resp,
86
+ "ground_truth": row["Decision"],
87
+ "response_time": elapsed
88
+ }
89
+
90
+ # write immediately
91
+ with open(out_path, "a") as f:
92
+ f.write(json.dumps(result) + "\n")
93
+
94
+ print(f" → {elapsed:.2f}s, review length {len(resp)} chars")
95
+ results.append(result)
96
+
97
+ # small delay
98
+ time.sleep(1)
99
+ except Exception as e:
100
+ print(f" Error on {paper_id}: {e}")
101
+
102
+ print(f"\nDone. Results written to {out_path}")
103
+ return results
104
+
105
+ if __name__ == "__main__":
106
+ # example usage: adjust path & sample count as needed
107
+ benchmark_paper_reviews(
108
+ csv_path="ICLR_2023.csv",
109
+ num_samples=1
110
+ )