oiisa commited on
Commit
60ea95a
·
verified ·
1 Parent(s): 22c2b62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -62
app.py CHANGED
@@ -1,76 +1,44 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
- import time
4
 
5
- # Настройки для экономии памяти
6
- MAX_LENGTH = 200
7
- MODEL_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"
8
 
9
- st.set_page_config(page_title="Магистратура Помощник")
10
- st.title("🎓 Помощник по вопросам магистратуры")
11
- st.write("Задайте вопрос о поступлении или обучении в магистратуре")
12
 
13
- # Кешируем загрузку модели
14
- @st.cache_resource(show_spinner=False)
15
  def load_model():
16
- with st.spinner("Загружаем модель... Это займет около 30 секунд"):
17
- return pipeline(
18
- "text-generation",
19
- model=MODEL_NAME,
20
- device_map="auto"
21
- )
22
 
23
- generator = load_model()
24
 
25
- # История диалога
26
  if "history" not in st.session_state:
27
  st.session_state.history = []
28
 
29
- # Функция генерации ответа
30
- def generate_answer(question):
31
- prompt = f"Вопрос о магистратуре: {question}\nОтвет:"
32
- try:
33
- result = generator(
34
- prompt,
35
- max_length=MAX_LENGTH,
36
- num_return_sequences=1,
37
- temperature=0.7,
38
- repetition_penalty=1.2,
39
- pad_token_id=50256
40
- )
41
- return result[0]['generated_text'].split("Ответ:")[-1].strip()
42
- except Exception as e:
43
- return f"Ошибка: {str(e)}"
44
 
45
- # Форма ввода
46
- with st.form("question_form"):
47
- user_input = st.text_input("Ваш вопрос:", placeholder="Что нужно для поступления?")
48
- submit_button = st.form_submit_button("Спросить")
49
 
50
- # Обработка вопроса
51
- if submit_button and user_input:
52
- with st.spinner("Генерируем ответ..."):
53
- start_time = time.time()
54
- response = generate_answer(user_input)
55
- st.session_state.history.append((user_input, response))
56
-
57
- # Очищаем поле ввода после отправки
58
- st.rerun()
59
 
60
- # Показываем историю диалога
61
- for i, (question, answer) in enumerate(st.session_state.history[::-1]):
62
- st.divider()
63
- st.subheader(f"❓ {question}")
64
- st.write(f"💡 {answer}")
65
 
66
- # Информация о модели
67
- st.sidebar.markdown("""
68
- ### О приложении
69
- Использует русскоязычную модель [rugpt3small](https://huggingface.co/sberbank-ai/rugpt3small_based_on_gpt2) для ответов на вопросы о магистратуре.
70
- **Примеры вопросов:**
71
- - Какие документы нужны для поступления?
72
- - Сколько длится обучение?
73
- - Чем магистратура отличается от бакалавриата?
74
- - Какие есть направления?
75
- - Есть ли бюджетные места?
76
- """)
 
1
  import streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ st.set_page_config(page_title="ИТМО Магистратура Чат-бот", page_icon="🎓")
6
+ st.title("🎓 Чат-бот про магистратуру ИТМО на русском")
 
7
 
8
+ MODEL_NAME = "Grossmend/rudialogpt3_medium_based_on_gpt2"
 
 
9
 
10
+ @st.cache_resource
 
11
  def load_model():
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
14
+ if torch.cuda.is_available():
15
+ model = model.to('cuda')
16
+ return tokenizer, model
 
17
 
18
+ tokenizer, model = load_model()
19
 
 
20
  if "history" not in st.session_state:
21
  st.session_state.history = []
22
 
23
+ SYSTEM_PROMPT = """Вы являетесь виртуальным помощником для абитуриентов магистратуры Университета ИТМО.
24
+ Отвечаете на вопросы о магистерских программах ИТМО.
25
+ """
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ user_input = st.text_input("Введите ваш вопрос про магистратуру ИТМО:")
 
 
 
28
 
29
+ if user_input:
30
+ input_text = SYSTEM_PROMPT + "\n" + user_input
31
+ inputs = tokenizer(input_text, return_tensors="pt")
 
 
 
 
 
 
32
 
33
+ if torch.cuda.is_available():
34
+ inputs = {k: v.to('cuda') for k, v in inputs.items()}
 
 
 
35
 
36
+ outputs = model.generate(**inputs, max_length=500, do_sample=True, temperature=0.7, pad_token_id=tokenizer.eos_token_id)
37
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
+ reply = response[len(input_text):].strip()
39
+
40
+ st.session_state.history.append((user_input, reply))
41
+
42
+ for i, (q, a) in enumerate(st.session_state.history):
43
+ st.markdown(f"**Вы:** {q}")
44
+ st.markdown(f"**Бот:** {a}")