import gradio as gr import pandas as pd import re import os import json import yaml import matplotlib.pyplot as plt import seaborn as sns import plotnine as p9 import sys sys.path.append('./src') sys.path.append('.') from huggingface_hub import HfApi repo_id = "HUBioDataLab/PROBE" api = HfApi() from src.about import * from src.saving_utils import * from src.vis_utils import * from src.bin.PROBE import run_probe # ------------------------------------------------------------------ # Helper functions -------------------------------------------------- # ------------------------------------------------------------------ def add_new_eval( human_file, skempi_file, model_name_textbox: str, benchmark_types, similarity_tasks, function_prediction_aspect, function_prediction_dataset, family_prediction_dataset, save, ): """Validate inputs, run evaluation and (optionally) save results.""" # map the user‐facing labels back to the original codes try: benchmark_types_mapped = [benchmark_type_map[b] for b in benchmark_types] similarity_tasks_mapped = [similarity_tasks_map[s] for s in similarity_tasks] function_prediction_aspect_mapped = function_prediction_aspect_map[function_prediction_aspect] family_prediction_dataset_mapped = [family_prediction_dataset_map[f] for f in family_prediction_dataset] except KeyError as e: gr.Warning(f"Unrecognized option: {e.args[0]}") return -1 # validate inputs if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None: gr.Warning("Human representations are required for similarity, family, or function benchmarks!") return -1 if 'affinity' in benchmark_types and skempi_file is None: gr.Warning("SKEMPI representations are required for affinity benchmark!") return -1 gr.Info("Your submission is being processed…") representation_name = model_name_textbox try: results = run_probe( benchmark_types, representation_name, human_file, skempi_file, similarity_tasks, function_prediction_aspect, function_prediction_dataset, family_prediction_dataset, ) except Exception: gr.Warning("Your submission has not been processed. Please check your representation files!") return -1 if save: save_results(representation_name, benchmark_types, results) gr.Info("Your submission has been processed and results are saved!") else: gr.Info("Your submission has been processed!") return 0 def refresh_data(): """Re‑start the space and pull fresh leaderboard CSVs from the HF Hub.""" api.restart_space(repo_id=repo_id) benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"] for benchmark_type in benchmark_types: path = f"/tmp/{benchmark_type}_results.csv" if os.path.exists(path): os.remove(path) benchmark_types.remove("leaderboard") download_from_hub(benchmark_types) # ------- Leaderboard helpers ----------------------------------------------- def update_metrics(selected_benchmarks): updated_metrics = set() for benchmark in selected_benchmarks: updated_metrics.update(benchmark_metric_mapping.get(benchmark, [])) return list(updated_metrics) def update_leaderboard(selected_methods, selected_metrics): return build_leaderboard_styler(selected_methods, selected_metrics) def colour_method_html(name: str) -> str: """Return the method string wrapped in a coloured . Handles raw names or markdown links like '[T5](https://…)' transparently.""" colour = color_dict.get(re.sub(r"\[|\]|\(.*?\)", "", name), "black") # strip md link return f"{name}" # darkest → lightest green TOP5_GREENS = ["#006400", "#228B22", "#32CD32", "#7CFC00", "#ADFF2F"] def shade_top5(col: pd.Series) -> list[str]: """Return a CSS list for one column: background for ranks 1-5, blank else.""" if not pd.api.types.is_numeric_dtype(col): return [""] * len(col) ranks = col.rank(ascending=False, method="first") return [ f"background-color:{TOP5_GREENS[int(r)-1]};" if r <= 5 else "" for r in ranks ] def build_leaderboard_styler(selected_methods=None, selected_metrics=None): df = get_baseline_df(selected_methods, selected_metrics).round(4) df = ( df.sort_values("Method", key=lambda s: s.str.lower()) # A->Z .reset_index(drop=True) # tidy row index ) df["Method"] = df["Method"].apply(colour_method_html) numeric_cols = [c for c in df.columns if c != "Method"] styler = ( df.style .apply(shade_top5, axis=0, subset=numeric_cols) .format(precision=4) ) return styler # ------- Visualisation helpers --------------------------------------------- def generate_plot(benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric): plot_path = benchmark_plot( benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric, ) return plot_path # --------------------------------------------------------------------------- # Custom CSS for frozen first column and clearer table styles # --------------------------------------------------------------------------- CUSTOM_CSS = """ /* freeze first column */ #leaderboard-table table tr th:first-child, #leaderboard-table table tr td:first-child { position: sticky; left: 0; z-index: 2; /* wider “Method” column */ min-width: 190px; width: 190px; white-space: nowrap; } /* centre numeric cells */ #leaderboard-table td:not(:first-child) { text-align: center; } /* scrollable and taller table */ #leaderboard-table .dataframe-wrap { max-height: 1200px; overflow-y: auto; overflow-x: auto; } """ # --------------------------------------------------------------------------- # UI definition # --------------------------------------------------------------------------- block = gr.Blocks(css=CUSTOM_CSS) with block: gr.Markdown(LEADERBOARD_INTRODUCTION) with gr.Tabs(elem_classes="tab-buttons") as tabs: # ------------------------------------------------------------------ # 1️⃣ Leaderboard tab # ------------------------------------------------------------------ with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1): # ── header ──────────────────────────────────────────────────── gr.Image( value="./src/data/PROBE_workflow_figure.jpg", show_label=False, height=1000, container=False, ) gr.Markdown( "## For detailed explanations of the metrics and benchmarks, please refer to the 📝 About tab.", elem_classes="leaderboard-note", ) # ── data prep ──────────────────────────────────────────────── leaderboard = get_baseline_df(None, None) method_names = leaderboard["Method"].unique().tolist() metric_names = leaderboard.columns.tolist(); metric_names.remove("Method") base_method_names = [m for m in method_names if m in base_methods] user_method_names = [m for m in method_names if m not in base_methods] benchmark_metric_mapping = { "Semantic Similarity Inference": [m for m in metric_names if m.startswith("sim_")], "Ontology-based Protein Function Prediction": [m for m in metric_names if m.startswith("func")], "Drug Target Protein Family Classification": [m for m in metric_names if m.startswith("fam_")], "Protein-Protein Binding Affinity Estimation": [m for m in metric_names if m.startswith("aff_")], } # ── callback helper ────────────────────────────────────────── def update_leaderboard_combined(selected_base, selected_user, selected_metrics): selected_methods = (selected_base or []) + (selected_user or []) return build_leaderboard_styler(selected_methods, selected_metrics) # ── collapsible selectors ──────────────────────────────────── with gr.Accordion("📦 Base Methods", open=False): leaderboard_method_selector_base = gr.CheckboxGroup( choices=base_method_names, label="Base Methods", value=base_method_names, # ← all selected interactive=True, ) with gr.Accordion("🛠️ User-defined Methods", open=False): leaderboard_method_selector_user = gr.CheckboxGroup( choices=user_method_names, label="User Methods", value=[], # ← none selected interactive=True, ) with gr.Accordion("🧪 Benchmark Types", open=False): benchmark_type_selector_lb = gr.CheckboxGroup( choices=list(benchmark_metric_mapping.keys()), label="Benchmark Types", value=list(benchmark_metric_mapping.keys()), # all selected interactive=True, ) with gr.Accordion("📐 Metrics", open=False): leaderboard_metric_selector = gr.CheckboxGroup( choices=metric_names, label="Select Metrics", value=metric_names, # ← all selected interactive=True, ) # ── colour / shading legend (unchanged) ────────────────────── with gr.Row(): with gr.Column(scale=1): gr.Markdown( """ ## Method-name colours 🟢  Classical representations 🔵  Small-scale Protein LMs 🔴  Large-scale Protein LMs 🟠  Multimodal Protein LMs """, elem_classes="leaderboard-note", ) with gr.Column(scale=1): gr.Markdown( """ ## Metric-cell shading 1 2 3 4 5
top-five scores (darker → better) """, elem_classes="leaderboard-note", ) # ── dataframe ──────────────────────────────────────────────── styler = build_leaderboard_styler(base_method_names, metric_names) data_component = gr.Dataframe( value=styler, headers=["Method"] + metric_names, type="pandas", datatype=["markdown"] + ["number"] * len(metric_names), interactive=False, elem_id="leaderboard-table", pinned_columns=1, max_height=1000, show_fullscreen_button=True, ) gr.Markdown("#### If a method name ends with **^**, it suggests potential suspicions of data leakage related to ***similarity***, ***function***, or ***family*** benchmarks.") # ── callbacks ──────────────────────────────────────────────── leaderboard_method_selector_base.change( update_leaderboard_combined, inputs=[leaderboard_method_selector_base, leaderboard_method_selector_user, leaderboard_metric_selector], outputs=data_component, ) leaderboard_method_selector_user.change( update_leaderboard_combined, inputs=[leaderboard_method_selector_base, leaderboard_method_selector_user, leaderboard_metric_selector], outputs=data_component, ) leaderboard_metric_selector.change( update_leaderboard_combined, inputs=[leaderboard_method_selector_base, leaderboard_method_selector_user, leaderboard_metric_selector], outputs=data_component, ) benchmark_type_selector_lb.change( lambda selected: update_metrics(selected), inputs=[benchmark_type_selector_lb], outputs=leaderboard_metric_selector, ) # ------------------------------------------------------------------ # 2️⃣ Visualisation tab # ------------------------------------------------------------------ with gr.TabItem("📊 Visualization", elem_id="probe-benchmark-tab-visualization", id=2): gr.Markdown( """## **Interactive Visualizations** Choose a benchmark type; context-specific options will appear.""", elem_classes="markdown-text", ) # ── benchmark-type selector ────────────────────────────────── vis_benchmark_type_selector = gr.Dropdown( choices=list(benchmark_specific_metrics.keys()), label="🧪 Benchmark Type", value=None, ) # ── metric / dataset selectors (appear contextually) ───────── with gr.Row(): vis_x_metric_selector = gr.Dropdown(choices=[], label="X-axis Metric", visible=False) vis_y_metric_selector = gr.Dropdown(choices=[], label="Y-axis Metric", visible=False) vis_aspect_type_selector = gr.Dropdown(choices=[], label="Aspect", visible=False) vis_dataset_selector = gr.Dropdown(choices=[], label="Dataset", visible=False) vis_single_metric_selector = gr.Dropdown(choices=[], label="Metric", visible=False) # ── method selectors (two accordions) ─────────────────────── base_method_names = [m for m in method_names if m in base_methods] user_method_names = [m for m in method_names if m not in base_methods] with gr.Accordion("📦 Base methods", open=False): vis_method_selector_base = gr.CheckboxGroup( choices=base_method_names, label="Base Methods", value=base_method_names, # default: all selected interactive=True, ) with gr.Accordion("🛠️ User-defined methods", open=False): vis_method_selector_user = gr.CheckboxGroup( choices=user_method_names, label="User Methods", value=[], # default: none selected interactive=True, ) # ── plot button & output ──────────────────────────────────── plot_button = gr.Button("Plot") with gr.Row(show_progress=True, variant='panel'): plot_output = gr.Image(label="Plot") gr.Markdown("#### If a method name ends with **^**, it suggests potential suspicions of data leakage related to ***similarity***, ***function***, or ***family*** benchmarks.") # ── callbacks ─────────────────────────────────────────────── vis_benchmark_type_selector.change( update_metric_choices, inputs=[vis_benchmark_type_selector], outputs=[ vis_x_metric_selector, vis_y_metric_selector, vis_aspect_type_selector, vis_dataset_selector, vis_single_metric_selector, ], ) # combine the two method lists, then call the original helper plot_button.click( lambda bt, base_sel, user_sel, xm, ym, asp, ds, sm: generate_plot( benchmark_type_map.get(bt, bt), (base_sel or []) + (user_sel or []), # merged method list xm, ym, asp, ds, sm, ), inputs=[ vis_benchmark_type_selector, vis_method_selector_base, vis_method_selector_user, vis_x_metric_selector, vis_y_metric_selector, vis_aspect_type_selector, vis_dataset_selector, vis_single_metric_selector, ], outputs=[plot_output], ) # ------------------------------------------------------------------ # 3️⃣ About tab # ------------------------------------------------------------------ with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=3): with gr.Row(): gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text") with gr.Row(): gr.Image( value="./src/data/PROBE_workflow_figure.jpg", label="PROBE Workflow Figure", elem_classes="about-image", ) # ------------------------------------------------------------------ # 4️⃣ Submit tab # ------------------------------------------------------------------ with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=4): with gr.Row(): gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text") with gr.Row(): gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text") with gr.Row(): with gr.Column(): model_name_textbox = gr.Textbox(label="Method name") benchmark_types = gr.CheckboxGroup(choices=TASK_INFO, label="Benchmark Types", interactive=True) similarity_tasks = gr.CheckboxGroup(choices=similarity_tasks_options, label="Semantic Similarity Inference Datasets", interactive=True) function_prediction_aspect = gr.Radio(choices=function_prediction_aspect_options, label="Ontology-based Function Prediction Aspects", interactive=True) family_prediction_dataset = gr.CheckboxGroup(choices=family_prediction_dataset_options, label="Drug Target Protein Family Classification Datasets", interactive=True) function_dataset = gr.Textbox(label="Function Prediction Datasets", visible=False, value="All_Data_Sets") save_checkbox = gr.Checkbox(label="Save results for leaderboard and visualization", value=True) with gr.Row(): human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath') skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath') submit_button = gr.Button("Submit Eval") submission_result = gr.Markdown() submit_button.click( add_new_eval, inputs=[ human_file, skempi_file, model_name_textbox, benchmark_types, similarity_tasks, function_prediction_aspect, function_dataset, family_prediction_dataset, save_checkbox, ], ) # global refresh + citation --------------------------------------------- with gr.Row(): data_run = gr.Button("Refresh") data_run.click(refresh_data, outputs=[data_component]) with gr.Accordion("Citation", open=False): citation_button = gr.Textbox( value=CITATION_BUTTON_TEXT, label=CITATION_BUTTON_LABEL, elem_id="citation-button", show_copy_button=True, ) # --------------------------------------------------------------------------- block.launch()