import streamlit as st from transformers import pipeline import torch from transformers import AutoModelForSequenceClassification import pandas as pd from typing import Dict from transformers import RobertaTokenizer from typing import List USED_MODEL = "distilroberta-base" @st.cache_resource # кэширование def load_model(): # csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное arxiv_topics_df = pd.read_csv('arxiv_topics.csv') category_to_index = {} current_index = 0 for i, row in arxiv_topics_df.iterrows(): category = row['category'] if category not in category_to_index: category_to_index[category] = current_index current_index += 1 index_to_category = {value: key for key, value in category_to_index.items()} model = AutoModelForSequenceClassification.from_pretrained( f"bumchik2/train-{USED_MODEL}-tags-classification", problem_type="multi_label_classification", num_labels=len(category_to_index), id2label=index_to_category, label2id=category_to_index ) model.eval() return model model = load_model() @st.cache_resource() def get_tokenizer(): return RobertaTokenizer.from_pretrained(USED_MODEL) def tokenize_function(text): tokenizer = get_tokenizer() return tokenizer(text, padding="max_length", truncation=True) @torch.no_grad def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]: # csv локально прочитать очень быстро, так что его не кешируем, хотя это не сложно было бы добавить наверное arxiv_topics_df = pd.read_csv('arxiv_topics.csv') category_to_index = {} current_index = 0 for i, row in arxiv_topics_df.iterrows(): category = row['category'] if category not in category_to_index: category_to_index[category] = current_index current_index += 1 index_to_category = {value: key for key, value in category_to_index.items()} text = f'{title} $ {summary or ""}' category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits sigmoid = torch.nn.Sigmoid() category_probs = sigmoid(category_logits.squeeze().cpu()).numpy() category_probs /= category_probs.sum() category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])} for index in range(len(index_to_category)): category_probs_dict[index_to_category[index]] += float(category_probs[index]) return category_probs_dict def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]: current_p = 0 probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1] current_index = 0 answer = [] while current_p <= target_probability: current_p += probs_list[current_index][0] if not print_probabilities: answer.append(probs_list[current_index][1]) else: answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})') current_index += 1 if current_index >= len(probs_list): break return answer title = st.text_input("Article title", value="Enter title here...") summary = st.text_input("Article summary", value="Enter summary here...") need_to_print_probabilities = st.radio("Need to print probabilities: ", ('Yes', 'No'), index=0) st.session_state['need_to_print_probabilities'] = need_to_print_probabilities target_probability = st.slider("Select minimum probability sum", 0.0, 1.0, step=0.01, value=0.95) st.session_state['target_probability'] = 'target_probability' if title or summary: category_probs_dict = get_category_probs_dict(model=model, title=title, summary=summary or '') result = get_most_probable_keys(probs_dict=category_probs_dict, target_probability=target_probability, print_probabilities=need_to_print_probabilities=='Yes') result_str = " \n ".join(result) st.write(result_str)