File size: 9,857 Bytes
9de0135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import time
from functools import lru_cache

import torch
import gradio as gr
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM


@lru_cache(maxsize=1)  # only cache the latest model
def get_model_and_tokenizer(model_id):
    config = AutoConfig.from_pretrained(model_id)
    if config.is_encoder_decoder:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_id)

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    return model, tokenizer


@lru_cache(maxsize=32768)  # cache up to 32k examples
def run_generation(
    text,
    model_id,
    max_new_tokens,
    alpha=0.0,
    top_k=0,
    num_beams=1,
    do_sample=False,
    top_p=0.0,
    seed=0
):
    model, tokenizer = get_model_and_tokenizer(model_id)

    inputs = tokenizer(text, return_tensors='pt')
    if seed:
        torch.manual_seed(seed)

    start = time.time_ns()
    contrastive_ids = model.generate(
        # from the tokenizer
        **inputs,
        # fixed arguments
        num_return_sequences=1,
        early_stopping=True,
        # variable arguments
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        num_beams=num_beams,
        penalty_alpha=alpha or None,
        top_k=top_k or None,
        top_p=top_p or None,
    )
    end = time.time_ns()

    contrastive_time = (end - start) / 1e6
    contrastive_text = tokenizer.decode(contrastive_ids[0], skip_special_tokens=True)
    return contrastive_text, contrastive_time


def generate_beam_search(text, model_id, max_new_tokens, alpha, k, num_beams):
    contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k)
    beam_search_text, beam_search_time = run_generation(text, model_id, max_new_tokens, num_beams=num_beams)
    return contrastive_text, contrastive_time, beam_search_text, beam_search_time


def generate_top_k(text, model_id, max_new_tokens, alpha, k, top_k, seed):
    contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k)
    top_k_text, top_k_time = run_generation(
        text, model_id, max_new_tokens, top_k=top_k, seed=seed, do_sample=True
    )
    return contrastive_text, contrastive_time, top_k_text, top_k_time


def generate_top_p(text, model_id, max_new_tokens, alpha, k, top_p, seed):
    contrastive_text, contrastive_time = run_generation(text, model_id, max_new_tokens, alpha=alpha, top_k=k)
    top_p_text, top_p_time = run_generation(
        text, model_id, max_new_tokens, top_p=top_p, seed=seed, do_sample=True
    )
    return contrastive_text, contrastive_time, top_p_text, top_p_time


demo = gr.Blocks()

with demo:
    gr.Markdown(
        """
        # Contrastive Search Generation comparison

        Credits to the contrastive search generation [paper](https://arxiv.org/abs/2202.06417) authors, including
        @[pangpang666](https://huggingface.co/pangpang666) and @[GMFTBY](https://huggingface.co/GMFTBY). Check out the
        follow-up [work](https://arxiv.org/abs/2210.14140), which demonstrates the usefulness of the technique with
        off-the-shelf LLMs, as well as their [HF guest blog post](https://huggingface.co/blog/introducing-csearch).

        From the paper:
        "At each decoding step, the key ideas of contrastive search are (i) the generated output should be selected
        from the set of most probable candidates predicted by the model; and (ii) the generated output should be
        discriminative enough with respect to the previous context. In this way, the generated text can (i) better
        maintain the semantic coherence with respect to the prefix while (ii) avoiding model degeneration."

        🚨 Warnings: 🚨
        - Avoid using large models (> 1GB) in this demo. It will take a long time to load the model and generate text.
        - Too slow/long queue? Check our
        [colab](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/115_introducing_contrastive_search.ipynb)
        instead.
        """
    )
    with gr.Tabs():
        with gr.TabItem("vs. Beam Search"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("## Inputs ✍️")
                    gr.Markdown("General options:")
                    model_id = gr.Text(value="facebook/opt-125m", label="Model Repository")
                    input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text")
                    max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate")
                    gr.Markdown("Contrastive Search options:")
                    alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha")
                    k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K")
                    gr.Markdown("Beam Search options:")
                    num_beams = gr.Slider(value=4, minimum=1, maximum=16, step=1, label="Number of beams")
                    generate_button = gr.Button(value="Generate", label="Generate")

                with gr.Column():
                    gr.Markdown("## Outputs πŸ€–")
                    gr.Markdown("Contrastive Search generation:")
                    text_contrastive = gr.Textbox(value="", label="")
                    time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
                    gr.Markdown("Beam Search generation:")
                    text_beam_search = gr.Textbox(value="", label="")
                    time_beam_search = gr.Number(value=0.0, precision=1, label="Generation time (ms)")

            # actions
            generate_button.click(
                fn=generate_beam_search,
                inputs=[input_text, model_id, max_new_tokens, alpha, k, num_beams],
                outputs=[text_contrastive, time_contrastive, text_beam_search, time_beam_search]
            )

        with gr.TabItem("vs. Top K Sampling"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("## Inputs ✍️")
                    gr.Markdown("General options:")
                    model_id = gr.Text(value="facebook/opt-125m", label="Model Repository")
                    input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text")
                    max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate")
                    gr.Markdown("Contrastive Search options:")
                    alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha")
                    k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K")
                    gr.Markdown("Sampling options:")
                    top_k = gr.Slider(value=50, minimum=1, maximum=100, step=1, label="Top K")
                    seed = gr.Number(value=42, precision=0, label="Seed")
                    generate_button = gr.Button(value="Generate", label="Generate")

                with gr.Column():
                    gr.Markdown("## Outputs πŸ€–")
                    gr.Markdown("Contrastive Search generation:")
                    text_contrastive = gr.Textbox(value="", label="")
                    time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
                    gr.Markdown("Top K Sampling generation:")
                    text_top_k = gr.Textbox(value="", label="")
                    time_top_k = gr.Number(value=0.0, precision=1, label="Generation time (ms)")

            # actions
            generate_button.click(
                fn=generate_top_k,
                inputs=[input_text, model_id, max_new_tokens, alpha, k, top_k, seed],
                outputs=[text_contrastive, time_contrastive, text_top_k, time_top_k]
            )

        with gr.TabItem("vs. Nucleus Sampling"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown("## Inputs ✍️")
                    gr.Markdown("General options:")
                    model_id = gr.Text(value="facebook/opt-125m", label="Model Repository")
                    input_text = gr.Textbox(value="DeepMind Company is", lines=5, label="Input Text")
                    max_new_tokens = gr.Slider(value=50, minimum=1, maximum=256, label="New tokens to generate")
                    gr.Markdown("Contrastive Search options:")
                    alpha = gr.Slider(value=0.6, minimum=0.01, maximum=1.0, step=0.01, label="Alpha")
                    k = gr.Slider(value=6, minimum=1, maximum=20, step=1, label="K")
                    gr.Markdown("Sampling options:")
                    top_p = gr.Slider(value=0.95, minimum=0.01, maximum=1.0, step=0.01, label="Top P")
                    seed = gr.Number(value=42, precision=0, label="Seed")
                    generate_button = gr.Button(value="Generate", label="Generate")

                with gr.Column():
                    gr.Markdown("## Outputs πŸ€–")
                    gr.Markdown("Contrastive Search generation:")
                    text_contrastive = gr.Textbox(value="", label="")
                    time_contrastive = gr.Number(value=0.0, precision=1, label="Generation time (ms)")
                    gr.Markdown("Nucleus Sampling generation:")
                    text_top_p = gr.Textbox(value="", label="")
                    time_top_p = gr.Number(value=0.0, precision=1, label="Generation time (ms)")

            # actions
            generate_button.click(
                fn=generate_top_p,
                inputs=[input_text, model_id, max_new_tokens, alpha, k, top_p, seed],
                outputs=[text_contrastive, time_contrastive, text_top_p, time_top_p]
            )

demo.launch()