bumchik2's picture
minor fix
e225b80
raw
history blame
3.98 kB
import streamlit as st
from transformers import pipeline
import torch
from transformers import AutoModelForSequenceClassification
import pandas as pd
from typing import Dict
from transformers import DistilBertTokenizer
from typing import List
USED_MODEL = "distilbert-base-cased"
@st.cache_resource # кэширование
def load_model():
# csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное
arxiv_topics_df = pd.read_csv('arxiv_topics.csv')
tag_to_index = {}
for i, row in arxiv_topics_df.iterrows():
tag_to_index[row['tag']] = i
index_to_tag = {value: key for key, value in tag_to_index.items()}
return AutoModelForSequenceClassification.from_pretrained(
"bumchik2/train_distilbert-base-cased-tags-classification-simple",
problem_type="multi_label_classification",
num_labels=len(tag_to_index),
id2label=index_to_tag,
label2id=tag_to_index
)
model = load_model()
@st.cache_resource()
def get_tokenizer():
return DistilBertTokenizer.from_pretrained(USED_MODEL)
def tokenize_function(text):
tokenizer = get_tokenizer()
return tokenizer(text, padding="max_length", truncation=True)
@torch.no_grad
def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:
# csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное
arxiv_topics_df = pd.read_csv('arxiv_topics.csv')
tag_to_index = {}
tag_to_category = {}
for i, row in arxiv_topics_df.iterrows():
tag_to_category[row['tag']] = row['category']
tag_to_index[row['tag']] = i
index_to_tag = {value: key for key, value in tag_to_index.items()}
text = f'{title} $ {summary}'
tags_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits
sigmoid = torch.nn.Sigmoid()
tags_probs = sigmoid(tags_logits.squeeze().cpu()).numpy()
tags_probs /= tags_probs.sum()
category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}
for index in range(len(index_to_tag)):
category_probs_dict[tag_to_category[index_to_tag[index]]] += float(tags_probs[index])
return category_probs_dict
def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]:
current_p = 0
probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1]
current_index = 0
answer = []
while current_p <= target_probability:
current_p += probs_list[current_index][0]
if not print_probabilities:
answer.append(probs_list[current_index][1])
else:
answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})')
current_index += 1
if current_index >= len(probs_list):
break
return answer
title = st.text_input("Article title", value="Enter title here...")
summary = st.text_input("Article summary", value="Enter summary here...")
need_to_print_probabilities = st.radio("Need to print probabilities: ", ('Yes', 'No'), index=0)
st.session_state['need_to_print_probabilities'] = need_to_print_probabilities
target_probability = st.slider("Select minimum probability sum", 0.0, 1.0, step=0.01, value=0.95)
st.session_state['target_probability'] = 'target_probability'
if title and summary:
category_probs_dict = get_category_probs_dict(model=model, title=title, summary=summary)
result = get_most_probable_keys(probs_dict=category_probs_dict, target_probability=target_probability, print_probabilities=need_to_print_probabilities=='Yes')
result_str = " \n ".join(result)
st.write(result_str)