File size: 4,789 Bytes
5f01a56
b4c0a34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
from transformers import pipeline

# Загружаем модель (замените на вашу модель, если нужно)
# Для примера используем zero-shot-classification
try:
    classifier = pipeline("zero-shot-classification")
except OSError as e:
    st.error(f"Ошибка загрузки модели: {e}. Убедитесь, что модель доступна или укажите другую.")
    st.stop()  # Остановка выполнения приложения при ошибке


# model = 
# tokenizer = 
# topic_classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
topic_classifier = pipeline("text-classification")

text = "This is an example sentence for topic classification."
result = topic_classifier(text)
print(result)

def classify_text(title, description, candidate_labels, show_all=False, threshold=0.95):
    """
    Классифицирует текст и возвращает результаты в отсортированном виде.

    Args:
        title (str): Заголовок текста.
        description (str): Краткое описание текста.
        candidate_labels (list): Список меток-кандидатов.
        show_all (bool): Показывать ли все результаты, независимо от порога.
        threshold (float): Порог суммарной вероятности.

    Returns:
        list: Отсортированный список результатов классификации.
    """
    text = f"{title} {description}"  # Объединяем заголовок и описание
    try:
        results = topic_classifier(text) 
        # results = topic_classifier(text, candidate_labels, multi_label=True)  # multi_label=True для нескольких меток
    except Exception as e:
        st.error(f"Ошибка классификации: {e}")
        return []

    # Сортируем результаты по убыванию вероятности
    sorted_results = sorted(zip(results['labels'], results['scores']), key=lambda x: x[1], reverse=True)

    if show_all:
        return sorted_results
    else:
        cumulative_prob = 0
        filtered_results = []
        for label, score in sorted_results:
            filtered_results.append((label, score))
            cumulative_prob += score
            if cumulative_prob >= threshold:
                break
        return filtered_results


# --- Интерфейс Streamlit ---
st.title("Классификация статей")

# Ввод данных
title = st.text_input("Заголовок статьи")
description = st.text_area("Краткое описание статьи", height=150)

# Ввод меток-кандидатов (разделенных запятыми)
default_labels = "политика, экономика, спорт, культура, технологии, наука, происшествия"
candidate_labels_str = st.text_input("Метки-кандидаты (через запятую)", default_labels)
candidate_labels = [label.strip() for label in candidate_labels_str.split(",") if label.strip()]

# Кнопка "Классифицировать"
if st.button("Классифицировать"):
    if not title or not description or not candidate_labels:
        st.warning("Пожалуйста, заполните все поля.")
    else:
        with st.spinner("Идет классификация..."):  # Индикатор загрузки
            results = classify_text(title, description, candidate_labels)
            if results:
              st.subheader("Результаты классификации (с ограничением по вероятности):")
              for label, score in results:
                  st.write(f"- **{label}**: {score:.4f}")

              # Кнопка "Показать все"
              if st.button("Показать все категории"):
                  all_results = classify_text(title, description, candidate_labels, show_all=True)
                  st.subheader("Полные результаты классификации:")
                  for label, score in all_results:
                      st.write(f"- **{label}**: {score:.4f}")
            else:
                st.info("Не удалось получить результаты классификации.")

elif title or description or candidate_labels_str != default_labels: #небольшой костыль, чтобы при старте не было предупреждения
    st.warning("Пожалуйста, заполните все поля.")