Spaces:
Runtime error
Runtime error
import torch | |
import streamlit as st | |
import torch.nn as nn | |
from transformers import RobertaTokenizer, RobertaModel | |
def init(): | |
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli") | |
model = RobertaModel.from_pretrained("roberta-large-mnli") | |
model.pooler = nn.Sequential( | |
nn.Linear(1024, 256), | |
nn.LayerNorm(256), | |
nn.ReLU(), | |
nn.Linear(256, 8), | |
nn.Sigmoid() | |
) | |
model_path = 'model.pt' | |
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) | |
model.eval() | |
return tokenizer, model | |
cats = ['Computer Science', 'Economics', 'Electrical Engineering', | |
'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics'] | |
def predict(outputs): | |
top = 0 | |
probs = nn.functional.softmax(outputs, dim=1).tolist()[0] | |
for prob, cat in sorted(zip(probs, cats), reverse=True): | |
if top < 95: | |
percent = prob * 100 | |
top += percent | |
st.write(f'{cat}: {round(percent, 1)}') | |
tokenizer, model = init() | |
st.markdown("### Title") | |
title = st.text_area("Enter title", height=20) | |
st.markdown("### Abstract") | |
abstract = st.text_area("Enter abstract", height=200) | |
if not title: | |
st.warning("Please fill out so required fields") | |
else: | |
encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True, | |
max_length = 512, truncation=True) | |
with torch.no_grad(): | |
outputs = model(**encoded_input).pooler_output[:, 0, :] | |
predict(outputs) | |