rrevoid's picture
Update app.py
954b6e7
raw
history blame
1.6 kB
import torch
import streamlit as st
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel
@st.cache(suppress_st_warning=True)
def init():
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
model = RobertaModel.from_pretrained("roberta-large-mnli")
model.pooler = nn.Sequential(
nn.Linear(1024, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 8),
nn.Sigmoid()
)
model_path = 'model.pt'
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return tokenizer, model
cats = ['Computer Science', 'Economics', 'Electrical Engineering',
'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
def predict(outputs):
top = 0
probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
for prob, cat in sorted(zip(probs, cats), reverse=True):
if top < 95:
percent = prob * 100
top += percent
st.write(f'{cat}: {round(percent, 1)}')
tokenizer, model = init()
st.markdown("### Title")
title = st.text_area("Enter title", height=20)
st.markdown("### Abstract")
abstract = st.text_area("Enter abstract", height=200)
if not title:
st.warning("Please fill out so required fields")
else:
encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
max_length = 512, truncation=True)
with torch.no_grad():
outputs = model(**encoded_input).pooler_output[:, 0, :]
predict(outputs)