pors commited on
Commit
10f9258
·
1 Parent(s): 21891ff

set default benchmark file

Browse files
Files changed (1) hide show
  1. app.py +109 -102
app.py CHANGED
@@ -6,191 +6,201 @@ import re
6
  import os
7
  import glob
8
 
 
9
  # Download the main results file
10
  def download_main_results():
11
  url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv"
12
- if not os.path.exists('results-imagenet.csv'):
13
  response = requests.get(url)
14
- with open('results-imagenet.csv', 'wb') as f:
15
  f.write(response.content)
16
 
 
17
  def download_github_csvs_api(
18
  repo="huggingface/pytorch-image-models",
19
  folder="results",
20
  filename_pattern=r"benchmark-.*\.csv",
21
- output_dir="benchmarks"
22
  ):
23
  """Download benchmark CSV files from GitHub API."""
24
  api_url = f"https://api.github.com/repos/{repo}/contents/{folder}"
25
  r = requests.get(api_url)
26
  if r.status_code != 200:
27
  return []
28
-
29
  files = r.json()
30
- matched_files = [f['name'] for f in files if re.match(filename_pattern, f['name'])]
31
-
32
  if not matched_files:
33
  return []
34
-
35
  raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/"
36
  os.makedirs(output_dir, exist_ok=True)
37
-
38
  for fname in matched_files:
39
  raw_url = raw_base + fname
40
  out_path = os.path.join(output_dir, fname)
41
-
42
  if not os.path.exists(out_path): # Only download if not exists
43
  resp = requests.get(raw_url)
44
  if resp.ok:
45
  with open(out_path, "wb") as f:
46
  f.write(resp.content)
47
-
48
  return matched_files
49
 
 
50
  def load_main_data():
51
  """Load the main ImageNet results."""
52
  download_main_results()
53
- df_results = pd.read_csv('results-imagenet.csv')
54
- df_results['model_org'] = df_results['model']
55
- df_results['model'] = df_results['model'].str.split('.').str[0]
56
  return df_results
57
 
 
58
  def get_data(benchmark_file, df_results):
59
  """Process benchmark data and merge with main results."""
60
  pattern = (
61
- r'^(?:'
62
- r'eva|'
63
- r'maxx?vit(?:v2)?|'
64
- r'coatnet|coatnext|'
65
- r'convnext(?:v2)?|'
66
- r'beit(?:v2)?|'
67
- r'efficient(?:net(?:v2)?|former(?:v2)?|vit)|'
68
- r'regnet[xyvz]?|'
69
- r'levit|'
70
- r'vitd?|'
71
- r'swin(?:v2)?'
72
- r')$'
73
  )
74
-
75
  if not os.path.exists(benchmark_file):
76
  return pd.DataFrame()
77
-
78
- df = pd.read_csv(benchmark_file).merge(df_results, on='model')
79
- df['secs'] = 1. / df['infer_samples_per_sec']
80
- df['family'] = df.model.str.extract('^([a-z]+?(?:v2)?)(?:\d|_|$)')
81
- df = df[~df.model.str.endswith('gn')]
82
- df.loc[df.model.str.contains('resnet.*d'),'family'] = df.loc[df.model.str.contains('resnet.*d'),'family'] + 'd'
 
 
83
  return df[df.family.str.contains(pattern)]
84
 
 
85
  def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y):
86
  """Create the scatter plot based on user selections."""
87
  df_results = load_main_data()
88
  df = get_data(benchmark_file, df_results)
89
-
90
  if df.empty:
91
  return None
92
-
93
  # Filter by selected families
94
  if selected_families:
95
- df = df[df['family'].isin(selected_families)]
96
-
97
  if df.empty:
98
  return None
99
-
100
  # Create the plot
101
  fig = px.scatter(
102
- df,
103
- width=1000,
104
  height=800,
105
- x=x_axis,
106
- y=y_axis,
107
- log_x=log_x,
108
  log_y=log_y,
109
- color='family',
110
- hover_name='model_org',
111
- hover_data=['infer_samples_per_sec', 'infer_img_size'],
112
- title=f'Model Performance: {y_axis} vs {x_axis}'
113
  )
114
-
115
  return fig
116
 
 
117
  def setup_interface():
118
  """Set up the Gradio interface."""
119
  # Download benchmark files
120
  downloaded_files = download_github_csvs_api()
121
-
122
  # Get available benchmark files
123
  benchmark_files = glob.glob("benchmarks/benchmark-*.csv")
124
  if not benchmark_files:
125
  benchmark_files = ["No benchmark files found"]
126
-
127
  # Load sample data to get families and columns
128
  df_results = load_main_data()
129
-
130
  # Relevant columns for plotting
131
  plot_columns = [
132
- 'top1', 'top5', 'infer_samples_per_sec',
133
- 'secs', 'param_count_x', 'infer_img_size'
 
 
 
 
134
  ]
135
-
136
  # Get families from a sample file (if available)
137
  families = []
138
  if benchmark_files and benchmark_files[0] != "No benchmark files found":
139
  sample_df = get_data(benchmark_files[0], df_results)
140
  if not sample_df.empty:
141
- families = sorted(sample_df['family'].unique().tolist())
142
-
143
  return benchmark_files, plot_columns, families
144
 
 
145
  # Initialize the interface
146
  benchmark_files, plot_columns, families = setup_interface()
147
 
148
  # Create the Gradio interface
149
  with gr.Blocks(title="Image Model Performance Analysis") as demo:
150
  gr.Markdown("# Image Model Performance Analysis")
151
- gr.Markdown("Analyze and visualize performance metrics of different image models based on benchmark data.")
152
-
 
 
153
  with gr.Row():
154
  with gr.Column(scale=1):
155
- benchmark_dropdown = gr.Dropdown(
156
- choices=benchmark_files,
157
- value=benchmark_files[0] if benchmark_files else None,
158
- label="Select Benchmark File"
159
  )
160
-
161
- x_axis_radio = gr.Radio(
162
- choices=plot_columns,
163
- value="secs",
164
- label="X-axis"
165
  )
166
-
167
- y_axis_radio = gr.Radio(
168
- choices=plot_columns,
169
- value="top1",
170
- label="Y-axis"
171
  )
172
-
 
 
 
 
173
  family_checkboxes = gr.CheckboxGroup(
174
- choices=families,
175
- value=families,
176
- label="Select Model Families"
177
- )
178
-
179
- log_x_checkbox = gr.Checkbox(
180
- value=True,
181
- label="Log scale X-axis"
182
  )
183
-
184
- log_y_checkbox = gr.Checkbox(
185
- value=False,
186
- label="Log scale Y-axis"
187
- )
188
-
189
  update_button = gr.Button("Update Plot", variant="primary")
190
-
191
  with gr.Column(scale=2):
192
  plot_output = gr.Plot()
193
-
194
  # Update plot when button is clicked
195
  update_button.click(
196
  fn=create_plot,
@@ -200,30 +210,28 @@ with gr.Blocks(title="Image Model Performance Analysis") as demo:
200
  y_axis_radio,
201
  family_checkboxes,
202
  log_x_checkbox,
203
- log_y_checkbox
204
  ],
205
- outputs=plot_output
206
  )
207
-
208
  # Auto-update when benchmark file changes
209
  def update_families(benchmark_file):
210
  if not benchmark_file or benchmark_file == "No benchmark files found":
211
  return gr.CheckboxGroup(choices=[], value=[])
212
-
213
  df_results = load_main_data()
214
  df = get_data(benchmark_file, df_results)
215
  if df.empty:
216
  return gr.CheckboxGroup(choices=[], value=[])
217
-
218
- new_families = sorted(df['family'].unique().tolist())
219
  return gr.CheckboxGroup(choices=new_families, value=new_families)
220
-
221
  benchmark_dropdown.change(
222
- fn=update_families,
223
- inputs=benchmark_dropdown,
224
- outputs=family_checkboxes
225
  )
226
-
227
  # Load initial plot
228
  demo.load(
229
  fn=create_plot,
@@ -233,11 +241,10 @@ with gr.Blocks(title="Image Model Performance Analysis") as demo:
233
  y_axis_radio,
234
  family_checkboxes,
235
  log_x_checkbox,
236
- log_y_checkbox
237
  ],
238
- outputs=plot_output
239
  )
240
 
241
  if __name__ == "__main__":
242
  demo.launch()
243
-
 
6
  import os
7
  import glob
8
 
9
+
10
  # Download the main results file
11
  def download_main_results():
12
  url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv"
13
+ if not os.path.exists("results-imagenet.csv"):
14
  response = requests.get(url)
15
+ with open("results-imagenet.csv", "wb") as f:
16
  f.write(response.content)
17
 
18
+
19
  def download_github_csvs_api(
20
  repo="huggingface/pytorch-image-models",
21
  folder="results",
22
  filename_pattern=r"benchmark-.*\.csv",
23
+ output_dir="benchmarks",
24
  ):
25
  """Download benchmark CSV files from GitHub API."""
26
  api_url = f"https://api.github.com/repos/{repo}/contents/{folder}"
27
  r = requests.get(api_url)
28
  if r.status_code != 200:
29
  return []
30
+
31
  files = r.json()
32
+ matched_files = [f["name"] for f in files if re.match(filename_pattern, f["name"])]
33
+
34
  if not matched_files:
35
  return []
36
+
37
  raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/"
38
  os.makedirs(output_dir, exist_ok=True)
39
+
40
  for fname in matched_files:
41
  raw_url = raw_base + fname
42
  out_path = os.path.join(output_dir, fname)
43
+
44
  if not os.path.exists(out_path): # Only download if not exists
45
  resp = requests.get(raw_url)
46
  if resp.ok:
47
  with open(out_path, "wb") as f:
48
  f.write(resp.content)
49
+
50
  return matched_files
51
 
52
+
53
  def load_main_data():
54
  """Load the main ImageNet results."""
55
  download_main_results()
56
+ df_results = pd.read_csv("results-imagenet.csv")
57
+ df_results["model_org"] = df_results["model"]
58
+ df_results["model"] = df_results["model"].str.split(".").str[0]
59
  return df_results
60
 
61
+
62
  def get_data(benchmark_file, df_results):
63
  """Process benchmark data and merge with main results."""
64
  pattern = (
65
+ r"^(?:"
66
+ r"eva|"
67
+ r"maxx?vit(?:v2)?|"
68
+ r"coatnet|coatnext|"
69
+ r"convnext(?:v2)?|"
70
+ r"beit(?:v2)?|"
71
+ r"efficient(?:net(?:v2)?|former(?:v2)?|vit)|"
72
+ r"regnet[xyvz]?|"
73
+ r"levit|"
74
+ r"vitd?|"
75
+ r"swin(?:v2)?"
76
+ r")$"
77
  )
78
+
79
  if not os.path.exists(benchmark_file):
80
  return pd.DataFrame()
81
+
82
+ df = pd.read_csv(benchmark_file).merge(df_results, on="model")
83
+ df["secs"] = 1.0 / df["infer_samples_per_sec"]
84
+ df["family"] = df.model.str.extract("^([a-z]+?(?:v2)?)(?:\d|_|$)")
85
+ df = df[~df.model.str.endswith("gn")]
86
+ df.loc[df.model.str.contains("resnet.*d"), "family"] = (
87
+ df.loc[df.model.str.contains("resnet.*d"), "family"] + "d"
88
+ )
89
  return df[df.family.str.contains(pattern)]
90
 
91
+
92
  def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y):
93
  """Create the scatter plot based on user selections."""
94
  df_results = load_main_data()
95
  df = get_data(benchmark_file, df_results)
96
+
97
  if df.empty:
98
  return None
99
+
100
  # Filter by selected families
101
  if selected_families:
102
+ df = df[df["family"].isin(selected_families)]
103
+
104
  if df.empty:
105
  return None
106
+
107
  # Create the plot
108
  fig = px.scatter(
109
+ df,
110
+ width=1000,
111
  height=800,
112
+ x=x_axis,
113
+ y=y_axis,
114
+ log_x=log_x,
115
  log_y=log_y,
116
+ color="family",
117
+ hover_name="model_org",
118
+ hover_data=["infer_samples_per_sec", "infer_img_size"],
119
+ title=f"Model Performance: {y_axis} vs {x_axis}",
120
  )
121
+
122
  return fig
123
 
124
+
125
  def setup_interface():
126
  """Set up the Gradio interface."""
127
  # Download benchmark files
128
  downloaded_files = download_github_csvs_api()
129
+
130
  # Get available benchmark files
131
  benchmark_files = glob.glob("benchmarks/benchmark-*.csv")
132
  if not benchmark_files:
133
  benchmark_files = ["No benchmark files found"]
134
+
135
  # Load sample data to get families and columns
136
  df_results = load_main_data()
137
+
138
  # Relevant columns for plotting
139
  plot_columns = [
140
+ "top1",
141
+ "top5",
142
+ "infer_samples_per_sec",
143
+ "secs",
144
+ "param_count_x",
145
+ "infer_img_size",
146
  ]
147
+
148
  # Get families from a sample file (if available)
149
  families = []
150
  if benchmark_files and benchmark_files[0] != "No benchmark files found":
151
  sample_df = get_data(benchmark_files[0], df_results)
152
  if not sample_df.empty:
153
+ families = sorted(sample_df["family"].unique().tolist())
154
+
155
  return benchmark_files, plot_columns, families
156
 
157
+
158
  # Initialize the interface
159
  benchmark_files, plot_columns, families = setup_interface()
160
 
161
  # Create the Gradio interface
162
  with gr.Blocks(title="Image Model Performance Analysis") as demo:
163
  gr.Markdown("# Image Model Performance Analysis")
164
+ gr.Markdown(
165
+ "Analyze and visualize performance metrics of different image models based on benchmark data."
166
+ )
167
+
168
  with gr.Row():
169
  with gr.Column(scale=1):
170
+
171
+ # Set preferred default file
172
+ preferred_file = (
173
+ "benchmarks/benchmark-infer-amp-nhwc-pt240-cu124-rtx3090.csv"
174
  )
175
+ default_file = (
176
+ preferred_file
177
+ if preferred_file in benchmark_files
178
+ else (benchmark_files[0] if benchmark_files else None)
 
179
  )
180
+
181
+ benchmark_dropdown = gr.Dropdown(
182
+ choices=benchmark_files,
183
+ value=default_file,
184
+ label="Select Benchmark File",
185
  )
186
+
187
+ x_axis_radio = gr.Radio(choices=plot_columns, value="secs", label="X-axis")
188
+
189
+ y_axis_radio = gr.Radio(choices=plot_columns, value="top1", label="Y-axis")
190
+
191
  family_checkboxes = gr.CheckboxGroup(
192
+ choices=families, value=families, label="Select Model Families"
 
 
 
 
 
 
 
193
  )
194
+
195
+ log_x_checkbox = gr.Checkbox(value=True, label="Log scale X-axis")
196
+
197
+ log_y_checkbox = gr.Checkbox(value=False, label="Log scale Y-axis")
198
+
 
199
  update_button = gr.Button("Update Plot", variant="primary")
200
+
201
  with gr.Column(scale=2):
202
  plot_output = gr.Plot()
203
+
204
  # Update plot when button is clicked
205
  update_button.click(
206
  fn=create_plot,
 
210
  y_axis_radio,
211
  family_checkboxes,
212
  log_x_checkbox,
213
+ log_y_checkbox,
214
  ],
215
+ outputs=plot_output,
216
  )
217
+
218
  # Auto-update when benchmark file changes
219
  def update_families(benchmark_file):
220
  if not benchmark_file or benchmark_file == "No benchmark files found":
221
  return gr.CheckboxGroup(choices=[], value=[])
222
+
223
  df_results = load_main_data()
224
  df = get_data(benchmark_file, df_results)
225
  if df.empty:
226
  return gr.CheckboxGroup(choices=[], value=[])
227
+
228
+ new_families = sorted(df["family"].unique().tolist())
229
  return gr.CheckboxGroup(choices=new_families, value=new_families)
230
+
231
  benchmark_dropdown.change(
232
+ fn=update_families, inputs=benchmark_dropdown, outputs=family_checkboxes
 
 
233
  )
234
+
235
  # Load initial plot
236
  demo.load(
237
  fn=create_plot,
 
241
  y_axis_radio,
242
  family_checkboxes,
243
  log_x_checkbox,
244
+ log_y_checkbox,
245
  ],
246
+ outputs=plot_output,
247
  )
248
 
249
  if __name__ == "__main__":
250
  demo.launch()