import os from typing import Any import gradio as gr import pandas as pd try: from trackio.sqlite_storage import SQLiteStorage from trackio.utils import RESERVED_KEYS, TRACKIO_LOGO_PATH except: # noqa: E722 from sqlite_storage import SQLiteStorage from utils import RESERVED_KEYS, TRACKIO_LOGO_PATH css = """ #run-cb .wrap { gap: 2px; } #run-cb .wrap label { line-height: 1; padding: 6px; } """ COLOR_PALETTE = [ "#3B82F6", "#EF4444", "#10B981", "#F59E0B", "#8B5CF6", "#EC4899", "#06B6D4", "#84CC16", "#F97316", "#6366F1", ] def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]: """Generate color mapping for runs, with transparency for original data when smoothing is enabled.""" color_map = {} for i, run in enumerate(runs): base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)] if smoothing: color_map[f"{run}_smoothed"] = base_color color_map[f"{run}_original"] = base_color + "4D" else: color_map[run] = base_color return color_map def get_projects(request: gr.Request): dataset_id = os.environ.get("TRACKIO_DATASET_ID") projects = SQLiteStorage.get_projects() if project := request.query_params.get("project"): interactive = False else: interactive = True project = projects[0] if projects else None return gr.Dropdown( label="Project", choices=projects, value=project, allow_custom_value=True, interactive=interactive, info=f"↻ Synced to {dataset_id} every 5 min" if dataset_id else None, ) def get_runs(project): if not project: return [] return SQLiteStorage.get_runs(project) def load_run_data(project: str | None, run: str | None, smoothing: bool): if not project or not run: return None metrics = SQLiteStorage.get_metrics(project, run) if not metrics: return None df = pd.DataFrame(metrics) if "step" not in df.columns: df["step"] = range(len(df)) if smoothing: numeric_cols = df.select_dtypes(include="number").columns numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS] df_original = df.copy() df_original["run"] = f"{run}_original" df_original["data_type"] = "original" df_smoothed = df.copy() df_smoothed[numeric_cols] = df_smoothed[numeric_cols].ewm(alpha=0.1).mean() df_smoothed["run"] = f"{run}_smoothed" df_smoothed["data_type"] = "smoothed" combined_df = pd.concat([df_original, df_smoothed], ignore_index=True) return combined_df else: df["run"] = run df["data_type"] = "original" return df def update_runs(project, filter_text, user_interacted_with_runs=False): if project is None: runs = [] num_runs = 0 else: runs = get_runs(project) num_runs = len(runs) if filter_text: runs = [r for r in runs if filter_text in r] if not user_interacted_with_runs: return gr.CheckboxGroup( choices=runs, value=[runs[0]] if runs else [] ), gr.Textbox(label=f"Runs ({num_runs})") else: return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})") def filter_runs(project, filter_text): runs = get_runs(project) runs = [r for r in runs if filter_text in r] return gr.CheckboxGroup(choices=runs, value=runs) def toggle_timer(cb_value): if cb_value: return gr.Timer(active=True) else: return gr.Timer(active=False) def log(project: str, run: str, metrics: dict[str, Any], dataset_id: str) -> None: # Note: the type hint for dataset_id should be str | None but gr.api # doesn't support that, see: https://github.com/gradio-app/gradio/issues/11175#issuecomment-2920203317 storage = SQLiteStorage(project, run, {}, dataset_id=dataset_id) storage.log(metrics) def sort_metrics_by_prefix(metrics: list[str]) -> list[str]: """ Sort metrics by grouping prefixes together. Metrics without prefixes come first, then grouped by prefix. Example: Input: ["train/loss", "loss", "train/acc", "val/loss"] Output: ["loss", "train/acc", "train/loss", "val/loss"] """ no_prefix = [] with_prefix = [] for metric in metrics: if "/" in metric: with_prefix.append(metric) else: no_prefix.append(metric) no_prefix.sort() prefix_groups = {} for metric in with_prefix: prefix = metric.split("/")[0] if prefix not in prefix_groups: prefix_groups[prefix] = [] prefix_groups[prefix].append(metric) sorted_with_prefix = [] for prefix in sorted(prefix_groups.keys()): sorted_with_prefix.extend(sorted(prefix_groups[prefix])) return no_prefix + sorted_with_prefix def configure(request: gr.Request): if metrics := request.query_params.get("metrics"): return metrics.split(",") else: return [] with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo: with gr.Sidebar() as sidebar: gr.Markdown( f"
Trackio
" ) project_dd = gr.Dropdown(label="Project") run_tb = gr.Textbox(label="Runs", placeholder="Type to filter...") run_cb = gr.CheckboxGroup( label="Runs", choices=[], interactive=True, elem_id="run-cb" ) with gr.Sidebar(position="right", open=False) as settings_sidebar: gr.Markdown("### ⚙️ Settings") realtime_cb = gr.Checkbox(label="Refresh realtime", value=True) smoothing_cb = gr.Checkbox(label="Smoothing", value=True) timer = gr.Timer(value=1) metrics_subset = gr.State([]) user_interacted_with_run_cb = gr.State(False) gr.on( [demo.load], fn=configure, outputs=metrics_subset, ) gr.on( [demo.load], fn=get_projects, outputs=project_dd, show_progress="hidden", ) gr.on( [timer.tick], fn=update_runs, inputs=[project_dd, run_tb, user_interacted_with_run_cb], outputs=[run_cb, run_tb], show_progress="hidden", ) gr.on( [demo.load, project_dd.change], fn=update_runs, inputs=[project_dd, run_tb], outputs=[run_cb, run_tb], show_progress="hidden", ) realtime_cb.change( fn=toggle_timer, inputs=realtime_cb, outputs=timer, api_name="toggle_timer", ) run_cb.input( fn=lambda: True, outputs=user_interacted_with_run_cb, ) run_tb.input( fn=filter_runs, inputs=[project_dd, run_tb], outputs=run_cb, ) gr.api( fn=log, api_name="log", ) x_lim = gr.State(None) def update_x_lim(select_data: gr.SelectData): return select_data.index @gr.render( triggers=[ demo.load, run_cb.change, timer.tick, smoothing_cb.change, x_lim.change, ], inputs=[project_dd, run_cb, smoothing_cb, metrics_subset, x_lim], ) def update_dashboard(project, runs, smoothing, metrics_subset, x_lim_value): dfs = [] original_runs = runs.copy() for run in runs: df = load_run_data(project, run, smoothing) if df is not None: dfs.append(df) if dfs: master_df = pd.concat(dfs, ignore_index=True) else: master_df = pd.DataFrame() if master_df.empty: return numeric_cols = master_df.select_dtypes(include="number").columns numeric_cols = [ c for c in numeric_cols if c not in RESERVED_KEYS and c != "step" ] if metrics_subset: numeric_cols = [c for c in numeric_cols if c in metrics_subset] numeric_cols = sort_metrics_by_prefix(list(numeric_cols)) color_map = get_color_mapping(original_runs, smoothing) with gr.Row(key="row"): for metric_idx, metric_name in enumerate(numeric_cols): metric_df = master_df.dropna(subset=[metric_name]) if not metric_df.empty: plot = gr.LinePlot( metric_df, x="step", y=metric_name, color="run" if "run" in metric_df.columns else None, color_map=color_map, title=metric_name, key=f"plot-{metric_idx}", preserved_by_key=None, x_lim=x_lim_value, y_lim=[ metric_df[metric_name].min(), metric_df[metric_name].max(), ], show_fullscreen_button=True, min_width=400, ) plot.select(update_x_lim, outputs=x_lim, key=f"select-{metric_idx}") plot.double_click( lambda: None, outputs=x_lim, key=f"double-{metric_idx}" ) if __name__ == "__main__": demo.launch(allowed_paths=[TRACKIO_LOGO_PATH], show_api=False)