import gradio as gr import pandas as pd import plotly.express as px import requests import re import os import glob # Download the main results file def download_main_results(): url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv" if not os.path.exists("results-imagenet.csv"): response = requests.get(url) with open("results-imagenet.csv", "wb") as f: f.write(response.content) def download_github_csvs_api( repo="huggingface/pytorch-image-models", folder="results", filename_pattern=r"benchmark-.*\.csv", output_dir="benchmarks", ): """Download benchmark CSV files from GitHub API.""" api_url = f"https://api.github.com/repos/{repo}/contents/{folder}" r = requests.get(api_url) if r.status_code != 200: return [] files = r.json() matched_files = [f["name"] for f in files if re.match(filename_pattern, f["name"])] if not matched_files: return [] raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/" os.makedirs(output_dir, exist_ok=True) for fname in matched_files: raw_url = raw_base + fname out_path = os.path.join(output_dir, fname) if not os.path.exists(out_path): # Only download if not exists resp = requests.get(raw_url) if resp.ok: with open(out_path, "wb") as f: f.write(resp.content) return matched_files def load_main_data(): """Load the main ImageNet results.""" download_main_results() df_results = pd.read_csv("results-imagenet.csv") df_results["model_org"] = df_results["model"] df_results["model"] = df_results["model"].str.split(".").str[0] return df_results def get_data(benchmark_file, df_results): """Process benchmark data and merge with main results.""" pattern = ( r"^(?:" r"eva|" r"maxx?vit(?:v2)?|" r"coatnet|coatnext|" r"convnext(?:v2)?|" r"beit(?:v2)?|" r"efficient(?:net(?:v2)?|former(?:v2)?|vit)|" r"regnet[xyvz]?|" r"levit|" r"mobilenet(?:v\d*)?|" r"vitd?|" r"swin(?:v2)?" r")$" ) if not os.path.exists(benchmark_file): return pd.DataFrame() df = pd.read_csv(benchmark_file).merge(df_results, on="model") df["secs"] = 1.0 / df["infer_samples_per_sec"] df["family"] = df.model.str.extract("^([a-z]+?(?:v2)?)(?:\d|_|$)") df = df[~df.model.str.endswith("gn")] df.loc[df.model.str.contains("resnet.*d"), "family"] = ( df.loc[df.model.str.contains("resnet.*d"), "family"] + "d" ) return df[df.family.str.contains(pattern)] def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y): """Create the scatter plot based on user selections.""" df_results = load_main_data() df = get_data(benchmark_file, df_results) if df.empty: return None # Filter by selected families if selected_families: df = df[df["family"].isin(selected_families)] if df.empty: return None # Create the plot fig = px.scatter( df, width=1000, height=800, x=x_axis, y=y_axis, size=df['infer_img_size']**2, log_x=log_x, log_y=log_y, color="family", hover_name="model_org", hover_data=["infer_samples_per_sec", "infer_img_size"], title=f"Model Performance: {y_axis} vs {x_axis}", ) return fig def setup_interface(): """Set up the Gradio interface.""" # Download benchmark files downloaded_files = download_github_csvs_api() # Get available benchmark files benchmark_files = glob.glob("benchmarks/benchmark-*.csv") if not benchmark_files: benchmark_files = ["No benchmark files found"] # Load sample data to get families and columns df_results = load_main_data() # Relevant columns for plotting plot_columns = [ "top1", "top5", "infer_samples_per_sec", "secs", "param_count_x", "infer_img_size", ] # Get families from a sample file (if available) families = [] if benchmark_files and benchmark_files[0] != "No benchmark files found": sample_df = get_data(benchmark_files[0], df_results) if not sample_df.empty: families = sorted(sample_df["family"].unique().tolist()) return benchmark_files, plot_columns, families # Initialize the interface benchmark_files, plot_columns, families = setup_interface() # Create the Gradio interface with gr.Blocks(title="Image Model Performance Analysis") as demo: gr.Markdown("# Image Model Performance Analysis") gr.Markdown( "Analyze and visualize performance metrics of different image models based on benchmark data." ) with gr.Row(): with gr.Column(scale=1): # Set preferred default file preferred_file = ( "benchmarks/benchmark-infer-amp-nhwc-pt240-cu124-rtx3090.csv" ) default_file = ( preferred_file if preferred_file in benchmark_files else (benchmark_files[0] if benchmark_files else None) ) benchmark_dropdown = gr.Dropdown( choices=benchmark_files, value=default_file, label="Select Benchmark File", ) x_axis_radio = gr.Radio(choices=plot_columns, value="secs", label="X-axis") y_axis_radio = gr.Radio(choices=plot_columns, value="top1", label="Y-axis") family_checkboxes = gr.CheckboxGroup( choices=families, value=families, label="Select Model Families" ) log_x_checkbox = gr.Checkbox(value=True, label="Log scale X-axis") log_y_checkbox = gr.Checkbox(value=False, label="Log scale Y-axis") update_button = gr.Button("Update Plot", variant="primary") with gr.Column(scale=2): plot_output = gr.Plot() gr.Markdown("The benchmark data comes from the [pytorch-image-models](https://github.com/huggingface/pytorch-image-models) repository by [Ross Wightman](https://huggingface.co/rwightman).") gr.Markdown("Based on the original notebook by [Jeremy Howard](https://huggingface.co/jph00).") gr.Markdown("Read more about the project on my blog [dronelab.dev](https://dronelab.dev/posts/which-image-models-are-best-updated/).") # Update plot when button is clicked update_button.click( fn=create_plot, inputs=[ benchmark_dropdown, x_axis_radio, y_axis_radio, family_checkboxes, log_x_checkbox, log_y_checkbox, ], outputs=plot_output, ) # Auto-update when benchmark file changes def update_families(benchmark_file): if not benchmark_file or benchmark_file == "No benchmark files found": return gr.CheckboxGroup(choices=[], value=[]) df_results = load_main_data() df = get_data(benchmark_file, df_results) if df.empty: return gr.CheckboxGroup(choices=[], value=[]) new_families = sorted(df["family"].unique().tolist()) return gr.CheckboxGroup(choices=new_families, value=new_families) benchmark_dropdown.change( fn=update_families, inputs=benchmark_dropdown, outputs=family_checkboxes ) # Load initial plot demo.load( fn=create_plot, inputs=[ benchmark_dropdown, x_axis_radio, y_axis_radio, family_checkboxes, log_x_checkbox, log_y_checkbox, ], outputs=plot_output, ) if __name__ == "__main__": demo.launch()