Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -19,13 +19,13 @@ def init_model():
|
|
19 |
nn.Sigmoid()
|
20 |
)
|
21 |
|
22 |
-
model_path =
|
23 |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
24 |
model.eval()
|
25 |
return model
|
26 |
|
27 |
-
cats = [
|
28 |
-
|
29 |
|
30 |
def predict(outputs):
|
31 |
top = 0
|
@@ -34,14 +34,19 @@ def predict(outputs):
|
|
34 |
top_cats = []
|
35 |
top_probs = []
|
36 |
|
|
|
37 |
for prob, cat in sorted(zip(probs, cats), reverse=True):
|
|
|
|
|
|
|
|
|
38 |
if top < 95:
|
39 |
percent = prob * 100
|
40 |
top += percent
|
41 |
top_cats.append(cat)
|
42 |
top_probs.append(round(percent, 1))
|
43 |
|
44 |
-
chart_data = pd.DataFrame(top_probs, index=top_cats, columns=[
|
45 |
st.bar_chart(chart_data)
|
46 |
|
47 |
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
|
@@ -59,7 +64,7 @@ if not title:
|
|
59 |
st.warning("Please fill in required fields")
|
60 |
else:
|
61 |
st.markdown("### Result")
|
62 |
-
encoded_input = tokenizer(title +
|
63 |
max_length = 512, truncation=True)
|
64 |
with torch.no_grad():
|
65 |
outputs = model(**encoded_input).pooler_output[:, 0, :]
|
|
|
19 |
nn.Sigmoid()
|
20 |
)
|
21 |
|
22 |
+
model_path = "model.pt"
|
23 |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
|
24 |
model.eval()
|
25 |
return model
|
26 |
|
27 |
+
cats = ["Computer Science", "Economics", "Electrical Engineering",
|
28 |
+
"Mathematics", "Physics", "Biology", "Finance", "Statistics"]
|
29 |
|
30 |
def predict(outputs):
|
31 |
top = 0
|
|
|
34 |
top_cats = []
|
35 |
top_probs = []
|
36 |
|
37 |
+
first = True
|
38 |
for prob, cat in sorted(zip(probs, cats), reverse=True):
|
39 |
+
if first:
|
40 |
+
if cat == "Computer Science":
|
41 |
+
st.write("Today everything is connected with Computer Science"
|
42 |
+
first = False
|
43 |
if top < 95:
|
44 |
percent = prob * 100
|
45 |
top += percent
|
46 |
top_cats.append(cat)
|
47 |
top_probs.append(round(percent, 1))
|
48 |
|
49 |
+
chart_data = pd.DataFrame(top_probs, index=top_cats, columns=["percent"])
|
50 |
st.bar_chart(chart_data)
|
51 |
|
52 |
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
|
|
|
64 |
st.warning("Please fill in required fields")
|
65 |
else:
|
66 |
st.markdown("### Result")
|
67 |
+
encoded_input = tokenizer(title + ". " + abstract, return_tensors="pt", padding=True,
|
68 |
max_length = 512, truncation=True)
|
69 |
with torch.no_grad():
|
70 |
outputs = model(**encoded_input).pooler_output[:, 0, :]
|