import os from typing import Any import gradio as gr import pandas as pd try: from trackio.sqlite_storage import SQLiteStorage from trackio.utils import RESERVED_KEYS, TRACKIO_LOGO_PATH except: # noqa: E722 from sqlite_storage import SQLiteStorage from utils import RESERVED_KEYS, TRACKIO_LOGO_PATH css = """ #run-cb .wrap { gap: 2px; } #run-cb .wrap label { line-height: 1; padding: 6px; } """ COLOR_PALETTE = [ "#3B82F6", "#EF4444", "#10B981", "#F59E0B", "#8B5CF6", "#EC4899", "#06B6D4", "#84CC16", "#F97316", "#6366F1", ] def get_color_mapping(runs: list[str], smoothing: bool) -> dict[str, str]: """Generate color mapping for runs, with transparency for original data when smoothing is enabled.""" color_map = {} for i, run in enumerate(runs): base_color = COLOR_PALETTE[i % len(COLOR_PALETTE)] if smoothing: color_map[f"{run}_smoothed"] = base_color color_map[f"{run}_original"] = base_color + "4D" else: color_map[run] = base_color return color_map def get_projects(request: gr.Request): dataset_id = os.environ.get("TRACKIO_DATASET_ID") projects = SQLiteStorage.get_projects() if project := request.query_params.get("project"): interactive = False else: interactive = True project = projects[0] if projects else None return gr.Dropdown( label="Project", choices=projects, value=project, allow_custom_value=True, interactive=interactive, info=f"↻ Synced to {dataset_id} every 5 min" if dataset_id else None, ) def get_runs(project): if not project: return [] return SQLiteStorage.get_runs(project) def load_run_data(project: str | None, run: str | None, smoothing: bool): if not project or not run: return None metrics = SQLiteStorage.get_metrics(project, run) if not metrics: return None df = pd.DataFrame(metrics) if "step" not in df.columns: df["step"] = range(len(df)) if smoothing: numeric_cols = df.select_dtypes(include="number").columns numeric_cols = [c for c in numeric_cols if c not in RESERVED_KEYS] df_original = df.copy() df_original["run"] = f"{run}_original" df_original["data_type"] = "original" df_smoothed = df.copy() df_smoothed[numeric_cols] = df_smoothed[numeric_cols].ewm(alpha=0.1).mean() df_smoothed["run"] = f"{run}_smoothed" df_smoothed["data_type"] = "smoothed" combined_df = pd.concat([df_original, df_smoothed], ignore_index=True) return combined_df else: df["run"] = run df["data_type"] = "original" return df def update_runs(project, filter_text, user_interacted_with_runs=False): if project is None: runs = [] num_runs = 0 else: runs = get_runs(project) num_runs = len(runs) if filter_text: runs = [r for r in runs if filter_text in r] if not user_interacted_with_runs: return gr.CheckboxGroup( choices=runs, value=[runs[0]] if runs else [] ), gr.Textbox(label=f"Runs ({num_runs})") else: return gr.CheckboxGroup(choices=runs), gr.Textbox(label=f"Runs ({num_runs})") def filter_runs(project, filter_text): runs = get_runs(project) runs = [r for r in runs if filter_text in r] return gr.CheckboxGroup(choices=runs, value=runs) def toggle_timer(cb_value): if cb_value: return gr.Timer(active=True) else: return gr.Timer(active=False) def log(project: str, run: str, metrics: dict[str, Any], dataset_id: str) -> None: # Note: the type hint for dataset_id should be str | None but gr.api # doesn't support that, see: https://github.com/gradio-app/gradio/issues/11175#issuecomment-2920203317 storage = SQLiteStorage(project, run, {}, dataset_id=dataset_id) storage.log(metrics) def sort_metrics_by_prefix(metrics: list[str]) -> list[str]: """ Sort metrics by grouping prefixes together. Metrics without prefixes come first, then grouped by prefix. Example: Input: ["train/loss", "loss", "train/acc", "val/loss"] Output: ["loss", "train/acc", "train/loss", "val/loss"] """ no_prefix = [] with_prefix = [] for metric in metrics: if "/" in metric: with_prefix.append(metric) else: no_prefix.append(metric) no_prefix.sort() prefix_groups = {} for metric in with_prefix: prefix = metric.split("/")[0] if prefix not in prefix_groups: prefix_groups[prefix] = [] prefix_groups[prefix].append(metric) sorted_with_prefix = [] for prefix in sorted(prefix_groups.keys()): sorted_with_prefix.extend(sorted(prefix_groups[prefix])) return no_prefix + sorted_with_prefix def configure(request: gr.Request): if metrics := request.query_params.get("metrics"): return metrics.split(",") else: return [] with gr.Blocks(theme="citrus", title="Trackio Dashboard", css=css) as demo: with gr.Sidebar() as sidebar: gr.Markdown( f"