File size: 3,103 Bytes
cd9e1fc
 
 
 
 
 
 
 
 
 
 
 
84837a0
 
 
 
cd9e1fc
 
84837a0
 
cd9e1fc
 
84837a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd9e1fc
 
 
 
 
dcadea4
 
 
cd9e1fc
 
 
 
 
9e56b79
cd9e1fc
 
 
 
8defca4
cd9e1fc
 
8defca4
0749c7f
18b8fa4
cd9e1fc
 
 
18b8fa4
 
cd9e1fc
 
 
 
 
 
 
5055671
0749c7f
dcadea4
cd9e1fc
 
 
 
 
9e56b79
cd9e1fc
 
 
 
 
 
e993597
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import xml.etree.ElementTree as ET
import re
import urllib
import torch


from transformers import pipeline

classifier = pipeline(model="Yozhikoff/arxiv-topics-distilbert-base-cased")


import re
import urllib.request
import xml.etree.ElementTree as ET

def get_arxiv_title_and_abstract(link):
    try:
        # Validate the arxiv link
        pattern = r'^https?://arxiv.org/(abs|pdf)/(\d{4}\.\d{4,5})(\.pdf)?$'
        match = re.match(pattern, link)
        if not match:
            raise ValueError('Invalid arxiv link')
        
        # Construct the arxiv API URL
        arxiv_id = match.group(2)
        api_url = f'http://export.arxiv.org/api/query?id_list={arxiv_id}'
        
        # Send a request to the arxiv API
        response = urllib.request.urlopen(api_url)
        xml_data = response.read()
        
        # Parse the XML data
        root = ET.fromstring(xml_data)
        entry = root.find('{http://www.w3.org/2005/Atom}entry')
        title = entry.find('{http://www.w3.org/2005/Atom}title').text
        summary = entry.find('{http://www.w3.org/2005/Atom}summary').text
        
        return title, summary
    except:
        raise gr.Error('Invalid arXiv URL!')


def classify_paper(title, abstract):
    if title == '' and abstract == '':
        raise gr.Error('Fill Title or/and Abstract')
    
    text = f"TITLE\n{title}\n\nABSTRACT\n{abstract}"
    item = classifier.tokenizer(text)
    input_tensor = torch.tensor(item['input_ids'])[None]
    logits = classifier.model(input_tensor).logits[0]
    preds = torch.sigmoid(logits).detach().cpu().numpy()
    result = {classifier.model.config.id2label[num]: float(prob) for num, prob in enumerate(preds) if prob > 0.25}
    return result
    

with gr.Blocks(title='Paper classifier') as demo:
    gr.Markdown('# Paper Topic Classifier')
    with gr.Row():
        with gr.Column():
            gr.Markdown('## Inputs')
            gr.Markdown('#### Please enter an arXiv link **OR** fill title and abstract manually')
            arxiv_link = gr.Textbox(label="Arxiv link", placeholder="Flip this text")

            b1 = gr.Button("Parse Link")

            title = gr.Textbox(label="Paper title", placeholder="Title text")
            abstract = gr.Textbox(label="Paper abstract", placeholder="Abstract text")

            b2 = gr.Button("Classify Paper", variant='primary')
            
            b1.click(fn=get_arxiv_title_and_abstract, inputs=arxiv_link, outputs=[title, abstract], api_name="parse")


        with gr.Column():
            gr.Markdown('## Topics')
            gr.Markdown('## ')
            gr.Markdown('## ')
            out = gr.Label(label="Topics")
            b2.click(classify_paper, inputs=[title, abstract], outputs=out)
    
    gr.Markdown('## Examples')
    gr.Examples(
        examples=[['https://arxiv.org/abs/1706.03762'], ['https://arxiv.org/abs/2304.06718'], ['https://arxiv.org/abs/1307.0058']],
        inputs=arxiv_link,
        outputs=[title, abstract],
        fn=get_arxiv_title_and_abstract,
        cache_examples=True,
    )

demo.launch()