File size: 1,980 Bytes
c5da81d d98e05f 5bf9261 c5da81d d98e05f c5da81d 5f333ba a8f9ede 5bf9261 f944458 a8f9ede f944458 a8f9ede f944458 a8f9ede 24e8e95 a8f9ede c5da81d 15e4f68 d98e05f 15e4f68 d98e05f 15e4f68 d98e05f 33d6eda c5da81d 33d6eda d98e05f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModel, pipeline
from torch import nn
st.markdown("### Articles classificator.")
# st.markdown("<img width=200px src='https://rozetked.me/images/uploads/dwoilp3BVjlE.jpg'>", unsafe_allow_html=True)
@st.cache(allow_output_mutation=True)
def get_bert_and_tokenizer():
model_name = 'bert-base-uncased'
# return AutoModel.from_pretrained(model_name), AutoTokenizer.from_pretrained(model_name)
return AutoTokenizer.from_pretrained(model_name)
tokenizer = get_bert_and_tokenizer()
class devops_model(nn.Module):
def __init__(self):
super(devops_model, self).__init__()
self.bert = None
self.fc = nn.Sequential(
nn.Linear(768, 768),
nn.ReLU(),
nn.Dropout(0.3),
nn.BatchNorm1d(768),
nn.Linear(768, 5),
nn.LogSoftmax(dim=-1)
)
def forward(self, train_batch):
emb = self.bert(**train_batch)['pooler_output']
return self.fc(emb)
@st.cache
def LoadModel():
return torch.load('model.pt', map_location=torch.device('cpu'))
model = LoadModel()
classes = ['Computer Science', 'Mathematics', 'Physics', 'Quantitative Biology', 'Statistics']
def process(title, summary):
text = title + summary
if not text.strip():
return ''
model.eval()
lines = [text]
X = tokenizer(lines, padding=True, truncation=True, return_tensors="pt")
out = model(X)
probs = torch.exp(out[0])
sorted_indexes = torch.argsort(probs, descending=True)
probs_sum = idx = 0
str = ''
while probs_sum < 0.95:
prob_idx = sorted_indexes[idx]
prob = probs[prob_idx]
str += f'{classes[prob_idx]}: {prob:.3f}\n'
idx += 1
probs_sum += prob
return str
title = st.text_area("Title", height=30)
summary = st.text_area("Summary", height=200)
st.markdown(f"{process(title, summary)}") |