rrevoid's picture
Update app.py
ad3d6a3
raw
history blame
1.82 kB
import torch
import tokenizers
import pandas as pd
import streamlit as st
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel
@st.cache(suppress_st_warning=True)
def init_model():
model = RobertaModel.from_pretrained("roberta-large-mnli")
model.pooler = nn.Sequential(
nn.Linear(1024, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Linear(256, 8),
nn.Sigmoid()
)
model_path = 'model.pt'
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()
return model
cats = ['Computer Science', 'Economics', 'Electrical Engineering',
'Mathematics', 'Physics', 'Biology', 'Finance', 'Statistics']
def predict(outputs):
top = 0
probs = nn.functional.softmax(outputs, dim=1).tolist()[0]
top_cats = []
top_probs = []
for prob, cat in sorted(zip(probs, cats), reverse=True):
if top < 95:
percent = prob * 100
top += percent
top_cats.append(cat)
top_probs.append(prob)
chart_data = pd.DataFrame(top_probs, columns=top_cats)
st.bar_chart(chart_data)
tokenizer = RobertaTokenizer.from_pretrained("roberta-large-mnli")
model = init_model()
st.markdown("### Title")
title = st.text_area("* Enter title (required)", height=20)
st.markdown("### Abstract")
abstract = st.text_area("Enter abstract", height=200)
if not title:
st.warning("Please fill out so required fields")
else:
st.markdown("### Result")
encoded_input = tokenizer(title + '. ' + abstract, return_tensors='pt', padding=True,
max_length = 512, truncation=True)
with torch.no_grad():
outputs = model(**encoded_input).pooler_output[:, 0, :]
predict(outputs)