oiisa commited on
Commit
4b3df4c
·
verified ·
1 Parent(s): 9a40368

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +59 -44
src/streamlit_app.py CHANGED
@@ -1,52 +1,67 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
3
- import torch
4
 
5
- # Настройки
6
- MODEL_NAME = "AlexKay/xlm-roberta-large-qa-multilingual-finedtuned-ru"
7
- CONTEXT = """
8
- Абитуриенты магистратуры подают документы через личный кабинет. Требуется диплом бакалавра/специалиста.
9
- Средний балл диплома рассчитывается как сумма всех оценок, деленная на количество дисциплин (без учета ВКР).
10
- Доступные программы:
11
- - Информатика и вычислительная техника (код 09.04.01).
12
- - Управление персоналом (код 38.04.03).
13
- - Физика (код 04.04.01).
14
- Параллельное зачисление на две программы возможно только на платной основе.
15
- Прием документов дистанционный через Госуслуги или чат-бот вуза.
16
- """
17
 
18
- # Загрузка модели
19
  @st.cache_resource
20
- def load_qa_model():
21
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
- model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME)
23
- return pipeline("question-answering", model=model, tokenizer=tokenizer)
 
 
 
 
 
 
 
24
 
25
- qa_pipeline = load_qa_model()
26
 
27
- # Интерфейс
28
- st.title("🤖 Консультант для абитуриентов магистратуры")
29
- st.info("Отвечаю только на вопросы о магистратуре! Примеры: программы, документы, средний балл.")
 
30
 
31
- question = st.text_input("Задайте вопрос:")
32
- if question:
33
- # Фильтр тематики
34
- forbidden_keywords = ["бакалавр", "егэ", "олимпиад", "школ", "аспирантур", "специалитет"]
35
- if any(word in question.lower() for word in forbidden_keywords):
36
- st.error("Извините, я консультирую только по магистратуре. Задайте вопрос о программах, документах или экзаменах.")
37
  else:
38
- # Поиск ответа в контексте
39
- result = qa_pipeline(question=question, context=CONTEXT)
40
- st.subheader("Ответ:")
41
- st.write(result["answer"])
42
- st.caption(f"Точность: {result['score']:.2f}")
43
-
44
- # Показать программы, если спрашивают о них
45
- if "программ" in question.lower():
46
- st.divider()
47
- st.write("**Все программы магистратуры:**")
48
- st.markdown("""
49
- - 🖥️ **Информатика и вычислительная техника** (09.04.01)
50
- - 👥 **Управление персоналом** (38.04.03)
51
- - 🔬 **Физика** (04.04.01)
52
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
3
 
4
+ # Заголовок приложения
5
+ st.title("🤖 Помощник по вопросам магистратуры")
6
+ st.write("Задайте вопрос о поступлении, обучении или программах магистратуры")
 
 
 
 
 
 
 
 
 
7
 
8
+ # Инициализация модели (кешируется)
9
  @st.cache_resource
10
+ def load_model():
11
+ model_name = "sberbank-ai/rugpt3small_based_on_gpt2"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForCausalLM.from_pretrained(model_name)
14
+ generator = pipeline(
15
+ "text-generation",
16
+ model=model,
17
+ tokenizer=tokenizer,
18
+ device=-1 # CPU (для GPU измените на 0)
19
+ )
20
+ return generator
21
 
22
+ generator = load_model()
23
 
24
+ # Форма для ввода вопроса
25
+ with st.form("question_form"):
26
+ user_input = st.text_area("Ваш вопрос:", "Что такое магистратура?")
27
+ submit_button = st.form_submit_button("Получить ответ")
28
 
29
+ # Обработка вопроса
30
+ if submit_button:
31
+ if not user_input.strip():
32
+ st.warning("Пожалуйста, введите вопрос")
 
 
33
  else:
34
+ with st.spinner("Генерируем ответ..."):
35
+ # Форматируем промпт для лучшего ответа
36
+ prompt = f"Вопрос о магистратуре: {user_input}\nОтвет:"
37
+
38
+ # Генерация ответа
39
+ try:
40
+ results = generator(
41
+ prompt,
42
+ max_length=300,
43
+ num_return_sequences=1,
44
+ temperature=0.7,
45
+ repetition_penalty=1.5,
46
+ pad_token_id=50256
47
+ )
48
+
49
+ # Извлекаем и очищаем ответ
50
+ answer = results[0]['generated_text'].split("Ответ:")[-1].strip()
51
+ answer = answer.split('\n')[0] # Берем первый абзац
52
+
53
+ # Выводим результат
54
+ st.subheader("Ответ:")
55
+ st.write(answer)
56
+
57
+ except Exception as e:
58
+ st.error(f"Ошибка генерации: {str(e)}")
59
+
60
+ # Информация о модели
61
+ st.divider()
62
+ st.markdown("""
63
+ **О приложении:**
64
+ - Использует русскоязычную модель `rugpt3small_based_on_gpt2`
65
+ - Отвечает на вопросы о магистратуре
66
+ - Работает на CPU (для ускорения используйте GPU)
67
+ """)