Aleksei Ovchenkov commited on
Commit
a1b510f
·
1 Parent(s): fba60b4
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +64 -0
  3. epoch_2.pt +3 -0
  4. label_encoder.pkl +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ epoch_2.pt filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
+ import pickle
5
+
6
+ # Скачать и загрузить токенизатор
7
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
8
+
9
+ # Загружаем квантизированную модель
10
+ def load_quantized_model(model_path, num_labels):
11
+ model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', num_labels=num_labels)
12
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model_state_dict'])
13
+ return model
14
+
15
+ model_path = "epoch_2.pt"
16
+ num_labels = 126 # или другое количество меток в зависимости от вашей задачи
17
+ model = load_quantized_model(model_path, num_labels)
18
+ model.eval() # Перевод модели в режим оценки
19
+
20
+ # Установите CUDA, если доступно, иначе используйте CPU
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model.to(device)
23
+
24
+ # Streamlit интерфейс
25
+ st.title('Arxiv tag classification')
26
+
27
+ title = st.text_input('Title:', '')
28
+ summary = st.text_area('Summary:', '')
29
+
30
+ if st.button('Predict'):
31
+ # Подготовка данных
32
+ combined_text = f"{title} {summary}"
33
+ inputs = tokenizer(
34
+ combined_text,
35
+ padding=True,
36
+ truncation=True,
37
+ max_length=512,
38
+ return_tensors='pt'
39
+ ).to(device)
40
+
41
+ with torch.no_grad():
42
+ outputs = model(**inputs)
43
+ logits = outputs.logits
44
+ probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
45
+
46
+ # Отсортируем вероятности и классы
47
+ sorted_indices = torch.argsort(probs, descending=True)
48
+ cumulative_probs = 0.0
49
+ selected_indices = []
50
+
51
+ # Выбираем классы пока сумма вероятностей не станет >= 0.95
52
+ for idx in sorted_indices:
53
+ cumulative_probs += probs[idx].item()
54
+ selected_indices.append((idx.item(), probs[idx].item()))
55
+ if cumulative_probs >= 0.95:
56
+ break
57
+
58
+ with open('label_encoder.pkl', 'rb') as f:
59
+ label_encoder = pickle.load(f)
60
+
61
+ # Отображение результата
62
+ st.write("Predicted classes by probability up to 95%:")
63
+ for class_id, probability in selected_indices:
64
+ st.write(f'Class : {label_encoder.inverse_transform([class_id])[0]}, Probability: {probability * 100:.2f}%')
epoch_2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4288752cd10f875ea22c7aafe7c55082b42745bf9ea80d29e745d37539593d1c
3
+ size 790669426
label_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a070ca9de881130ec4ccda071fdec6842f604051a7c47507857c0d2343c76eca
3
+ size 9312