pors's picture
added mobilenet models
e81fcbb
import gradio as gr
import pandas as pd
import plotly.express as px
import requests
import re
import os
import glob
# Download the main results file
def download_main_results():
url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv"
if not os.path.exists("results-imagenet.csv"):
response = requests.get(url)
with open("results-imagenet.csv", "wb") as f:
f.write(response.content)
def download_github_csvs_api(
repo="huggingface/pytorch-image-models",
folder="results",
filename_pattern=r"benchmark-.*\.csv",
output_dir="benchmarks",
):
"""Download benchmark CSV files from GitHub API."""
api_url = f"https://api.github.com/repos/{repo}/contents/{folder}"
r = requests.get(api_url)
if r.status_code != 200:
return []
files = r.json()
matched_files = [f["name"] for f in files if re.match(filename_pattern, f["name"])]
if not matched_files:
return []
raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/"
os.makedirs(output_dir, exist_ok=True)
for fname in matched_files:
raw_url = raw_base + fname
out_path = os.path.join(output_dir, fname)
if not os.path.exists(out_path): # Only download if not exists
resp = requests.get(raw_url)
if resp.ok:
with open(out_path, "wb") as f:
f.write(resp.content)
return matched_files
def load_main_data():
"""Load the main ImageNet results."""
download_main_results()
df_results = pd.read_csv("results-imagenet.csv")
df_results["model_org"] = df_results["model"]
df_results["model"] = df_results["model"].str.split(".").str[0]
return df_results
def get_data(benchmark_file, df_results):
"""Process benchmark data and merge with main results."""
pattern = (
r"^(?:"
r"eva|"
r"maxx?vit(?:v2)?|"
r"coatnet|coatnext|"
r"convnext(?:v2)?|"
r"beit(?:v2)?|"
r"efficient(?:net(?:v2)?|former(?:v2)?|vit)|"
r"regnet[xyvz]?|"
r"levit|"
r"mobilenet(?:v\d*)?|"
r"vitd?|"
r"swin(?:v2)?"
r")$"
)
if not os.path.exists(benchmark_file):
return pd.DataFrame()
df = pd.read_csv(benchmark_file).merge(df_results, on="model")
df["secs"] = 1.0 / df["infer_samples_per_sec"]
df["family"] = df.model.str.extract("^([a-z]+?(?:v2)?)(?:\d|_|$)")
df = df[~df.model.str.endswith("gn")]
df.loc[df.model.str.contains("resnet.*d"), "family"] = (
df.loc[df.model.str.contains("resnet.*d"), "family"] + "d"
)
return df[df.family.str.contains(pattern)]
def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y):
"""Create the scatter plot based on user selections."""
df_results = load_main_data()
df = get_data(benchmark_file, df_results)
if df.empty:
return None
# Filter by selected families
if selected_families:
df = df[df["family"].isin(selected_families)]
if df.empty:
return None
# Create the plot
fig = px.scatter(
df,
width=1000,
height=800,
x=x_axis,
y=y_axis,
size=df['infer_img_size']**2,
log_x=log_x,
log_y=log_y,
color="family",
hover_name="model_org",
hover_data=["infer_samples_per_sec", "infer_img_size"],
title=f"Model Performance: {y_axis} vs {x_axis}",
)
return fig
def setup_interface():
"""Set up the Gradio interface."""
# Download benchmark files
downloaded_files = download_github_csvs_api()
# Get available benchmark files
benchmark_files = glob.glob("benchmarks/benchmark-*.csv")
if not benchmark_files:
benchmark_files = ["No benchmark files found"]
# Load sample data to get families and columns
df_results = load_main_data()
# Relevant columns for plotting
plot_columns = [
"top1",
"top5",
"infer_samples_per_sec",
"secs",
"param_count_x",
"infer_img_size",
]
# Get families from a sample file (if available)
families = []
if benchmark_files and benchmark_files[0] != "No benchmark files found":
sample_df = get_data(benchmark_files[0], df_results)
if not sample_df.empty:
families = sorted(sample_df["family"].unique().tolist())
return benchmark_files, plot_columns, families
# Initialize the interface
benchmark_files, plot_columns, families = setup_interface()
# Create the Gradio interface
with gr.Blocks(title="Image Model Performance Analysis") as demo:
gr.Markdown("# Image Model Performance Analysis")
gr.Markdown(
"Analyze and visualize performance metrics of different image models based on benchmark data."
)
with gr.Row():
with gr.Column(scale=1):
# Set preferred default file
preferred_file = (
"benchmarks/benchmark-infer-amp-nhwc-pt240-cu124-rtx3090.csv"
)
default_file = (
preferred_file
if preferred_file in benchmark_files
else (benchmark_files[0] if benchmark_files else None)
)
benchmark_dropdown = gr.Dropdown(
choices=benchmark_files,
value=default_file,
label="Select Benchmark File",
)
x_axis_radio = gr.Radio(choices=plot_columns, value="secs", label="X-axis")
y_axis_radio = gr.Radio(choices=plot_columns, value="top1", label="Y-axis")
family_checkboxes = gr.CheckboxGroup(
choices=families, value=families, label="Select Model Families"
)
log_x_checkbox = gr.Checkbox(value=True, label="Log scale X-axis")
log_y_checkbox = gr.Checkbox(value=False, label="Log scale Y-axis")
update_button = gr.Button("Update Plot", variant="primary")
with gr.Column(scale=2):
plot_output = gr.Plot()
gr.Markdown("The benchmark data comes from the [pytorch-image-models](https://github.com/huggingface/pytorch-image-models) repository by [Ross Wightman](https://huggingface.co/rwightman).")
gr.Markdown("Based on the original notebook by [Jeremy Howard](https://huggingface.co/jph00).")
gr.Markdown("Read more about the project on my blog [dronelab.dev](https://dronelab.dev/posts/which-image-models-are-best-updated/).")
# Update plot when button is clicked
update_button.click(
fn=create_plot,
inputs=[
benchmark_dropdown,
x_axis_radio,
y_axis_radio,
family_checkboxes,
log_x_checkbox,
log_y_checkbox,
],
outputs=plot_output,
)
# Auto-update when benchmark file changes
def update_families(benchmark_file):
if not benchmark_file or benchmark_file == "No benchmark files found":
return gr.CheckboxGroup(choices=[], value=[])
df_results = load_main_data()
df = get_data(benchmark_file, df_results)
if df.empty:
return gr.CheckboxGroup(choices=[], value=[])
new_families = sorted(df["family"].unique().tolist())
return gr.CheckboxGroup(choices=new_families, value=new_families)
benchmark_dropdown.change(
fn=update_families, inputs=benchmark_dropdown, outputs=family_checkboxes
)
# Load initial plot
demo.load(
fn=create_plot,
inputs=[
benchmark_dropdown,
x_axis_radio,
y_axis_radio,
family_checkboxes,
log_x_checkbox,
log_y_checkbox,
],
outputs=plot_output,
)
if __name__ == "__main__":
demo.launch()