ghostai1 commited on
Commit
2acf864
Β·
verified Β·
1 Parent(s): 23c6c15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -8
app.py CHANGED
@@ -1,15 +1,30 @@
1
  # πŸ” Masked Word Predictor | CPU-only HF Space
2
 
3
  import gradio as gr
4
- from transformers import pipeline
5
 
6
  # Load the fill-mask pipeline once at startup
7
  fill_mask = pipeline("fill-mask", model="distilroberta-base", device=-1)
8
 
9
  def predict_mask(sentence: str, top_k: int):
10
- if "[MASK]" not in sentence:
11
- return [{"sequence": "Error: include [MASK] in your sentence.", "score": 0.0}]
12
- preds = fill_mask(sentence, top_k=top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  return [
14
  {"sequence": p["sequence"], "score": round(p["score"], 3)}
15
  for p in preds
@@ -18,20 +33,21 @@ def predict_mask(sentence: str, top_k: int):
18
  with gr.Blocks(title="πŸ” Masked Word Predictor") as demo:
19
  gr.Markdown(
20
  "# πŸ” Masked Word Predictor\n"
21
- "Enter a sentence with one `[MASK]` token and see the model’s top predictions."
22
  )
23
 
24
  with gr.Row():
25
  sentence = gr.Textbox(
26
  lines=2,
27
- placeholder="The capital of France is [MASK].",
28
  label="Input Sentence"
29
  )
30
  top_k = gr.Slider(
31
- minimum=1, maximum=10, step=1, value=5,
32
  label="Top K Predictions"
33
  )
34
- predict_btn = gr.Button("Predict", variant="primary")
 
35
 
36
  results = gr.Dataframe(
37
  headers=["sequence", "score"],
 
1
  # πŸ” Masked Word Predictor | CPU-only HF Space
2
 
3
  import gradio as gr
4
+ from transformers import pipeline, PipelineException
5
 
6
  # Load the fill-mask pipeline once at startup
7
  fill_mask = pipeline("fill-mask", model="distilroberta-base", device=-1)
8
 
9
  def predict_mask(sentence: str, top_k: int):
10
+ # Get the model’s actual mask token (e.g. "<mask>")
11
+ mask = fill_mask.tokenizer.mask_token
12
+
13
+ # Allow users to type [MASK]; convert it under the hood
14
+ if "[MASK]" in sentence:
15
+ sentence = sentence.replace("[MASK]", mask)
16
+
17
+ # If no mask token present, show error
18
+ if mask not in sentence:
19
+ return [{"sequence": f"Error: please include `[MASK]` in your sentence.", "score": 0.0}]
20
+
21
+ # Call the pipeline and catch any unexpected exceptions
22
+ try:
23
+ preds = fill_mask(sentence, top_k=top_k)
24
+ except PipelineException as e:
25
+ return [{"sequence": f"Error: {str(e)}", "score": 0.0}]
26
+
27
+ # Format into list-of-dicts for Gradio Dataframe
28
  return [
29
  {"sequence": p["sequence"], "score": round(p["score"], 3)}
30
  for p in preds
 
33
  with gr.Blocks(title="πŸ” Masked Word Predictor") as demo:
34
  gr.Markdown(
35
  "# πŸ” Masked Word Predictor\n"
36
+ "Enter a sentence with one `[MASK]` token and see the top-K model predictions."
37
  )
38
 
39
  with gr.Row():
40
  sentence = gr.Textbox(
41
  lines=2,
42
+ placeholder="e.g. The salon’s new color treatment is [MASK].",
43
  label="Input Sentence"
44
  )
45
  top_k = gr.Slider(
46
+ minimum=1, maximum=10, value=5, step=1,
47
  label="Top K Predictions"
48
  )
49
+
50
+ predict_btn = gr.Button("Predict πŸ”", variant="primary")
51
 
52
  results = gr.Dataframe(
53
  headers=["sequence", "score"],