rrevoid commited on
Commit
8d3f7b8
·
1 Parent(s): df1dde0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -19,13 +19,13 @@ def init_model():
19
  nn.Sigmoid()
20
  )
21
 
22
- model_path = 'model.pt'
23
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
24
  model.eval()
25
  return model
26
 
27
- cats = ['Computer Science', 'Economics', 'Electrical Engineering',
28
- 'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
29
 
30
  def predict(outputs):
31
  top = 0
@@ -34,14 +34,19 @@ def predict(outputs):
34
  top_cats = []
35
  top_probs = []
36
 
 
37
  for prob, cat in sorted(zip(probs, cats), reverse=True):
 
 
 
 
38
  if top < 95:
39
  percent = prob * 100
40
  top += percent
41
  top_cats.append(cat)
42
  top_probs.append(round(percent, 1))
43
 
44
- chart_data = pd.DataFrame(top_probs, index=top_cats, columns=['percent'])
45
  st.bar_chart(chart_data)
46
 
47
  tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
@@ -59,7 +64,7 @@ if not title:
59
  st.warning("Please fill in required fields")
60
  else:
61
  st.markdown("### Result")
62
- encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
63
  max_length = 512, truncation=True)
64
  with torch.no_grad():
65
  outputs = model(**encoded_input).pooler_output[:, 0, :]
 
19
  nn.Sigmoid()
20
  )
21
 
22
+ model_path = "model.pt"
23
  model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
24
  model.eval()
25
  return model
26
 
27
+ cats = ["Computer Science", "Economics", "Electrical Engineering",
28
+ "Mathematics", "Physics", "Biology", "Finance", "Statistics"]
29
 
30
  def predict(outputs):
31
  top = 0
 
34
  top_cats = []
35
  top_probs = []
36
 
37
+ first = True
38
  for prob, cat in sorted(zip(probs, cats), reverse=True):
39
+ if first:
40
+ if cat == "Computer Science":
41
+ st.write("Today everything is connected with Computer Science"
42
+ first = False
43
  if top < 95:
44
  percent = prob * 100
45
  top += percent
46
  top_cats.append(cat)
47
  top_probs.append(round(percent, 1))
48
 
49
+ chart_data = pd.DataFrame(top_probs, index=top_cats, columns=["percent"])
50
  st.bar_chart(chart_data)
51
 
52
  tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
 
64
  st.warning("Please fill in required fields")
65
  else:
66
  st.markdown("### Result")
67
+ encoded_input = tokenizer(title + ". " + abstract, return_tensors="pt", padding=True,
68
  max_length = 512, truncation=True)
69
  with torch.no_grad():
70
  outputs = model(**encoded_input).pooler_output[:, 0, :]