Spaces:
Runtime error
Runtime error
File size: 2,180 Bytes
875cfff ad3d6a3 875cfff 6c23168 aabbe07 9baaef5 875cfff 8d3f7b8 875cfff 954b6e7 9baaef5 875cfff 8d3f7b8 875cfff ad3d6a3 875cfff 8d3f7b8 875cfff 8d3f7b8 9f3271a 8d3f7b8 875cfff ad3d6a3 5c1deee 02d9909 875cfff 9baaef5 02d9909 875cfff 5c1780d 875cfff 9f3271a 875cfff 5c1780d 875cfff ad3d6a3 8d3f7b8 875cfff 954b6e7 02d9909 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import torch
import pandas as pd
import streamlit as st
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel
@st.cache(suppress_st_warning=True)
def init_model():
model = RobertaModel.from_pretrained("roberta-large-mnli")
model.pooler = nn.Sequential(
nn.Linear(1024, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 8),
nn.Sigmoid()
)
model_path = "model.pt"
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
cats = ["Computer Science", "Economics", "Electrical Engineering",
"Mathematics", "Physics", "Biology", "Finance", "Statistics"]
def predict(outputs):
top = 0
probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
top_cats = []
top_probs = []
first = True
for prob, cat in sorted(zip(probs, cats), reverse=True):
if first:
if cat == "Computer Science":
st.write("Today everything is connected with Computer Science")
first = False
if top < 95:
percent = prob * 100
top += percent
top_cats.append(cat)
top_probs.append(round(percent, 1))
return pd.DataFrame(top_probs, index=top_cats, columns=['Percent'])
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
model = init_model()
st.title("Article classifier")
st.markdown("<img width=200px src='https://lionbridge.ai/wp-content/uploads/2020/09/2020-09-08_text-classification-tools-services.jpg'>", unsafe_allow_html=True)
st.markdown("### Title")
title = st.text_area("*Enter title (required)", height=20)
st.markdown("### Abstract")
abstract = st.text_area(" Enter abstract", height=200)
if not title:
st.warning("Please fill in required fields")
else:
st.markdown("### Result")
encoded_input = tokenizer(title + ". " + abstract, return_tensors="pt", padding=True,
max_length = 512, truncation=True)
with torch.no_grad():
outputs = model(**encoded_input).pooler_output[:, 0, :]
res = predict(outputs)
st.write(res)
|