chat / app.py
SimrusDenuvo's picture
Update app.py
52af11a verified
raw
history blame
7.76 kB
import gradio as gr
import time
from transformers import pipeline
from datasets import load_dataset
# Группированные обращения по категориям
categories = {
"доступ": [
"Клиент: Я не могу войти в личный кабинет\nОтвет: Пожалуйста, проверьте правильность логина и пароля. Если проблема сохраняется — воспользуйтесь восстановлением доступа или обратитесь в поддержку.",
"Клиент: У меня заблокирован вход в интернет-банк\nОтвет: Обратитесь в поддержку для подтверждения личности и восстановления доступа."
],
"переводы": [
"Клиент: Как перевести деньги на другую карту?\nОтвет: Перевод возможен через мобильное приложение или в отделении. Убедитесь, что карта получателя активна.",
"Клиент: Почему не проходит перевод?\nОтвет: Проверьте лимиты по карте и правильность реквизитов. Если всё верно, свяжитесь с поддержкой."
],
"смс": [
"Клиент: Мне пришло смс, которого я не ожидал\nОтвет: Это может быть техническое уведомление. Уточните дату и текст сообщения, чтобы мы проверили.",
"Клиент: Получаю подозрительные смс от банка\nОтвет: Не переходите по ссылкам. Немедленно смените пароль и сообщите в службу безопасности."
]
}
# Инструкция CoT (расширенная)
cot_instruction = (
"Ты — банковский ассистент. Клиент задал вопрос. Сначала разложи его по смысловым блокам: что именно он хочет, какие данные нужны, возможные причины. "
"После анализа предложи чёткий, полезный и вежливый ответ. Если информации недостаточно — запроси уточнение."
)
# Инструкция обычная
simple_instruction = (
"Ты — банковский помощник. Отвечай официально, коротко и понятно, без рассуждений."
)
# Классификация темы обращения
def classify_topic(user_input):
if any(x in user_input.lower() for x in ["войти", "кабинет", "доступ", "логин", "пароль"]):
return "доступ"
elif any(x in user_input.lower() for x in ["перевод", "перевести", "деньги", "карта"]):
return "переводы"
elif any(x in user_input.lower() for x in ["смс", "сообщение", "уведомление"]):
return "смс"
return "доступ" # fallback
# Сбор примеров по теме
def get_few_shot_examples(user_input):
topic = classify_topic(user_input)
return categories.get(topic, categories["доступ"])
# Модели
models = {
"ruDialoGPT-small": pipeline("text-generation", model="t-bank-ai/ruDialoGPT-small", tokenizer="t-bank-ai/ruDialoGPT-small", device=-1),
"ruDialoGPT-medium": pipeline("text-generation", model="t-bank-ai/ruDialoGPT-medium", tokenizer="t-bank-ai/ruDialoGPT-medium", device=-1),
"ruGPT3-small": pipeline("text-generation", model="ai-forever/rugpt3small_based_on_gpt2", tokenizer="ai-forever/rugpt3small_based_on_gpt2", device=-1),
}
# Формирование промптов
def build_cot_prompt(user_input):
examples = "\n\n".join(get_few_shot_examples(user_input))
return f"{cot_instruction}\n\n{examples}\n\nКлиент: {user_input}\nРассуждение и ответ:"
def build_simple_prompt(user_input):
examples = "\n\n".join(get_few_shot_examples(user_input))
return f"{simple_instruction}\n\n{examples}\n\nКлиент: {user_input}\nОтвет:"
# Фильтрация финального ответа
def extract_final_answer(generated_text):
for line in generated_text.split('\n'):
if "Ответ:" in line:
return line.split("Ответ:")[-1].strip()
return generated_text.strip().split('\n')[-1]
# Генерация
def generate_dual_answers(user_input):
results = {}
prompt_cot = build_cot_prompt(user_input)
prompt_simple = build_simple_prompt(user_input)
for name, pipe in models.items():
# CoT
start_cot = time.time()
out_cot = pipe(prompt_cot, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.9)[0]["generated_text"]
end_cot = round(time.time() - start_cot, 2)
answer_cot = extract_final_answer(out_cot)
# Simple
start_simple = time.time()
out_simple = pipe(prompt_simple, max_new_tokens=100, do_sample=True, top_p=0.95, temperature=0.9)[0]["generated_text"]
end_simple = round(time.time() - start_simple, 2)
answer_simple = extract_final_answer(out_simple)
results[name] = {
"cot_answer": answer_cot,
"cot_time": end_cot,
"simple_answer": answer_simple,
"simple_time": end_simple
}
return (
results["ruDialoGPT-small"]["cot_answer"], f"{results['ruDialoGPT-small']['cot_time']} сек",
results["ruDialoGPT-small"]["simple_answer"], f"{results['ruDialoGPT-small']['simple_time']} сек",
results["ruDialoGPT-medium"]["cot_answer"], f"{results['ruDialoGPT-medium']['cot_time']} сек",
results["ruDialoGPT-medium"]["simple_answer"], f"{results['ruDialoGPT-medium']['simple_time']} сек",
results["ruGPT3-small"]["cot_answer"], f"{results['ruGPT3-small']['cot_time']} сек",
results["ruGPT3-small"]["simple_answer"], f"{results['ruGPT3-small']['simple_time']} сек",
)
# Интерфейс Gradio
with gr.Blocks() as demo:
gr.Markdown("## 🏦 Банковский помощник: CoT vs. Обычный ответ (магистерская работа)")
inp = gr.Textbox(label="Вопрос клиента", placeholder="Например: Почему не проходит перевод?", lines=2)
btn = gr.Button("Сгенерировать")
gr.Markdown("### ruDialoGPT-small")
cot1 = gr.Textbox(label="CoT ответ")
cot1_time = gr.Textbox(label="Время CoT")
simple1 = gr.Textbox(label="Обычный ответ")
simple1_time = gr.Textbox(label="Время обычного")
gr.Markdown("### ruDialoGPT-medium")
cot2 = gr.Textbox(label="CoT ответ")
cot2_time = gr.Textbox(label="Время CoT")
simple2 = gr.Textbox(label="Обычный ответ")
simple2_time = gr.Textbox(label="Время обычного")
gr.Markdown("### ruGPT3-small")
cot3 = gr.Textbox(label="CoT ответ")
cot3_time = gr.Textbox(label="Время CoT")
simple3 = gr.Textbox(label="Обычный ответ")
simple3_time = gr.Textbox(label="Время обычного")
btn.click(generate_dual_answers, inputs=[inp], outputs=[
cot1, cot1_time, simple1, simple1_time,
cot2, cot2_time, simple2, simple2_time,
cot3, cot3_time, simple3, simple3_time
])
demo.launch()