pors's picture
initial commit
21891ff
raw
history blame
7.41 kB
import gradio as gr
import pandas as pd
import plotly.express as px
import requests
import re
import os
import glob
# Download the main results file
def download_main_results():
url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv"
if not os.path.exists('results-imagenet.csv'):
response = requests.get(url)
with open('results-imagenet.csv', 'wb') as f:
f.write(response.content)
def download_github_csvs_api(
repo="huggingface/pytorch-image-models",
folder="results",
filename_pattern=r"benchmark-.*\.csv",
output_dir="benchmarks"
):
"""Download benchmark CSV files from GitHub API."""
api_url = f"https://api.github.com/repos/{repo}/contents/{folder}"
r = requests.get(api_url)
if r.status_code != 200:
return []
files = r.json()
matched_files = [f['name'] for f in files if re.match(filename_pattern, f['name'])]
if not matched_files:
return []
raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/"
os.makedirs(output_dir, exist_ok=True)
for fname in matched_files:
raw_url = raw_base + fname
out_path = os.path.join(output_dir, fname)
if not os.path.exists(out_path): # Only download if not exists
resp = requests.get(raw_url)
if resp.ok:
with open(out_path, "wb") as f:
f.write(resp.content)
return matched_files
def load_main_data():
"""Load the main ImageNet results."""
download_main_results()
df_results = pd.read_csv('results-imagenet.csv')
df_results['model_org'] = df_results['model']
df_results['model'] = df_results['model'].str.split('.').str[0]
return df_results
def get_data(benchmark_file, df_results):
"""Process benchmark data and merge with main results."""
pattern = (
r'^(?:'
r'eva|'
r'maxx?vit(?:v2)?|'
r'coatnet|coatnext|'
r'convnext(?:v2)?|'
r'beit(?:v2)?|'
r'efficient(?:net(?:v2)?|former(?:v2)?|vit)|'
r'regnet[xyvz]?|'
r'levit|'
r'vitd?|'
r'swin(?:v2)?'
r')$'
)
if not os.path.exists(benchmark_file):
return pd.DataFrame()
df = pd.read_csv(benchmark_file).merge(df_results, on='model')
df['secs'] = 1. / df['infer_samples_per_sec']
df['family'] = df.model.str.extract('^([a-z]+?(?:v2)?)(?:\d|_|$)')
df = df[~df.model.str.endswith('gn')]
df.loc[df.model.str.contains('resnet.*d'),'family'] = df.loc[df.model.str.contains('resnet.*d'),'family'] + 'd'
return df[df.family.str.contains(pattern)]
def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y):
"""Create the scatter plot based on user selections."""
df_results = load_main_data()
df = get_data(benchmark_file, df_results)
if df.empty:
return None
# Filter by selected families
if selected_families:
df = df[df['family'].isin(selected_families)]
if df.empty:
return None
# Create the plot
fig = px.scatter(
df,
width=1000,
height=800,
x=x_axis,
y=y_axis,
log_x=log_x,
log_y=log_y,
color='family',
hover_name='model_org',
hover_data=['infer_samples_per_sec', 'infer_img_size'],
title=f'Model Performance: {y_axis} vs {x_axis}'
)
return fig
def setup_interface():
"""Set up the Gradio interface."""
# Download benchmark files
downloaded_files = download_github_csvs_api()
# Get available benchmark files
benchmark_files = glob.glob("benchmarks/benchmark-*.csv")
if not benchmark_files:
benchmark_files = ["No benchmark files found"]
# Load sample data to get families and columns
df_results = load_main_data()
# Relevant columns for plotting
plot_columns = [
'top1', 'top5', 'infer_samples_per_sec',
'secs', 'param_count_x', 'infer_img_size'
]
# Get families from a sample file (if available)
families = []
if benchmark_files and benchmark_files[0] != "No benchmark files found":
sample_df = get_data(benchmark_files[0], df_results)
if not sample_df.empty:
families = sorted(sample_df['family'].unique().tolist())
return benchmark_files, plot_columns, families
# Initialize the interface
benchmark_files, plot_columns, families = setup_interface()
# Create the Gradio interface
with gr.Blocks(title="Image Model Performance Analysis") as demo:
gr.Markdown("# Image Model Performance Analysis")
gr.Markdown("Analyze and visualize performance metrics of different image models based on benchmark data.")
with gr.Row():
with gr.Column(scale=1):
benchmark_dropdown = gr.Dropdown(
choices=benchmark_files,
value=benchmark_files[0] if benchmark_files else None,
label="Select Benchmark File"
)
x_axis_radio = gr.Radio(
choices=plot_columns,
value="secs",
label="X-axis"
)
y_axis_radio = gr.Radio(
choices=plot_columns,
value="top1",
label="Y-axis"
)
family_checkboxes = gr.CheckboxGroup(
choices=families,
value=families,
label="Select Model Families"
)
log_x_checkbox = gr.Checkbox(
value=True,
label="Log scale X-axis"
)
log_y_checkbox = gr.Checkbox(
value=False,
label="Log scale Y-axis"
)
update_button = gr.Button("Update Plot", variant="primary")
with gr.Column(scale=2):
plot_output = gr.Plot()
# Update plot when button is clicked
update_button.click(
fn=create_plot,
inputs=[
benchmark_dropdown,
x_axis_radio,
y_axis_radio,
family_checkboxes,
log_x_checkbox,
log_y_checkbox
],
outputs=plot_output
)
# Auto-update when benchmark file changes
def update_families(benchmark_file):
if not benchmark_file or benchmark_file == "No benchmark files found":
return gr.CheckboxGroup(choices=[], value=[])
df_results = load_main_data()
df = get_data(benchmark_file, df_results)
if df.empty:
return gr.CheckboxGroup(choices=[], value=[])
new_families = sorted(df['family'].unique().tolist())
return gr.CheckboxGroup(choices=new_families, value=new_families)
benchmark_dropdown.change(
fn=update_families,
inputs=benchmark_dropdown,
outputs=family_checkboxes
)
# Load initial plot
demo.load(
fn=create_plot,
inputs=[
benchmark_dropdown,
x_axis_radio,
y_axis_radio,
family_checkboxes,
log_x_checkbox,
log_y_checkbox
],
outputs=plot_output
)
if __name__ == "__main__":
demo.launch()