rrevoid commited on
Commit
875cfff
·
1 Parent(s): a0dce4b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import streamlit as st
4
+ import torch.nn as nn
5
+ from transformers import RobertaTokenizer, RobertaModel
6
+
7
+ @st.cache(suppress_st_warning=True)
8
+ def init_tokenizer():
9
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
10
+ return tokenizer
11
+
12
+
13
+ @st.cache(suppress_st_warning=True)
14
+ def init_model():
15
+ model = RobertaModel.from_pretrained("roberta-large-mnli")
16
+
17
+ model.pooler = nn.Sequential(
18
+ nn.Linear(1024, 256),
19
+ nn.LayerNorm(256),
20
+ nn.ReLU(),
21
+ nn.Linear(256, 8),
22
+ nn.Sigmoid()
23
+ )
24
+
25
+ model_path = 'model.pt'
26
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
27
+
28
+ cats = ['Computer Science', 'Economics', 'Electrical Engineering',
29
+ 'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
30
+
31
+ def predict(outputs):
32
+ top = 0
33
+ probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
34
+
35
+ for prob, cat in sorted(zip(probs, cats), reverse=True):
36
+ if top < 95:
37
+ percent = prob * 100
38
+ top += percent
39
+ st.write(f'{cat}: {round(percent, 1)}')
40
+
41
+
42
+ st.markdown("### Title")
43
+
44
+ title = st.text_area("Enter title", height=20)
45
+
46
+ st.markdown("### Abstract")
47
+
48
+ abstract = st.text_area("Enter abstract", height=200)
49
+
50
+ if not title:
51
+ st.warning("Please fill out so required fields")
52
+ else:
53
+ tokenizer = init_tokenizer()
54
+ model = init_model()
55
+
56
+ encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
57
+ max_length = 512, truncation=True)
58
+ outputs = model(**encoded_input).pooler_output[:, 0, :]
59
+ predict(outputs)