pors commited on
Commit
21891ff
·
1 Parent(s): a46f3d2

initial commit

Browse files
Files changed (3) hide show
  1. README.md +18 -1
  2. app.py +243 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -11,4 +11,21 @@ license: apache-2.0
11
  short_description: Alternative to the timm leaderboard
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  short_description: Alternative to the timm leaderboard
12
  ---
13
 
14
+ # Image Model Performance Analysis
15
+
16
+ This interactive tool analyzes and visualizes performance metrics of different image models based on benchmark data from the pytorch-image-models repository.
17
+
18
+ ## Features
19
+
20
+ - Select from various benchmark files
21
+ - Choose different metrics for X and Y axes
22
+ - Filter by model families
23
+ - Toggle log scales
24
+ - Interactive Plotly visualizations
25
+
26
+ ## Data Source
27
+
28
+ The benchmark data comes from the [pytorch-image-models](https://github.com/huggingface/pytorch-image-models) repository by Ross Wightman.
29
+
30
+ Based on the original notebook by Jeremy Howard.
31
+
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import requests
5
+ import re
6
+ import os
7
+ import glob
8
+
9
+ # Download the main results file
10
+ def download_main_results():
11
+ url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv"
12
+ if not os.path.exists('results-imagenet.csv'):
13
+ response = requests.get(url)
14
+ with open('results-imagenet.csv', 'wb') as f:
15
+ f.write(response.content)
16
+
17
+ def download_github_csvs_api(
18
+ repo="huggingface/pytorch-image-models",
19
+ folder="results",
20
+ filename_pattern=r"benchmark-.*\.csv",
21
+ output_dir="benchmarks"
22
+ ):
23
+ """Download benchmark CSV files from GitHub API."""
24
+ api_url = f"https://api.github.com/repos/{repo}/contents/{folder}"
25
+ r = requests.get(api_url)
26
+ if r.status_code != 200:
27
+ return []
28
+
29
+ files = r.json()
30
+ matched_files = [f['name'] for f in files if re.match(filename_pattern, f['name'])]
31
+
32
+ if not matched_files:
33
+ return []
34
+
35
+ raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/"
36
+ os.makedirs(output_dir, exist_ok=True)
37
+
38
+ for fname in matched_files:
39
+ raw_url = raw_base + fname
40
+ out_path = os.path.join(output_dir, fname)
41
+
42
+ if not os.path.exists(out_path): # Only download if not exists
43
+ resp = requests.get(raw_url)
44
+ if resp.ok:
45
+ with open(out_path, "wb") as f:
46
+ f.write(resp.content)
47
+
48
+ return matched_files
49
+
50
+ def load_main_data():
51
+ """Load the main ImageNet results."""
52
+ download_main_results()
53
+ df_results = pd.read_csv('results-imagenet.csv')
54
+ df_results['model_org'] = df_results['model']
55
+ df_results['model'] = df_results['model'].str.split('.').str[0]
56
+ return df_results
57
+
58
+ def get_data(benchmark_file, df_results):
59
+ """Process benchmark data and merge with main results."""
60
+ pattern = (
61
+ r'^(?:'
62
+ r'eva|'
63
+ r'maxx?vit(?:v2)?|'
64
+ r'coatnet|coatnext|'
65
+ r'convnext(?:v2)?|'
66
+ r'beit(?:v2)?|'
67
+ r'efficient(?:net(?:v2)?|former(?:v2)?|vit)|'
68
+ r'regnet[xyvz]?|'
69
+ r'levit|'
70
+ r'vitd?|'
71
+ r'swin(?:v2)?'
72
+ r')$'
73
+ )
74
+
75
+ if not os.path.exists(benchmark_file):
76
+ return pd.DataFrame()
77
+
78
+ df = pd.read_csv(benchmark_file).merge(df_results, on='model')
79
+ df['secs'] = 1. / df['infer_samples_per_sec']
80
+ df['family'] = df.model.str.extract('^([a-z]+?(?:v2)?)(?:\d|_|$)')
81
+ df = df[~df.model.str.endswith('gn')]
82
+ df.loc[df.model.str.contains('resnet.*d'),'family'] = df.loc[df.model.str.contains('resnet.*d'),'family'] + 'd'
83
+ return df[df.family.str.contains(pattern)]
84
+
85
+ def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y):
86
+ """Create the scatter plot based on user selections."""
87
+ df_results = load_main_data()
88
+ df = get_data(benchmark_file, df_results)
89
+
90
+ if df.empty:
91
+ return None
92
+
93
+ # Filter by selected families
94
+ if selected_families:
95
+ df = df[df['family'].isin(selected_families)]
96
+
97
+ if df.empty:
98
+ return None
99
+
100
+ # Create the plot
101
+ fig = px.scatter(
102
+ df,
103
+ width=1000,
104
+ height=800,
105
+ x=x_axis,
106
+ y=y_axis,
107
+ log_x=log_x,
108
+ log_y=log_y,
109
+ color='family',
110
+ hover_name='model_org',
111
+ hover_data=['infer_samples_per_sec', 'infer_img_size'],
112
+ title=f'Model Performance: {y_axis} vs {x_axis}'
113
+ )
114
+
115
+ return fig
116
+
117
+ def setup_interface():
118
+ """Set up the Gradio interface."""
119
+ # Download benchmark files
120
+ downloaded_files = download_github_csvs_api()
121
+
122
+ # Get available benchmark files
123
+ benchmark_files = glob.glob("benchmarks/benchmark-*.csv")
124
+ if not benchmark_files:
125
+ benchmark_files = ["No benchmark files found"]
126
+
127
+ # Load sample data to get families and columns
128
+ df_results = load_main_data()
129
+
130
+ # Relevant columns for plotting
131
+ plot_columns = [
132
+ 'top1', 'top5', 'infer_samples_per_sec',
133
+ 'secs', 'param_count_x', 'infer_img_size'
134
+ ]
135
+
136
+ # Get families from a sample file (if available)
137
+ families = []
138
+ if benchmark_files and benchmark_files[0] != "No benchmark files found":
139
+ sample_df = get_data(benchmark_files[0], df_results)
140
+ if not sample_df.empty:
141
+ families = sorted(sample_df['family'].unique().tolist())
142
+
143
+ return benchmark_files, plot_columns, families
144
+
145
+ # Initialize the interface
146
+ benchmark_files, plot_columns, families = setup_interface()
147
+
148
+ # Create the Gradio interface
149
+ with gr.Blocks(title="Image Model Performance Analysis") as demo:
150
+ gr.Markdown("# Image Model Performance Analysis")
151
+ gr.Markdown("Analyze and visualize performance metrics of different image models based on benchmark data.")
152
+
153
+ with gr.Row():
154
+ with gr.Column(scale=1):
155
+ benchmark_dropdown = gr.Dropdown(
156
+ choices=benchmark_files,
157
+ value=benchmark_files[0] if benchmark_files else None,
158
+ label="Select Benchmark File"
159
+ )
160
+
161
+ x_axis_radio = gr.Radio(
162
+ choices=plot_columns,
163
+ value="secs",
164
+ label="X-axis"
165
+ )
166
+
167
+ y_axis_radio = gr.Radio(
168
+ choices=plot_columns,
169
+ value="top1",
170
+ label="Y-axis"
171
+ )
172
+
173
+ family_checkboxes = gr.CheckboxGroup(
174
+ choices=families,
175
+ value=families,
176
+ label="Select Model Families"
177
+ )
178
+
179
+ log_x_checkbox = gr.Checkbox(
180
+ value=True,
181
+ label="Log scale X-axis"
182
+ )
183
+
184
+ log_y_checkbox = gr.Checkbox(
185
+ value=False,
186
+ label="Log scale Y-axis"
187
+ )
188
+
189
+ update_button = gr.Button("Update Plot", variant="primary")
190
+
191
+ with gr.Column(scale=2):
192
+ plot_output = gr.Plot()
193
+
194
+ # Update plot when button is clicked
195
+ update_button.click(
196
+ fn=create_plot,
197
+ inputs=[
198
+ benchmark_dropdown,
199
+ x_axis_radio,
200
+ y_axis_radio,
201
+ family_checkboxes,
202
+ log_x_checkbox,
203
+ log_y_checkbox
204
+ ],
205
+ outputs=plot_output
206
+ )
207
+
208
+ # Auto-update when benchmark file changes
209
+ def update_families(benchmark_file):
210
+ if not benchmark_file or benchmark_file == "No benchmark files found":
211
+ return gr.CheckboxGroup(choices=[], value=[])
212
+
213
+ df_results = load_main_data()
214
+ df = get_data(benchmark_file, df_results)
215
+ if df.empty:
216
+ return gr.CheckboxGroup(choices=[], value=[])
217
+
218
+ new_families = sorted(df['family'].unique().tolist())
219
+ return gr.CheckboxGroup(choices=new_families, value=new_families)
220
+
221
+ benchmark_dropdown.change(
222
+ fn=update_families,
223
+ inputs=benchmark_dropdown,
224
+ outputs=family_checkboxes
225
+ )
226
+
227
+ # Load initial plot
228
+ demo.load(
229
+ fn=create_plot,
230
+ inputs=[
231
+ benchmark_dropdown,
232
+ x_axis_radio,
233
+ y_axis_radio,
234
+ family_checkboxes,
235
+ log_x_checkbox,
236
+ log_y_checkbox
237
+ ],
238
+ outputs=plot_output
239
+ )
240
+
241
+ if __name__ == "__main__":
242
+ demo.launch()
243
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ pandas
3
+ plotly
4
+ requests
5
+