Aleksei Ovchenkov
final
5e0be4d
import streamlit as st
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import pickle
import sklearn
from sklearn.preprocessing import LabelEncoder
# Скачать и загрузить токенизатор
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
# Загружаем квантизированную модель
@st.cache_resource
def load_quantized_model(model_path, num_labels):
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', num_labels=num_labels)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model_state_dict'])
return model
model_path = "epoch_2.pt"
num_labels = 126 # или другое количество меток в зависимости от вашей задачи
model = load_quantized_model(model_path, num_labels)
model.eval() # Перевод модели в режим оценки
# Установите CUDA, если доступно, иначе используйте CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Streamlit интерфейс
st.title('Arxiv tag classification')
title = st.text_input('Title:', '')
summary = st.text_area('Summary:', '')
if st.button('Predict'):
# Подготовка данных
combined_text = f"{title} {summary}"
inputs = tokenizer(
combined_text,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
# Отсортируем вероятности и классы
sorted_indices = torch.argsort(probs, descending=True)
cumulative_probs = 0.0
selected_indices = []
# Выбираем классы пока сумма вероятностей не станет >= 0.95
for idx in sorted_indices:
cumulative_probs += probs[idx].item()
selected_indices.append((idx.item(), probs[idx].item()))
if cumulative_probs >= 0.95:
break
with open('label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
# Отображение результата
st.write("Predicted classes by probability up to 95%:")
for class_id, probability in selected_indices:
st.write(f'Class : {label_encoder.inverse_transform([class_id])[0]}, Probability: {probability * 100:.2f}%')