oiisa commited on
Commit
d7be5ca
·
verified ·
1 Parent(s): 4b3df4c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +64 -54
src/streamlit_app.py CHANGED
@@ -1,67 +1,77 @@
1
  import streamlit as st
2
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
 
3
 
4
- # Заголовок приложения
5
- st.title("🤖 Помощник по вопросам магистратуры")
6
- st.write("Задайте вопрос о поступлении, обучении или программах магистратуры")
7
 
8
- # Инициализация модели (кешируется)
9
- @st.cache_resource
 
 
 
 
10
  def load_model():
11
- model_name = "sberbank-ai/rugpt3small_based_on_gpt2"
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- model = AutoModelForCausalLM.from_pretrained(model_name)
14
- generator = pipeline(
15
- "text-generation",
16
- model=model,
17
- tokenizer=tokenizer,
18
- device=-1 # CPU (для GPU измените на 0)
19
- )
20
- return generator
21
 
22
  generator = load_model()
23
 
24
- # Форма для ввода вопроса
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  with st.form("question_form"):
26
- user_input = st.text_area("Ваш вопрос:", "Что такое магистратура?")
27
- submit_button = st.form_submit_button("Получить ответ")
28
 
29
  # Обработка вопроса
30
- if submit_button:
31
- if not user_input.strip():
32
- st.warning("Пожалуйста, введите вопрос")
33
- else:
34
- with st.spinner("Генерируем ответ..."):
35
- # Форматируем промпт для лучшего ответа
36
- prompt = f"Вопрос о магистратуре: {user_input}\nОтвет:"
37
-
38
- # Генерация ответа
39
- try:
40
- results = generator(
41
- prompt,
42
- max_length=300,
43
- num_return_sequences=1,
44
- temperature=0.7,
45
- repetition_penalty=1.5,
46
- pad_token_id=50256
47
- )
48
-
49
- # Извлекаем и очищаем ответ
50
- answer = results[0]['generated_text'].split("Ответ:")[-1].strip()
51
- answer = answer.split('\n')[0] # Берем первый абзац
52
-
53
- # Выводим результат
54
- st.subheader("Ответ:")
55
- st.write(answer)
56
-
57
- except Exception as e:
58
- st.error(f"Ошибка генерации: {str(e)}")
59
 
60
  # Информация о модели
61
- st.divider()
62
- st.markdown("""
63
- **О приложении:**
64
- - Использует русскоязычную модель `rugpt3small_based_on_gpt2`
65
- - Отвечает на вопросы о магистратуре
66
- - Работает на CPU (для ускорения используйте GPU)
 
 
 
 
67
  """)
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
+ import time
4
 
5
+ # Настройки для экономии памяти
6
+ MAX_LENGTH = 200
7
+ MODEL_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"
8
 
9
+ st.set_page_config(page_title="Магистратура Помощник")
10
+ st.title("🎓 Помощник по вопросам магистратуры")
11
+ st.write("Задайте вопрос о поступлении или обучении в магистратуре")
12
+
13
+ # Кешируем загрузку модели
14
+ @st.cache_resource(show_spinner=False)
15
  def load_model():
16
+ with st.spinner("Загружаем модель... Это займет около 30 секунд"):
17
+ return pipeline(
18
+ "text-generation",
19
+ model=MODEL_NAME,
20
+ device_map="auto"
21
+ )
 
 
 
 
22
 
23
  generator = load_model()
24
 
25
+ # История диалога
26
+ if "history" not in st.session_state:
27
+ st.session_state.history = []
28
+
29
+ # Функция генерации ответа
30
+ def generate_answer(question):
31
+ prompt = f"Вопрос о магистратуре: {question}\nОтвет:"
32
+ try:
33
+ result = generator(
34
+ prompt,
35
+ max_length=MAX_LENGTH,
36
+ num_return_sequences=1,
37
+ temperature=0.7,
38
+ repetition_penalty=1.2,
39
+ pad_token_id=50256
40
+ )
41
+ return result[0]['generated_text'].split("Ответ:")[-1].strip()
42
+ except Exception as e:
43
+ return f"Ошибка: {str(e)}"
44
+
45
+ # Форма ввода
46
  with st.form("question_form"):
47
+ user_input = st.text_input("Ваш вопрос:", placeholder="Что нужно для поступления?")
48
+ submit_button = st.form_submit_button("Спросить")
49
 
50
  # Обработка вопроса
51
+ if submit_button and user_input:
52
+ with st.spinner("Генерируем ответ..."):
53
+ start_time = time.time()
54
+ response = generate_answer(user_input)
55
+ st.session_state.history.append((user_input, response))
56
+
57
+ # Очищаем поле ввода после отправки
58
+ st.rerun()
59
+
60
+ # Показываем историю диалога
61
+ for i, (question, answer) in enumerate(st.session_state.history[::-1]):
62
+ st.divider()
63
+ st.subheader(f"❓ {question}")
64
+ st.write(f"💡 {answer}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Информация о модели
67
+ st.sidebar.markdown("""
68
+ ### О приложении
69
+ Использует ру��скоязычную модель [rugpt3small](https://huggingface.co/sberbank-ai/rugpt3small_based_on_gpt2) для ответов на вопросы о магистратуре.
70
+
71
+ **Примеры вопросов:**
72
+ - Какие документы нужны для поступления?
73
+ - Сколько длится обучение?
74
+ - Чем магистратура отличается от бакалавриата?
75
+ - Какие есть направления?
76
+ - Есть ли бюджетные места?
77
  """)