lukasgarbas's picture
new build
2018b94
raw
history blame
8.44 kB
import gradio as gr
from datasets import disable_caching, load_dataset
from transformer_ranker import TransformerRanker
from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS, GRADIO_THEME
from demo.utils import (
BANNER, FOOTER, CSS, UNSET,
EmbeddingProgressTracker, compute_ratio,
validate_dataset, preprocess_dataset, ensure_dataset_is_loaded
)
disable_caching()
with gr.Blocks(css=CSS, theme=None) as demo:
gr.Markdown(BANNER)
##### 1. Load from datasets #####
gr.Markdown("## Load Downstream Dataset")
gr.Markdown(
"Select a dataset from the Hugging Face Hub such as `trec`. "
"This defines your downstream task."
)
with gr.Group():
dataset = gr.State(None)
dataset_id = gr.Textbox(
label="Dataset name",
placeholder="try: trec, conll2003, ag_news",
max_lines=1,
)
load_dataset_button = gr.Button(value="Load data", variant="primary", interactive=True,)
# enable loading if dataset exists on hub
dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button)
gr.Markdown(
"Settings auto-configured. "
"Adjust the downsampling ratio in Dataset Setup, "
"or use the complete dataset with the [framework](https://github.com/flairNLP/transformer-ranker)."
)
##### data preprocessing #####
with gr.Accordion("Dataset Setup", open=False) as dataset_config:
with gr.Row() as dataset_details:
dataset_id_label = gr.Label("", label="Dataset")
num_samples = gr.State(0)
num_samples_label = gr.Label("", label="Dataset size")
num_samples.change(
lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]
)
with gr.Row():
text_column = gr.Dropdown("", label="Text Column")
text_pair_column = gr.Dropdown("", label="Text Pair")
with gr.Row():
label_column = gr.Dropdown("", label="Labels")
task_category = gr.Dropdown("", label="Downstream Task")
with gr.Group():
downsample_ratio = gr.State(0.0)
sampling_rate = gr.Slider(
20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1
)
downsample_ratio_label = gr.Label("", label="Sampling rate")
downsample_ratio.change(
lambda x: f"{x:.1%}",
inputs=[downsample_ratio],
outputs=[downsample_ratio_label],
)
sampling_rate.change(
compute_ratio,
inputs=[sampling_rate, num_samples],
outputs=downsample_ratio,
)
num_samples.change(
compute_ratio,
inputs=[sampling_rate, num_samples],
outputs=downsample_ratio,
)
# load and show details
def load_hf_dataset(dataset_id):
try:
dataset = load_dataset(dataset_id, trust_remote_code=True)
dataset_details = preprocess_dataset(dataset)
except ValueError as e:
gr.Warning("Collections not supported. Load one dataset only.")
return (
gr.update(value="Loaded"),
dataset_id,
dataset,
*dataset_details
)
load_dataset_button.click(
load_hf_dataset,
inputs=[dataset_id],
outputs=[
load_dataset_button,
dataset_id_label,
dataset,
task_category,
text_column,
text_pair_column,
label_column,
num_samples,
],
scroll_to_output=True,
)
########## 2. Select LMs ##########
gr.Markdown("## Select Language Models")
gr.Markdown(
"Add two or more pretrained models for ranking. "
"Go with small models since this demo runs on CPU."
)
with gr.Group():
model_options = [
(model_handle.split("/")[-1], model_handle)
for model_handle in ALL_LMS
]
models = gr.CheckboxGroup(
choices=model_options, label="Model List", value=PRESELECTED_LMS
)
########## 3. Run ranking ##########
gr.Markdown("## Rank Language Models")
gr.Markdown(
"Rank models by transferability to your downstream task. "
"Adjust the metric and layer aggregation in Advanced Settings."
)
with gr.Group():
submit_button = gr.Button("Run ranking", variant="primary", interactive=False)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
estimator = gr.Dropdown(
choices=["hscore", "logme", "knn"],
label="Transferability metric",
value="hscore",
)
layer_aggregator = gr.Dropdown(
choices=["lastlayer", "layermean", "bestlayer"],
label="Layer aggregation",
value="layermean",
)
# ranking button works after dataset loads
dataset.change(
ensure_dataset_is_loaded,
inputs=[dataset, text_column, label_column, task_category],
outputs=submit_button
)
label_column.change(
ensure_dataset_is_loaded,
inputs=[dataset, text_column, label_column, task_category],
outputs=submit_button
)
text_column.change(
ensure_dataset_is_loaded,
inputs=[dataset, text_column, label_column, task_category],
outputs=submit_button
)
def rank_models(
dataset,
downsample_ratio,
selected_models,
layer_aggregator,
estimator,
text_column,
text_pair_column,
label_column,
task_category,
progress=gr.Progress(),
):
if text_column == UNSET:
raise gr.Error("Text column is not set.")
if label_column == UNSET:
raise gr.Error("Label column is not set.")
if task_category == UNSET:
raise gr.Error(
"Task category not set. Dataset must support classification or regression."
)
if text_pair_column == UNSET:
text_pair_column = None
progress(0.0, "Starting")
with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker:
try:
ranker = TransformerRanker(
dataset,
dataset_downsample=downsample_ratio,
text_column=text_column,
text_pair_column=text_pair_column,
label_column=label_column,
task_category=task_category,
)
results = ranker.run(
models=selected_models,
layer_aggregator=layer_aggregator,
estimator=estimator,
batch_size=64,
tracker=tracker,
)
sorted_results = sorted(
results._results.items(), key=lambda item: item[1], reverse=True
)
return [
(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)
]
except Exception as e:
print(e)
gr.Warning(f"Ranking issue: {e}")
return []
gr.Markdown("Ranking table → higher scores indicate better downstream performance.")
ranking_results = gr.Dataframe(
headers=["Rank", "Model", "Score"],
datatype=["number", "str", "number"],
value=[["-", "-", "-"]]
)
submit_button.click(
rank_models,
inputs=[
dataset,
downsample_ratio,
models,
layer_aggregator,
estimator,
text_column,
text_pair_column,
label_column,
task_category,
],
outputs=ranking_results,
scroll_to_output=True,
)
gr.Markdown(FOOTER)
if __name__ == "__main__":
# run up to 3 requests at once
demo.queue(default_concurrency_limit=3)
# run with 6 workers
demo.launch(max_threads=6)