Shiyu Zhao commited on
Commit
630c9ca
·
1 Parent(s): 4d52cf5

Update space

Browse files
Files changed (1) hide show
  1. app.py +143 -97
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
 
 
4
 
5
  # Sample data based on your table (you'll need to update this with the full dataset)
6
  data_synthesized_full = {
@@ -51,105 +53,149 @@ data_human_generated = {
51
  'STARK-PRIME_MRR': [30.37, 7.05, 10.07, 9.39, 26.35, 24.33, 15.24, 34.28, 32.98, 19.67, 36.32, 34.82]
52
  }
53
 
54
- df_synthesized_full = pd.DataFrame(data_synthesized_full)
55
- df_synthesized_10 = pd.DataFrame(data_synthesized_10)
56
- df_human_generated = pd.DataFrame(data_human_generated)
57
-
58
- def format_dataframe(df, dataset):
59
- # Filter the dataframe for the selected dataset
60
- columns = ['Method'] + [col for col in df.columns if dataset in col]
61
- filtered_df = df[columns].copy()
62
-
63
- # Rename columns
64
- filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
65
-
66
- # Sort by MRR
67
- filtered_df = filtered_df.sort_values('MRR', ascending=False)
68
-
69
- return filtered_df
70
-
71
- model_types = {
72
- 'Sparse Retriever': ['BM25'],
73
- 'Small Dense Retrievers': ['DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)'],
74
- 'LLM-based Dense Retrievers': ['ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b'],
75
- 'Multivector Retrievers': ['multi-ada-002', 'ColBERTv2'],
76
- 'LLM Rerankers': ['Claude3 Reranker', 'GPT4 Reranker']
77
- }
78
-
79
- def filter_by_model_type(df, selected_types):
80
- if not selected_types: # If no types are selected, return an empty DataFrame
81
- return df.head(0)
82
- selected_models = [model for type in selected_types for model in model_types[type]]
83
- return df[df['Method'].isin(selected_models)]
84
-
85
- def format_dataframe(df, dataset):
86
- columns = ['Method'] + [col for col in df.columns if dataset in col]
87
- filtered_df = df[columns].copy()
88
- filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
89
- filtered_df = filtered_df.sort_values('MRR', ascending=False)
90
- return filtered_df
91
-
92
- def update_tables(selected_types):
93
- filtered_df_full = filter_by_model_type(df_synthesized_full, selected_types)
94
- filtered_df_10 = filter_by_model_type(df_synthesized_10, selected_types)
95
- filtered_df_human = filter_by_model_type(df_human_generated, selected_types)
96
-
97
- outputs = []
98
- for df in [filtered_df_full, filtered_df_10, filtered_df_human]:
99
- for dataset in ['AMAZON', 'MAG', 'PRIME']:
100
- outputs.append(format_dataframe(df, f"STARK-{dataset}"))
101
-
102
- return outputs
103
-
104
- css = """
105
- table > thead {
106
- white-space: normal
107
- }
108
-
109
- table {
110
- --cell-width-1: 250px
111
- }
112
-
113
- table > tbody > tr > td:nth-child(2) > div {
114
- overflow-x: auto
115
- }
116
- """
117
-
118
- with gr.Blocks(css=css) as demo:
119
- gr.Markdown("# Semi-structured Retrieval Benchmark (STaRK) Leaderboard")
120
- gr.Markdown("Refer to the [STaRK paper](https://arxiv.org/pdf/2404.13207) for details on metrics, tasks and models.")
121
 
122
- with gr.Row():
123
- model_type_filter = gr.CheckboxGroup(
124
- choices=list(model_types.keys()),
125
- value=list(model_types.keys()),
126
- label="Model types",
127
- interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  )
129
-
130
- all_dfs = []
131
-
132
- with gr.Tabs() as outer_tabs:
133
- for tab_name, df_source in [("Synthesized (full)", df_synthesized_full),
134
- ("Synthesized (10%)", df_synthesized_10),
135
- ("Human-Generated", df_human_generated)]:
136
- with gr.TabItem(tab_name):
137
- with gr.Tabs() as inner_tabs:
138
- for dataset in ['AMAZON', 'MAG', 'PRIME']:
139
- with gr.TabItem(dataset):
140
- df = gr.DataFrame(interactive=False)
141
- all_dfs.append(df)
142
 
143
- model_type_filter.change(
144
- update_tables,
145
- inputs=[model_type_filter],
146
- outputs=all_dfs
147
- )
 
148
 
149
- demo.load(
150
- update_tables,
151
- inputs=[model_type_filter],
152
- outputs=all_dfs
153
- )
154
 
155
- demo.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
+ from typing import List, Dict
5
+
6
 
7
  # Sample data based on your table (you'll need to update this with the full dataset)
8
  data_synthesized_full = {
 
53
  'STARK-PRIME_MRR': [30.37, 7.05, 10.07, 9.39, 26.35, 24.33, 15.24, 34.28, 32.98, 19.67, 36.32, 34.82]
54
  }
55
 
56
+ class DataManager:
57
+ def __init__(self, data_synthesized_full: Dict, data_synthesized_10: Dict, data_human_generated: Dict):
58
+ self.df_synthesized_full = pd.DataFrame(data_synthesized_full)
59
+ self.df_synthesized_10 = pd.DataFrame(data_synthesized_10)
60
+ self.df_human_generated = pd.DataFrame(data_human_generated)
61
+
62
+ self.model_types = {
63
+ 'Sparse Retriever': ['BM25'],
64
+ 'Small Dense Retrievers': ['DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)'],
65
+ 'LLM-based Dense Retrievers': ['ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b'],
66
+ 'Multivector Retrievers': ['multi-ada-002', 'ColBERTv2'],
67
+ 'LLM Rerankers': ['Claude3 Reranker', 'GPT4 Reranker']
68
+ }
69
+
70
+ self.metrics = ['Hit@1', 'Hit@5', 'R@20', 'MRR']
71
+ self.datasets = ['AMAZON', 'MAG', 'PRIME']
72
+
73
+ def filter_by_model_type(self, df: pd.DataFrame, selected_types: List[str]) -> pd.DataFrame:
74
+ if not selected_types:
75
+ return df.head(0)
76
+ selected_models = [model for type in selected_types for model in self.model_types[type]]
77
+ return df[df['Method'].isin(selected_models)]
78
+
79
+ def format_dataframe(self, df: pd.DataFrame, dataset: str) -> pd.DataFrame:
80
+ columns = ['Method'] + [col for col in df.columns if dataset in col]
81
+ filtered_df = df[columns].copy()
82
+ filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
83
+
84
+ # Format numeric columns to 2 decimal places
85
+ for col in filtered_df.columns:
86
+ if col != 'Method':
87
+ filtered_df[col] = filtered_df[col].round(2)
88
+
89
+ # Sort by MRR by default
90
+ filtered_df = filtered_df.sort_values('MRR', ascending=False)
91
+ return filtered_df
92
+
93
+ def get_best_model(self, df: pd.DataFrame, metric: str) -> str:
94
+ return df.loc[df[metric].idxmax(), 'Method']
95
+
96
+ # Custom components
97
+ def create_metric_summary(df: pd.DataFrame, dataset: str) -> str:
98
+ best_mrr = df['MRR'].max()
99
+ best_hit1 = df['Hit@1'].max()
100
+ best_model_mrr = df.loc[df['MRR'].idxmax(), 'Method']
101
+ best_model_hit1 = df.loc[df['Hit@1'].idxmax(), 'Method']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ return f"""
104
+ ### {dataset} Dataset Summary
105
+ - Best MRR: {best_mrr:.2f}% ({best_model_mrr})
106
+ - Best Hit@1: {best_hit1:.2f}% ({best_model_hit1})
107
+ """
108
+
109
+ # Main application
110
+ def create_app(data_manager: DataManager):
111
+ with gr.Blocks(css="""
112
+ .metric-summary { margin: 1rem 0; padding: 1rem; background: #f7f7f7; border-radius: 4px; }
113
+ .table-container { margin-top: 1rem; }
114
+ .model-filter { margin-bottom: 1rem; }
115
+ """) as demo:
116
+
117
+ gr.Markdown("# Semi-structured Retrieval Benchmark (STaRK) Leaderboard")
118
+ gr.Markdown("### An evaluation benchmark for semi-structured text retrieval")
119
+
120
+ with gr.Row():
121
+ with gr.Column(scale=3):
122
+ model_type_filter = gr.CheckboxGroup(
123
+ choices=list(data_manager.model_types.keys()),
124
+ value=list(data_manager.model_types.keys()),
125
+ label="Model Types",
126
+ interactive=True,
127
+ elem_classes=["model-filter"]
128
+ )
129
+
130
+ with gr.Column(scale=1):
131
+ sort_by = gr.Radio(
132
+ choices=data_manager.metrics,
133
+ value="MRR",
134
+ label="Sort by Metric",
135
+ interactive=True
136
+ )
137
+
138
+ all_dataframes = []
139
+
140
+ with gr.Tabs() as tabs:
141
+ data_sources = [
142
+ ("Synthesized (Full)", data_manager.df_synthesized_full),
143
+ ("Synthesized (10%)", data_manager.df_synthesized_10),
144
+ ("Human-Generated", data_manager.df_human_generated)
145
+ ]
146
+
147
+ for source_name, source_df in data_sources:
148
+ with gr.TabItem(source_name):
149
+ for dataset in data_manager.datasets:
150
+ with gr.Box():
151
+ gr.Markdown(create_metric_summary(
152
+ data_manager.format_dataframe(source_df, f"STARK-{dataset}"),
153
+ dataset
154
+ ))
155
+ df_display = gr.DataFrame(
156
+ interactive=False,
157
+ elem_classes=["table-container"]
158
+ )
159
+ all_dataframes.append(df_display)
160
+
161
+ def update_tables(selected_types: List[str], sort_metric: str):
162
+ outputs = []
163
+ for df_source in [data_manager.df_synthesized_full,
164
+ data_manager.df_synthesized_10,
165
+ data_manager.df_human_generated]:
166
+ filtered_df = data_manager.filter_by_model_type(df_source, selected_types)
167
+ for dataset in data_manager.datasets:
168
+ formatted_df = data_manager.format_dataframe(filtered_df, f"STARK-{dataset}")
169
+ formatted_df = formatted_df.sort_values(sort_metric, ascending=False)
170
+ outputs.append(formatted_df)
171
+ return outputs
172
+
173
+ # Register event handlers
174
+ model_type_filter.change(
175
+ update_tables,
176
+ inputs=[model_type_filter, sort_by],
177
+ outputs=all_dataframes
178
+ )
179
+
180
+ sort_by.change(
181
+ update_tables,
182
+ inputs=[model_type_filter, sort_by],
183
+ outputs=all_dataframes
184
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ # Initial load
187
+ demo.load(
188
+ update_tables,
189
+ inputs=[model_type_filter, sort_by],
190
+ outputs=all_dataframes
191
+ )
192
 
193
+ return demo
 
 
 
 
194
 
195
+ if __name__ == "__main__":
196
+ # Initialize data manager with your existing data
197
+ data_manager = DataManager(data_synthesized_full, data_synthesized_10, data_human_generated)
198
+
199
+ # Create and launch the app
200
+ demo = create_app(data_manager)
201
+ demo.launch()