Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import plotly.express as px | |
import requests | |
import re | |
import os | |
import glob | |
# Download the main results file | |
def download_main_results(): | |
url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv" | |
if not os.path.exists('results-imagenet.csv'): | |
response = requests.get(url) | |
with open('results-imagenet.csv', 'wb') as f: | |
f.write(response.content) | |
def download_github_csvs_api( | |
repo="huggingface/pytorch-image-models", | |
folder="results", | |
filename_pattern=r"benchmark-.*\.csv", | |
output_dir="benchmarks" | |
): | |
"""Download benchmark CSV files from GitHub API.""" | |
api_url = f"https://api.github.com/repos/{repo}/contents/{folder}" | |
r = requests.get(api_url) | |
if r.status_code != 200: | |
return [] | |
files = r.json() | |
matched_files = [f['name'] for f in files if re.match(filename_pattern, f['name'])] | |
if not matched_files: | |
return [] | |
raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/" | |
os.makedirs(output_dir, exist_ok=True) | |
for fname in matched_files: | |
raw_url = raw_base + fname | |
out_path = os.path.join(output_dir, fname) | |
if not os.path.exists(out_path): # Only download if not exists | |
resp = requests.get(raw_url) | |
if resp.ok: | |
with open(out_path, "wb") as f: | |
f.write(resp.content) | |
return matched_files | |
def load_main_data(): | |
"""Load the main ImageNet results.""" | |
download_main_results() | |
df_results = pd.read_csv('results-imagenet.csv') | |
df_results['model_org'] = df_results['model'] | |
df_results['model'] = df_results['model'].str.split('.').str[0] | |
return df_results | |
def get_data(benchmark_file, df_results): | |
"""Process benchmark data and merge with main results.""" | |
pattern = ( | |
r'^(?:' | |
r'eva|' | |
r'maxx?vit(?:v2)?|' | |
r'coatnet|coatnext|' | |
r'convnext(?:v2)?|' | |
r'beit(?:v2)?|' | |
r'efficient(?:net(?:v2)?|former(?:v2)?|vit)|' | |
r'regnet[xyvz]?|' | |
r'levit|' | |
r'vitd?|' | |
r'swin(?:v2)?' | |
r')$' | |
) | |
if not os.path.exists(benchmark_file): | |
return pd.DataFrame() | |
df = pd.read_csv(benchmark_file).merge(df_results, on='model') | |
df['secs'] = 1. / df['infer_samples_per_sec'] | |
df['family'] = df.model.str.extract('^([a-z]+?(?:v2)?)(?:\d|_|$)') | |
df = df[~df.model.str.endswith('gn')] | |
df.loc[df.model.str.contains('resnet.*d'),'family'] = df.loc[df.model.str.contains('resnet.*d'),'family'] + 'd' | |
return df[df.family.str.contains(pattern)] | |
def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y): | |
"""Create the scatter plot based on user selections.""" | |
df_results = load_main_data() | |
df = get_data(benchmark_file, df_results) | |
if df.empty: | |
return None | |
# Filter by selected families | |
if selected_families: | |
df = df[df['family'].isin(selected_families)] | |
if df.empty: | |
return None | |
# Create the plot | |
fig = px.scatter( | |
df, | |
width=1000, | |
height=800, | |
x=x_axis, | |
y=y_axis, | |
log_x=log_x, | |
log_y=log_y, | |
color='family', | |
hover_name='model_org', | |
hover_data=['infer_samples_per_sec', 'infer_img_size'], | |
title=f'Model Performance: {y_axis} vs {x_axis}' | |
) | |
return fig | |
def setup_interface(): | |
"""Set up the Gradio interface.""" | |
# Download benchmark files | |
downloaded_files = download_github_csvs_api() | |
# Get available benchmark files | |
benchmark_files = glob.glob("benchmarks/benchmark-*.csv") | |
if not benchmark_files: | |
benchmark_files = ["No benchmark files found"] | |
# Load sample data to get families and columns | |
df_results = load_main_data() | |
# Relevant columns for plotting | |
plot_columns = [ | |
'top1', 'top5', 'infer_samples_per_sec', | |
'secs', 'param_count_x', 'infer_img_size' | |
] | |
# Get families from a sample file (if available) | |
families = [] | |
if benchmark_files and benchmark_files[0] != "No benchmark files found": | |
sample_df = get_data(benchmark_files[0], df_results) | |
if not sample_df.empty: | |
families = sorted(sample_df['family'].unique().tolist()) | |
return benchmark_files, plot_columns, families | |
# Initialize the interface | |
benchmark_files, plot_columns, families = setup_interface() | |
# Create the Gradio interface | |
with gr.Blocks(title="Image Model Performance Analysis") as demo: | |
gr.Markdown("# Image Model Performance Analysis") | |
gr.Markdown("Analyze and visualize performance metrics of different image models based on benchmark data.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
benchmark_dropdown = gr.Dropdown( | |
choices=benchmark_files, | |
value=benchmark_files[0] if benchmark_files else None, | |
label="Select Benchmark File" | |
) | |
x_axis_radio = gr.Radio( | |
choices=plot_columns, | |
value="secs", | |
label="X-axis" | |
) | |
y_axis_radio = gr.Radio( | |
choices=plot_columns, | |
value="top1", | |
label="Y-axis" | |
) | |
family_checkboxes = gr.CheckboxGroup( | |
choices=families, | |
value=families, | |
label="Select Model Families" | |
) | |
log_x_checkbox = gr.Checkbox( | |
value=True, | |
label="Log scale X-axis" | |
) | |
log_y_checkbox = gr.Checkbox( | |
value=False, | |
label="Log scale Y-axis" | |
) | |
update_button = gr.Button("Update Plot", variant="primary") | |
with gr.Column(scale=2): | |
plot_output = gr.Plot() | |
# Update plot when button is clicked | |
update_button.click( | |
fn=create_plot, | |
inputs=[ | |
benchmark_dropdown, | |
x_axis_radio, | |
y_axis_radio, | |
family_checkboxes, | |
log_x_checkbox, | |
log_y_checkbox | |
], | |
outputs=plot_output | |
) | |
# Auto-update when benchmark file changes | |
def update_families(benchmark_file): | |
if not benchmark_file or benchmark_file == "No benchmark files found": | |
return gr.CheckboxGroup(choices=[], value=[]) | |
df_results = load_main_data() | |
df = get_data(benchmark_file, df_results) | |
if df.empty: | |
return gr.CheckboxGroup(choices=[], value=[]) | |
new_families = sorted(df['family'].unique().tolist()) | |
return gr.CheckboxGroup(choices=new_families, value=new_families) | |
benchmark_dropdown.change( | |
fn=update_families, | |
inputs=benchmark_dropdown, | |
outputs=family_checkboxes | |
) | |
# Load initial plot | |
demo.load( | |
fn=create_plot, | |
inputs=[ | |
benchmark_dropdown, | |
x_axis_radio, | |
y_axis_radio, | |
family_checkboxes, | |
log_x_checkbox, | |
log_y_checkbox | |
], | |
outputs=plot_output | |
) | |
if __name__ == "__main__": | |
demo.launch() | |