import streamlit as st import torch from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import pickle import sklearn from sklearn.preprocessing import LabelEncoder # Скачать и загрузить токенизатор tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') # Загружаем квантизированную модель @st.cache_resource def load_quantized_model(model_path, num_labels): model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', num_labels=num_labels) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model_state_dict']) return model model_path = "epoch_2.pt" num_labels = 126 # или другое количество меток в зависимости от вашей задачи model = load_quantized_model(model_path, num_labels) model.eval() # Перевод модели в режим оценки # Установите CUDA, если доступно, иначе используйте CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Streamlit интерфейс st.title('Arxiv tag classification') title = st.text_input('Title:', '') summary = st.text_area('Summary:', '') if st.button('Predict'): # Подготовка данных combined_text = f"{title} {summary}" inputs = tokenizer( combined_text, padding=True, truncation=True, max_length=512, return_tensors='pt' ).to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.nn.functional.softmax(logits, dim=-1).squeeze() # Отсортируем вероятности и классы sorted_indices = torch.argsort(probs, descending=True) cumulative_probs = 0.0 selected_indices = [] # Выбираем классы пока сумма вероятностей не станет >= 0.95 for idx in sorted_indices: cumulative_probs += probs[idx].item() selected_indices.append((idx.item(), probs[idx].item())) if cumulative_probs >= 0.95: break with open('label_encoder.pkl', 'rb') as f: label_encoder = pickle.load(f) # Отображение результата st.write("Predicted classes by probability up to 95%:") for class_id, probability in selected_indices: st.write(f'Class : {label_encoder.inverse_transform([class_id])[0]}, Probability: {probability * 100:.2f}%')