rrevoid commited on
Commit
9baaef5
·
1 Parent(s): e739fa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -6,10 +6,8 @@ import torch.nn as nn
6
  from transformers import RobertaTokenizer, RobertaModel
7
 
8
 
9
- @st.cache(hash_funcs={tokenizers.AddedToken: lambda _: None, _regex.Pattern: lambda _: None})
10
- def init():
11
- tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
12
-
13
  model = RobertaModel.from_pretrained("roberta-large-mnli")
14
 
15
  model.pooler = nn.Sequential(
@@ -23,7 +21,7 @@ def init():
23
  model_path = 'model.pt'
24
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
25
  model.eval()
26
- return tokenizer, model
27
 
28
  cats = ['Computer Science', 'Economics', 'Electrical Engineering',
29
  'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
@@ -38,7 +36,8 @@ def predict(outputs):
38
  top += percent
39
  st.write(f'{cat}: {round(percent, 1)}')
40
 
41
- tokenizer, model = init()
 
42
 
43
  st.markdown("### Title")
44
 
 
6
  from transformers import RobertaTokenizer, RobertaModel
7
 
8
 
9
+ @st.cache(suppress_st_warning=True)
10
+ def init_model():
 
 
11
  model = RobertaModel.from_pretrained("roberta-large-mnli")
12
 
13
  model.pooler = nn.Sequential(
 
21
  model_path = 'model.pt'
22
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
23
  model.eval()
24
+ return model
25
 
26
  cats = ['Computer Science', 'Economics', 'Electrical Engineering',
27
  'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
 
36
  top += percent
37
  st.write(f'{cat}: {round(percent, 1)}')
38
 
39
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
40
+ model = init_model()
41
 
42
  st.markdown("### Title")
43