jedick commited on
Commit
7879fc7
·
1 Parent(s): 0eb9d15

Remove barplot

Browse files
Files changed (1) hide show
  1. app.py +64 -142
app.py CHANGED
@@ -50,28 +50,6 @@ if gr.NO_RELOAD:
50
  )
51
 
52
 
53
- def prediction_to_df(prediction=None):
54
- """
55
- Convert prediction text to DataFrame for barplot
56
- """
57
- if prediction is None or prediction == "":
58
- # Show an empty plot for app initialization or auto-reload
59
- prediction = {"SUPPORT": 0, "NEI": 0, "REFUTE": 0}
60
- elif "Model" in prediction:
61
- # Show full-height bars when the model is changed
62
- prediction = {"SUPPORT": 1, "NEI": 1, "REFUTE": 1}
63
- else:
64
- # Convert predictions text to dictionary
65
- prediction = eval(prediction)
66
- # Use custom order for labels (pipe() returns labels in descending order of softmax score)
67
- labels = ["SUPPORT", "NEI", "REFUTE"]
68
- prediction = {k: prediction[k] for k in labels}
69
- # Convert dictionary to DataFrame with one column (Probability)
70
- df = pd.DataFrame.from_dict(prediction, orient="index", columns=["Probability"])
71
- # Move the index to the Class column
72
- return df.reset_index(names="Class")
73
-
74
-
75
  # Setup theme without background image
76
  my_theme = gr.Theme.from_hub("NoCrypt/miku")
77
  my_theme.set(body_background_fill="#FFFFFF", body_background_fill_dark="#000000")
@@ -127,22 +105,44 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
127
  completion_tokens = gr.Number(
128
  label="Completion tokens", visible=False
129
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  with gr.Column(scale=2):
132
- # Keep the prediction textbox hidden
133
- with gr.Accordion(visible=False):
134
- prediction = gr.Textbox(label="Prediction")
135
- barplot = gr.BarPlot(
136
- prediction_to_df,
137
- x="Class",
138
- y="Probability",
139
- color="Class",
140
- color_map={"SUPPORT": "green", "NEI": "#888888", "REFUTE": "#FF8888"},
141
- inputs=prediction,
142
- y_lim=([0, 1]),
143
- visible=False,
144
- )
145
- label = gr.Label(label="Prediction")
146
  with gr.Accordion("Feedback"):
147
  gr.Markdown(
148
  "*Provide the correct label to help improve this app*<br>**NOTE:** The claim and evidence will also be saved"
@@ -189,71 +189,19 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
189
  "label"
190
  ].tolist(),
191
  )
192
-
193
- with gr.Row():
194
- with gr.Column(scale=3):
195
- with gr.Row():
196
- with gr.Column(scale=1):
197
- gr.Markdown(
198
- """
199
- ### Usage:
200
-
201
- - Input a **Claim**, then:
202
- - Upload a PDF and click **Get Evidence** OR
203
- - Input **Evidence** statements yourself
204
- """
205
- )
206
- with gr.Column(scale=2):
207
- gr.Markdown(
208
- """
209
- ### To make the prediction:
210
-
211
- - Hit 'Enter' in the **Claim** text box OR
212
- - Hit 'Shift-Enter' in the **Evidence** text box
213
-
214
- _The prediction is also made after clicking **Get Evidence**_
215
- """
216
- )
217
-
218
- with gr.Column(scale=2):
219
- with gr.Accordion("Settings", open=False):
220
- # Create dropdown menu to select the model
221
- model = gr.Dropdown(
222
- choices=[
223
- # TODO: For bert-base-uncased, how can we set num_labels = 2 in HF pipeline?
224
- # (num_labels is available in AutoModelForSequenceClassification.from_pretrained)
225
- # "bert-base-uncased",
226
- "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
227
- "jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint",
228
- ],
229
- value=MODEL_NAME,
230
- label="Model",
231
- )
232
- radio = gr.Radio(
233
- ["label", "barplot"], value="label", label="Prediction"
234
- )
235
- with gr.Accordion("Sources", open=False):
236
- gr.Markdown(
237
- """
238
- #### *Capstone project*
239
- - <i class="fa-brands fa-github"></i> [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo)
240
- - <i class="fa-brands fa-github"></i> [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo)
241
- #### *Text Classification*
242
- - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned)
243
- - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base)
244
- #### *Evidence Retrieval*
245
- - <i class="fa-brands fa-github"></i> [xhluca/bm25s](https://github.com/xhluca/bm25s) (BM25S)
246
- - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [deepset/deberta-v3-large-squad2](https://huggingface.co/deepset/deberta-v3-large-squad2) (DeBERTa)
247
- - <img src="https://upload.wikimedia.org/wikipedia/commons/4/4d/OpenAI_Logo.svg" style="height: 1.2em; display: inline-block;"> [gpt-4o-mini-2024-07-18](https://platform.openai.com/docs/pricing) (GPT)
248
- #### *Datasets for fine-tuning*
249
- - <i class="fa-brands fa-github"></i> [allenai/SciFact](https://github.com/allenai/scifact) (SciFact)
250
- - <i class="fa-brands fa-github"></i> [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt)
251
- #### *Other sources*
252
- - <img src="https://plos.org/wp-content/uploads/2020/01/logo-color-blue.svg" style="height: 1.4em; display: inline-block;"> [Medicine](https://doi.org/10.1371/journal.pmed.0030197), <i class="fa-brands fa-wikipedia-w"></i> [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples)
253
- - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [nyu-mll/multi_nli](https://huggingface.co/datasets/nyu-mll/multi_nli/viewer/default/train?row=37&views%5B%5D=train) (MNLI example)
254
- - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [NoCrypt/miku](https://huggingface.co/spaces/NoCrypt/miku) (theme)
255
- """
256
- )
257
 
258
  # Functions
259
 
@@ -284,8 +232,7 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
284
  ("REFUTE" if k in ["REFUTE", "contradiction"] else k): v
285
  for k, v in prediction.items()
286
  }
287
- # Return two instances of the prediction to send to different Gradio components
288
- return prediction, prediction
289
 
290
  def select_model(model_name):
291
  """
@@ -298,15 +245,6 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
298
  model=MODEL_NAME,
299
  )
300
 
301
- def change_visualization(choice):
302
- if choice == "barplot":
303
- barplot = gr.update(visible=True)
304
- label = gr.update(visible=False)
305
- elif choice == "label":
306
- barplot = gr.update(visible=False)
307
- label = gr.update(visible=True)
308
- return barplot, label
309
-
310
  # From gradio/client/python/gradio_client/utils.py
311
  def is_http_url_like(possible_url) -> bool:
312
  """
@@ -354,13 +292,13 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
354
  return f"Unknown retrieval method: {method}"
355
 
356
  def append_feedback(
357
- claim: str, evidence: str, model: str, label: str, user_label: str
358
  ) -> None:
359
  """
360
  Append input/outputs and user feedback to a JSON Lines file.
361
  """
362
  # Get the first label (prediction with highest probability)
363
- prediction = next(iter(label))
364
  with USER_FEEDBACK_PATH.open("a") as f:
365
  f.write(
366
  json.dumps(
@@ -368,7 +306,7 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
368
  "claim": claim,
369
  "evidence": evidence,
370
  "model": model,
371
- "prediction": prediction,
372
  "user_label": user_label,
373
  "datetime": datetime.now().isoformat(),
374
  }
@@ -435,7 +373,7 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
435
  triggers=[claim.submit, evidence.submit],
436
  fn=query_model,
437
  inputs=[claim, evidence],
438
- outputs=[prediction, label],
439
  )
440
 
441
  # Get evidence from PDF and run the model
@@ -447,7 +385,7 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
447
  ).then(
448
  fn=query_model,
449
  inputs=[claim, evidence],
450
- outputs=[prediction, label],
451
  api_name=False,
452
  )
453
 
@@ -461,7 +399,7 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
461
  ).then(
462
  fn=query_model,
463
  inputs=[claim, evidence],
464
- outputs=[prediction, label],
465
  api_name=False,
466
  )
467
 
@@ -475,7 +413,7 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
475
  ).then(
476
  fn=query_model,
477
  inputs=[claim, evidence],
478
- outputs=[prediction, label],
479
  api_name=False,
480
  )
481
 
@@ -489,7 +427,7 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
489
  ).then(
490
  fn=query_model,
491
  inputs=[claim, evidence],
492
- outputs=[prediction, label],
493
  api_name=False,
494
  )
495
 
@@ -508,53 +446,37 @@ with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo:
508
  ).then(
509
  fn=query_model,
510
  inputs=[claim, evidence],
511
- outputs=[prediction, label],
512
- api_name=False,
513
- )
514
-
515
- # Change visualization
516
- radio.change(
517
- fn=change_visualization,
518
- inputs=radio,
519
- outputs=[barplot, label],
520
- api_name=False,
521
- )
522
-
523
- # Clear the previous predictions when the model is changed
524
- gr.on(
525
- triggers=[model.select],
526
- fn=lambda: "Model changed! Waiting for updated predictions...",
527
- outputs=[prediction],
528
  api_name=False,
529
  )
530
 
531
- # Change the model the update the predictions
532
  model.change(
533
  fn=select_model,
534
  inputs=model,
535
  ).then(
536
  fn=query_model,
537
  inputs=[claim, evidence],
538
- outputs=[prediction, label],
539
  api_name=False,
540
  )
541
 
542
  # Log user feedback when button is clicked
543
  flag_support.click(
544
  fn=save_feedback_support,
545
- inputs=[claim, evidence, model, label],
546
  outputs=None,
547
  api_name=False,
548
  )
549
  flag_nei.click(
550
  fn=save_feedback_nei,
551
- inputs=[claim, evidence, model, label],
552
  outputs=None,
553
  api_name=False,
554
  )
555
  flag_refute.click(
556
  fn=save_feedback_refute,
557
- inputs=[claim, evidence, model, label],
558
  outputs=None,
559
  api_name=False,
560
  )
 
50
  )
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Setup theme without background image
54
  my_theme = gr.Theme.from_hub("NoCrypt/miku")
55
  my_theme.set(body_background_fill="#FFFFFF", body_background_fill_dark="#000000")
 
105
  completion_tokens = gr.Number(
106
  label="Completion tokens", visible=False
107
  )
108
+ gr.Markdown(
109
+ """
110
+ ### App Usage:
111
+
112
+ - Input a **Claim**, then:
113
+ - Upload a PDF and click **Get Evidence** OR
114
+ - Input **Evidence** statements yourself
115
+ - Make the **Prediction**:
116
+ - Hit 'Enter' in the **Claim** text box OR
117
+ - Hit 'Shift-Enter' in the **Evidence** text box OR
118
+ - Click **Get Evidence**
119
+ """
120
+ )
121
+ with gr.Accordion("Sources", open=False):
122
+ gr.Markdown(
123
+ """
124
+ #### *Capstone project*
125
+ - <i class="fa-brands fa-github"></i> [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo)
126
+ - <i class="fa-brands fa-github"></i> [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo)
127
+ #### *Text Classification*
128
+ - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned)
129
+ - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base)
130
+ #### *Evidence Retrieval*
131
+ - <i class="fa-brands fa-github"></i> [xhluca/bm25s](https://github.com/xhluca/bm25s) (BM25S)
132
+ - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [deepset/deberta-v3-large-squad2](https://huggingface.co/deepset/deberta-v3-large-squad2) (DeBERTa)
133
+ - <img src="https://upload.wikimedia.org/wikipedia/commons/4/4d/OpenAI_Logo.svg" style="height: 1.2em; display: inline-block;"> [gpt-4o-mini-2024-07-18](https://platform.openai.com/docs/pricing) (GPT)
134
+ #### *Datasets for fine-tuning*
135
+ - <i class="fa-brands fa-github"></i> [allenai/SciFact](https://github.com/allenai/scifact) (SciFact)
136
+ - <i class="fa-brands fa-github"></i> [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt)
137
+ #### *Other sources*
138
+ - <img src="https://plos.org/wp-content/uploads/2020/01/logo-color-blue.svg" style="height: 1.4em; display: inline-block;"> [Medicine](https://doi.org/10.1371/journal.pmed.0030197), <i class="fa-brands fa-wikipedia-w"></i> [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples)
139
+ - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [nyu-mll/multi_nli](https://huggingface.co/datasets/nyu-mll/multi_nli/viewer/default/train?row=37&views%5B%5D=train) (MNLI example)
140
+ - <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [NoCrypt/miku](https://huggingface.co/spaces/NoCrypt/miku) (theme)
141
+ """
142
+ )
143
 
144
  with gr.Column(scale=2):
145
+ prediction = gr.Label(label="Prediction")
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  with gr.Accordion("Feedback"):
147
  gr.Markdown(
148
  "*Provide the correct label to help improve this app*<br>**NOTE:** The claim and evidence will also be saved"
 
189
  "label"
190
  ].tolist(),
191
  )
192
+ # Create dropdown menu to select the model
193
+ model = gr.Dropdown(
194
+ choices=[
195
+ # TODO: For bert-base-uncased, how can we set num_labels = 2 in HF pipeline?
196
+ # (num_labels is available in AutoModelForSequenceClassification.from_pretrained)
197
+ # "bert-base-uncased",
198
+ "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
199
+ "jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint",
200
+ ],
201
+ value=MODEL_NAME,
202
+ label="Model",
203
+ info="Text classification model used for claim verification",
204
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  # Functions
207
 
 
232
  ("REFUTE" if k in ["REFUTE", "contradiction"] else k): v
233
  for k, v in prediction.items()
234
  }
235
+ return prediction
 
236
 
237
  def select_model(model_name):
238
  """
 
245
  model=MODEL_NAME,
246
  )
247
 
 
 
 
 
 
 
 
 
 
248
  # From gradio/client/python/gradio_client/utils.py
249
  def is_http_url_like(possible_url) -> bool:
250
  """
 
292
  return f"Unknown retrieval method: {method}"
293
 
294
  def append_feedback(
295
+ claim: str, evidence: str, model: str, prediction: str, user_label: str
296
  ) -> None:
297
  """
298
  Append input/outputs and user feedback to a JSON Lines file.
299
  """
300
  # Get the first label (prediction with highest probability)
301
+ _prediction = next(iter(prediction))
302
  with USER_FEEDBACK_PATH.open("a") as f:
303
  f.write(
304
  json.dumps(
 
306
  "claim": claim,
307
  "evidence": evidence,
308
  "model": model,
309
+ "prediction": _prediction,
310
  "user_label": user_label,
311
  "datetime": datetime.now().isoformat(),
312
  }
 
373
  triggers=[claim.submit, evidence.submit],
374
  fn=query_model,
375
  inputs=[claim, evidence],
376
+ outputs=prediction,
377
  )
378
 
379
  # Get evidence from PDF and run the model
 
385
  ).then(
386
  fn=query_model,
387
  inputs=[claim, evidence],
388
+ outputs=prediction,
389
  api_name=False,
390
  )
391
 
 
399
  ).then(
400
  fn=query_model,
401
  inputs=[claim, evidence],
402
+ outputs=prediction,
403
  api_name=False,
404
  )
405
 
 
413
  ).then(
414
  fn=query_model,
415
  inputs=[claim, evidence],
416
+ outputs=prediction,
417
  api_name=False,
418
  )
419
 
 
427
  ).then(
428
  fn=query_model,
429
  inputs=[claim, evidence],
430
+ outputs=prediction,
431
  api_name=False,
432
  )
433
 
 
446
  ).then(
447
  fn=query_model,
448
  inputs=[claim, evidence],
449
+ outputs=prediction,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  api_name=False,
451
  )
452
 
453
+ # Change the model then update the predictions
454
  model.change(
455
  fn=select_model,
456
  inputs=model,
457
  ).then(
458
  fn=query_model,
459
  inputs=[claim, evidence],
460
+ outputs=prediction,
461
  api_name=False,
462
  )
463
 
464
  # Log user feedback when button is clicked
465
  flag_support.click(
466
  fn=save_feedback_support,
467
+ inputs=[claim, evidence, model, prediction],
468
  outputs=None,
469
  api_name=False,
470
  )
471
  flag_nei.click(
472
  fn=save_feedback_nei,
473
+ inputs=[claim, evidence, model, prediction],
474
  outputs=None,
475
  api_name=False,
476
  )
477
  flag_refute.click(
478
  fn=save_feedback_refute,
479
+ inputs=[claim, evidence, model, prediction],
480
  outputs=None,
481
  api_name=False,
482
  )