chat / app.py
SimrusDenuvo's picture
Update app.py
7e06069 verified
raw
history blame
4.73 kB
import gradio as gr
import time
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset
# Инициализация трёх бесплатных русскоязычных моделей
models = {}
# 1) ruDialoGPT-small
models['ruDialoGPT-small'] = pipeline(
'text-generation',
model='t-bank-ai/ruDialoGPT-small',
tokenizer='t-bank-ai/ruDialoGPT-small',
device=-1
)
# 2) ruGPT3-small
models['ruGPT3-small'] = pipeline(
'text-generation',
model='ai-forever/rugpt3small_based_on_gpt2',
tokenizer='ai-forever/rugpt3small_based_on_gpt2',
device=-1
)
# 3) rut5-small-chitchat (T5 requires text2text и slow tokenizer)
t5_tokenizer = AutoTokenizer.from_pretrained(
'cointegrated/rut5-small-chitchat',
use_fast=False
)
models['rut5-small-chitchat'] = pipeline(
'text2text-generation',
model='cointegrated/rut5-small-chitchat',
tokenizer=t5_tokenizer,
device=-1
)
# Загрузка "мини" банковского датасета для few-shot (стриминг)
bank_data_stream = load_dataset(
'ai-lab/MBD-mini',
split='train',
streaming=True
)
# Определяем колонку с диалогами по ключам
first_record = next(iter(bank_data_stream))
col = next((c for c in first_record.keys() if 'dialog' in c.lower() or 'диалог' in c.lower()), None)
if col is None:
raise ValueError('Не найдена колонка с диалогами в MBD-mini')
# Собираем два few-shot примера
examples = []
for rec in bank_data_stream:
examples.append(rec[col])
if len(examples) == 2:
break
# Системная инструкция для CoT
system_instruction = (
"Вы — банковский ассистент. Ваша задача — корректно и вежливо отвечать на запросы клиентов банка,"
" рассказывать о причинах и способах решения их проблем с банковскими услугами."
)
# Построение CoT-промпта с few-shot
def build_prompt(question: str) -> str:
few_shot_text = '\n\n'.join(f"Пример диалога:\n{ex}" for ex in examples)
prompt = (
f"{system_instruction}\n\n"
f"{few_shot_text}\n\n"
f"Вопрос клиента: {question}\n"
"Сначала подробно опишите рассуждения шаг за шагом, а затем дайте краткий связный ответ."
)
return prompt
# Генерация ответов и измерение времени
def generate(question: str):
prompt = build_prompt(question)
results = {}
for name, pipe in models.items():
start = time.time()
# для T5 используем text2text, для других text-generation
out = pipe(
prompt,
max_length=400,
do_sample=True,
top_p=0.9,
temperature=0.7
)[0]['generated_text']
elapsed = round(time.time() - start, 2)
# Извлечь финальный ответ
if 'Ответ:' in out:
answer = out.split('Ответ:')[-1].strip()
else:
answer = out.strip().split('\n')[-1]
results[name] = {'answer': answer, 'time': elapsed}
return results
# Подготовка вывода для Gradio
def format_outputs(question: str):
res = generate(question)
return (
res['ruDialoGPT-small']['answer'], f"{res['ruDialoGPT-small']['time']}s",
res['ruGPT3-small']['answer'], f"{res['ruGPT3-small']['time']}s",
res['rut5-small-chitchat']['answer'], f"{res['rut5-small-chitchat']['time']}s"
)
# Интерфейс Gradio
with gr.Blocks() as demo:
gr.Markdown('## Клиентские обращения: CoT на трёх моделях с MBD-mini и тайминг')
txt = gr.Textbox(
label='Опишите проблему клиента',
placeholder='Например: "Почему я не могу снять деньги с карты?"',
lines=2
)
btn = gr.Button('Сгенерировать ответ')
out1 = gr.Textbox(label='ruDialoGPT-small Ответ')
t1 = gr.Textbox(label='ruDialoGPT-small Время')
out2 = gr.Textbox(label='ruGPT3-small Ответ')
t2 = gr.Textbox(label='ruGPT3-small Время')
out3 = gr.Textbox(label='rut5-small-chitchat Ответ')
t3 = gr.Textbox(label='rut5-small-chitchat Время')
btn.click(
format_outputs,
inputs=[txt],
outputs=[out1, t1, out2, t2, out3, t3]
)
demo.launch()