File size: 8,514 Bytes
ad9f5e1
 
1aa50db
 
928143b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1aa50db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
928143b
 
 
1aa50db
031d82b
 
 
 
1aa50db
031d82b
 
1aa50db
031d82b
 
1aa50db
031d82b
1aa50db
928143b
 
 
 
 
 
 
 
 
f9fbc42
 
928143b
 
 
f9fbc42
 
 
 
 
 
3afe0be
f9fbc42
 
 
 
 
 
 
 
 
 
 
3afe0be
 
031d82b
 
3afe0be
928143b
 
 
 
 
 
 
 
031d82b
928143b
 
f9fbc42
 
928143b
f9fbc42
928143b
 
 
f9fbc42
 
928143b
f9fbc42
031d82b
 
 
f9fbc42
 
928143b
f9fbc42
928143b
f9fbc42
928143b
 
 
 
f9fbc42
928143b
 
 
 
 
f9fbc42
928143b
3afe0be
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
import pandas as pd
import numpy as np

# Sample data based on your table (you'll need to update this with the full dataset)
data_synthesized_full = {
    'Method': ['BM25', 'DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)', 'ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b', 'multi-ada-002', 'ColBERTv2'],
    'STARK-AMAZON_Hit@1': [44.94, 15.29, 30.96, 26.56, 39.16, 40.93, 21.74, 42.08, 40.07, 46.10],
    'STARK-AMAZON_Hit@5': [67.42, 47.93, 51.06, 50.01, 62.73, 64.37, 41.65, 66.87, 64.98, 66.02],
    'STARK-AMAZON_R@20': [53.77, 44.49, 41.95, 52.05, 53.29, 54.28, 33.22, 56.52, 55.12, 53.44],
    'STARK-AMAZON_MRR': [55.30, 30.20, 40.66, 37.75, 50.35, 51.60, 31.47, 53.46, 51.55, 55.51],
    'STARK-MAG_Hit@1': [25.85, 10.51, 21.96, 12.88, 29.08, 30.06, 18.01, 37.90, 25.92, 31.18],
    'STARK-MAG_Hit@5': [45.25, 35.23, 36.50, 39.01, 49.61, 50.58, 34.85, 56.74, 50.43, 46.42],
    'STARK-MAG_R@20': [45.69, 42.11, 35.32, 46.97, 48.36, 50.49, 35.46, 46.40, 50.80, 43.94],
    'STARK-MAG_MRR': [34.91, 21.34, 29.14, 29.12, 38.62, 39.66, 26.10, 47.25, 36.94, 38.39],
    'STARK-PRIME_Hit@1': [12.75, 4.46, 6.53, 8.85, 12.63, 10.85, 10.10, 15.57, 15.10, 11.75],
    'STARK-PRIME_Hit@5': [27.92, 21.85, 15.67, 21.35, 31.49, 30.23, 22.49, 33.42, 33.56, 23.85],
    'STARK-PRIME_R@20': [31.25, 30.13, 16.52, 29.63, 36.00, 37.83, 26.34, 39.09, 38.05, 25.04],
    'STARK-PRIME_MRR': [19.84, 12.38, 11.05, 14.73, 21.41, 19.99, 16.12, 24.11, 23.49, 17.39]
}

data_synthesized_10 = {
    'Method': ['BM25', 'DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)', 'ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b', 'multi-ada-002', 'ColBERTv2', 'Claude3 Reranker', 'GPT4 Reranker'],
    'STARK-AMAZON_Hit@1': [42.68, 16.46, 30.09, 25.00, 39.02, 43.29, 18.90, 43.29, 40.85, 44.31, 45.49, 44.79],
    'STARK-AMAZON_Hit@5': [67.07, 50.00, 49.27, 48.17, 64.02, 67.68, 37.80, 71.34, 62.80, 65.24, 71.13, 71.17],
    'STARK-AMAZON_R@20': [54.48, 42.15, 41.91, 51.65, 49.30, 56.04, 34.73, 56.14, 52.47, 51.00, 53.77, 55.35],
    'STARK-AMAZON_MRR': [54.02, 30.20, 39.30, 36.87, 50.32, 54.20, 28.76, 55.07, 51.54, 55.07, 55.91, 55.69],
    'STARK-MAG_Hit@1': [27.81, 11.65, 22.89, 12.03, 28.20, 34.59, 19.17, 38.35, 25.56, 31.58, 36.54, 40.90],
    'STARK-MAG_Hit@5': [45.48, 36.84, 37.26, 37.97, 52.63, 50.75, 33.46, 58.64, 50.37, 47.36, 53.17, 58.18],
    'STARK-MAG_R@20': [44.59, 42.30, 44.16, 47.98, 49.25, 50.75, 29.85, 46.38, 53.03, 45.72, 48.36, 48.60],
    'STARK-MAG_MRR': [35.97, 21.82, 30.00, 28.70, 38.55, 42.90, 26.06, 48.25, 36.82, 38.98, 44.15, 49.00],
    'STARK-PRIME_Hit@1': [13.93, 5.00, 6.78, 7.14, 15.36, 12.14, 9.29, 16.79, 15.36, 15.00, 17.79, 18.28],
    'STARK-PRIME_Hit@5': [31.07, 23.57, 16.15, 17.14, 31.07, 31.42, 20.7, 34.29, 32.86, 26.07, 36.90, 37.28],
    'STARK-PRIME_R@20': [32.84, 30.50, 17.07, 32.95, 37.88, 37.34, 25.54, 41.11, 40.99, 27.78, 35.57, 34.05],
    'STARK-PRIME_MRR': [21.68, 13.50, 11.42, 16.27, 23.50, 21.23, 15.00, 24.99, 23.70, 19.98, 26.27, 26.55]
}

data_human_generated = {
    'Method': ['BM25', 'DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)', 'ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b', 'multi-ada-002', 'ColBERTv2', 'Claude3 Reranker', 'GPT4 Reranker'],
    'STARK-AMAZON_Hit@1': [27.16, 16.05, 25.93, 22.22, 39.50, 35.80, 29.63, 40.74, 46.91, 33.33, 53.09, 50.62],
    'STARK-AMAZON_Hit@5': [51.85, 39.51, 54.32, 49.38, 64.19, 62.96, 46.91, 71.60, 72.84, 55.56, 74.07, 75.31],
    'STARK-AMAZON_R@20': [29.23, 15.23, 23.69, 21.54, 35.46, 33.01, 21.21, 36.30, 40.22, 29.03, 35.46, 35.46],
    'STARK-AMAZON_MRR': [18.79, 27.21, 37.12, 31.33, 52.65, 47.84, 38.61, 53.21, 58.74, 43.77, 62.11, 61.06],
    'STARK-MAG_Hit@1': [32.14, 4.72, 25.00, 20.24, 28.57, 22.62, 16.67, 34.52, 23.81, 33.33, 38.10, 36.90],
    'STARK-MAG_Hit@5': [41.67, 9.52, 30.95, 26.19, 41.67, 36.90, 28.57, 44.04, 41.67, 36.90, 45.24, 46.43],
    'STARK-MAG_R@20': [32.46, 25.00, 27.24, 28.76, 35.95, 32.44, 21.74, 34.57, 39.85, 30.50, 35.95, 35.95],
    'STARK-MAG_MRR': [37.42, 7.90, 27.98, 25.53, 35.81, 29.68, 21.59, 38.72, 31.43, 35.97, 42.00, 40.65],
    'STARK-PRIME_Hit@1': [22.45, 2.04, 7.14, 6.12, 17.35, 16.33, 9.18, 25.51, 24.49, 15.31, 28.57, 28.57],
    'STARK-PRIME_Hit@5': [41.84, 9.18, 13.27, 13.27, 34.69, 32.65, 21.43, 41.84, 39.80, 26.53, 46.94, 44.90],
    'STARK-PRIME_R@20': [42.32, 10.69, 11.72, 17.62, 41.09, 39.01, 26.77, 48.10, 47.21, 25.56, 41.61, 41.61],
    'STARK-PRIME_MRR': [30.37, 7.05, 10.07, 9.39, 26.35, 24.33, 15.24, 34.28, 32.98, 19.67, 36.32, 34.82]
}

df_synthesized_full = pd.DataFrame(data_synthesized_full)
df_synthesized_10 = pd.DataFrame(data_synthesized_10)
df_human_generated = pd.DataFrame(data_human_generated)

def format_dataframe(df, dataset):
    # Filter the dataframe for the selected dataset
    columns = ['Method'] + [col for col in df.columns if dataset in col]
    filtered_df = df[columns].copy()
    
    # Rename columns
    filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
    
    # Sort by MRR
    filtered_df = filtered_df.sort_values('MRR', ascending=False)
    
    return filtered_df

model_types = {
    'Sparse Retriever': ['BM25'],
    'Small Dense Retrievers': ['DPR (roberta)', 'ANCE (roberta)', 'QAGNN (roberta)'],
    'LLM-based Dense Retrievers': ['ada-002', 'voyage-l2-instruct', 'LLM2Vec', 'GritLM-7b'],
    'Multivector Retrievers': ['multi-ada-002', 'ColBERTv2'],
    'LLM Rerankers': ['Claude3 Reranker', 'GPT4 Reranker']
}

def filter_by_model_type(df, selected_types):
    if not selected_types:  # If no types are selected, return an empty DataFrame
        return df.head(0)
    selected_models = [model for type in selected_types for model in model_types[type]]
    return df[df['Method'].isin(selected_models)]

def format_dataframe(df, dataset):
    columns = ['Method'] + [col for col in df.columns if dataset in col]
    filtered_df = df[columns].copy()
    filtered_df.columns = [col.split('_')[-1] if '_' in col else col for col in filtered_df.columns]
    filtered_df = filtered_df.sort_values('MRR', ascending=False)
    return filtered_df

def update_tables(selected_types):
    filtered_df_full = filter_by_model_type(df_synthesized_full, selected_types)
    filtered_df_10 = filter_by_model_type(df_synthesized_10, selected_types)
    filtered_df_human = filter_by_model_type(df_human_generated, selected_types)
    
    outputs = []
    for df in [filtered_df_full, filtered_df_10, filtered_df_human]:
        for dataset in ['AMAZON', 'MAG', 'PRIME']:
            outputs.append(format_dataframe(df, f"STARK-{dataset}"))
    
    return outputs

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Semi-structured Retrieval Benchmark (STaRK) Leaderboard")
    gr.Markdown("Refer to the [STaRK paper](https://arxiv.org/pdf/2404.13207) for details on metrics, tasks and models.")
    
    with gr.Row():
        model_type_filter = gr.CheckboxGroup(
            choices=list(model_types.keys()),
            value=list(model_types.keys()),
            label="Model types",
            interactive=True
        )
    
    with gr.Tabs() as outer_tabs:
        with gr.TabItem("Synthesized (full)"):
            with gr.Tabs() as inner_tabs_synthesized_full:
                syn_full_dfs = [gr.DataFrame(interactive=False) for _ in range(3)]
                for df, dataset in zip(syn_full_dfs, ['AMAZON', 'MAG', 'PRIME']):
                    with gr.TabItem(dataset):
                        df.render()
        
        with gr.TabItem("Synthesized (10%)"):
            with gr.Tabs() as inner_tabs_synthesized_10:
                syn_10_dfs = [gr.DataFrame(interactive=False) for _ in range(3)]
                for df, dataset in zip(syn_10_dfs, ['AMAZON', 'MAG', 'PRIME']):
                    with gr.TabItem(dataset):
                        df.render()
        
        with gr.TabItem("Human-Generated"):
            with gr.Tabs() as inner_tabs_human:
                human_dfs = [gr.DataFrame(interactive=False) for _ in range(3)]
                for df, dataset in zip(human_dfs, ['AMAZON', 'MAG', 'PRIME']):
                    with gr.TabItem(dataset):
                        df.render()

    all_dfs = syn_full_dfs + syn_10_dfs + human_dfs

    model_type_filter.change(
        update_tables,
        inputs=[model_type_filter],
        outputs=all_dfs
    )

    demo.load(
        update_tables,
        inputs=[model_type_filter],
        outputs=all_dfs
    )

demo.launch()