File size: 1,980 Bytes
c5da81d
d98e05f
5bf9261
 
c5da81d
d98e05f
 
c5da81d
5f333ba
a8f9ede
5bf9261
f944458
 
a8f9ede
f944458
a8f9ede
 
 
 
f944458
a8f9ede
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24e8e95
a8f9ede
 
c5da81d
15e4f68
 
d98e05f
 
15e4f68
 
d98e05f
 
 
 
 
15e4f68
 
 
 
 
 
 
 
 
 
d98e05f
33d6eda
c5da81d
33d6eda
d98e05f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel, pipeline
from torch import nn

st.markdown("### Articles classificator.")
# st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)

@st.cache(allow_output_mutation=True)
def get_bert_and_tokenizer():
    model_name = 'bert-base-uncased'
    # return AutoModel.from_pretrained(model_name), AutoTokenizer.from_pretrained(model_name)
    return AutoTokenizer.from_pretrained(model_name)

tokenizer = get_bert_and_tokenizer()

class devops_model(nn.Module):
    def __init__(self):
        super(devops_model, self).__init__()
        self.bert = None
        self.fc = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.BatchNorm1d(768),            
            nn.Linear(768, 5),
            nn.LogSoftmax(dim=-1)
        )
        
    def forward(self, train_batch):
        emb = self.bert(**train_batch)['pooler_output']
        return self.fc(emb)

@st.cache
def LoadModel():
    return torch.load('model.pt', map_location=torch.device('cpu'))

model = LoadModel()

classes = ['Computer Science', 'Mathematics', 'Physics', 'Quantitative Biology', 'Statistics']

def process(title, summary):
    text = title + summary
    if not text.strip():
        return ''
    model.eval()
    lines = [text]
    X = tokenizer(lines, padding=True, truncation=True, return_tensors="pt")
    out = model(X)
    probs = torch.exp(out[0])
    sorted_indexes = torch.argsort(probs, descending=True)
    probs_sum = idx = 0
    str = ''
    while probs_sum < 0.95:
        prob_idx = sorted_indexes[idx]
        prob = probs[prob_idx]
        str += f'{classes[prob_idx]}: {prob:.3f}\n'    
        idx += 1
        probs_sum += prob
    return str
    
title = st.text_area("Title", height=30)

summary = st.text_area("Summary", height=200)

st.markdown(f"{process(title, summary)}")