chat / app.py
SimrusDenuvo's picture
Update app.py
88dd26b verified
raw
history blame
5.85 kB
import gradio as gr
import time
from transformers import pipeline
from datasets import load_dataset
# Загружаем датасет
DATASET_NAME = "Romjiik/Russian_bank_reviews"
dataset = load_dataset(DATASET_NAME, split="train")
# Краткий список примеров для подстановки в промпт (для классификации)
few_shot_examples = [
"Клиент: Не могу войти в приложение.\nКлассификация: Техническая проблема",
"Клиент: Почему с меня сняли деньги дважды?\nКлассификация: Ошибка транзакции",
"Клиент: Хочу оформить кредит.\nКлассификация: Запрос на продукт",
"Клиент: У меня украли карту.\nКлассификация: Безопасность",
"Клиент: Не приходит СМС для входа.\nКлассификация: Проблема авторизации"
]
# Инструкции
cot_instruction = (
"Ты — банковский помощник. Клиент описывает ситуацию. "
"Проанализируй обращение шаг за шагом и определи категорию (например: 'Техническая проблема', 'Запрос на продукт', 'Безопасность' и т.п.)"
)
simple_instruction = (
"Ты — банковский помощник. Клиент описывает обращение. "
"Кратко укажи категорию обращения (например: 'Техническая проблема', 'Запрос на продукт', 'Безопасность' и т.п.)."
)
# Используемые модели (CPU-compatible, ≤16GB)
models = {
"ChatGPT-like (FRED-T5-small)": pipeline("text2text-generation", model="cointegrated/rugpt3small_based_on_gpt2", tokenizer="ai-forever/FRED-T5-Base", device=-1),
"DeepSeek-like (ruGPT3-small)": pipeline("text-generation", model="ai-forever/rugpt3small_based_on_gpt2", tokenizer="ai-forever/rugpt3small_based_on_gpt2", device=-1),
"GigaChat-like (RuBERT-tiny2-clf)": pipeline("text-classification", model="cointegrated/rubert-tiny2", tokenizer="cointegrated/rubert-tiny2", device=-1)
}
# Построение промптов
def build_cot_prompt(user_input):
examples = "\n\n".join(few_shot_examples)
return (
f"{cot_instruction}\n\n{examples}\n\nКлиент: {user_input}\nРассуждение и классификация:"
)
def build_simple_prompt(user_input):
examples = "\n\n".join(few_shot_examples)
return (
f"{simple_instruction}\n\n{examples}\n\nКлиент: {user_input}\nКлассификация:"
)
# Генерация классификаций
def generate_dual_answers(user_input):
results = {}
prompt_cot = build_cot_prompt(user_input)
prompt_simple = build_simple_prompt(user_input)
for name, pipe in models.items():
if "text-generation" in str(pipe.task):
# CoT
start_cot = time.time()
out_cot = pipe(prompt_cot, max_length=256, do_sample=True, top_p=0.9, temperature=0.7)[0]["generated_text"]
end_cot = round(time.time() - start_cot, 2)
answer_cot = out_cot.strip().split("\n")[-1]
# Simple
start_simple = time.time()
out_simple = pipe(prompt_simple, max_length=128, do_sample=True, top_p=0.9, temperature=0.7)[0]["generated_text"]
end_simple = round(time.time() - start_simple, 2)
answer_simple = out_simple.strip().split("\n")[-1]
elif "text2text-generation" in str(pipe.task):
start_cot = time.time()
out_cot = pipe(prompt_cot, max_new_tokens=50)[0]["generated_text"]
end_cot = round(time.time() - start_cot, 2)
start_simple = time.time()
out_simple = pipe(prompt_simple, max_new_tokens=30)[0]["generated_text"]
end_simple = round(time.time() - start_simple, 2)
answer_cot = out_cot.strip()
answer_simple = out_simple.strip()
elif "text-classification" in str(pipe.task):
# Для классификации используем только сам ввод без промпта
start = time.time()
answer = pipe(user_input)[0]['label']
end = round(time.time() - start, 2)
answer_cot = answer
answer_simple = answer
end_cot = end_simple = end
results[name] = {
"cot_answer": answer_cot,
"cot_time": end_cot,
"simple_answer": answer_simple,
"simple_time": end_simple
}
return tuple(
results[model][key] for model in models for key in ["cot_answer", "cot_time", "simple_answer", "simple_time"]
)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🧠 Классификация клиентских обращений в банке (CoT vs обычный промпт)")
inp = gr.Textbox(label="Вопрос клиента", placeholder="Например: У меня не проходит оплата картой", lines=2)
btn = gr.Button("Сгенерировать")
results_blocks = []
for name in models:
gr.Markdown(f"### {name}")
cot = gr.Textbox(label="CoT ответ")
cot_time = gr.Textbox(label="Время CoT")
simple = gr.Textbox(label="Обычный ответ")
simple_time = gr.Textbox(label="Время обычного")
results_blocks.extend([cot, cot_time, simple, simple_time])
btn.click(generate_dual_answers, inputs=[inp], outputs=results_blocks)
demo.launch()