rrevoid commited on
Commit
93acac7
·
1 Parent(s): dbc0dcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -28,29 +28,35 @@ cats = ["Computer Science", "Economics", "Electrical Engineering",
28
 
29
  def predict(outputs):
30
  top = 0
31
- probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
 
32
 
33
  top_cats = []
34
  top_probs = []
35
 
36
  first = True
 
37
  for prob, cat in sorted(zip(probs, cats), reverse=True):
38
  if first:
39
  if cat == "Computer Science":
40
- st.write("Today everything is connected with Computer Science")
41
  first = False
42
  if top < 95:
43
  percent = prob * 100
44
  top += percent
45
  top_cats.append(cat)
46
- top_probs.append(round(percent, 1))
47
- return pd.DataFrame(top_probs, index=top_cats, columns=['Percent'])
 
 
 
 
48
 
49
  tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
50
  model = init_model()
51
 
52
  st.title("Article classifier")
53
- st.markdown("<img width=200px src='https://lionbridge.ai/wp-content/uploads/2020/09/2020-09-08_text-classification-tools-services.jpg'>", unsafe_allow_html=True)
54
  st.markdown("### Title")
55
 
56
  title = st.text_area("*Enter title (required)", height=20)
@@ -67,5 +73,4 @@ else:
67
  max_length=1024, truncation=True)
68
  with torch.no_grad():
69
  outputs = model(**encoded_input).pooler_output[:, 0, :]
70
- res = predict(outputs)
71
- st.write(res)
 
28
 
29
  def predict(outputs):
30
  top = 0
31
+ temp = 0.5
32
+ probs = nn.functional.softmax(outputs / temp, dim=1).tolist()[0]
33
 
34
  top_cats = []
35
  top_probs = []
36
 
37
  first = True
38
+ write_cs = False
39
  for prob, cat in sorted(zip(probs, cats), reverse=True):
40
  if first:
41
  if cat == "Computer Science":
42
+ write_cs = True
43
  first = False
44
  if top < 95:
45
  percent = prob * 100
46
  top += percent
47
  top_cats.append(cat)
48
+ top_probs.append(str(round(percent, 1)))
49
+ res = pd.DataFrame(top_probs, index=top_cats, columns=['Percent'])
50
+ st.write(res)
51
+
52
+ if write_cs:
53
+ st.write("Today everything is connected with Computer Science")
54
 
55
  tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
56
  model = init_model()
57
 
58
  st.title("Article classifier")
59
+ st.markdown("<img width=500px src='https://lionbridge.ai/wp-content/uploads/2020/09/2020-09-08_text-classification-tools-services.jpg'>", unsafe_allow_html=True)
60
  st.markdown("### Title")
61
 
62
  title = st.text_area("*Enter title (required)", height=20)
 
73
  max_length=1024, truncation=True)
74
  with torch.no_grad():
75
  outputs = model(**encoded_input).pooler_output[:, 0, :]
76
+ predict(outputs)