lukasgarbas's picture
new build
2018b94
raw
history blame
6.87 kB
import math
import gradio as gr
from datasets import concatenate_datasets
from huggingface_hub import HfApi
from huggingface_hub.errors import HFValidationError
from requests.exceptions import HTTPError
from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
from transformer_ranker.embedder import Embedder
BANNER = """
<h1 align="center">🔥 TransformerRanker 🔥</h1>
<p align="center" style="max-width: 560px; margin: auto;">
Find the best language model for your downstream task.
Load a dataset, select models from the 🤗 Hub, and rank them by <strong>transferability</strong>.
</p>
<p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
<a href="https://github.com/flairNLP/transformer-ranker">
<img src="https://img.shields.io/badge/Code Repo-black?style=flat&logo=github" alt="repository">
</a>
<a href="https://opensource.org/licenses/MIT">
<img src="https://img.shields.io/badge/License-MIT-brightgreen?style=flat" alt="license">
</a>
<a href="https://pypi.org/project/transformer-ranker/">
<img src="https://img.shields.io/badge/Package-orange?style=flat&logo=python" alt="package">
</a>
<a href="https://github.com/flairNLP/transformer-ranker/blob/main/docs/01-walkthrough.md">
<img src="https://img.shields.io/badge/Tutorials-blue?style=flat&logo=readthedocs&logoColor=white" alt="tutorials">
</a>
</p>
<p align="center">Developed at <a href="https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/">Humboldt University of Berlin</a>.</p>
"""
FOOTER = """
**Note:** CPU-only quick demo. **Built by:** @lukasgarbas & @plonerma
**Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫.
"""
CSS = """
.gradio-container {
max-width: 800px;
margin: auto;
}
"""
UNSET = "-"
hf_api = HfApi()
preprocessing = DatasetCleaner()
def validate_dataset(dataset_name):
"""Enable if dataset exists on Hub."""
try:
hf_api.dataset_info(dataset_name) # quick dataset info call
return gr.update(interactive=True)
except (HTTPError, HFValidationError):
return gr.update(value="Load data", interactive=False)
def preprocess_dataset(dataset):
"""Use data preprocessing to find text/label columns and task category."""
data = concatenate_datasets(list(dataset.values()))
try:
text_column = preprocessing._find_column(data, "text column")
except ValueError:
gr.Warning("Text column not auto-detected — select in settings.")
text_column = UNSET
try:
label_column = preprocessing._find_column(data, "label column")
except ValueError:
gr.Warning("Label column not auto-detected — select in settings.")
label_column = UNSET
task_category = UNSET
if label_column != UNSET:
try:
task_category = preprocessing._find_task_category(data, label_column)
except ValueError:
gr.Warning("Task category not auto-detected — framework supports classification, regression.")
text_column = gr.update(value=text_column, choices=data.column_names, interactive=True)
label_column = gr.update(value=label_column, choices=data.column_names, interactive=True)
text_pair = gr.update(value=UNSET, choices=[UNSET, *data.column_names], interactive=True)
task_category = gr.update(value=task_category, choices=[str(t) for t in TaskCategory], interactive=True)
sample_size = len(data)
return task_category, text_column, text_pair, label_column, sample_size
"""
return (
text_column,
gr.update(
value=task_category,
choices=[str(t) for t in TaskCategory],
interactive=True,
),
gr.update(
value=text_column, choices=data.column_names, interactive=True
),
gr.update(
value=UNSET, choices=[UNSET, *data.column_names], interactive=True
),
gr.update(
value=label_column, choices=data.column_names, interactive=True
),
num_samples,
)
"""
def compute_ratio(num_samples_to_use, num_samples):
if num_samples > 0:
return num_samples_to_use / num_samples
else:
return 0.0
def ensure_dataset_is_loaded(dataset, text_column, label_column, task_category):
if dataset and text_column != UNSET and label_column != UNSET and task_category != UNSET:
return gr.update(interactive=True)
else:
return gr.update(interactive=False)
def ensure_one_lm_selected(checkbox_values, previous_values):
if not any(checkbox_values):
return previous_values
return checkbox_values
# apply monkey patch to enable callbacks
_old_embed = Embedder.embed
def _new_embed(embedder, sentences, batch_size: int = 32, **kw):
if embedder.tracker is not None:
embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size))
return _old_embed(embedder, sentences, batch_size=batch_size, **kw)
Embedder.embed = _new_embed
_old_embed_batch = Embedder.embed_batch
def _new_embed_batch(embedder, *args, **kw):
r = _old_embed_batch(embedder, *args, **kw)
if embedder.tracker is not None:
embedder.tracker.update_batch_complete()
return r
Embedder.embed_batch = _new_embed_batch
_old_init = Embedder.__init__
def _new_init(embedder, *args, tracker=None, **kw):
_old_init(embedder, *args, **kw)
embedder.tracker = tracker
Embedder.__init__ = _new_init
class EmbeddingProgressTracker:
def __init__(self, *, progress, model_names):
self.model_names = model_names
self.progress_bar = progress
@property
def total(self):
return len(self.model_names)
def __enter__(self):
self.progress_bar = gr.Progress(track_tqdm=False)
self.current_model = -1
self.batches_complete = 0
self.batches_total = None
return self
def __exit__(self, typ, value, tb):
if typ is None:
self.progress_bar(1.0, desc="Done")
else:
self.progress_bar(1.0, desc="Error")
# Do not suppress any errors
return False
def update_num_batches(self, total):
self.current_model += 1
self.batches_complete = 0
self.batches_total = total
self.update_bar()
def update_batch_complete(self):
self.batches_complete += 1
self.update_bar()
def update_bar(self):
i = self.current_model
description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
progress = i / self.total
if self.batches_total is not None:
progress += (self.batches_complete / self.batches_total) / self.total
self.progress_bar(progress=progress, desc=description)