pors's picture
added credits in the app itself
6f6e81e
raw
history blame
7.65 kB
import gradio as gr
import pandas as pd
import plotly.express as px
import requests
import re
import os
import glob
# Download the main results file
def download_main_results():
url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv"
if not os.path.exists("results-imagenet.csv"):
response = requests.get(url)
with open("results-imagenet.csv", "wb") as f:
f.write(response.content)
def download_github_csvs_api(
repo="huggingface/pytorch-image-models",
folder="results",
filename_pattern=r"benchmark-.*\.csv",
output_dir="benchmarks",
):
"""Download benchmark CSV files from GitHub API."""
api_url = f"https://api.github.com/repos/{repo}/contents/{folder}"
r = requests.get(api_url)
if r.status_code != 200:
return []
files = r.json()
matched_files = [f["name"] for f in files if re.match(filename_pattern, f["name"])]
if not matched_files:
return []
raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/"
os.makedirs(output_dir, exist_ok=True)
for fname in matched_files:
raw_url = raw_base + fname
out_path = os.path.join(output_dir, fname)
if not os.path.exists(out_path): # Only download if not exists
resp = requests.get(raw_url)
if resp.ok:
with open(out_path, "wb") as f:
f.write(resp.content)
return matched_files
def load_main_data():
"""Load the main ImageNet results."""
download_main_results()
df_results = pd.read_csv("results-imagenet.csv")
df_results["model_org"] = df_results["model"]
df_results["model"] = df_results["model"].str.split(".").str[0]
return df_results
def get_data(benchmark_file, df_results):
"""Process benchmark data and merge with main results."""
pattern = (
r"^(?:"
r"eva|"
r"maxx?vit(?:v2)?|"
r"coatnet|coatnext|"
r"convnext(?:v2)?|"
r"beit(?:v2)?|"
r"efficient(?:net(?:v2)?|former(?:v2)?|vit)|"
r"regnet[xyvz]?|"
r"levit|"
r"vitd?|"
r"swin(?:v2)?"
r")$"
)
if not os.path.exists(benchmark_file):
return pd.DataFrame()
df = pd.read_csv(benchmark_file).merge(df_results, on="model")
df["secs"] = 1.0 / df["infer_samples_per_sec"]
df["family"] = df.model.str.extract("^([a-z]+?(?:v2)?)(?:\d|_|$)")
df = df[~df.model.str.endswith("gn")]
df.loc[df.model.str.contains("resnet.*d"), "family"] = (
df.loc[df.model.str.contains("resnet.*d"), "family"] + "d"
)
return df[df.family.str.contains(pattern)]
def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y):
"""Create the scatter plot based on user selections."""
df_results = load_main_data()
df = get_data(benchmark_file, df_results)
if df.empty:
return None
# Filter by selected families
if selected_families:
df = df[df["family"].isin(selected_families)]
if df.empty:
return None
# Create the plot
fig = px.scatter(
df,
width=1000,
height=800,
x=x_axis,
y=y_axis,
log_x=log_x,
log_y=log_y,
color="family",
hover_name="model_org",
hover_data=["infer_samples_per_sec", "infer_img_size"],
title=f"Model Performance: {y_axis} vs {x_axis}",
)
return fig
def setup_interface():
"""Set up the Gradio interface."""
# Download benchmark files
downloaded_files = download_github_csvs_api()
# Get available benchmark files
benchmark_files = glob.glob("benchmarks/benchmark-*.csv")
if not benchmark_files:
benchmark_files = ["No benchmark files found"]
# Load sample data to get families and columns
df_results = load_main_data()
# Relevant columns for plotting
plot_columns = [
"top1",
"top5",
"infer_samples_per_sec",
"secs",
"param_count_x",
"infer_img_size",
]
# Get families from a sample file (if available)
families = []
if benchmark_files and benchmark_files[0] != "No benchmark files found":
sample_df = get_data(benchmark_files[0], df_results)
if not sample_df.empty:
families = sorted(sample_df["family"].unique().tolist())
return benchmark_files, plot_columns, families
# Initialize the interface
benchmark_files, plot_columns, families = setup_interface()
# Create the Gradio interface
with gr.Blocks(title="Image Model Performance Analysis") as demo:
gr.Markdown("# Image Model Performance Analysis")
gr.Markdown(
"Analyze and visualize performance metrics of different image models based on benchmark data."
)
gr.Markdown("The benchmark data comes from the [pytorch-image-models](https://github.com/huggingface/pytorch-image-models) repository by [Ross Wightman](https://huggingface.co/rwightman).")
gr.Markdown("Based on the original notebook by [Jeremy Howard](https://huggingface.co/jph00).")
with gr.Row():
with gr.Column(scale=1):
# Set preferred default file
preferred_file = (
"benchmarks/benchmark-infer-amp-nhwc-pt240-cu124-rtx3090.csv"
)
default_file = (
preferred_file
if preferred_file in benchmark_files
else (benchmark_files[0] if benchmark_files else None)
)
benchmark_dropdown = gr.Dropdown(
choices=benchmark_files,
value=default_file,
label="Select Benchmark File",
)
x_axis_radio = gr.Radio(choices=plot_columns, value="secs", label="X-axis")
y_axis_radio = gr.Radio(choices=plot_columns, value="top1", label="Y-axis")
family_checkboxes = gr.CheckboxGroup(
choices=families, value=families, label="Select Model Families"
)
log_x_checkbox = gr.Checkbox(value=True, label="Log scale X-axis")
log_y_checkbox = gr.Checkbox(value=False, label="Log scale Y-axis")
update_button = gr.Button("Update Plot", variant="primary")
with gr.Column(scale=2):
plot_output = gr.Plot()
# Update plot when button is clicked
update_button.click(
fn=create_plot,
inputs=[
benchmark_dropdown,
x_axis_radio,
y_axis_radio,
family_checkboxes,
log_x_checkbox,
log_y_checkbox,
],
outputs=plot_output,
)
# Auto-update when benchmark file changes
def update_families(benchmark_file):
if not benchmark_file or benchmark_file == "No benchmark files found":
return gr.CheckboxGroup(choices=[], value=[])
df_results = load_main_data()
df = get_data(benchmark_file, df_results)
if df.empty:
return gr.CheckboxGroup(choices=[], value=[])
new_families = sorted(df["family"].unique().tolist())
return gr.CheckboxGroup(choices=new_families, value=new_families)
benchmark_dropdown.change(
fn=update_families, inputs=benchmark_dropdown, outputs=family_checkboxes
)
# Load initial plot
demo.load(
fn=create_plot,
inputs=[
benchmark_dropdown,
x_axis_radio,
y_axis_radio,
family_checkboxes,
log_x_checkbox,
log_y_checkbox,
],
outputs=plot_output,
)
if __name__ == "__main__":
demo.launch()