Shiyu Zhao
commited on
Commit
·
4c504d3
1
Parent(s):
2c4b436
Update space
Browse files
app.py
CHANGED
@@ -1,104 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
|
|
3 |
import os
|
4 |
import re
|
5 |
from datetime import datetime
|
6 |
import json
|
7 |
-
import pandas as pd
|
8 |
-
import torch
|
9 |
-
import numpy as np
|
10 |
-
from tqdm import tqdm
|
11 |
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
12 |
-
|
13 |
-
from stark_qa import load_qa
|
14 |
-
from stark_qa.evaluator import Evaluator
|
15 |
-
|
16 |
-
|
17 |
-
def process_single_instance(args):
|
18 |
-
idx, eval_csv, qa_dataset, evaluator, eval_metrics = args
|
19 |
-
query, query_id, answer_ids, meta_info = qa_dataset[idx]
|
20 |
-
|
21 |
-
try:
|
22 |
-
pred_rank = eval_csv[eval_csv['query_id'] == query_id]['pred_rank'].item()
|
23 |
-
except IndexError:
|
24 |
-
raise IndexError(f'Error when processing query_id={query_id}, please make sure the predicted results exist for this query.')
|
25 |
-
except Exception as e:
|
26 |
-
raise RuntimeError(f'Unexpected error occurred while fetching prediction rank for query_id={query_id}: {e}')
|
27 |
-
|
28 |
-
if isinstance(pred_rank, str):
|
29 |
-
try:
|
30 |
-
pred_rank = eval(pred_rank)
|
31 |
-
except SyntaxError as e:
|
32 |
-
raise ValueError(f'Failed to parse pred_rank as a list for query_id={query_id}: {e}')
|
33 |
-
|
34 |
-
if not isinstance(pred_rank, list):
|
35 |
-
raise TypeError(f'Error when processing query_id={query_id}, expected pred_rank to be a list but got {type(pred_rank)}.')
|
36 |
-
|
37 |
-
pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
|
38 |
-
answer_ids = torch.LongTensor(answer_ids)
|
39 |
-
result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
|
40 |
-
|
41 |
-
result["idx"], result["query_id"] = idx, query_id
|
42 |
-
return result
|
43 |
-
|
44 |
-
|
45 |
-
def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
|
46 |
-
candidate_ids_dict = {
|
47 |
-
'amazon': [i for i in range(957192)],
|
48 |
-
'mag': [i for i in range(1172724, 1872968)],
|
49 |
-
'prime': [i for i in range(129375)]
|
50 |
-
}
|
51 |
-
try:
|
52 |
-
eval_csv = pd.read_csv(csv_path)
|
53 |
-
if 'query_id' not in eval_csv.columns:
|
54 |
-
raise ValueError('No `query_id` column found in the submitted csv.')
|
55 |
-
if 'pred_rank' not in eval_csv.columns:
|
56 |
-
raise ValueError('No `pred_rank` column found in the submitted csv.')
|
57 |
-
|
58 |
-
eval_csv = eval_csv[['query_id', 'pred_rank']]
|
59 |
-
|
60 |
-
if dataset not in candidate_ids_dict:
|
61 |
-
raise ValueError(f"Invalid dataset '{dataset}', expected one of {list(candidate_ids_dict.keys())}.")
|
62 |
-
if split not in ['test', 'test-0.1', 'human_generated_eval']:
|
63 |
-
raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
|
64 |
|
65 |
-
|
66 |
-
eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
|
67 |
-
qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
|
68 |
-
split_idx = qa_dataset.get_idx_split()
|
69 |
-
all_indices = split_idx[split].tolist()
|
70 |
-
|
71 |
-
results_list = []
|
72 |
-
query_ids = []
|
73 |
-
|
74 |
-
# Prepare args for each worker
|
75 |
-
args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics) for idx in all_indices]
|
76 |
-
|
77 |
-
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
78 |
-
futures = [executor.submit(process_single_instance, arg) for arg in args]
|
79 |
-
for future in tqdm(as_completed(futures), total=len(futures)):
|
80 |
-
result = future.result() # This will raise an error if the worker encountered one
|
81 |
-
results_list.append(result)
|
82 |
-
query_ids.append(result['query_id'])
|
83 |
-
|
84 |
-
# Concatenate results and compute final metrics
|
85 |
-
eval_csv = pd.concat([eval_csv, pd.DataFrame(results_list)], ignore_index=True)
|
86 |
-
final_results = {
|
87 |
-
metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
|
88 |
-
}
|
89 |
-
return final_results
|
90 |
-
|
91 |
-
except pd.errors.EmptyDataError:
|
92 |
-
return "Error: The CSV file is empty or could not be read. Please check the file and try again."
|
93 |
-
except FileNotFoundError:
|
94 |
-
return f"Error: The file {csv_path} could not be found. Please check the file path and try again."
|
95 |
-
except Exception as error:
|
96 |
-
return f"{error}"
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
# Sample data based on your table (you'll need to update this with the full dataset)
|
102 |
data_synthesized_full = {
|
103 |
'Method': ['BM25', 'DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)', 'ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b', 'multi-ada-002', 'ColBERTv2'],
|
104 |
'STARK-AMAZON_Hit@1': [44.94, 15.29, 30.96, 26.56, 39.16, 40.93, 21.74, 42.08, 40.07, 46.10],
|
@@ -147,9 +55,21 @@ data_human_generated = {
|
|
147 |
'STARK-PRIME_MRR': [30.37, 7.05, 10.07, 9.39, 26.35, 24.33, 15.24, 34.28, 32.98, 19.67, 36.32, 34.82]
|
148 |
}
|
149 |
|
|
|
150 |
df_synthesized_full = pd.DataFrame(data_synthesized_full)
|
151 |
df_synthesized_10 = pd.DataFrame(data_synthesized_10)
|
152 |
df_human_generated = pd.DataFrame(data_human_generated)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
def validate_email(email_str):
|
154 |
"""Validate email format(s)"""
|
155 |
emails = [e.strip() for e in email_str.split(';')]
|
@@ -169,11 +89,9 @@ def validate_csv(file_obj):
|
|
169 |
df = pd.read_csv(file_obj.name)
|
170 |
required_cols = ['query_id', 'pred_rank']
|
171 |
|
172 |
-
# Check columns
|
173 |
if not all(col in df.columns for col in required_cols):
|
174 |
return False, "CSV must contain 'query_id' and 'pred_rank' columns"
|
175 |
|
176 |
-
# Check pred_rank format and length
|
177 |
try:
|
178 |
first_rank = eval(df['pred_rank'].iloc[0]) if isinstance(df['pred_rank'].iloc[0], str) else df['pred_rank'].iloc[0]
|
179 |
if not isinstance(first_rank, list) or len(first_rank) < 20:
|
@@ -190,16 +108,39 @@ def save_submission(submission_data):
|
|
190 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
191 |
submission_id = f"{submission_data['team_name']}_{timestamp}"
|
192 |
|
193 |
-
# Create submissions directory if it doesn't exist
|
194 |
os.makedirs("submissions", exist_ok=True)
|
195 |
-
|
196 |
-
# Save submission data
|
197 |
submission_path = f"submissions/{submission_id}.json"
|
198 |
with open(submission_path, 'w') as f:
|
199 |
json.dump(submission_data, f, indent=4)
|
200 |
|
201 |
return submission_id
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
def process_submission(
|
204 |
method_name, team_name, dataset, split, contact_email,
|
205 |
code_repo, csv_file, model_description, hardware, paper_link
|
@@ -270,6 +211,21 @@ def process_submission(
|
|
270 |
except Exception as e:
|
271 |
return f"Error processing submission: {str(e)}"
|
272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
def add_submission_form(demo):
|
274 |
with demo:
|
275 |
gr.Markdown("---")
|
@@ -338,102 +294,47 @@ def add_submission_form(demo):
|
|
338 |
],
|
339 |
outputs=result
|
340 |
)
|
341 |
-
def format_dataframe(df, dataset):
|
342 |
-
# Filter the dataframe for the selected dataset
|
343 |
-
columns = ['Method'] + [col for col in df.columns if dataset in col]
|
344 |
-
filtered_df = df[columns].copy()
|
345 |
-
|
346 |
-
# Rename columns
|
347 |
-
filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
|
348 |
-
|
349 |
-
# Sort by MRR
|
350 |
-
filtered_df = filtered_df.sort_values('MRR', ascending=False)
|
351 |
-
|
352 |
-
return filtered_df
|
353 |
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
return outputs
|
386 |
-
|
387 |
-
css = """
|
388 |
-
table > thead {
|
389 |
-
white-space: normal
|
390 |
-
}
|
391 |
-
|
392 |
-
table {
|
393 |
-
--cell-width-1: 250px
|
394 |
-
}
|
395 |
-
|
396 |
-
table > tbody > tr > td:nth-child(2) > div {
|
397 |
-
overflow-x: auto
|
398 |
-
}
|
399 |
-
"""
|
400 |
-
|
401 |
-
with gr.Blocks(css=css) as demo:
|
402 |
-
gr.Markdown("# Semi-structured Retrieval Benchmark (STaRK) Leaderboard")
|
403 |
-
gr.Markdown("Refer to the [STaRK paper](https://arxiv.org/pdf/2404.13207) for details on metrics, tasks and models.")
|
404 |
-
|
405 |
-
with gr.Row():
|
406 |
-
model_type_filter = gr.CheckboxGroup(
|
407 |
-
choices=list(model_types.keys()),
|
408 |
-
value=list(model_types.keys()),
|
409 |
-
label="Model types",
|
410 |
-
interactive=True
|
411 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
412 |
|
413 |
-
|
414 |
-
|
415 |
-
with gr.Tabs() as outer_tabs:
|
416 |
-
for tab_name, df_source in [("Synthesized (full)", df_synthesized_full),
|
417 |
-
("Synthesized (10%)", df_synthesized_10),
|
418 |
-
("Human-Generated", df_human_generated)]:
|
419 |
-
with gr.TabItem(tab_name):
|
420 |
-
with gr.Tabs() as inner_tabs:
|
421 |
-
for dataset in ['AMAZON', 'MAG', 'PRIME']:
|
422 |
-
with gr.TabItem(dataset):
|
423 |
-
df = gr.DataFrame(interactive=False)
|
424 |
-
all_dfs.append(df)
|
425 |
-
|
426 |
-
model_type_filter.change(
|
427 |
-
update_tables,
|
428 |
-
inputs=[model_type_filter],
|
429 |
-
outputs=all_dfs
|
430 |
-
)
|
431 |
-
|
432 |
-
demo.load(
|
433 |
-
update_tables,
|
434 |
-
inputs=[model_type_filter],
|
435 |
-
outputs=all_dfs
|
436 |
-
)
|
437 |
-
add_submission_form(demo)
|
438 |
-
|
439 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
import os
|
5 |
import re
|
6 |
from datetime import datetime
|
7 |
import json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
# Data dictionaries for leaderboard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
data_synthesized_full = {
|
11 |
'Method': ['BM25', 'DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)', 'ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b', 'multi-ada-002', 'ColBERTv2'],
|
12 |
'STARK-AMAZON_Hit@1': [44.94, 15.29, 30.96, 26.56, 39.16, 40.93, 21.74, 42.08, 40.07, 46.10],
|
|
|
55 |
'STARK-PRIME_MRR': [30.37, 7.05, 10.07, 9.39, 26.35, 24.33, 15.24, 34.28, 32.98, 19.67, 36.32, 34.82]
|
56 |
}
|
57 |
|
58 |
+
# Initialize DataFrames
|
59 |
df_synthesized_full = pd.DataFrame(data_synthesized_full)
|
60 |
df_synthesized_10 = pd.DataFrame(data_synthesized_10)
|
61 |
df_human_generated = pd.DataFrame(data_human_generated)
|
62 |
+
|
63 |
+
# Model type definitions
|
64 |
+
model_types = {
|
65 |
+
'Sparse Retriever': ['BM25'],
|
66 |
+
'Small Dense Retrievers': ['DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)'],
|
67 |
+
'LLM-based Dense Retrievers': ['ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b'],
|
68 |
+
'Multivector Retrievers': ['multi-ada-002', 'ColBERTv2'],
|
69 |
+
'LLM Rerankers': ['Claude3 Reranker', 'GPT4 Reranker']
|
70 |
+
}
|
71 |
+
|
72 |
+
# Submission form validation functions
|
73 |
def validate_email(email_str):
|
74 |
"""Validate email format(s)"""
|
75 |
emails = [e.strip() for e in email_str.split(';')]
|
|
|
89 |
df = pd.read_csv(file_obj.name)
|
90 |
required_cols = ['query_id', 'pred_rank']
|
91 |
|
|
|
92 |
if not all(col in df.columns for col in required_cols):
|
93 |
return False, "CSV must contain 'query_id' and 'pred_rank' columns"
|
94 |
|
|
|
95 |
try:
|
96 |
first_rank = eval(df['pred_rank'].iloc[0]) if isinstance(df['pred_rank'].iloc[0], str) else df['pred_rank'].iloc[0]
|
97 |
if not isinstance(first_rank, list) or len(first_rank) < 20:
|
|
|
108 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
109 |
submission_id = f"{submission_data['team_name']}_{timestamp}"
|
110 |
|
|
|
111 |
os.makedirs("submissions", exist_ok=True)
|
|
|
|
|
112 |
submission_path = f"submissions/{submission_id}.json"
|
113 |
with open(submission_path, 'w') as f:
|
114 |
json.dump(submission_data, f, indent=4)
|
115 |
|
116 |
return submission_id
|
117 |
|
118 |
+
# Leaderboard functions
|
119 |
+
def filter_by_model_type(df, selected_types):
|
120 |
+
if not selected_types:
|
121 |
+
return df.head(0)
|
122 |
+
selected_models = [model for type in selected_types for model in model_types[type]]
|
123 |
+
return df[df['Method'].isin(selected_models)]
|
124 |
+
|
125 |
+
def format_dataframe(df, dataset):
|
126 |
+
columns = ['Method'] + [col for col in df.columns if dataset in col]
|
127 |
+
filtered_df = df[columns].copy()
|
128 |
+
filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
|
129 |
+
filtered_df = filtered_df.sort_values('MRR', ascending=False)
|
130 |
+
return filtered_df
|
131 |
+
|
132 |
+
def update_tables(selected_types):
|
133 |
+
filtered_df_full = filter_by_model_type(df_synthesized_full, selected_types)
|
134 |
+
filtered_df_10 = filter_by_model_type(df_synthesized_10, selected_types)
|
135 |
+
filtered_df_human = filter_by_model_type(df_human_generated, selected_types)
|
136 |
+
|
137 |
+
outputs = []
|
138 |
+
for df in [filtered_df_full, filtered_df_10, filtered_df_human]:
|
139 |
+
for dataset in ['AMAZON', 'MAG', 'PRIME']:
|
140 |
+
outputs.append(format_dataframe(df, f"STARK-{dataset}"))
|
141 |
+
|
142 |
+
return outputs
|
143 |
+
|
144 |
def process_submission(
|
145 |
method_name, team_name, dataset, split, contact_email,
|
146 |
code_repo, csv_file, model_description, hardware, paper_link
|
|
|
211 |
except Exception as e:
|
212 |
return f"Error processing submission: {str(e)}"
|
213 |
|
214 |
+
# CSS styling
|
215 |
+
css = """
|
216 |
+
table > thead {
|
217 |
+
white-space: normal
|
218 |
+
}
|
219 |
+
|
220 |
+
table {
|
221 |
+
--cell-width-1: 250px
|
222 |
+
}
|
223 |
+
|
224 |
+
table > tbody > tr > td:nth-child(2) > div {
|
225 |
+
overflow-x: auto
|
226 |
+
}
|
227 |
+
"""
|
228 |
+
|
229 |
def add_submission_form(demo):
|
230 |
with demo:
|
231 |
gr.Markdown("---")
|
|
|
294 |
],
|
295 |
outputs=result
|
296 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
+
# Main application
|
299 |
+
if __name__ == "__main__":
|
300 |
+
with gr.Blocks(css=css) as demo:
|
301 |
+
gr.Markdown("# Semi-structured Retrieval Benchmark (STaRK) Leaderboard")
|
302 |
+
gr.Markdown("Refer to the [STaRK paper](https://arxiv.org/pdf/2404.13207) for details on metrics, tasks and models.")
|
303 |
+
|
304 |
+
with gr.Row():
|
305 |
+
model_type_filter = gr.CheckboxGroup(
|
306 |
+
choices=list(model_types.keys()),
|
307 |
+
value=list(model_types.keys()),
|
308 |
+
label="Model types",
|
309 |
+
interactive=True
|
310 |
+
)
|
311 |
+
|
312 |
+
all_dfs = []
|
313 |
+
|
314 |
+
with gr.Tabs() as outer_tabs:
|
315 |
+
for tab_name, df_source in [("Synthesized (full)", df_synthesized_full),
|
316 |
+
("Synthesized (10%)", df_synthesized_10),
|
317 |
+
("Human-Generated", df_human_generated)]:
|
318 |
+
with gr.TabItem(tab_name):
|
319 |
+
with gr.Tabs() as inner_tabs:
|
320 |
+
for dataset in ['AMAZON', 'MAG', 'PRIME']:
|
321 |
+
with gr.TabItem(dataset):
|
322 |
+
df = gr.DataFrame(interactive=False)
|
323 |
+
all_dfs.append(df)
|
324 |
+
|
325 |
+
model_type_filter.change(
|
326 |
+
update_tables,
|
327 |
+
inputs=[model_type_filter],
|
328 |
+
outputs=all_dfs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
)
|
330 |
+
|
331 |
+
demo.load(
|
332 |
+
update_tables,
|
333 |
+
inputs=[model_type_filter],
|
334 |
+
outputs=all_dfs
|
335 |
+
)
|
336 |
+
|
337 |
+
# Add submission form
|
338 |
+
add_submission_form(demo)
|
339 |
|
340 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|