File size: 2,004 Bytes
c5da81d
d98e05f
5bf9261
 
c5da81d
d98e05f
 
c5da81d
5f333ba
a8f9ede
5bf9261
f944458
 
a8f9ede
f944458
a8f9ede
 
 
 
f944458
a8f9ede
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24e8e95
a8f9ede
 
c5da81d
15e4f68
 
d98e05f
 
15e4f68
 
d98e05f
 
 
 
 
15e4f68
 
6b4aa4c
15e4f68
 
 
6b4aa4c
15e4f68
 
22a0333
d98e05f
33d6eda
c5da81d
33d6eda
d98e05f
a8e25eb
6b4aa4c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel, pipeline
from torch import nn

st.markdown("### Articles classificator.")
# st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)

@st.cache(allow_output_mutation=True)
def get_bert_and_tokenizer():
    model_name = 'bert-base-uncased'
    # return AutoModel.from_pretrained(model_name), AutoTokenizer.from_pretrained(model_name)
    return AutoTokenizer.from_pretrained(model_name)

tokenizer = get_bert_and_tokenizer()

class devops_model(nn.Module):
    def __init__(self):
        super(devops_model, self).__init__()
        self.bert = None
        self.fc = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.BatchNorm1d(768),            
            nn.Linear(768, 5),
            nn.LogSoftmax(dim=-1)
        )
        
    def forward(self, train_batch):
        emb = self.bert(**train_batch)['pooler_output']
        return self.fc(emb)

@st.cache
def LoadModel():
    return torch.load('model.pt', map_location=torch.device('cpu'))

model = LoadModel()

classes = ['Computer Science', 'Mathematics', 'Physics', 'Quantitative Biology', 'Statistics']

def process(title, summary):
    text = title + summary
    if not text.strip():
        return ''
    model.eval()
    lines = [text]
    X = tokenizer(lines, padding=True, truncation=True, return_tensors="pt")
    out = model(X)
    probs = torch.exp(out[0])
    sorted_indexes = torch.argsort(probs, descending=True)
    probs_sum = idx = 0
    res = []
    while probs_sum < 0.95:
        prob_idx = sorted_indexes[idx]
        prob = probs[prob_idx]
        res.append(f'{classes[prob_idx]}: {prob:.3f}')    
        idx += 1
        probs_sum += prob
    return res
    
title = st.text_area("Title", height=30)

summary = st.text_area("Summary", height=200)

for string in process(title, summary):
    st.markdown(string)