File size: 2,116 Bytes
311f28b
 
 
493b784
67be689
7ed3084
311f28b
493b784
311f28b
 
 
7ed3084
2acf864
 
7ed3084
 
2acf864
7ed3084
2acf864
493b784
 
 
 
2acf864
7ed3084
2acf864
 
 
493b784
 
2acf864
493b784
 
 
311f28b
 
 
 
7ed3084
311f28b
 
 
 
 
2acf864
311f28b
 
 
7ed3084
311f28b
 
2acf864
 
311f28b
493b784
7ed3084
493b784
 
 
311f28b
 
 
 
7ed3084
311f28b
493b784
311f28b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# πŸ” Masked Word Predictor | CPU-only HF Space

import gradio as gr
import pandas as pd
from transformers import pipeline
from transformers.pipelines.base import PipelineException

# 1. Load the fill-mask pipeline once
fill_mask = pipeline("fill-mask", model="distilroberta-base", device=-1)

def predict_mask(sentence: str, top_k: int):
    # 2. Get the actual mask token (e.g. "<mask>")
    mask = fill_mask.tokenizer.mask_token

    # 3. Allow users to type [MASK]
    sentence = sentence.replace("[MASK]", mask)

    # 4. Validate presence of mask
    if mask not in sentence:
        return pd.DataFrame(
            [["Error: please include `[MASK]` in your sentence.", 0.0]],
            columns=["Sequence", "Score"]
        )

    # 5. Run the pipeline safely
    try:
        preds = fill_mask(sentence, top_k=top_k)
    except PipelineException as e:
        return pd.DataFrame([[f"Error: {str(e)}", 0.0]],
                            columns=["Sequence", "Score"])

    # 6. Build a DataFrame from list-of-lists
    rows = [[p["sequence"], round(p["score"], 3)] for p in preds]
    return pd.DataFrame(rows, columns=["Sequence", "Score"])

with gr.Blocks(title="πŸ” Masked Word Predictor") as demo:
    gr.Markdown(
        "# πŸ” Masked Word Predictor\n"
        "Enter a sentence with one `[MASK]` token and see the top-K completions."
    )

    with gr.Row():
        sentence = gr.Textbox(
            lines=2,
            placeholder="e.g. The salon’s new color treatment is [MASK].",
            label="Input Sentence"
        )
        top_k = gr.Slider(
            minimum=1, maximum=10, step=1, value=5,
            label="Top K Predictions"
        )

    predict_btn = gr.Button("Predict πŸ”", variant="primary")

    results_df = gr.Dataframe(
        headers=["Sequence", "Score"],
        datatype=["str", "number"],
        wrap=True,
        interactive=False,
        label="Predictions"
    )

    predict_btn.click(
        fn=predict_mask,
        inputs=[sentence, top_k],
        outputs=results_df
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0")