Spaces:
Running
on
Zero
Running
on
Zero
import pandas as pd | |
import gradio as gr | |
from transformers import pipeline | |
import nltk | |
from retrieval_bm25s import retrieve_with_bm25s | |
from retrieval_bert import retrieve_with_deberta | |
from retrieval_gpt import retrieve_with_gpt | |
import os | |
import json | |
from datetime import datetime | |
from pathlib import Path | |
from uuid import uuid4 | |
import spaces | |
def is_running_in_hf_spaces(): | |
""" | |
Detects if app is running in Hugging Face Spaces | |
""" | |
return "SPACE_ID" in os.environ | |
if gr.NO_RELOAD: | |
# Resource punkt_tab not found during application startup on HF spaces | |
nltk.download("punkt_tab") | |
# Keep track of the model name in a global variable so correct model is shown after page refresh | |
# https://github.com/gradio-app/gradio/issues/3173 | |
MODEL_NAME = "jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint" | |
pipe = pipeline( | |
"text-classification", | |
model=MODEL_NAME, | |
) | |
# Setup user feedback file for uploading to HF dataset | |
# https://huggingface.co/spaces/Wauplin/space_to_dataset_saver | |
# https://huggingface.co/docs/huggingface_hub/v0.16.3/en/guides/upload#scheduled-uploads | |
USER_FEEDBACK_DIR = Path("user_feedback") | |
USER_FEEDBACK_DIR.mkdir(parents=True, exist_ok=True) | |
USER_FEEDBACK_PATH = USER_FEEDBACK_DIR / f"train-{uuid4()}.json" | |
if is_running_in_hf_spaces(): | |
from huggingface_hub import CommitScheduler | |
scheduler = CommitScheduler( | |
repo_id="AI4citations-feedback", | |
repo_type="dataset", | |
folder_path=USER_FEEDBACK_DIR, | |
path_in_repo="data", | |
) | |
# Setup theme without background image | |
my_theme = gr.Theme.from_hub("NoCrypt/miku") | |
my_theme.set(body_background_fill="#FFFFFF", body_background_fill_dark="#000000") | |
# Define the HTML for Font Awesome | |
font_awesome_html = '<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css" rel="stylesheet">' | |
# Gradio interface setup | |
with gr.Blocks(theme=my_theme, head=font_awesome_html) as demo: | |
# Layout | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(): | |
gr.Markdown("# AI4citations") | |
gr.Markdown( | |
"## *AI-powered citation verification* ([more info](https://github.com/jedick/AI4citations))" | |
) | |
claim = gr.Textbox( | |
label="Claim", | |
info="aka hypothesis", | |
placeholder="Input claim", | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Accordion("Get Evidence from PDF"): | |
pdf_file = gr.File( | |
label="Upload PDF", type="filepath", height=120 | |
) | |
with gr.Row(): | |
retrieval_method = gr.Radio( | |
choices=["BM25S", "DeBERTa", "GPT"], | |
value="BM25S", | |
label="Retrieval Method", | |
info="Keyword search (BM25S) or AI (DeBERTa, GPT)", | |
) | |
get_evidence = gr.Button(value="Get Evidence") | |
top_k = gr.Slider( | |
1, | |
10, | |
value=5, | |
step=1, | |
label="Top k sentences", | |
) | |
with gr.Column(scale=3): | |
evidence = gr.TextArea( | |
label="Evidence", | |
info="aka premise", | |
placeholder="Input evidence or use Get Evidence from PDF", | |
) | |
with gr.Row(): | |
prompt_tokens = gr.Number(label="Prompt tokens", visible=False) | |
completion_tokens = gr.Number( | |
label="Completion tokens", visible=False | |
) | |
gr.Markdown( | |
""" | |
### App Usage: | |
- Input a **Claim**, then: | |
- Upload a PDF and click **Get Evidence** OR | |
- Input **Evidence** statements yourself | |
- Make the **Prediction**: | |
- Hit 'Enter' in the **Claim** text box OR | |
- Hit 'Shift-Enter' in the **Evidence** text box OR | |
- Click **Get Evidence** | |
""" | |
) | |
with gr.Accordion("Sources", open=False): | |
gr.Markdown( | |
""" | |
#### *Capstone project* | |
- <i class="fa-brands fa-github"></i> [jedick/MLE-capstone-project](https://github.com/jedick/MLE-capstone-project) (project repo) | |
- <i class="fa-brands fa-github"></i> [jedick/AI4citations](https://github.com/jedick/AI4citations) (app repo) | |
#### *Text Classification* | |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint](https://huggingface.co/jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint) (fine-tuned) | |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli](https://huggingface.co/MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli) (base) | |
#### *Evidence Retrieval* | |
- <i class="fa-brands fa-github"></i> [xhluca/bm25s](https://github.com/xhluca/bm25s) (BM25S) | |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [deepset/deberta-v3-large-squad2](https://huggingface.co/deepset/deberta-v3-large-squad2) (DeBERTa) | |
- <img src="https://upload.wikimedia.org/wikipedia/commons/4/4d/OpenAI_Logo.svg" style="height: 1.2em; display: inline-block;"> [gpt-4o-mini-2024-07-18](https://platform.openai.com/docs/pricing) (GPT) | |
#### *Datasets for fine-tuning* | |
- <i class="fa-brands fa-github"></i> [allenai/SciFact](https://github.com/allenai/scifact) (SciFact) | |
- <i class="fa-brands fa-github"></i> [ScienceNLP-Lab/Citation-Integrity](https://github.com/ScienceNLP-Lab/Citation-Integrity) (CitInt) | |
#### *Other sources* | |
- <img src="https://plos.org/wp-content/uploads/2020/01/logo-color-blue.svg" style="height: 1.4em; display: inline-block;"> [Medicine](https://doi.org/10.1371/journal.pmed.0030197), <i class="fa-brands fa-wikipedia-w"></i> [CRISPR](https://en.wikipedia.org/wiki/CRISPR) (evidence retrieval examples) | |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [nyu-mll/multi_nli](https://huggingface.co/datasets/nyu-mll/multi_nli/viewer/default/train?row=37&views%5B%5D=train) (MNLI example) | |
- <img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg" style="height: 1.2em; display: inline-block;"> [NoCrypt/miku](https://huggingface.co/spaces/NoCrypt/miku) (theme) | |
""" | |
) | |
with gr.Column(scale=2): | |
prediction = gr.Label(label="Prediction") | |
with gr.Accordion("Feedback"): | |
gr.Markdown( | |
"*Provide the correct label to help improve this app*<br>**NOTE:** The claim and evidence will also be saved" | |
), | |
with gr.Row(): | |
flag_support = gr.Button("Support") | |
flag_nei = gr.Button("NEI") | |
flag_refute = gr.Button("Refute") | |
gr.Markdown( | |
"Feedback is uploaded every 5 minutes to [AI4citations-feedback](https://huggingface.co/datasets/jedick/AI4citations-feedback)" | |
), | |
with gr.Accordion("Examples"): | |
gr.Markdown("*Examples are run when clicked*"), | |
with gr.Row(): | |
support_example = gr.Examples( | |
examples="examples/Support", | |
label="Support", | |
inputs=[claim, evidence], | |
example_labels=pd.read_csv("examples/Support/log.csv")[ | |
"label" | |
].tolist(), | |
) | |
nei_example = gr.Examples( | |
examples="examples/NEI", | |
label="NEI", | |
inputs=[claim, evidence], | |
example_labels=pd.read_csv("examples/NEI/log.csv")[ | |
"label" | |
].tolist(), | |
) | |
refute_example = gr.Examples( | |
examples="examples/Refute", | |
label="Refute", | |
inputs=[claim, evidence], | |
example_labels=pd.read_csv("examples/Refute/log.csv")[ | |
"label" | |
].tolist(), | |
) | |
retrieval_example = gr.Examples( | |
examples="examples/retrieval", | |
label="Get Evidence from PDF", | |
inputs=[pdf_file, claim], | |
example_labels=pd.read_csv("examples/retrieval/log.csv")[ | |
"label" | |
].tolist(), | |
) | |
# Create dropdown menu to select the model | |
model = gr.Dropdown( | |
choices=[ | |
# TODO: For bert-base-uncased, how can we set num_labels = 2 in HF pipeline? | |
# (num_labels is available in AutoModelForSequenceClassification.from_pretrained) | |
# "bert-base-uncased", | |
"MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli", | |
"jedick/DeBERTa-v3-base-mnli-fever-anli-scifact-citint", | |
], | |
value=MODEL_NAME, | |
label="Model", | |
info="Text classification model used for claim verification", | |
) | |
# Functions | |
def query_model(claim, evidence): | |
""" | |
Get prediction for a claim and evidence pair | |
""" | |
prediction = { | |
# Send a dictionary containing {"text", "text_pair"} keys; use top_k=3 to get results for all classes | |
# https://huggingface.co/docs/transformers/v4.51.3/en/main_classes/pipelines#transformers.TextClassificationPipeline.__call__.inputs | |
# Put evidence before claim | |
# https://github.com/jedick/MLE-capstone-project | |
# Output {label: confidence} dictionary format as expected by gr.Label() | |
# https://github.com/gradio-app/gradio/issues/11170 | |
d["label"]: d["score"] | |
for d in pipe({"text": evidence, "text_pair": claim}, top_k=3) | |
} | |
# Rename dictionary keys to use consistent labels across models | |
prediction = { | |
("SUPPORT" if k in ["SUPPORT", "entailment"] else k): v | |
for k, v in prediction.items() | |
} | |
prediction = { | |
("NEI" if k in ["NEI", "neutral"] else k): v for k, v in prediction.items() | |
} | |
prediction = { | |
("REFUTE" if k in ["REFUTE", "contradiction"] else k): v | |
for k, v in prediction.items() | |
} | |
return prediction | |
def select_model(model_name): | |
""" | |
Select the specified model | |
""" | |
global pipe, MODEL_NAME | |
MODEL_NAME = model_name | |
pipe = pipeline( | |
"text-classification", | |
model=MODEL_NAME, | |
) | |
# From gradio/client/python/gradio_client/utils.py | |
def is_http_url_like(possible_url) -> bool: | |
""" | |
Check if the given value is a string that looks like an HTTP(S) URL. | |
""" | |
if not isinstance(possible_url, str): | |
return False | |
return possible_url.startswith(("http://", "https://")) | |
def select_example(value, evt: gr.EventData): | |
# Get the PDF file and claim from the event data | |
claim, evidence = value[1] | |
# Add the directory path | |
return claim, evidence | |
def select_retrieval_example(value, evt: gr.EventData): | |
""" | |
Get the PDF file and claim from the event data. | |
""" | |
pdf_file, claim = value[1] | |
# Add the directory path | |
if not is_http_url_like(pdf_file): | |
pdf_file = f"examples/retrieval/{pdf_file}" | |
return pdf_file, claim | |
def _retrieve_with_deberta(pdf_file, claim, top_k): | |
""" | |
Retrieve evidence using DeBERTa | |
""" | |
return retrieve_with_deberta(pdf_file, claim, top_k) | |
def retrieve_evidence(pdf_file, claim, top_k, method): | |
""" | |
Retrieve evidence using the selected method | |
""" | |
if method == "BM25S": | |
# Append 0 for number of prompt and completion tokens | |
return retrieve_with_bm25s(pdf_file, claim, top_k), 0, 0 | |
elif method == "DeBERTa": | |
return _retrieve_with_deberta(pdf_file, claim, top_k), 0, 0 | |
elif method == "GPT": | |
return retrieve_with_gpt(pdf_file, claim) | |
else: | |
return f"Unknown retrieval method: {method}" | |
def append_feedback( | |
claim: str, evidence: str, model: str, prediction: str, user_label: str | |
) -> None: | |
""" | |
Append input/outputs and user feedback to a JSON Lines file. | |
""" | |
# Get the first label (prediction with highest probability) | |
_prediction = next(iter(prediction)) | |
with USER_FEEDBACK_PATH.open("a") as f: | |
f.write( | |
json.dumps( | |
{ | |
"claim": claim, | |
"evidence": evidence, | |
"model": model, | |
"prediction": _prediction, | |
"user_label": user_label, | |
"datetime": datetime.now().isoformat(), | |
} | |
) | |
) | |
f.write("\n") | |
gr.Success(f"Saved your feedback: {user_label}", duration=2, title="Thank you!") | |
def save_feedback_support(*args) -> None: | |
""" | |
Save user feedback: Support | |
""" | |
if is_running_in_hf_spaces(): | |
# Use a thread lock to avoid concurrent writes from different users. | |
with scheduler.lock: | |
append_feedback(*args, user_label="SUPPORT") | |
else: | |
append_feedback(*args, user_label="SUPPORT") | |
def save_feedback_nei(*args) -> None: | |
""" | |
Save user feedback: NEI | |
""" | |
if is_running_in_hf_spaces(): | |
# Use a thread lock to avoid concurrent writes from different users. | |
with scheduler.lock: | |
append_feedback(*args, user_label="NEI") | |
else: | |
append_feedback(*args, user_label="NEI") | |
def save_feedback_refute(*args) -> None: | |
""" | |
Save user feedback: Refute | |
""" | |
if is_running_in_hf_spaces(): | |
# Use a thread lock to avoid concurrent writes from different users. | |
with scheduler.lock: | |
append_feedback(*args, user_label="REFUTE") | |
else: | |
append_feedback(*args, user_label="REFUTE") | |
def number_visible(value): | |
""" | |
Show numbers (token counts) if GPT is selcted for retrieval | |
""" | |
if value == "GPT": | |
return gr.Number(visible=True) | |
else: | |
return gr.Number(visible=False) | |
def slider_visible(value): | |
""" | |
Hide slider (top_k) if GPT is selcted for retrieval | |
""" | |
if value == "GPT": | |
return gr.Slider(visible=False) | |
else: | |
return gr.Slider(visible=True) | |
# Event listeners | |
# Press Enter or Shift-Enter to submit | |
gr.on( | |
triggers=[claim.submit, evidence.submit], | |
fn=query_model, | |
inputs=[claim, evidence], | |
outputs=prediction, | |
) | |
# Get evidence from PDF and run the model | |
gr.on( | |
triggers=[get_evidence.click], | |
fn=retrieve_evidence, | |
inputs=[pdf_file, claim, top_k, retrieval_method], | |
outputs=[evidence, prompt_tokens, completion_tokens], | |
).then( | |
fn=query_model, | |
inputs=[claim, evidence], | |
outputs=prediction, | |
api_name=False, | |
) | |
# Handle "Support" examples | |
gr.on( | |
triggers=[support_example.dataset.select], | |
fn=select_example, | |
inputs=support_example.dataset, | |
outputs=[claim, evidence], | |
api_name=False, | |
).then( | |
fn=query_model, | |
inputs=[claim, evidence], | |
outputs=prediction, | |
api_name=False, | |
) | |
# Handle "NEI" examples | |
gr.on( | |
triggers=[nei_example.dataset.select], | |
fn=select_example, | |
inputs=nei_example.dataset, | |
outputs=[claim, evidence], | |
api_name=False, | |
).then( | |
fn=query_model, | |
inputs=[claim, evidence], | |
outputs=prediction, | |
api_name=False, | |
) | |
# Handle "Refute" examples | |
gr.on( | |
triggers=[refute_example.dataset.select], | |
fn=select_example, | |
inputs=refute_example.dataset, | |
outputs=[claim, evidence], | |
api_name=False, | |
).then( | |
fn=query_model, | |
inputs=[claim, evidence], | |
outputs=prediction, | |
api_name=False, | |
) | |
# Handle evidence retrieval examples: get evidence from PDF and run the model | |
gr.on( | |
triggers=[retrieval_example.dataset.select], | |
fn=select_retrieval_example, | |
inputs=retrieval_example.dataset, | |
outputs=[pdf_file, claim], | |
api_name=False, | |
).then( | |
fn=retrieve_evidence, | |
inputs=[pdf_file, claim, top_k, retrieval_method], | |
outputs=[evidence, prompt_tokens, completion_tokens], | |
api_name=False, | |
).then( | |
fn=query_model, | |
inputs=[claim, evidence], | |
outputs=prediction, | |
api_name=False, | |
) | |
# Change the model then update the predictions | |
model.change( | |
fn=select_model, | |
inputs=model, | |
).then( | |
fn=query_model, | |
inputs=[claim, evidence], | |
outputs=prediction, | |
api_name=False, | |
) | |
# Log user feedback when button is clicked | |
flag_support.click( | |
fn=save_feedback_support, | |
inputs=[claim, evidence, model, prediction], | |
outputs=None, | |
api_name=False, | |
) | |
flag_nei.click( | |
fn=save_feedback_nei, | |
inputs=[claim, evidence, model, prediction], | |
outputs=None, | |
api_name=False, | |
) | |
flag_refute.click( | |
fn=save_feedback_refute, | |
inputs=[claim, evidence, model, prediction], | |
outputs=None, | |
api_name=False, | |
) | |
# Change visibility of top-k slider and token counts if GPT is selected for retrieval | |
retrieval_method.change( | |
number_visible, retrieval_method, prompt_tokens, api_name=False | |
) | |
retrieval_method.change( | |
number_visible, retrieval_method, completion_tokens, api_name=False | |
) | |
retrieval_method.change(slider_visible, retrieval_method, top_k, api_name=False) | |
if __name__ == "__main__": | |
# allowed_paths is needed to upload PDFs from specific example directory | |
demo.launch(allowed_paths=[f"{os.getcwd()}/examples/retrieval"]) | |