File size: 2,076 Bytes
0efe602 78f93e5 fffa35e 78f93e5 2083f73 78f93e5 2083f73 0efe602 2083f73 78f93e5 b8ff6ff 78f93e5 0efe602 78f93e5 b8ff6ff 78f93e5 b8ff6ff 78f93e5 b8ff6ff 78f93e5 b8ff6ff 78f93e5 b8ff6ff 78f93e5 0efe602 78f93e5 0efe602 78f93e5 0efe602 78f93e5 0efe602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import pandas as pd
import streamlit as st
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def combine_title_summary(title, summary):
return "title: " + title + " summary: " + summary
tag2ind = {
"Biology": 0,
"Physics": 1,
"Math": 2,
"Computer Science": 3,
}
@st.cache_resource
def load_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dir_name = "./distilbert/distilbert-base-cased/checkpoint-738"
dir_name = "./microsoft/deberta-v3-small/checkpoint-4915"
tokenizer = AutoTokenizer.from_pretrained(dir_name, use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained(dir_name).to(device)
return tokenizer, model
tokenizer, model = load_model()
def run_model(model, tokenizer, title, summary):
text = combine_title_summary(title, summary)
tokens_info = tokenizer(
text,
padding=False,
truncation=True,
return_tensors="pt",
)
model.eval()
model.cpu()
with torch.no_grad():
out = model(**tokens_info)
probs = torch.nn.functional.softmax(out.logits, dim=-1)[0]
ids = torch.argsort(probs, descending=True)
p = 0
best_tags, best_probs = [], []
for ind in ids:
p += probs[ind]
best_tags.append(list(tag2ind.keys())[ind])
best_probs.append(probs[ind])
if p >= 0.95:
break
return best_tags, best_probs
def main():
title = st.text_input(label="Title", value="")
abstract = st.text_area(label="Abstract", value="", height=200)
if st.button("Classify"):
if title == "" and abstract == "":
st.error("At least one of title or abstract must be provided")
else:
best_tags, best_probs = run_model(model, tokenizer, title, abstract)
df = pd.DataFrame(
dict(zip(best_tags, best_probs)).items(),
columns=["Theme", "Probability"],
)
st.table(df)
if __name__ == "__main__":
main()
|