oiisa's picture
Update app.py
d0cfd15 verified
raw
history blame
1.72 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
st.set_page_config(page_title="ИТМО Магистратура Чат-бот", page_icon="🎓")
st.title("🎓 Чат-бот про магистратуру ИТМО")
MODEL_NAME = "sberbank-ai/rugpt3small_based_on_gpt2"
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
if torch.cuda.is_available():
model = model.to('cuda')
return tokenizer, model
tokenizer, model = load_model()
if "history" not in st.session_state:
st.session_state.history = []
SYSTEM_PROMPT = """Вы являетесь виртуальным помощником для абитуриентов магистратуры Университета ИТМО. Отвечаете на вопросы о магистерских программах ИТМО."""
user_input = st.text_input("Введите ваш вопрос про магистратуру ИТМО:")
if user_input:
input_text = SYSTEM_PROMPT + "\n" + user_input
inputs = tokenizer(input_text, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to('cuda') for k, v in inputs.items()}
outputs = model.generate(**inputs, max_length=500, do_sample=True, temperature=0.7, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
reply = response[len(input_text):].strip()
st.session_state.history.append((user_input, reply))
for i, (q, a) in enumerate(st.session_state.history):
st.markdown(f"**Вы:** {q}")
st.markdown(f"**Бот:** {a}")