Shiyu Zhao commited on
Commit
2c8dbc2
·
1 Parent(s): b47352e

Update space

Browse files
Files changed (1) hide show
  1. app.py +97 -146
app.py CHANGED
@@ -1,8 +1,6 @@
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
4
- from typing import List, Dict
5
-
6
 
7
  # Sample data based on your table (you'll need to update this with the full dataset)
8
  data_synthesized_full = {
@@ -53,152 +51,105 @@ data_human_generated = {
53
  'STARK-PRIME_MRR': [30.37, 7.05, 10.07, 9.39, 26.35, 24.33, 15.24, 34.28, 32.98, 19.67, 36.32, 34.82]
54
  }
55
 
56
- class DataManager:
57
- def __init__(self, data_synthesized_full: Dict, data_synthesized_10: Dict, data_human_generated: Dict):
58
- self.df_synthesized_full = pd.DataFrame(data_synthesized_full)
59
- self.df_synthesized_10 = pd.DataFrame(data_synthesized_10)
60
- self.df_human_generated = pd.DataFrame(data_human_generated)
61
-
62
- self.model_types = {
63
- 'Sparse Retriever': ['BM25'],
64
- 'Small Dense Retrievers': ['DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)'],
65
- 'LLM-based Dense Retrievers': ['ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b'],
66
- 'Multivector Retrievers': ['multi-ada-002', 'ColBERTv2'],
67
- 'LLM Rerankers': ['Claude3 Reranker', 'GPT4 Reranker']
68
- }
69
-
70
- self.metrics = ['Hit@1', 'Hit@5', 'R@20', 'MRR']
71
- self.datasets = ['AMAZON', 'MAG', 'PRIME']
72
-
73
- def filter_by_model_type(self, df: pd.DataFrame, selected_types: List[str]) -> pd.DataFrame:
74
- if not selected_types:
75
- return df.head(0)
76
- selected_models = [model for type in selected_types for model in self.model_types[type]]
77
- return df[df['Method'].isin(selected_models)]
78
-
79
- def format_dataframe(self, df: pd.DataFrame, dataset: str) -> pd.DataFrame:
80
- columns = ['Method'] + [col for col in df.columns if dataset in col]
81
- filtered_df = df[columns].copy()
82
- filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
83
-
84
- # Format numeric columns to 2 decimal places
85
- for col in filtered_df.columns:
86
- if col != 'Method':
87
- filtered_df[col] = filtered_df[col].round(2)
88
-
89
- # Sort by MRR by default
90
- filtered_df = filtered_df.sort_values('MRR', ascending=False)
91
- return filtered_df
92
-
93
- def get_best_model(self, df: pd.DataFrame, metric: str) -> str:
94
- return df.loc[df[metric].idxmax(), 'Method']
95
-
96
- # Custom components
97
- def create_metric_summary(df: pd.DataFrame, dataset: str) -> str:
98
- best_mrr = df['MRR'].max()
99
- best_hit1 = df['Hit@1'].max()
100
- best_model_mrr = df.loc[df['MRR'].idxmax(), 'Method']
101
- best_model_hit1 = df.loc[df['Hit@1'].idxmax(), 'Method']
102
 
103
- return f"""
104
- ### {dataset} Dataset Summary
105
- - Best MRR: {best_mrr:.2f}% ({best_model_mrr})
106
- - Best Hit@1: {best_hit1:.2f}% ({best_model_hit1})
107
- """
108
-
109
- # Main application
110
- def create_app(data_manager: DataManager):
111
- with gr.Blocks(css="""
112
- .metric-summary { margin: 1rem 0; padding: 1rem; background: #f7f7f7; border-radius: 4px; }
113
- .table-container { margin-top: 1rem; }
114
- .model-filter { margin-bottom: 1rem; }
115
- .dataset-section { border: 1px solid #ddd; padding: 1rem; margin: 1rem 0; border-radius: 4px; }
116
- """) as demo:
117
-
118
- gr.Markdown("# Semi-structured Retrieval Benchmark (STaRK) Leaderboard")
119
- gr.Markdown("### An evaluation benchmark for semi-structured text retrieval")
120
- gr.Markdown("Refer to the [STaRK paper](https://arxiv.org/pdf/2404.13207) for details on metrics, tasks and models.")
121
-
122
- with gr.Row():
123
- with gr.Column(scale=3):
124
- model_type_filter = gr.CheckboxGroup(
125
- choices=list(data_manager.model_types.keys()),
126
- value=list(data_manager.model_types.keys()),
127
- label="Model Types",
128
- interactive=True,
129
- elem_classes=["model-filter"]
130
- )
131
-
132
- with gr.Column(scale=1):
133
- sort_by = gr.Radio(
134
- choices=data_manager.metrics,
135
- value="MRR",
136
- label="Sort by Metric",
137
- interactive=True
138
- )
139
-
140
- all_dataframes = []
141
-
142
- with gr.Tabs() as tabs:
143
- data_sources = [
144
- ("Synthesized (Full)", data_manager.df_synthesized_full),
145
- ("Synthesized (10%)", data_manager.df_synthesized_10),
146
- ("Human-Generated", data_manager.df_human_generated)
147
- ]
148
-
149
- for source_name, source_df in data_sources:
150
- with gr.TabItem(source_name):
151
- for dataset in data_manager.datasets:
152
- with gr.Row(elem_classes=["dataset-section"]):
153
- with gr.Column():
154
- gr.Markdown(create_metric_summary(
155
- data_manager.format_dataframe(source_df, f"STARK-{dataset}"),
156
- dataset
157
- ))
158
- df_display = gr.DataFrame(
159
- interactive=False,
160
- elem_classes=["table-container"]
161
- )
162
- all_dataframes.append(df_display)
163
-
164
- def update_tables(selected_types: List[str], sort_metric: str):
165
- outputs = []
166
- for df_source in [data_manager.df_synthesized_full,
167
- data_manager.df_synthesized_10,
168
- data_manager.df_human_generated]:
169
- filtered_df = data_manager.filter_by_model_type(df_source, selected_types)
170
- for dataset in data_manager.datasets:
171
- formatted_df = data_manager.format_dataframe(filtered_df, f"STARK-{dataset}")
172
- formatted_df = formatted_df.sort_values(sort_metric, ascending=False)
173
- outputs.append(formatted_df)
174
- return outputs
175
-
176
- # Register event handlers
177
- model_type_filter.change(
178
- update_tables,
179
- inputs=[model_type_filter, sort_by],
180
- outputs=all_dataframes
181
- )
182
-
183
- sort_by.change(
184
- update_tables,
185
- inputs=[model_type_filter, sort_by],
186
- outputs=all_dataframes
187
- )
188
 
189
- # Initial load
190
- demo.load(
191
- update_tables,
192
- inputs=[model_type_filter, sort_by],
193
- outputs=all_dataframes
194
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
- return demo
 
 
 
 
 
 
 
197
 
198
- if __name__ == "__main__":
199
- # Initialize data manager with your existing data
200
- data_manager = DataManager(data_synthesized_full, data_synthesized_10, data_human_generated)
201
 
202
- # Create and launch the app
203
- demo = create_app(data_manager)
204
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import numpy as np
 
 
4
 
5
  # Sample data based on your table (you'll need to update this with the full dataset)
6
  data_synthesized_full = {
 
51
  'STARK-PRIME_MRR': [30.37, 7.05, 10.07, 9.39, 26.35, 24.33, 15.24, 34.28, 32.98, 19.67, 36.32, 34.82]
52
  }
53
 
54
+ df_synthesized_full = pd.DataFrame(data_synthesized_full)
55
+ df_synthesized_10 = pd.DataFrame(data_synthesized_10)
56
+ df_human_generated = pd.DataFrame(data_human_generated)
57
+
58
+ def format_dataframe(df, dataset):
59
+ # Filter the dataframe for the selected dataset
60
+ columns = ['Method'] + [col for col in df.columns if dataset in col]
61
+ filtered_df = df[columns].copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # Rename columns
64
+ filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
65
+
66
+ # Sort by MRR
67
+ filtered_df = filtered_df.sort_values('MRR', ascending=False)
68
+
69
+ return filtered_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ model_types = {
72
+ 'Sparse Retriever': ['BM25'],
73
+ 'Small Dense Retrievers': ['DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)'],
74
+ 'LLM-based Dense Retrievers': ['ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b'],
75
+ 'Multivector Retrievers': ['multi-ada-002', 'ColBERTv2'],
76
+ 'LLM Rerankers': ['Claude3 Reranker', 'GPT4 Reranker']
77
+ }
78
+
79
+ def filter_by_model_type(df, selected_types):
80
+ if not selected_types: # If no types are selected, return an empty DataFrame
81
+ return df.head(0)
82
+ selected_models = [model for type in selected_types for model in model_types[type]]
83
+ return df[df['Method'].isin(selected_models)]
84
+
85
+ def format_dataframe(df, dataset):
86
+ columns = ['Method'] + [col for col in df.columns if dataset in col]
87
+ filtered_df = df[columns].copy()
88
+ filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
89
+ filtered_df = filtered_df.sort_values('MRR', ascending=False)
90
+ return filtered_df
91
+
92
+ def update_tables(selected_types):
93
+ filtered_df_full = filter_by_model_type(df_synthesized_full, selected_types)
94
+ filtered_df_10 = filter_by_model_type(df_synthesized_10, selected_types)
95
+ filtered_df_human = filter_by_model_type(df_human_generated, selected_types)
96
+
97
+ outputs = []
98
+ for df in [filtered_df_full, filtered_df_10, filtered_df_human]:
99
+ for dataset in ['AMAZON', 'MAG', 'PRIME']:
100
+ outputs.append(format_dataframe(df, f"STARK-{dataset}"))
101
+
102
+ return outputs
103
+
104
+ css = """
105
+ table > thead {
106
+ white-space: normal
107
+ }
108
 
109
+ table {
110
+ --cell-width-1: 250px
111
+ }
112
+
113
+ table > tbody > tr > td:nth-child(2) > div {
114
+ overflow-x: auto
115
+ }
116
+ """
117
 
118
+ with gr.Blocks(css=css) as demo:
119
+ gr.Markdown("# Semi-structured Retrieval Benchmark (STaRK) Leaderboard")
120
+ gr.Markdown("Refer to the [STaRK paper](https://arxiv.org/pdf/2404.13207) for details on metrics, tasks and models.")
121
 
122
+ with gr.Row():
123
+ model_type_filter = gr.CheckboxGroup(
124
+ choices=list(model_types.keys()),
125
+ value=list(model_types.keys()),
126
+ label="Model types",
127
+ interactive=True
128
+ )
129
+
130
+ all_dfs = []
131
+
132
+ with gr.Tabs() as outer_tabs:
133
+ for tab_name, df_source in [("Synthesized (full)", df_synthesized_full),
134
+ ("Synthesized (10%)", df_synthesized_10),
135
+ ("Human-Generated", df_human_generated)]:
136
+ with gr.TabItem(tab_name):
137
+ with gr.Tabs() as inner_tabs:
138
+ for dataset in ['AMAZON', 'MAG', 'PRIME']:
139
+ with gr.TabItem(dataset):
140
+ df = gr.DataFrame(interactive=False)
141
+ all_dfs.append(df)
142
+
143
+ model_type_filter.change(
144
+ update_tables,
145
+ inputs=[model_type_filter],
146
+ outputs=all_dfs
147
+ )
148
+
149
+ demo.load(
150
+ update_tables,
151
+ inputs=[model_type_filter],
152
+ outputs=all_dfs
153
+ )
154
+
155
+ demo.launch()