Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -28,29 +28,35 @@ cats = ["Computer Science", "Economics", "Electrical Engineering",
|
|
28 |
|
29 |
def predict(outputs):
|
30 |
top = 0
|
31 |
-
|
|
|
32 |
|
33 |
top_cats = []
|
34 |
top_probs = []
|
35 |
|
36 |
first = True
|
|
|
37 |
for prob, cat in sorted(zip(probs, cats), reverse=True):
|
38 |
if first:
|
39 |
if cat == "Computer Science":
|
40 |
-
|
41 |
first = False
|
42 |
if top < 95:
|
43 |
percent = prob * 100
|
44 |
top += percent
|
45 |
top_cats.append(cat)
|
46 |
-
top_probs.append(round(percent, 1))
|
47 |
-
|
|
|
|
|
|
|
|
|
48 |
|
49 |
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
|
50 |
model = init_model()
|
51 |
|
52 |
st.title("Article classifier")
|
53 |
-
st.markdown("<img width=
|
54 |
st.markdown("### Title")
|
55 |
|
56 |
title = st.text_area("*Enter title (required)", height=20)
|
@@ -67,5 +73,4 @@ else:
|
|
67 |
max_length=1024, truncation=True)
|
68 |
with torch.no_grad():
|
69 |
outputs = model(**encoded_input).pooler_output[:, 0, :]
|
70 |
-
|
71 |
-
st.write(res)
|
|
|
28 |
|
29 |
def predict(outputs):
|
30 |
top = 0
|
31 |
+
temp = 0.5
|
32 |
+
probs = nn.functional.softmax(outputs / temp, dim=1).tolist()[0]
|
33 |
|
34 |
top_cats = []
|
35 |
top_probs = []
|
36 |
|
37 |
first = True
|
38 |
+
write_cs = False
|
39 |
for prob, cat in sorted(zip(probs, cats), reverse=True):
|
40 |
if first:
|
41 |
if cat == "Computer Science":
|
42 |
+
write_cs = True
|
43 |
first = False
|
44 |
if top < 95:
|
45 |
percent = prob * 100
|
46 |
top += percent
|
47 |
top_cats.append(cat)
|
48 |
+
top_probs.append(str(round(percent, 1)))
|
49 |
+
res = pd.DataFrame(top_probs, index=top_cats, columns=['Percent'])
|
50 |
+
st.write(res)
|
51 |
+
|
52 |
+
if write_cs:
|
53 |
+
st.write("Today everything is connected with Computer Science")
|
54 |
|
55 |
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
|
56 |
model = init_model()
|
57 |
|
58 |
st.title("Article classifier")
|
59 |
+
st.markdown("<img width=500px src='https://lionbridge.ai/wp-content/uploads/2020/09/2020-09-08_text-classification-tools-services.jpg'>", unsafe_allow_html=True)
|
60 |
st.markdown("### Title")
|
61 |
|
62 |
title = st.text_area("*Enter title (required)", height=20)
|
|
|
73 |
max_length=1024, truncation=True)
|
74 |
with torch.no_grad():
|
75 |
outputs = model(**encoded_input).pooler_output[:, 0, :]
|
76 |
+
predict(outputs)
|
|