Spaces:
Running
Running
import gradio as gr | |
from datasets import disable_caching, load_dataset | |
from transformer_ranker import TransformerRanker | |
from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS | |
from demo.utils import ( | |
BANNER, | |
FOOTER, | |
CSS, | |
UNSET, | |
EmbeddingProgressTracker, | |
compute_ratio, | |
validate_dataset, | |
preprocess_dataset, | |
ensure_dataset_is_loaded, | |
) | |
disable_caching() | |
with gr.Blocks(css=CSS, theme=None) as demo: | |
gr.Markdown(BANNER, elem_classes="banner") | |
##### 1. Load from datasets ##### | |
gr.Markdown("## 📚 Load Data") | |
gr.Markdown( | |
"Pick a dataset from the Hugging Face Hub (e.g. `trec`). This defines your downstream task." | |
) | |
with gr.Group(): | |
dataset = gr.State(None) | |
dataset_id = gr.Textbox( | |
label="Dataset identifier", | |
placeholder="try: trec, conll2003, ag_news", | |
max_lines=1, | |
) | |
load_dataset_button = gr.Button( | |
value="Load data", | |
variant="primary", | |
interactive=True, | |
) | |
# enable loading if dataset exists on hub | |
dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button) | |
gr.Markdown( | |
"⚡️ Speed mode on: tweak the downsampling ratio in *Dataset Setup* for quicker runs. " | |
"Unlock the full data via [framework](https://github.com/flairNLP/transformer-ranker)." | |
) | |
##### data preprocessing ##### | |
with gr.Accordion("Dataset Setup", open=False) as dataset_config: | |
with gr.Row() as dataset_details: | |
dataset_id_label = gr.Label("", label="Dataset") | |
num_samples = gr.State(0) | |
num_samples_label = gr.Label("", label="Dataset size") | |
num_samples.change(lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]) | |
with gr.Row(): | |
text_column = gr.Dropdown("", label="Text Column") | |
text_pair_column = gr.Dropdown("", label="Text Pair") | |
with gr.Row(): | |
label_column = gr.Dropdown("", label="Labels") | |
task_category = gr.Dropdown("", label="Downstream Task") | |
with gr.Group(): | |
downsample_ratio = gr.State(0.0) | |
sampling_rate = gr.Slider(20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1) | |
downsample_ratio_label = gr.Label("", label="Sampling rate") | |
downsample_ratio.change( | |
lambda x: f"{x:.1%}", | |
inputs=[downsample_ratio], | |
outputs=[downsample_ratio_label], | |
) | |
sampling_rate.change( | |
compute_ratio, | |
inputs=[sampling_rate, num_samples], | |
outputs=downsample_ratio, | |
) | |
num_samples.change( | |
compute_ratio, | |
inputs=[sampling_rate, num_samples], | |
outputs=downsample_ratio, | |
) | |
def load_hf_dataset(dataset_id): | |
try: | |
dataset = load_dataset(dataset_id, trust_remote_code=True) | |
dataset_details = preprocess_dataset(dataset) | |
except ValueError as e: | |
gr.Warning(f"Watch out — single datasets only. Cannot load dataset: {e}") | |
return (gr.update(value="Loaded"), dataset_id, dataset, *dataset_details) | |
load_dataset_button.click( | |
load_hf_dataset, | |
inputs=[dataset_id], | |
outputs=[ | |
load_dataset_button, | |
dataset_id_label, | |
dataset, | |
task_category, | |
text_column, | |
text_pair_column, | |
label_column, | |
num_samples, | |
], | |
scroll_to_output=True, | |
) | |
########## 2. Select LMs ########## | |
gr.Markdown("## 🧠 Select Language Models") | |
gr.Markdown( | |
"Add two or more pretrained models to compare. " | |
"Stick to smaller models here since the demo runs on CPU." | |
) | |
with gr.Group(): | |
model_options = [(model_handle.split("/")[-1], model_handle) for model_handle in ALL_LMS] | |
models = gr.CheckboxGroup(choices=model_options, label="Model List", value=PRESELECTED_LMS) | |
########## 3. Run ranking ########## | |
gr.Markdown("## 🏆 Rank Models") | |
gr.Markdown( | |
"Rank models by transferability to your task. " | |
"More control? Tweak transferability metric and layer aggregation in *Settings*." | |
) | |
with gr.Group(): | |
submit_button = gr.Button("Run ranking", variant="primary", interactive=False) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Row(): | |
estimator = gr.Dropdown( | |
choices=["hscore", "logme", "knn"], | |
label="Transferability metric", | |
value="hscore", | |
) | |
layer_aggregator = gr.Dropdown( | |
choices=["lastlayer", "layermean", "bestlayer"], | |
label="Layer aggregation", | |
value="layermean", | |
) | |
# ranking button works after dataset loads | |
dataset.change( | |
ensure_dataset_is_loaded, | |
inputs=[dataset, text_column, label_column, task_category], | |
outputs=submit_button | |
) | |
label_column.change( | |
ensure_dataset_is_loaded, | |
inputs=[dataset, text_column, label_column, task_category], | |
outputs=submit_button | |
) | |
text_column.change( | |
ensure_dataset_is_loaded, | |
inputs=[dataset, text_column, label_column, task_category], | |
outputs=submit_button | |
) | |
def rank_models( | |
dataset, | |
downsample_ratio, | |
selected_models, | |
layer_aggregator, | |
estimator, | |
text_column, | |
text_pair_column, | |
label_column, | |
task_category, | |
progress=gr.Progress(), | |
): | |
if text_column == UNSET: | |
raise gr.Error("Text column is required.") | |
if label_column == UNSET: | |
raise gr.Error("Label column is required.") | |
if task_category == UNSET: | |
raise gr.Error("Task category is required.") | |
if text_pair_column == UNSET: | |
text_pair_column = None | |
progress(0.0, "Starting") | |
with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker: | |
try: | |
ranker = TransformerRanker( | |
dataset, | |
dataset_downsample=downsample_ratio, | |
text_column=text_column, | |
text_pair_column=text_pair_column, | |
label_column=label_column, | |
task_category=task_category, | |
) | |
results = ranker.run( | |
models=selected_models, | |
layer_aggregator=layer_aggregator, | |
estimator=estimator, | |
batch_size=64, | |
tracker=tracker, | |
) | |
sorted_results = sorted(results._results.items(), key=lambda item: item[1], reverse=True) | |
return [(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)] | |
except Exception as e: | |
gr.Warning(f"Ranking issue: {e}") | |
return [] | |
gr.Markdown("**Leaderboard:** higher score → better downstream performance.") | |
ranking_results = gr.Dataframe( | |
headers=["Rank", "Model", "Score"], | |
datatype=["number", "str", "number"], | |
value=[["-", "-", "-"]], | |
interactive=False | |
) | |
submit_button.click( | |
rank_models, | |
inputs=[ | |
dataset, | |
downsample_ratio, | |
models, | |
layer_aggregator, | |
estimator, | |
text_column, | |
text_pair_column, | |
label_column, | |
task_category, | |
], | |
outputs=ranking_results, | |
scroll_to_output=True, | |
) | |
gr.Markdown(FOOTER) | |