Spaces:
Runtime error
Runtime error
File size: 1,848 Bytes
875cfff 89d4e60 875cfff ad3d6a3 875cfff 6c23168 aabbe07 9baaef5 875cfff 954b6e7 9baaef5 875cfff ad3d6a3 875cfff ad3d6a3 5c1deee ad3d6a3 7853862 df1dde0 875cfff 9baaef5 875cfff 5c1780d 875cfff 5c1780d 875cfff ad3d6a3 875cfff 954b6e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import torch
import tokenizers
import pandas as pd
import streamlit as st
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel
@st.cache(suppress_st_warning=True)
def init_model():
model = RobertaModel.from_pretrained("roberta-large-mnli")
model.pooler = nn.Sequential(
nn.Linear(1024, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 8),
nn.Sigmoid()
)
model_path = 'model.pt'
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
cats = ['Computer Science', 'Economics', 'Electrical Engineering',
'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
def predict(outputs):
top = 0
probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
top_cats = []
top_probs = []
for prob, cat in sorted(zip(probs, cats), reverse=True):
if top < 95:
percent = prob * 100
top += percent
top_cats.append(cat)
top_probs.append(round(percent, 1))
chart_data = pd.DataFrame(top_probs, index=top_cats, columns=['percent'])
st.bar_chart(chart_data)
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
model = init_model()
st.markdown("### Title")
title = st.text_area("*Enter title (required)", height=20)
st.markdown("### Abstract")
abstract = st.text_area("Enter abstract", height=200)
if not title:
st.warning("Please fill in required fields")
else:
st.markdown("### Result")
encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
max_length = 512, truncation=True)
with torch.no_grad():
outputs = model(**encoded_input).pooler_output[:, 0, :]
predict(outputs)
|