lukasgarbas's picture
use markdown for demo banner
9867f8a
import gradio as gr
from datasets import disable_caching, load_dataset
from transformer_ranker import TransformerRanker
from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS
from demo.utils import (
BANNER,
FOOTER,
CSS,
UNSET,
EmbeddingProgressTracker,
compute_ratio,
validate_dataset,
preprocess_dataset,
ensure_dataset_is_loaded,
)
disable_caching()
with gr.Blocks(css=CSS, theme=None) as demo:
gr.Markdown(BANNER, elem_classes="banner")
##### 1. Load from datasets #####
gr.Markdown("## 📚 Load Data")
gr.Markdown(
"Pick a dataset from the Hugging Face Hub (e.g. `trec`). This defines your downstream task."
)
with gr.Group():
dataset = gr.State(None)
dataset_id = gr.Textbox(
label="Dataset identifier",
placeholder="try: trec, conll2003, ag_news",
max_lines=1,
)
load_dataset_button = gr.Button(
value="Load data",
variant="primary",
interactive=True,
)
# enable loading if dataset exists on hub
dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button)
gr.Markdown(
"⚡️ Speed mode on: tweak the downsampling ratio in *Dataset Setup* for quicker runs. "
"Unlock the full data via [framework](https://github.com/flairNLP/transformer-ranker)."
)
##### data preprocessing #####
with gr.Accordion("Dataset Setup", open=False) as dataset_config:
with gr.Row() as dataset_details:
dataset_id_label = gr.Label("", label="Dataset")
num_samples = gr.State(0)
num_samples_label = gr.Label("", label="Dataset size")
num_samples.change(lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label])
with gr.Row():
text_column = gr.Dropdown("", label="Text Column")
text_pair_column = gr.Dropdown("", label="Text Pair")
with gr.Row():
label_column = gr.Dropdown("", label="Labels")
task_category = gr.Dropdown("", label="Downstream Task")
with gr.Group():
downsample_ratio = gr.State(0.0)
sampling_rate = gr.Slider(20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1)
downsample_ratio_label = gr.Label("", label="Sampling rate")
downsample_ratio.change(
lambda x: f"{x:.1%}",
inputs=[downsample_ratio],
outputs=[downsample_ratio_label],
)
sampling_rate.change(
compute_ratio,
inputs=[sampling_rate, num_samples],
outputs=downsample_ratio,
)
num_samples.change(
compute_ratio,
inputs=[sampling_rate, num_samples],
outputs=downsample_ratio,
)
def load_hf_dataset(dataset_id):
try:
dataset = load_dataset(dataset_id, trust_remote_code=True)
dataset_details = preprocess_dataset(dataset)
except ValueError as e:
gr.Warning(f"Watch out — single datasets only. Cannot load dataset: {e}")
return (gr.update(value="Loaded"), dataset_id, dataset, *dataset_details)
load_dataset_button.click(
load_hf_dataset,
inputs=[dataset_id],
outputs=[
load_dataset_button,
dataset_id_label,
dataset,
task_category,
text_column,
text_pair_column,
label_column,
num_samples,
],
scroll_to_output=True,
)
########## 2. Select LMs ##########
gr.Markdown("## 🧠 Select Language Models")
gr.Markdown(
"Add two or more pretrained models to compare. "
"Stick to smaller models here since the demo runs on CPU."
)
with gr.Group():
model_options = [(model_handle.split("/")[-1], model_handle) for model_handle in ALL_LMS]
models = gr.CheckboxGroup(choices=model_options, label="Model List", value=PRESELECTED_LMS)
########## 3. Run ranking ##########
gr.Markdown("## 🏆 Rank Models")
gr.Markdown(
"Rank models by transferability to your task. "
"More control? Tweak transferability metric and layer aggregation in *Settings*."
)
with gr.Group():
submit_button = gr.Button("Run ranking", variant="primary", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
estimator = gr.Dropdown(
choices=["hscore", "logme", "knn"],
label="Transferability metric",
value="hscore",
)
layer_aggregator = gr.Dropdown(
choices=["lastlayer", "layermean", "bestlayer"],
label="Layer aggregation",
value="layermean",
)
# ranking button works after dataset loads
dataset.change(
ensure_dataset_is_loaded,
inputs=[dataset, text_column, label_column, task_category],
outputs=submit_button
)
label_column.change(
ensure_dataset_is_loaded,
inputs=[dataset, text_column, label_column, task_category],
outputs=submit_button
)
text_column.change(
ensure_dataset_is_loaded,
inputs=[dataset, text_column, label_column, task_category],
outputs=submit_button
)
def rank_models(
dataset,
downsample_ratio,
selected_models,
layer_aggregator,
estimator,
text_column,
text_pair_column,
label_column,
task_category,
progress=gr.Progress(),
):
if text_column == UNSET:
raise gr.Error("Text column is required.")
if label_column == UNSET:
raise gr.Error("Label column is required.")
if task_category == UNSET:
raise gr.Error("Task category is required.")
if text_pair_column == UNSET:
text_pair_column = None
progress(0.0, "Starting")
with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker:
try:
ranker = TransformerRanker(
dataset,
dataset_downsample=downsample_ratio,
text_column=text_column,
text_pair_column=text_pair_column,
label_column=label_column,
task_category=task_category,
)
results = ranker.run(
models=selected_models,
layer_aggregator=layer_aggregator,
estimator=estimator,
batch_size=64,
tracker=tracker,
)
sorted_results = sorted(results._results.items(), key=lambda item: item[1], reverse=True)
return [(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)]
except Exception as e:
gr.Warning(f"Ranking issue: {e}")
return []
gr.Markdown("**Leaderboard:** higher score → better downstream performance.")
ranking_results = gr.Dataframe(
headers=["Rank", "Model", "Score"],
datatype=["number", "str", "number"],
value=[["-", "-", "-"]],
interactive=False
)
submit_button.click(
rank_models,
inputs=[
dataset,
downsample_ratio,
models,
layer_aggregator,
estimator,
text_column,
text_pair_column,
label_column,
task_category,
],
outputs=ranking_results,
scroll_to_output=True,
)
gr.Markdown(FOOTER)