Spaces:
Running
Running
import streamlit as st | |
from transformers import pipeline | |
import torch | |
from transformers import AutoModelForSequenceClassification | |
import pandas as pd | |
from typing import Dict | |
from transformers import RobertaTokenizer | |
from typing import List | |
USED_MODEL = "distilroberta-base" | |
# кэширование | |
def load_model(): | |
# csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное | |
arxiv_topics_df = pd.read_csv('arxiv_topics.csv') | |
category_to_index = {} | |
current_index = 0 | |
for i, row in arxiv_topics_df.iterrows(): | |
category = row['category'] | |
if category not in category_to_index: | |
category_to_index[category] = current_index | |
current_index += 1 | |
index_to_category = {value: key for key, value in category_to_index.items()} | |
model = AutoModelForSequenceClassification.from_pretrained( | |
f"bumchik2/train-{USED_MODEL}-tags-classification", | |
problem_type="multi_label_classification", | |
num_labels=len(category_to_index), | |
id2label=index_to_category, | |
label2id=category_to_index | |
) | |
model.eval() | |
return model | |
model = load_model() | |
def get_tokenizer(): | |
return RobertaTokenizer.from_pretrained(USED_MODEL) | |
def tokenize_function(text): | |
tokenizer = get_tokenizer() | |
return tokenizer(text, padding="max_length", truncation=True) | |
def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]: | |
# csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное | |
arxiv_topics_df = pd.read_csv('arxiv_topics.csv') | |
category_to_index = {} | |
current_index = 0 | |
for i, row in arxiv_topics_df.iterrows(): | |
category = row['category'] | |
if category not in category_to_index: | |
category_to_index[category] = current_index | |
current_index += 1 | |
index_to_category = {value: key for key, value in category_to_index.items()} | |
text = f'{title} $ {summary or ""}' | |
category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits | |
sigmoid = torch.nn.Sigmoid() | |
category_probs = sigmoid(category_logits.squeeze().cpu()).numpy() | |
category_probs /= category_probs.sum() | |
category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])} | |
for index in range(len(index_to_category)): | |
category_probs_dict[index_to_category[index]] += float(category_probs[index]) | |
return category_probs_dict | |
def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]: | |
current_p = 0 | |
probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1] | |
current_index = 0 | |
answer = [] | |
while current_p <= target_probability: | |
current_p += probs_list[current_index][0] | |
if not print_probabilities: | |
answer.append(probs_list[current_index][1]) | |
else: | |
answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})') | |
current_index += 1 | |
if current_index >= len(probs_list): | |
break | |
return answer | |
title = st.text_input("Article title", value="Enter title here...") | |
summary = st.text_input("Article summary", value="Enter summary here...") | |
need_to_print_probabilities = st.radio("Need to print probabilities: ", ('Yes', 'No'), index=0) | |
st.session_state['need_to_print_probabilities'] = need_to_print_probabilities | |
target_probability = st.slider("Select minimum probability sum", 0.0, 1.0, step=0.01, value=0.95) | |
st.session_state['target_probability'] = 'target_probability' | |
if title or summary: | |
category_probs_dict = get_category_probs_dict(model=model, title=title, summary=summary or '') | |
result = get_most_probable_keys(probs_dict=category_probs_dict, target_probability=target_probability, print_probabilities=need_to_print_probabilities=='Yes') | |
result_str = " \n ".join(result) | |
st.write(result_str) | |