Mars-Board / app /leaderboard.py
gremlin97's picture
gremlin97: hf setup
ed881c4
"""Leaderboard visualization functions."""
import pandas as pd
import plotly.express as px
from typing import Dict, List, Tuple
from plotly.graph_objects import Figure
from app.data import TASK_DATA
def create_leaderboard_table(data: Dict[str, List], selected_models: List[str] = None) -> pd.DataFrame:
"""Create a formatted DataFrame for the leaderboard."""
df = pd.DataFrame(data)
if selected_models:
df = df[df['Model'].isin(selected_models)]
# Format numeric columns to 1 decimal place
numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns
df[numeric_cols] = df[numeric_cols].apply(lambda x: x.apply(lambda y: f"{y:.1f}"))
return df
def create_performance_plot(data: Dict[str, List], task: str, metric: str, selected_models: List[str] = None) -> Figure:
"""Create a bar plot showing model performance."""
df = pd.DataFrame(data)
if selected_models:
df = df[df['Model'].isin(selected_models)]
fig = px.bar(
df,
x="Model",
y=metric,
color="Dataset",
title=f"{task} - {metric}",
barmode="group",
)
fig.update_layout(
xaxis_title="Model",
yaxis_title=metric,
showlegend=True,
)
return fig
def get_best_models(data: Dict[str, List], metrics: List[str]) -> pd.DataFrame:
"""Get the best performing models for each metric across datasets."""
df = pd.DataFrame(data)
best_models = []
for metric in metrics:
top_models = df.groupby('Model')[metric].mean().sort_values(ascending=False).head(3)
best_models.extend([
{
'Metric': metric,
'Rank': rank,
'Model': model,
'Average Score': f"{score:.1f}"
}
for rank, (model, score) in enumerate(top_models.items(), 1)
])
return pd.DataFrame(best_models)
def update_leaderboard(task: str, selected_models: List[str] = None) -> Tuple[pd.DataFrame, Figure, Figure, pd.DataFrame]:
"""Update the leaderboard based on selected task and models."""
data, metrics = TASK_DATA[task]
return (
create_leaderboard_table(data, selected_models),
*[create_performance_plot(data, task, metric, selected_models) for metric in metrics],
get_best_models(data, metrics)
)