1NEYRON1 commited on
Commit
e786d48
·
1 Parent(s): 0855d6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -45
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
 
4
- # Инициализация состояния сессии
5
  if 'show_all' not in st.session_state:
6
  st.session_state.show_all = False
7
  if 'results' not in st.session_state:
@@ -146,9 +146,7 @@ id_to_cat = {0: 'Performance',
146
  136: 'Nuclear Experiment',
147
  137: 'Artificial Intelligence'}
148
 
149
- id_to_cat = {0: 'Performance', ...} # Ваш полный словарь категорий
150
-
151
- # Загружаем модель
152
  model_name = 'checkpoint'
153
  try:
154
  tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
@@ -159,25 +157,35 @@ try:
159
  )
160
  except OSError as e:
161
  st.error(f"Ошибка загрузки модели: {e}. Убедитесь, что модель доступна или укажите другую.")
162
- st.stop()
 
163
 
164
  def classify_text(title, description):
165
- text = f"{title} {description}"
166
- topic_classifier = pipeline("text-classification",
167
- model=model,
168
- tokenizer=tokenizer,
169
- top_k=len(id_to_cat))
 
 
 
 
 
170
  try:
171
- results = topic_classifier(text)
172
- processed = []
173
- for item in results[0]:
174
- label_id = int(item['label'].split('_')[1])
175
- processed.append((id_to_cat[label_id], item['score']))
176
- return sorted(processed, key=lambda x: x[1], reverse=True)
177
  except Exception as e:
178
  st.error(f"Ошибка классификации: {e}")
179
  return []
180
 
 
 
 
 
 
 
 
 
 
181
  # --- Интерфейс Streamlit ---
182
  st.title("Классификация статей 1")
183
 
@@ -185,43 +193,34 @@ st.title("Классификация статей 1")
185
  title = st.text_input("Заголовок статьи")
186
  description = st.text_area("Краткое описание статьи", height=150)
187
 
188
- # Кнопка классификации
189
  if st.button("Классифицировать"):
190
  if not title and not description:
191
  st.warning("Пожалуйста, заполните хотя бы одно поле.")
192
  else:
193
- with st.spinner("Идет классификация..."):
194
  st.session_state.results = classify_text(title, description)
195
- st.session_state.show_all = False
196
 
197
- # Отображение результатов
198
  if st.session_state.results:
199
- st.subheader("Результаты классификации:")
200
-
201
- # Определение порога отображения
202
- cumulative = 0
203
- shown_results = []
204
- for label, score in st.session_state.results:
205
- if not st.session_state.show_all and cumulative < 0.95:
206
- shown_results.append((label, score))
207
- cumulative += score
208
- else:
209
- shown_results.append((label, score))
210
-
211
- # Отображение результатов
212
- for label, score in shown_results:
213
- st.write(f"- **{label}**: {score:.4f}")
214
-
215
- # Кнопка переключения режима отображения
216
  if st.session_state.show_all:
217
- if st.button("Скрыть подробности"):
218
- st.session_state.show_all = False
219
- st.experimental_rerun()
220
  else:
221
- if st.button("Показать все категории"):
 
 
 
 
 
 
 
 
 
222
  st.session_state.show_all = True
223
- st.experimental_rerun()
224
 
225
- # Отображение предупреждения только после первой попытки
226
- elif any([title, description]) and not st.session_state.results:
227
- st.warning("Пожалуйста, нажмите кнопку 'Классифицировать'")
 
1
  import streamlit as st
2
  from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
3
 
4
+ # Initialize session state
5
  if 'show_all' not in st.session_state:
6
  st.session_state.show_all = False
7
  if 'results' not in st.session_state:
 
146
  136: 'Nuclear Experiment',
147
  137: 'Artificial Intelligence'}
148
 
149
+ # Загружаем модель (замените на вашу модель, если нужно)
 
 
150
  model_name = 'checkpoint'
151
  try:
152
  tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased')
 
157
  )
158
  except OSError as e:
159
  st.error(f"Ошибка загрузки модели: {e}. Убедитесь, что модель доступна или укажите другую.")
160
+ st.stop() # Остановка выполнения приложения при ошибке
161
+
162
 
163
  def classify_text(title, description):
164
+ """
165
+ Классифицирует текст и возвращает результаты в отсортированном виде.
166
+ Args:
167
+ title (str): Заголовок текста.
168
+ description (str): Краткое описание текста.
169
+ Returns:
170
+ list: Отсортированный список результатов классификации.
171
+ """
172
+ text = f"{title} {description}" # Объединяем заголовок и описание
173
+ topic_classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, top_k = len(id_to_cat))
174
  try:
175
+ results = topic_classifier(text)
 
 
 
 
 
176
  except Exception as e:
177
  st.error(f"Ошибка классификации: {e}")
178
  return []
179
 
180
+ for i in results[0]:
181
+ i['label'] = id_to_cat[int(i['label'].split('_')[1])]
182
+
183
+ filtered_results = []
184
+ for i in results[0]:
185
+ filtered_results.append((i['label'], i['score']))
186
+ return filtered_results
187
+
188
+
189
  # --- Интерфейс Streamlit ---
190
  st.title("Классификация статей 1")
191
 
 
193
  title = st.text_input("Заголовок статьи")
194
  description = st.text_area("Краткое описание статьи", height=150)
195
 
196
+ # Кнопка "Классифицировать"
197
  if st.button("Классифицировать"):
198
  if not title and not description:
199
  st.warning("Пожалуйста, заполните хотя бы одно поле.")
200
  else:
201
+ with st.spinner("Идет классификация..."): # Индикатор загрузки
202
  st.session_state.results = classify_text(title, description)
203
+ st.session_state.show_all = False # Reset to show only top 95%
204
 
205
+ # Display results if available
206
  if st.session_state.results:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  if st.session_state.show_all:
208
+ st.subheader("Полные результаты классификации:")
209
+ for label, score in st.session_state.results:
210
+ st.write(f"- **{label}**: {score:.4f}")
211
  else:
212
+ st.subheader("Результаты классификации (top 95%):")
213
+ cumulative_prob = 0
214
+ for label, score in st.session_state.results:
215
+ st.write(f"- **{label}**: {score:.4f}")
216
+ cumulative_prob += score
217
+ if cumulative_prob >= 0.95:
218
+ break
219
+
220
+ # Renamed button that refreshes the page
221
+ if st.button("Покажи все"):
222
  st.session_state.show_all = True
223
+ st.experimental_rerun() # Refresh the page
224
 
225
+ elif title or description: # небольшой костыль, чтобы при старте не было предупреждения
226
+ st.warning("Пожалуйста, нажмите кнопку 'Классифицировать'.")