import gradio as gr import xml.etree.ElementTree as ET import re import urllib import torch from transformers import pipeline classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased") def get_arxiv_title_and_abstract(link): # Regular expression pattern for arXiv link validation try: pattern = r'^https?://arxiv\.org/(?:abs|pdf)/(\d{4}\.\d{4,5})(?:\.pdf)?/?$' match = re.match(pattern, link) if not match: raise ValueError("Invalid arXiv link") # Construct the arXiv API URL for the paper arxiv_id = match.group(1) api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}" # Retrieve the paper metadata using the arXiv API with urllib.request.urlopen(api_url) as response: xml_data = response.read().decode() # Extract the title and abstract from the XML data title = re.search(r'(.*?)', xml_data).group(1) abstract = re.search(r'(.*?)', xml_data, re.DOTALL).group(1) # Clean up the title and abstract title = re.sub(r'\s+', ' ', title).strip() abstract = re.sub(r'\s+', ' ', abstract).strip() return title, abstract except: raise gr.Error('Invalid arXiv URL!') def classify_paper(title, abstract): if title == '' and abstract == '': raise gr.Error('Fill Title or/and Abstract') text = f"TITLE\n{title}\n\nABSTRACT\n{abstract}" item = classifier.tokenizer(text) input_tensor = torch.tensor(item['input_ids'])[None] logits = classifier.model(input_tensor).logits[0] preds = torch.sigmoid(logits).detach().cpu().numpy() result = {classifier.model.config.id2label[num]: float(prob) for num, prob in enumerate(preds) if prob > 0.1} return result with gr.Blocks(title='Paper classifier') as demo: gr.Markdown('# Paper Topic Classifier') with gr.Row(): with gr.Column(): gr.Markdown('## Inputs') gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually') arxiv_link = gr.Textbox(label="Arxiv link") b1 = gr.Button("Parse Link") title = gr.Textbox(label="Paper title") abstract = gr.Textbox(label="Paper abstract") b2 = gr.Button("Classify Paper", variant='primary') b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse") with gr.Column(): gr.Markdown('## Topics') gr.Markdown('## ') gr.Markdown('## ') out = gr.Label(label="Topics") b2.click(classify_paper, inputs=[title, abstract], outputs=out) gr.Markdown('## Examples') gr.Examples( examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/1503.04376'], ['https://arxiv.org/abs/2201.06601']], inputs=arxiv_link, outputs=[title, abstract], fn=get_arxiv_title_and_abstract, cache_examples=True, ) demo.launch(share=True)