File size: 2,823 Bytes
cd9e1fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import xml.etree.ElementTree as ET
import re
import urllib
import torch


from transformers import pipeline

classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")


def get_arxiv_title_and_abstract(link):
    # Regular expression pattern for arXiv link validation
    try:
        pattern = r'^https?://arxiv\.org/(?:abs|pdf)/(\d{4}\.\d{4,5})(?:\.pdf)?/?$'
        match = re.match(pattern, link)

        if not match:
            raise ValueError("Invalid arXiv link")

        # Construct the arXiv API URL for the paper
        arxiv_id = match.group(1)
        api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}"

        # Retrieve the paper metadata using the arXiv API
        with urllib.request.urlopen(api_url) as response:
            xml_data = response.read().decode()

        # Extract the title and abstract from the XML data
        title = re.search(r'<title>(.*?)</title>', xml_data).group(1)
        abstract = re.search(r'<summary>(.*?)</summary>', xml_data, re.DOTALL).group(1)

        # Clean up the title and abstract
        title = re.sub(r'\s+', ' ', title).strip()
        abstract = re.sub(r'\s+', ' ', abstract).strip()
    
        return title, abstract
    except:
        raise gr.Error('Invalid arXiv URL!')


def classify_paper(title, abstract):
    text = f"TITLE\n{title}\n\nABSTRACT\n{abstract}"
    item = classifier.tokenizer(text)
    input_tensor = torch.tensor(item['input_ids'])[None]
    logits = classifier.model(input_tensor).logits[0]
    preds = torch.sigmoid(logits).detach().cpu().numpy()
    result = {classifier.model.config.id2label[num]: float(prob) for num, prob in enumerate(preds) if prob > 0.1}
    return result
    

with gr.Blocks(title='Paper classifier') as demo:
    
    gr.Markdown('Please enter an arXiv link **OR** fill title and abstract manually')
    with gr.Row():
        with gr.Column():
            arxiv_link = gr.Textbox(label="Arxiv link")

            b1 = gr.Button("Parse Link")

            title = gr.Textbox(label="Paper title")
            abstract = gr.Textbox(label="Paper abstract")

            b2 = gr.Button("Classify Paper", variant='primary')
            
            b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")


        with gr.Column():
            out = gr.Label(label="Topics")
            b2.click(classify_paper, inputs=[title, abstract], outputs=out)
    
    gr.Markdown('## Examples')
    gr.Examples(
        examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/1503.04376'], ['https://arxiv.org/abs/2201.06601']],
        inputs=arxiv_link,
        outputs=[title, abstract],
        fn=get_arxiv_title_and_abstract,
        cache_examples=True,
    )

demo.launch(share=True)