File size: 6,871 Bytes
2018b94
 
73d9a01
 
 
 
 
 
 
 
2018b94
 
73d9a01
 
2018b94
 
73d9a01
2018b94
73d9a01
 
2018b94
 
 
 
73d9a01
 
2018b94
73d9a01
2018b94
 
73d9a01
 
2018b94
73d9a01
 
 
 
2018b94
 
73d9a01
 
 
2018b94
 
 
 
73d9a01
 
2018b94
73d9a01
 
2018b94
73d9a01
 
2018b94
 
73d9a01
2018b94
 
73d9a01
 
2018b94
73d9a01
 
2018b94
 
 
73d9a01
 
2018b94
73d9a01
2018b94
 
73d9a01
 
2018b94
73d9a01
2018b94
 
73d9a01
2018b94
 
73d9a01
2018b94
73d9a01
2018b94
73d9a01
2018b94
 
 
 
 
73d9a01
2018b94
 
 
73d9a01
2018b94
73d9a01
 
 
 
 
 
2018b94
73d9a01
 
2018b94
73d9a01
 
2018b94
73d9a01
 
 
2018b94
73d9a01
 
 
 
 
 
 
 
2018b94
 
 
 
 
 
 
73d9a01
 
 
 
 
 
2018b94
73d9a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import math

import gradio as gr
from datasets import concatenate_datasets
from huggingface_hub import HfApi
from huggingface_hub.errors import HFValidationError
from requests.exceptions import HTTPError
from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
from transformer_ranker.embedder import Embedder

BANNER = """
<h1 align="center">🔥 TransformerRanker 🔥</h1>

<p align="center" style="max-width: 560px; margin: auto;">
    Find the best language model for your downstream task.
    Load a dataset, select models from the 🤗 Hub, and rank them by <strong>transferability</strong>.
</p>

<p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
    <a href="https://github.com/flairNLP/transformer-ranker">
        <img src="https://img.shields.io/badge/Code Repo-black?style=flat&logo=github" alt="repository">
    </a>
    <a href="https://opensource.org/licenses/MIT">
        <img src="https://img.shields.io/badge/License-MIT-brightgreen?style=flat" alt="license">
    </a>
    <a href="https://pypi.org/project/transformer-ranker/">
        <img src="https://img.shields.io/badge/Package-orange?style=flat&logo=python" alt="package">
    </a>
    <a href="https://github.com/flairNLP/transformer-ranker/blob/main/docs/01-walkthrough.md">
        <img src="https://img.shields.io/badge/Tutorials-blue?style=flat&logo=readthedocs&logoColor=white" alt="tutorials">
    </a>
</p>

<p align="center">Developed at <a href="https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/">Humboldt University of Berlin</a>.</p>
"""

FOOTER = """
**Note:** CPU-only quick demo. **Built by:** @lukasgarbas & @plonerma
**Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫.
"""

CSS = """
.gradio-container {
    max-width: 800px;
    margin: auto;
}
"""

UNSET = "-"

hf_api = HfApi()
preprocessing = DatasetCleaner()


def validate_dataset(dataset_name):
    """Enable if dataset exists on Hub."""
    try:
        hf_api.dataset_info(dataset_name)  # quick dataset info call
        return gr.update(interactive=True)

    except (HTTPError, HFValidationError):
        return gr.update(value="Load data", interactive=False)


def preprocess_dataset(dataset):
    """Use data preprocessing to find text/label columns and task category."""
    data = concatenate_datasets(list(dataset.values()))

    try:
        text_column = preprocessing._find_column(data, "text column")
    except ValueError:
        gr.Warning("Text column not auto-detected — select in settings.")
        text_column = UNSET

    try:
        label_column = preprocessing._find_column(data, "label column")
    except ValueError:
        gr.Warning("Label column not auto-detected — select in settings.")
        label_column = UNSET

    task_category = UNSET
    if label_column != UNSET:
        try:
            task_category = preprocessing._find_task_category(data, label_column)
        except ValueError:
            gr.Warning("Task category not auto-detected — framework supports classification, regression.")

    text_column = gr.update(value=text_column, choices=data.column_names, interactive=True)
    label_column = gr.update(value=label_column, choices=data.column_names, interactive=True)
    text_pair  = gr.update(value=UNSET, choices=[UNSET, *data.column_names], interactive=True)
    task_category = gr.update(value=task_category, choices=[str(t) for t in TaskCategory], interactive=True)
    sample_size = len(data)

    return task_category, text_column, text_pair, label_column, sample_size

"""
    return (
        text_column,
        gr.update(
            value=task_category,
            choices=[str(t) for t in TaskCategory],
            interactive=True,
        ),
        gr.update(
            value=text_column, choices=data.column_names, interactive=True
        ),
        gr.update(
            value=UNSET, choices=[UNSET, *data.column_names], interactive=True
        ),
        gr.update(
            value=label_column, choices=data.column_names, interactive=True
        ),
        num_samples,
    )
"""

def compute_ratio(num_samples_to_use, num_samples):
    if num_samples > 0:
        return num_samples_to_use / num_samples
    else:
        return 0.0


def ensure_dataset_is_loaded(dataset, text_column, label_column, task_category):
    if dataset and text_column != UNSET and label_column != UNSET and task_category != UNSET:
        return gr.update(interactive=True)
    else:
        return gr.update(interactive=False)


def ensure_one_lm_selected(checkbox_values, previous_values):
    if not any(checkbox_values):
        return previous_values
    return checkbox_values


# apply monkey patch to enable callbacks
_old_embed = Embedder.embed

def _new_embed(embedder, sentences, batch_size: int = 32, **kw):
    if embedder.tracker is not None:
        embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size))

    return _old_embed(embedder, sentences, batch_size=batch_size, **kw)

Embedder.embed = _new_embed

_old_embed_batch = Embedder.embed_batch

def _new_embed_batch(embedder, *args, **kw):
    r = _old_embed_batch(embedder, *args, **kw)
    if embedder.tracker is not None:
        embedder.tracker.update_batch_complete()
    return r

Embedder.embed_batch = _new_embed_batch

_old_init = Embedder.__init__

def _new_init(embedder, *args, tracker=None, **kw):
    _old_init(embedder, *args, **kw)
    embedder.tracker = tracker

Embedder.__init__ = _new_init


class EmbeddingProgressTracker:
    def __init__(self, *, progress, model_names):
        self.model_names = model_names
        self.progress_bar = progress

    @property
    def total(self):
        return len(self.model_names)

    def __enter__(self):
        self.progress_bar = gr.Progress(track_tqdm=False)
        self.current_model = -1
        self.batches_complete = 0
        self.batches_total = None
        return self

    def __exit__(self, typ, value, tb):
        if typ is None:
            self.progress_bar(1.0, desc="Done")
        else:
            self.progress_bar(1.0, desc="Error")

        # Do not suppress any errors
        return False

    def update_num_batches(self, total):
        self.current_model += 1
        self.batches_complete = 0
        self.batches_total = total
        self.update_bar()

    def update_batch_complete(self):
        self.batches_complete += 1
        self.update_bar()

    def update_bar(self):
        i = self.current_model

        description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"

        progress = i / self.total
        if self.batches_total is not None:
            progress += (self.batches_complete / self.batches_total) / self.total

        self.progress_bar(progress=progress, desc=description)