import gradio as gr import pandas as pd import plotly.express as px import requests import re import os import glob # Download the main results file def download_main_results(): url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv" if not os.path.exists('results-imagenet.csv'): response = requests.get(url) with open('results-imagenet.csv', 'wb') as f: f.write(response.content) def download_github_csvs_api( repo="huggingface/pytorch-image-models", folder="results", filename_pattern=r"benchmark-.*\.csv", output_dir="benchmarks" ): """Download benchmark CSV files from GitHub API.""" api_url = f"https://api.github.com/repos/{repo}/contents/{folder}" r = requests.get(api_url) if r.status_code != 200: return [] files = r.json() matched_files = [f['name'] for f in files if re.match(filename_pattern, f['name'])] if not matched_files: return [] raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/" os.makedirs(output_dir, exist_ok=True) for fname in matched_files: raw_url = raw_base + fname out_path = os.path.join(output_dir, fname) if not os.path.exists(out_path): # Only download if not exists resp = requests.get(raw_url) if resp.ok: with open(out_path, "wb") as f: f.write(resp.content) return matched_files def load_main_data(): """Load the main ImageNet results.""" download_main_results() df_results = pd.read_csv('results-imagenet.csv') df_results['model_org'] = df_results['model'] df_results['model'] = df_results['model'].str.split('.').str[0] return df_results def get_data(benchmark_file, df_results): """Process benchmark data and merge with main results.""" pattern = ( r'^(?:' r'eva|' r'maxx?vit(?:v2)?|' r'coatnet|coatnext|' r'convnext(?:v2)?|' r'beit(?:v2)?|' r'efficient(?:net(?:v2)?|former(?:v2)?|vit)|' r'regnet[xyvz]?|' r'levit|' r'vitd?|' r'swin(?:v2)?' r')$' ) if not os.path.exists(benchmark_file): return pd.DataFrame() df = pd.read_csv(benchmark_file).merge(df_results, on='model') df['secs'] = 1. / df['infer_samples_per_sec'] df['family'] = df.model.str.extract('^([a-z]+?(?:v2)?)(?:\d|_|$)') df = df[~df.model.str.endswith('gn')] df.loc[df.model.str.contains('resnet.*d'),'family'] = df.loc[df.model.str.contains('resnet.*d'),'family'] + 'd' return df[df.family.str.contains(pattern)] def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y): """Create the scatter plot based on user selections.""" df_results = load_main_data() df = get_data(benchmark_file, df_results) if df.empty: return None # Filter by selected families if selected_families: df = df[df['family'].isin(selected_families)] if df.empty: return None # Create the plot fig = px.scatter( df, width=1000, height=800, x=x_axis, y=y_axis, log_x=log_x, log_y=log_y, color='family', hover_name='model_org', hover_data=['infer_samples_per_sec', 'infer_img_size'], title=f'Model Performance: {y_axis} vs {x_axis}' ) return fig def setup_interface(): """Set up the Gradio interface.""" # Download benchmark files downloaded_files = download_github_csvs_api() # Get available benchmark files benchmark_files = glob.glob("benchmarks/benchmark-*.csv") if not benchmark_files: benchmark_files = ["No benchmark files found"] # Load sample data to get families and columns df_results = load_main_data() # Relevant columns for plotting plot_columns = [ 'top1', 'top5', 'infer_samples_per_sec', 'secs', 'param_count_x', 'infer_img_size' ] # Get families from a sample file (if available) families = [] if benchmark_files and benchmark_files[0] != "No benchmark files found": sample_df = get_data(benchmark_files[0], df_results) if not sample_df.empty: families = sorted(sample_df['family'].unique().tolist()) return benchmark_files, plot_columns, families # Initialize the interface benchmark_files, plot_columns, families = setup_interface() # Create the Gradio interface with gr.Blocks(title="Image Model Performance Analysis") as demo: gr.Markdown("# Image Model Performance Analysis") gr.Markdown("Analyze and visualize performance metrics of different image models based on benchmark data.") with gr.Row(): with gr.Column(scale=1): benchmark_dropdown = gr.Dropdown( choices=benchmark_files, value=benchmark_files[0] if benchmark_files else None, label="Select Benchmark File" ) x_axis_radio = gr.Radio( choices=plot_columns, value="secs", label="X-axis" ) y_axis_radio = gr.Radio( choices=plot_columns, value="top1", label="Y-axis" ) family_checkboxes = gr.CheckboxGroup( choices=families, value=families, label="Select Model Families" ) log_x_checkbox = gr.Checkbox( value=True, label="Log scale X-axis" ) log_y_checkbox = gr.Checkbox( value=False, label="Log scale Y-axis" ) update_button = gr.Button("Update Plot", variant="primary") with gr.Column(scale=2): plot_output = gr.Plot() # Update plot when button is clicked update_button.click( fn=create_plot, inputs=[ benchmark_dropdown, x_axis_radio, y_axis_radio, family_checkboxes, log_x_checkbox, log_y_checkbox ], outputs=plot_output ) # Auto-update when benchmark file changes def update_families(benchmark_file): if not benchmark_file or benchmark_file == "No benchmark files found": return gr.CheckboxGroup(choices=[], value=[]) df_results = load_main_data() df = get_data(benchmark_file, df_results) if df.empty: return gr.CheckboxGroup(choices=[], value=[]) new_families = sorted(df['family'].unique().tolist()) return gr.CheckboxGroup(choices=new_families, value=new_families) benchmark_dropdown.change( fn=update_families, inputs=benchmark_dropdown, outputs=family_checkboxes ) # Load initial plot demo.load( fn=create_plot, inputs=[ benchmark_dropdown, x_axis_radio, y_axis_radio, family_checkboxes, log_x_checkbox, log_y_checkbox ], outputs=plot_output ) if __name__ == "__main__": demo.launch()