File size: 3,639 Bytes
8c0b646
 
d5cc744
13f7861
d5cc744
 
8c0b646
 
d5cc744
fce4c33
d5cc744
 
fce4c33
 
 
 
 
 
 
 
 
d5cc744
802cf91
d5cc744
 
 
 
 
 
 
 
 
802cf91
d5cc744
802cf91
 
d5cc744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fce4c33
2bf7c7e
8c0b646
3566540
8c0b646
 
 
 
 
3566540
8c0b646
3566540
13f7861
8c0b646
b5b3297
38bdfc6
 
 
3566540
38bdfc6
 
 
 
 
 
 
 
 
 
 
8c0b646
802cf91
 
8c0b646
 
3566540
8c0b646
 
3566540
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import numpy as np
import torch
from transformers import pipeline, Pipeline
from transformers.pipelines import PIPELINE_REGISTRY, FillMaskPipeline
from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM

unmasker = pipeline("fill-mask", model="anferico/bert-for-patents")
# unmasker = pipeline("temp-scale", model="anferico/bert-for-patents")
example = 'A crustless [MASK] made from two slices of baked bread'
example_dict = {}
example_dict['input_ids'] = example

def add_mask(text, size=1):
    split_text = text.split()
    idx = np.random.randint(len(split_text), size=size)
    for i in idx:
        split_text[i] = '[MASK]'
    return ' '.join(split_text)


class TempScalePipe(FillMaskPipeline):
    def postprocess(self, model_outputs, top_k=3, target_ids=None):
        # Cap top_k if there are targets
        if target_ids is not None and target_ids.shape[0] < top_k:
            top_k = target_ids.shape[0]
        input_ids = model_outputs["input_ids"][0]
        outputs = model_outputs["logits"]

        masked_index = torch.nonzero(input_ids == self.tokenizer.mask_token_id, as_tuple=False).squeeze(-1)
        # Fill mask pipeline supports only one ${mask_token} per sample

        logits = outputs[0, masked_index, :] / 1e1
        probs = logits.softmax(dim=-1)
        indices = torch.multinomial(probs, num_samples=3)
        probs = probs[indices]
        if target_ids is not None:
            probs = probs[..., target_ids]

        values, predictions = probs.topk(top_k)

        result = []
        single_mask = values.shape[0] == 1
        for i, (_values, _predictions) in enumerate(zip(values.tolist(), predictions.tolist())):
            row = []
            for v, p in zip(_values, _predictions):
                # Copy is important since we're going to modify this array in place
                tokens = input_ids.numpy().copy()
                if target_ids is not None:
                    p = target_ids[p].tolist()

                tokens[masked_index[i]] = p
                # Filter padding out:
                tokens = tokens[np.where(tokens != self.tokenizer.pad_token_id)]
                # Originally we skip special tokens to give readable output.
                # For multi masks though, the other [MASK] would be removed otherwise
                # making the output look odd, so we add them back
                sequence = self.tokenizer.decode(tokens, skip_special_tokens=single_mask)
                proposition = {"score": v, "token": p, "token_str": self.tokenizer.decode([p]), "sequence": sequence}
                row.append(proposition)
            result.append(row)
        if single_mask:
            return result[0]
        return result


PIPELINE_REGISTRY.register_pipeline(
    "temp-scale",
    pipeline_class=TempScalePipe,
    pt_model=AutoModelForMaskedLM,
)


def unmask(text):
    # text = add_mask(text)
    res = unmasker(text)
    out = {item["token_str"]: item["score"] for item in res}
    return out



textbox = gr.Textbox(label="Type language here", lines=5)
# import gradio as gr
from transformers import pipeline, Pipeline


# unmasker = pipeline("fill-mask", model="anferico/bert-for-patents")
#
#

#
#
# def unmask(text):
#     text = add_mask(text)
#     res = unmasker(text)
#     out = {item["token_str"]: item["score"] for item in res}
#     return out
#
#
# textbox = gr.Textbox(label="Type language here", lines=5)
#
demo = gr.Interface(
    unmask,
    [gr.Slider(minimum=0, maximum=15, value=8, step=1, label="Guidance scale")],
    inputs=textbox,
    outputs="label",
    examples=[example],
)

demo.launch()