Spaces:
Runtime error
Runtime error
import gradio as gr | |
import xml.etree.ElementTree as ET | |
import re | |
import urllib | |
import torch | |
from transformers import pipeline | |
classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased") | |
def get_arxiv_title_and_abstract(link): | |
# Regular expression pattern for arXiv link validation | |
try: | |
pattern = r'^https?://arxiv\.org/(?:abs|pdf)/(\d{4}\.\d{4,5})(?:\.pdf)?/?$' | |
match = re.match(pattern, link) | |
if not match: | |
raise ValueError("Invalid arXiv link") | |
# Construct the arXiv API URL for the paper | |
arxiv_id = match.group(1) | |
api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}" | |
# Retrieve the paper metadata using the arXiv API | |
with urllib.request.urlopen(api_url) as response: | |
xml_data = response.read().decode() | |
# Extract the title and abstract from the XML data | |
title = re.search(r'<title>(.*?)</title>', xml_data).group(1) | |
abstract = re.search(r'<summary>(.*?)</summary>', xml_data, re.DOTALL).group(1) | |
# Clean up the title and abstract | |
title = re.sub(r'\s+', ' ', title).strip() | |
abstract = re.sub(r'\s+', ' ', abstract).strip() | |
return title, abstract | |
except: | |
raise gr.Error('Invalid arXiv URL!') | |
def classify_paper(title, abstract): | |
if title == '' and abstract == '': | |
raise gr.Error('Fill Title or/and Abstract') | |
text = f"TITLE\n{title}\n\nABSTRACT\n{abstract}" | |
item = classifier.tokenizer(text) | |
input_tensor = torch.tensor(item['input_ids'])[None] | |
logits = classifier.model(input_tensor).logits[0] | |
preds = torch.sigmoid(logits).detach().cpu().numpy() | |
result = {classifier.model.config.id2label[num]: float(prob) for num, prob in enumerate(preds) if prob > 0.1} | |
return result | |
with gr.Blocks(title='Paper classifier') as demo: | |
gr.Markdown('# Paper Topic Classifier') | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown('## Inputs') | |
gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually') | |
arxiv_link = gr.Textbox(label="Arxiv link") | |
b1 = gr.Button("Parse Link") | |
title = gr.Textbox(label="Paper title") | |
abstract = gr.Textbox(label="Paper abstract") | |
b2 = gr.Button("Classify Paper", variant='primary') | |
b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse") | |
with gr.Column(): | |
gr.Markdown('## Topics') | |
gr.Markdown('## ') | |
gr.Markdown('## ') | |
out = gr.Label(label="Topics") | |
b2.click(classify_paper, inputs=[title, abstract], outputs=out) | |
gr.Markdown('## Examples') | |
gr.Examples( | |
examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/1503.04376'], ['https://arxiv.org/abs/2201.06601']], | |
inputs=arxiv_link, | |
outputs=[title, abstract], | |
fn=get_arxiv_title_and_abstract, | |
cache_examples=True, | |
) | |
demo.launch(share=True) |