Spaces:
Running
Running
# ๐ Masked Word Predictor | CPU-only HF Space | |
import gradio as gr | |
from transformers import pipeline | |
from transformers.pipelines.base import PipelineException # correct import | |
# Load the fill-mask pipeline once at startup | |
fill_mask = pipeline("fill-mask", model="distilroberta-base", device=-1) | |
def predict_mask(sentence: str, top_k: int): | |
# Get the modelโs actual mask token (e.g. "<mask>") | |
mask = fill_mask.tokenizer.mask_token | |
# Allow users to type [MASK]; convert it under the hood | |
if "[MASK]" in sentence: | |
sentence = sentence.replace("[MASK]", mask) | |
# If no mask token present, show error | |
if mask not in sentence: | |
return [{"sequence": "Error: please include `[MASK]` in your sentence.", "score": 0.0}] | |
# Call the pipeline and catch any pipeline-specific exceptions | |
try: | |
preds = fill_mask(sentence, top_k=top_k) | |
except PipelineException as e: | |
return [{"sequence": f"Error: {str(e)}", "score": 0.0}] | |
# Format into list-of-dicts for Gradio Dataframe | |
return [ | |
{"sequence": p["sequence"], "score": round(p["score"], 3)} | |
for p in preds | |
] | |
with gr.Blocks(title="๐ Masked Word Predictor") as demo: | |
gr.Markdown( | |
"# ๐ Masked Word Predictor\n" | |
"Enter a sentence with one `[MASK]` token and see the top-K model predictions." | |
) | |
with gr.Row(): | |
sentence = gr.Textbox( | |
lines=2, | |
placeholder="e.g. The salonโs new color treatment is [MASK].", | |
label="Input Sentence" | |
) | |
top_k = gr.Slider( | |
minimum=1, maximum=10, value=5, step=1, | |
label="Top K Predictions" | |
) | |
predict_btn = gr.Button("Predict ๐", variant="primary") | |
results = gr.Dataframe( | |
headers=["sequence", "score"], | |
datatype=["str", "number"], | |
wrap=True, | |
interactive=False, | |
label="Predictions" | |
) | |
predict_btn.click( | |
predict_mask, | |
inputs=[sentence, top_k], | |
outputs=results | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0") | |