Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification | |
import pickle | |
import sklearn | |
from sklearn.preprocessing import LabelEncoder | |
# Скачать и загрузить токенизатор | |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') | |
# Загружаем квантизированную модель | |
def load_quantized_model(model_path, num_labels): | |
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', num_labels=num_labels) | |
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model_state_dict']) | |
return model | |
model_path = "epoch_2.pt" | |
num_labels = 126 # или другое количество меток в зависимости от вашей задачи | |
model = load_quantized_model(model_path, num_labels) | |
model.eval() # Перевод модели в режим оценки | |
# Установите CUDA, если доступно, иначе используйте CPU | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Streamlit интерфейс | |
st.title('Arxiv tag classification') | |
title = st.text_input('Title:', '') | |
summary = st.text_area('Summary:', '') | |
if st.button('Predict'): | |
# Подготовка данных | |
combined_text = f"{title} {summary}" | |
inputs = tokenizer( | |
combined_text, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors='pt' | |
).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze() | |
# Отсортируем вероятности и классы | |
sorted_indices = torch.argsort(probs, descending=True) | |
cumulative_probs = 0.0 | |
selected_indices = [] | |
# Выбираем классы пока сумма вероятностей не станет >= 0.95 | |
for idx in sorted_indices: | |
cumulative_probs += probs[idx].item() | |
selected_indices.append((idx.item(), probs[idx].item())) | |
if cumulative_probs >= 0.95: | |
break | |
with open('label_encoder.pkl', 'rb') as f: | |
label_encoder = pickle.load(f) | |
# Отображение результата | |
st.write("Predicted classes by probability up to 95%:") | |
for class_id, probability in selected_indices: | |
st.write(f'Class : {label_encoder.inverse_transform([class_id])[0]}, Probability: {probability * 100:.2f}%') |