Shiyu Zhao commited on
Commit
4c504d3
·
1 Parent(s): 2c4b436

Update space

Browse files
Files changed (1) hide show
  1. app.py +96 -195
app.py CHANGED
@@ -1,104 +1,12 @@
1
  import gradio as gr
2
  import pandas as pd
 
3
  import os
4
  import re
5
  from datetime import datetime
6
  import json
7
- import pandas as pd
8
- import torch
9
- import numpy as np
10
- from tqdm import tqdm
11
- from concurrent.futures import ProcessPoolExecutor, as_completed
12
-
13
- from stark_qa import load_qa
14
- from stark_qa.evaluator import Evaluator
15
-
16
-
17
- def process_single_instance(args):
18
- idx, eval_csv, qa_dataset, evaluator, eval_metrics = args
19
- query, query_id, answer_ids, meta_info = qa_dataset[idx]
20
-
21
- try:
22
- pred_rank = eval_csv[eval_csv['query_id'] == query_id]['pred_rank'].item()
23
- except IndexError:
24
- raise IndexError(f'Error when processing query_id={query_id}, please make sure the predicted results exist for this query.')
25
- except Exception as e:
26
- raise RuntimeError(f'Unexpected error occurred while fetching prediction rank for query_id={query_id}: {e}')
27
-
28
- if isinstance(pred_rank, str):
29
- try:
30
- pred_rank = eval(pred_rank)
31
- except SyntaxError as e:
32
- raise ValueError(f'Failed to parse pred_rank as a list for query_id={query_id}: {e}')
33
-
34
- if not isinstance(pred_rank, list):
35
- raise TypeError(f'Error when processing query_id={query_id}, expected pred_rank to be a list but got {type(pred_rank)}.')
36
-
37
- pred_dict = {pred_rank[i]: -i for i in range(min(100, len(pred_rank)))}
38
- answer_ids = torch.LongTensor(answer_ids)
39
- result = evaluator.evaluate(pred_dict, answer_ids, metrics=eval_metrics)
40
-
41
- result["idx"], result["query_id"] = idx, query_id
42
- return result
43
-
44
-
45
- def compute_metrics(csv_path: str, dataset: str, split: str, num_workers: int = 4):
46
- candidate_ids_dict = {
47
- 'amazon': [i for i in range(957192)],
48
- 'mag': [i for i in range(1172724, 1872968)],
49
- 'prime': [i for i in range(129375)]
50
- }
51
- try:
52
- eval_csv = pd.read_csv(csv_path)
53
- if 'query_id' not in eval_csv.columns:
54
- raise ValueError('No `query_id` column found in the submitted csv.')
55
- if 'pred_rank' not in eval_csv.columns:
56
- raise ValueError('No `pred_rank` column found in the submitted csv.')
57
-
58
- eval_csv = eval_csv[['query_id', 'pred_rank']]
59
-
60
- if dataset not in candidate_ids_dict:
61
- raise ValueError(f"Invalid dataset '{dataset}', expected one of {list(candidate_ids_dict.keys())}.")
62
- if split not in ['test', 'test-0.1', 'human_generated_eval']:
63
- raise ValueError(f"Invalid split '{split}', expected one of ['test', 'test-0.1', 'human_generated_eval'].")
64
 
65
- evaluator = Evaluator(candidate_ids_dict[dataset])
66
- eval_metrics = ['hit@1', 'hit@5', 'recall@20', 'mrr']
67
- qa_dataset = load_qa(dataset, human_generated_eval=split == 'human_generated_eval')
68
- split_idx = qa_dataset.get_idx_split()
69
- all_indices = split_idx[split].tolist()
70
-
71
- results_list = []
72
- query_ids = []
73
-
74
- # Prepare args for each worker
75
- args = [(idx, eval_csv, qa_dataset, evaluator, eval_metrics) for idx in all_indices]
76
-
77
- with ProcessPoolExecutor(max_workers=num_workers) as executor:
78
- futures = [executor.submit(process_single_instance, arg) for arg in args]
79
- for future in tqdm(as_completed(futures), total=len(futures)):
80
- result = future.result() # This will raise an error if the worker encountered one
81
- results_list.append(result)
82
- query_ids.append(result['query_id'])
83
-
84
- # Concatenate results and compute final metrics
85
- eval_csv = pd.concat([eval_csv, pd.DataFrame(results_list)], ignore_index=True)
86
- final_results = {
87
- metric: np.mean(eval_csv[eval_csv['query_id'].isin(query_ids)][metric]) for metric in eval_metrics
88
- }
89
- return final_results
90
-
91
- except pd.errors.EmptyDataError:
92
- return "Error: The CSV file is empty or could not be read. Please check the file and try again."
93
- except FileNotFoundError:
94
- return f"Error: The file {csv_path} could not be found. Please check the file path and try again."
95
- except Exception as error:
96
- return f"{error}"
97
-
98
-
99
-
100
-
101
- # Sample data based on your table (you'll need to update this with the full dataset)
102
  data_synthesized_full = {
103
  'Method': ['BM25', 'DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)', 'ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b', 'multi-ada-002', 'ColBERTv2'],
104
  'STARK-AMAZON_Hit@1': [44.94, 15.29, 30.96, 26.56, 39.16, 40.93, 21.74, 42.08, 40.07, 46.10],
@@ -147,9 +55,21 @@ data_human_generated = {
147
  'STARK-PRIME_MRR': [30.37, 7.05, 10.07, 9.39, 26.35, 24.33, 15.24, 34.28, 32.98, 19.67, 36.32, 34.82]
148
  }
149
 
 
150
  df_synthesized_full = pd.DataFrame(data_synthesized_full)
151
  df_synthesized_10 = pd.DataFrame(data_synthesized_10)
152
  df_human_generated = pd.DataFrame(data_human_generated)
 
 
 
 
 
 
 
 
 
 
 
153
  def validate_email(email_str):
154
  """Validate email format(s)"""
155
  emails = [e.strip() for e in email_str.split(';')]
@@ -169,11 +89,9 @@ def validate_csv(file_obj):
169
  df = pd.read_csv(file_obj.name)
170
  required_cols = ['query_id', 'pred_rank']
171
 
172
- # Check columns
173
  if not all(col in df.columns for col in required_cols):
174
  return False, "CSV must contain 'query_id' and 'pred_rank' columns"
175
 
176
- # Check pred_rank format and length
177
  try:
178
  first_rank = eval(df['pred_rank'].iloc[0]) if isinstance(df['pred_rank'].iloc[0], str) else df['pred_rank'].iloc[0]
179
  if not isinstance(first_rank, list) or len(first_rank) < 20:
@@ -190,16 +108,39 @@ def save_submission(submission_data):
190
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
191
  submission_id = f"{submission_data['team_name']}_{timestamp}"
192
 
193
- # Create submissions directory if it doesn't exist
194
  os.makedirs("submissions", exist_ok=True)
195
-
196
- # Save submission data
197
  submission_path = f"submissions/{submission_id}.json"
198
  with open(submission_path, 'w') as f:
199
  json.dump(submission_data, f, indent=4)
200
 
201
  return submission_id
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def process_submission(
204
  method_name, team_name, dataset, split, contact_email,
205
  code_repo, csv_file, model_description, hardware, paper_link
@@ -270,6 +211,21 @@ def process_submission(
270
  except Exception as e:
271
  return f"Error processing submission: {str(e)}"
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  def add_submission_form(demo):
274
  with demo:
275
  gr.Markdown("---")
@@ -338,102 +294,47 @@ def add_submission_form(demo):
338
  ],
339
  outputs=result
340
  )
341
- def format_dataframe(df, dataset):
342
- # Filter the dataframe for the selected dataset
343
- columns = ['Method'] + [col for col in df.columns if dataset in col]
344
- filtered_df = df[columns].copy()
345
-
346
- # Rename columns
347
- filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
348
-
349
- # Sort by MRR
350
- filtered_df = filtered_df.sort_values('MRR', ascending=False)
351
-
352
- return filtered_df
353
 
354
- model_types = {
355
- 'Sparse Retriever': ['BM25'],
356
- 'Small Dense Retrievers': ['DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)'],
357
- 'LLM-based Dense Retrievers': ['ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b'],
358
- 'Multivector Retrievers': ['multi-ada-002', 'ColBERTv2'],
359
- 'LLM Rerankers': ['Claude3 Reranker', 'GPT4 Reranker']
360
- }
361
-
362
- def filter_by_model_type(df, selected_types):
363
- if not selected_types: # If no types are selected, return an empty DataFrame
364
- return df.head(0)
365
- selected_models = [model for type in selected_types for model in model_types[type]]
366
- return df[df['Method'].isin(selected_models)]
367
-
368
- def format_dataframe(df, dataset):
369
- columns = ['Method'] + [col for col in df.columns if dataset in col]
370
- filtered_df = df[columns].copy()
371
- filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
372
- filtered_df = filtered_df.sort_values('MRR', ascending=False)
373
- return filtered_df
374
-
375
- def update_tables(selected_types):
376
- filtered_df_full = filter_by_model_type(df_synthesized_full, selected_types)
377
- filtered_df_10 = filter_by_model_type(df_synthesized_10, selected_types)
378
- filtered_df_human = filter_by_model_type(df_human_generated, selected_types)
379
-
380
- outputs = []
381
- for df in [filtered_df_full, filtered_df_10, filtered_df_human]:
382
- for dataset in ['AMAZON', 'MAG', 'PRIME']:
383
- outputs.append(format_dataframe(df, f"STARK-{dataset}"))
384
-
385
- return outputs
386
-
387
- css = """
388
- table > thead {
389
- white-space: normal
390
- }
391
-
392
- table {
393
- --cell-width-1: 250px
394
- }
395
-
396
- table > tbody > tr > td:nth-child(2) > div {
397
- overflow-x: auto
398
- }
399
- """
400
-
401
- with gr.Blocks(css=css) as demo:
402
- gr.Markdown("# Semi-structured Retrieval Benchmark (STaRK) Leaderboard")
403
- gr.Markdown("Refer to the [STaRK paper](https://arxiv.org/pdf/2404.13207) for details on metrics, tasks and models.")
404
-
405
- with gr.Row():
406
- model_type_filter = gr.CheckboxGroup(
407
- choices=list(model_types.keys()),
408
- value=list(model_types.keys()),
409
- label="Model types",
410
- interactive=True
411
  )
 
 
 
 
 
 
 
 
 
412
 
413
- all_dfs = []
414
-
415
- with gr.Tabs() as outer_tabs:
416
- for tab_name, df_source in [("Synthesized (full)", df_synthesized_full),
417
- ("Synthesized (10%)", df_synthesized_10),
418
- ("Human-Generated", df_human_generated)]:
419
- with gr.TabItem(tab_name):
420
- with gr.Tabs() as inner_tabs:
421
- for dataset in ['AMAZON', 'MAG', 'PRIME']:
422
- with gr.TabItem(dataset):
423
- df = gr.DataFrame(interactive=False)
424
- all_dfs.append(df)
425
-
426
- model_type_filter.change(
427
- update_tables,
428
- inputs=[model_type_filter],
429
- outputs=all_dfs
430
- )
431
-
432
- demo.load(
433
- update_tables,
434
- inputs=[model_type_filter],
435
- outputs=all_dfs
436
- )
437
- add_submission_form(demo)
438
-
439
- demo.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
+ import numpy as np
4
  import os
5
  import re
6
  from datetime import datetime
7
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Data dictionaries for leaderboard
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  data_synthesized_full = {
11
  'Method': ['BM25', 'DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)', 'ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b', 'multi-ada-002', 'ColBERTv2'],
12
  'STARK-AMAZON_Hit@1': [44.94, 15.29, 30.96, 26.56, 39.16, 40.93, 21.74, 42.08, 40.07, 46.10],
 
55
  'STARK-PRIME_MRR': [30.37, 7.05, 10.07, 9.39, 26.35, 24.33, 15.24, 34.28, 32.98, 19.67, 36.32, 34.82]
56
  }
57
 
58
+ # Initialize DataFrames
59
  df_synthesized_full = pd.DataFrame(data_synthesized_full)
60
  df_synthesized_10 = pd.DataFrame(data_synthesized_10)
61
  df_human_generated = pd.DataFrame(data_human_generated)
62
+
63
+ # Model type definitions
64
+ model_types = {
65
+ 'Sparse Retriever': ['BM25'],
66
+ 'Small Dense Retrievers': ['DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)'],
67
+ 'LLM-based Dense Retrievers': ['ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b'],
68
+ 'Multivector Retrievers': ['multi-ada-002', 'ColBERTv2'],
69
+ 'LLM Rerankers': ['Claude3 Reranker', 'GPT4 Reranker']
70
+ }
71
+
72
+ # Submission form validation functions
73
  def validate_email(email_str):
74
  """Validate email format(s)"""
75
  emails = [e.strip() for e in email_str.split(';')]
 
89
  df = pd.read_csv(file_obj.name)
90
  required_cols = ['query_id', 'pred_rank']
91
 
 
92
  if not all(col in df.columns for col in required_cols):
93
  return False, "CSV must contain 'query_id' and 'pred_rank' columns"
94
 
 
95
  try:
96
  first_rank = eval(df['pred_rank'].iloc[0]) if isinstance(df['pred_rank'].iloc[0], str) else df['pred_rank'].iloc[0]
97
  if not isinstance(first_rank, list) or len(first_rank) < 20:
 
108
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
109
  submission_id = f"{submission_data['team_name']}_{timestamp}"
110
 
 
111
  os.makedirs("submissions", exist_ok=True)
 
 
112
  submission_path = f"submissions/{submission_id}.json"
113
  with open(submission_path, 'w') as f:
114
  json.dump(submission_data, f, indent=4)
115
 
116
  return submission_id
117
 
118
+ # Leaderboard functions
119
+ def filter_by_model_type(df, selected_types):
120
+ if not selected_types:
121
+ return df.head(0)
122
+ selected_models = [model for type in selected_types for model in model_types[type]]
123
+ return df[df['Method'].isin(selected_models)]
124
+
125
+ def format_dataframe(df, dataset):
126
+ columns = ['Method'] + [col for col in df.columns if dataset in col]
127
+ filtered_df = df[columns].copy()
128
+ filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
129
+ filtered_df = filtered_df.sort_values('MRR', ascending=False)
130
+ return filtered_df
131
+
132
+ def update_tables(selected_types):
133
+ filtered_df_full = filter_by_model_type(df_synthesized_full, selected_types)
134
+ filtered_df_10 = filter_by_model_type(df_synthesized_10, selected_types)
135
+ filtered_df_human = filter_by_model_type(df_human_generated, selected_types)
136
+
137
+ outputs = []
138
+ for df in [filtered_df_full, filtered_df_10, filtered_df_human]:
139
+ for dataset in ['AMAZON', 'MAG', 'PRIME']:
140
+ outputs.append(format_dataframe(df, f"STARK-{dataset}"))
141
+
142
+ return outputs
143
+
144
  def process_submission(
145
  method_name, team_name, dataset, split, contact_email,
146
  code_repo, csv_file, model_description, hardware, paper_link
 
211
  except Exception as e:
212
  return f"Error processing submission: {str(e)}"
213
 
214
+ # CSS styling
215
+ css = """
216
+ table > thead {
217
+ white-space: normal
218
+ }
219
+
220
+ table {
221
+ --cell-width-1: 250px
222
+ }
223
+
224
+ table > tbody > tr > td:nth-child(2) > div {
225
+ overflow-x: auto
226
+ }
227
+ """
228
+
229
  def add_submission_form(demo):
230
  with demo:
231
  gr.Markdown("---")
 
294
  ],
295
  outputs=result
296
  )
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
+ # Main application
299
+ if __name__ == "__main__":
300
+ with gr.Blocks(css=css) as demo:
301
+ gr.Markdown("# Semi-structured Retrieval Benchmark (STaRK) Leaderboard")
302
+ gr.Markdown("Refer to the [STaRK paper](https://arxiv.org/pdf/2404.13207) for details on metrics, tasks and models.")
303
+
304
+ with gr.Row():
305
+ model_type_filter = gr.CheckboxGroup(
306
+ choices=list(model_types.keys()),
307
+ value=list(model_types.keys()),
308
+ label="Model types",
309
+ interactive=True
310
+ )
311
+
312
+ all_dfs = []
313
+
314
+ with gr.Tabs() as outer_tabs:
315
+ for tab_name, df_source in [("Synthesized (full)", df_synthesized_full),
316
+ ("Synthesized (10%)", df_synthesized_10),
317
+ ("Human-Generated", df_human_generated)]:
318
+ with gr.TabItem(tab_name):
319
+ with gr.Tabs() as inner_tabs:
320
+ for dataset in ['AMAZON', 'MAG', 'PRIME']:
321
+ with gr.TabItem(dataset):
322
+ df = gr.DataFrame(interactive=False)
323
+ all_dfs.append(df)
324
+
325
+ model_type_filter.change(
326
+ update_tables,
327
+ inputs=[model_type_filter],
328
+ outputs=all_dfs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  )
330
+
331
+ demo.load(
332
+ update_tables,
333
+ inputs=[model_type_filter],
334
+ outputs=all_dfs
335
+ )
336
+
337
+ # Add submission form
338
+ add_submission_form(demo)
339
 
340
+ demo.launch()