File size: 1,776 Bytes
22c2b62
60ea95a
 
22c2b62
7c9d529
 
 
9dbe8ba
 
 
22c2b62
 
60ea95a
22c2b62
60ea95a
 
 
 
 
22c2b62
60ea95a
22c2b62
 
 
 
fdfb963
22c2b62
60ea95a
22c2b62
60ea95a
 
 
22c2b62
60ea95a
 
22c2b62
60ea95a
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

st.set_page_config(page_title="ИТМО Магистратура Чат-бот", page_icon="🎓")
st.title("🎓 Чат-бот про магистратуру ИТМО")

# MODEL_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"
MODEL_NAME = "IlyaGusev/saiga_llama3_8b, saiga_7b_lora"



@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
    if torch.cuda.is_available():
        model = model.to('cuda')
    return tokenizer, model

tokenizer, model = load_model()

if "history" not in st.session_state:
    st.session_state.history = []

SYSTEM_PROMPT = """Вы являетесь виртуальным помощником для абитуриентов магистратуры Университета ИТМО. Отвечаете на вопросы о магистерских программах ИТМО."""

user_input = st.text_input("Введите ваш вопрос про магистратуру ИТМО:")

if user_input:
    input_text = SYSTEM_PROMPT + "\n" + user_input
    inputs = tokenizer(input_text, return_tensors="pt")

    if torch.cuda.is_available():
        inputs = {k: v.to('cuda') for k, v in inputs.items()}

    outputs = model.generate(**inputs, max_length=500, do_sample=True, temperature=0.7, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    reply = response[len(input_text):].strip()

    st.session_state.history.append((user_input, reply))

for i, (q, a) in enumerate(st.session_state.history):
    st.markdown(f"**Вы:** {q}")
    st.markdown(f"**Бот:** {a}")