File size: 2,180 Bytes
875cfff
 
ad3d6a3
875cfff
 
 
 
6c23168
aabbe07
9baaef5
875cfff
 
 
 
 
 
 
 
 
 
8d3f7b8
875cfff
954b6e7
9baaef5
875cfff
8d3f7b8
 
875cfff
 
 
 
ad3d6a3
 
 
875cfff
8d3f7b8
875cfff
8d3f7b8
 
9f3271a
8d3f7b8
875cfff
 
 
ad3d6a3
5c1deee
02d9909
875cfff
9baaef5
 
02d9909
 
 
875cfff
 
5c1780d
875cfff
 
 
9f3271a
875cfff
 
5c1780d
875cfff
ad3d6a3
8d3f7b8
875cfff
954b6e7
 
02d9909
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch

import pandas as pd
import streamlit as st
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel


@st.cache(suppress_st_warning=True)
def init_model():
    model = RobertaModel.from_pretrained("roberta-large-mnli")

    model.pooler = nn.Sequential(
        nn.Linear(1024, 256),
        nn.LayerNorm(256),
        nn.ReLU(),
        nn.Linear(256, 8),
        nn.Sigmoid()
    )
    
    model_path = "model.pt"
    model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
    model.eval()
    return model

cats = ["Computer Science", "Economics", "Electrical Engineering", 
        "Mathematics", "Physics", "Biology", "Finance", "Statistics"]

def predict(outputs):
    top = 0
    probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
    
    top_cats = []
    top_probs = []

    first = True
    for prob, cat in sorted(zip(probs, cats), reverse=True):
        if first:
            if cat == "Computer Science":
                st.write("Today everything is connected with Computer Science")
            first = False
        if top < 95:
            percent = prob * 100
            top += percent
            top_cats.append(cat)
            top_probs.append(round(percent, 1))
     return pd.DataFrame(top_probs, index=top_cats, columns=['Percent'])

tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
model = init_model()

st.title("Article classifier")
st.markdown("<img width=200px src='https://lionbridge.ai/wp-content/uploads/2020/09/2020-09-08_text-classification-tools-services.jpg'>", unsafe_allow_html=True)
st.markdown("### Title")

title = st.text_area("*Enter title (required)", height=20)

st.markdown("### Abstract")

abstract = st.text_area(" Enter abstract", height=200)

if not title:
    st.warning("Please fill in required fields")
else:    
    st.markdown("### Result")
    encoded_input = tokenizer(title + ". " + abstract, return_tensors="pt", padding=True, 
                          max_length = 512, truncation=True)
    with torch.no_grad():
        outputs = model(**encoded_input).pooler_output[:, 0, :]
        res = predict(outputs)
        st.write(res)