1NEYRON1 commited on
Commit
0855d6f
·
1 Parent(s): 92010d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -50
app.py CHANGED
@@ -1,6 +1,12 @@
1
  import streamlit as st
2
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
 
 
 
 
 
 
 
4
  id_to_cat = {0: 'Performance',
5
  1: 'Molecular Networks',
6
  2: 'Operating Systems',
@@ -140,7 +146,9 @@ id_to_cat = {0: 'Performance',
140
  136: 'Nuclear Experiment',
141
  137: 'Artificial Intelligence'}
142
 
143
- # Загружаем модель (замените на вашу модель, если нужно)
 
 
144
  model_name = 'checkpoint'
145
  try:
146
  tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
@@ -151,42 +159,25 @@ try:
151
  )
152
  except OSError as e:
153
  st.error(f"Ошибка загрузки модели: {e}. Убедитесь, что модель доступна или укажите другую.")
154
- st.stop() # Остановка выполнения приложения при ошибке
155
-
156
 
157
  def classify_text(title, description):
158
- """
159
- Классифицирует текст и возвращает результаты в отсортированном виде.
160
-
161
- Args:
162
- title (str): Заголовок текста.
163
- description (str): Краткое описание текста.
164
- show_all (bool): Показывать ли все результаты, независимо от порога.
165
- threshold (float): Порог суммарной вероятности.
166
-
167
- Returns:
168
- list: Отсортированный список результатов классификации.
169
- """
170
- text = f"{title} {description}" # Объединяем заголовок и описание
171
- topic_classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k = len(id_to_cat))
172
  try:
173
- results = topic_classifier(text)
174
- # results = topic_classifier(text, candidate_labels, multi_label=True) # multi_label=True для нескольких меток
 
 
 
 
175
  except Exception as e:
176
  st.error(f"Ошибка классификации: {e}")
177
  return []
178
 
179
- for i in results[0]:
180
- i['label'] = id_to_cat[int(i['label'].split('_')[1])]
181
-
182
-
183
- filtered_results = []
184
- for i in results[0]:
185
- filtered_results.append((i['label'], i['score']))
186
- return filtered_results
187
-
188
-
189
-
190
  # --- Интерфейс Streamlit ---
191
  st.title("Классификация статей 1")
192
 
@@ -194,29 +185,43 @@ st.title("Классификация статей 1")
194
  title = st.text_input("Заголовок статьи")
195
  description = st.text_area("Краткое описание статьи", height=150)
196
 
197
- # Кнопка "Классифицировать"
198
  if st.button("Классифицировать"):
199
  if not title and not description:
200
  st.warning("Пожалуйста, заполните хотя бы одно поле.")
201
  else:
202
- with st.spinner("Идет классификация..."): # Индикатор загрузки
203
- results = classify_text(title, description)
204
- if results:
205
- st.subheader("Результаты классификации (top 95%):")
206
- cumulative_prob = 0
207
- for label, score in results:
208
- st.write(f"- **{label}**: {score:.4f}")
209
- cumulative_prob += score
210
- if cumulative_prob >= 0.95:
211
- break
212
 
213
- # Кнопка "Показать все"
214
- if st.button("Показать все категории"):
215
- st.subheader("Полные результаты классификации:")
216
- for label, score in results:
217
- st.write(f"- **{label}**: {score:.4f}")
218
- else:
219
- st.info("Не удалось получить результаты классификации.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
- elif title or description: #небольшой костыль, чтобы при старте не было предупреждения
222
- st.warning("Пожалуйста, заполните хотя бы одно поле.")
 
 
1
  import streamlit as st
2
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
 
4
+ # Инициализация состояния сессии
5
+ if 'show_all' not in st.session_state:
6
+ st.session_state.show_all = False
7
+ if 'results' not in st.session_state:
8
+ st.session_state.results = []
9
+
10
  id_to_cat = {0: 'Performance',
11
  1: 'Molecular Networks',
12
  2: 'Operating Systems',
 
146
  136: 'Nuclear Experiment',
147
  137: 'Artificial Intelligence'}
148
 
149
+ id_to_cat = {0: 'Performance', ...} # Ваш полный словарь категорий
150
+
151
+ # Загружаем модель
152
  model_name = 'checkpoint'
153
  try:
154
  tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
 
159
  )
160
  except OSError as e:
161
  st.error(f"Ошибка загрузки модели: {e}. Убедитесь, что модель доступна или укажите другую.")
162
+ st.stop()
 
163
 
164
  def classify_text(title, description):
165
+ text = f"{title} {description}"
166
+ topic_classifier = pipeline("text-classification",
167
+ model=model,
168
+ tokenizer=tokenizer,
169
+ top_k=len(id_to_cat))
 
 
 
 
 
 
 
 
 
170
  try:
171
+ results = topic_classifier(text)
172
+ processed = []
173
+ for item in results[0]:
174
+ label_id = int(item['label'].split('_')[1])
175
+ processed.append((id_to_cat[label_id], item['score']))
176
+ return sorted(processed, key=lambda x: x[1], reverse=True)
177
  except Exception as e:
178
  st.error(f"Ошибка классификации: {e}")
179
  return []
180
 
 
 
 
 
 
 
 
 
 
 
 
181
  # --- Интерфейс Streamlit ---
182
  st.title("Классификация статей 1")
183
 
 
185
  title = st.text_input("Заголовок статьи")
186
  description = st.text_area("Краткое описание статьи", height=150)
187
 
188
+ # Кнопка классификации
189
  if st.button("Классифицировать"):
190
  if not title and not description:
191
  st.warning("Пожалуйста, заполните хотя бы одно поле.")
192
  else:
193
+ with st.spinner("Идет классификация..."):
194
+ st.session_state.results = classify_text(title, description)
195
+ st.session_state.show_all = False
 
 
 
 
 
 
 
196
 
197
+ # Отображение результатов
198
+ if st.session_state.results:
199
+ st.subheader("Результаты классификации:")
200
+
201
+ # Определение порога отображения
202
+ cumulative = 0
203
+ shown_results = []
204
+ for label, score in st.session_state.results:
205
+ if not st.session_state.show_all and cumulative < 0.95:
206
+ shown_results.append((label, score))
207
+ cumulative += score
208
+ else:
209
+ shown_results.append((label, score))
210
+
211
+ # Отображение результатов
212
+ for label, score in shown_results:
213
+ st.write(f"- **{label}**: {score:.4f}")
214
+
215
+ # Кнопка переключения режима отображения
216
+ if st.session_state.show_all:
217
+ if st.button("Скрыть подробности"):
218
+ st.session_state.show_all = False
219
+ st.experimental_rerun()
220
+ else:
221
+ if st.button("Показать все категории"):
222
+ st.session_state.show_all = True
223
+ st.experimental_rerun()
224
 
225
+ # Отображение предупреждения только после первой попытки
226
+ elif any([title, description]) and not st.session_state.results:
227
+ st.warning("Пожалуйста, нажмите кнопку 'Классифицировать'")