chat / app.py
SimrusDenuvo's picture
Update app.py
f4f65e1 verified
raw
history blame
4.64 kB
import gradio as gr
import time
from transformers import pipeline
from datasets import load_dataset
# Инициализация трёх бесплатных русскоязычных моделей
models = {
'ruDialoGPT-small': pipeline(
'text-generation',
model='t-bank-ai/ruDialoGPT-small',
tokenizer='t-bank-ai/ruDialoGPT-small',
device=-1
),
'ruGPT3-small': pipeline(
'text-generation',
model='ai-forever/rugpt3small_based_on_gpt2',
tokenizer='ai-forever/rugpt3small_based_on_gpt2',
device=-1
),
'rut5-small-chitchat': pipeline(
'text-generation',
model='cointegrated/rut5-small-chitchat',
tokenizer='cointegrated/rut5-small-chitchat',
device=-1
)
}
# Стриминг основного банковского датасета чтобы не загружать всё сразу
bank_stream = load_dataset(
'ai-lab/MBD',
split='train',
streaming=True
)
# Используем явно колонку 'dialogs' из описания датасета
# Берём первые два примера для few-shot
examples = []
for record in bank_stream:
if 'dialogs' in record:
examples.append(record['dialogs'])
elif 'dialog_embeddings' in record:
examples.append(record['dialog_embeddings'])
if len(examples) == 2:
break
if len(examples) < 2:
raise ValueError('Не удалось получить два примера dialog из MBD')
# Системная инструкция для CoT
system_instruction = (
"Вы — банковский ассистент. Ваша задача — корректно и вежливо отвечать на запросы клиентов банка,"
" рассказывать о причинах и способах решения их проблем с банковскими услугами."
)
# Функция построения CoT промпта с few-shot примерами
def build_prompt(question: str) -> str:
few_shot = '\n\n'.join(f"Пример диалога:\n{ex}" for ex in examples)
prompt = (
f"{system_instruction}\n\n"
f"{few_shot}\n\n"
f"Вопрос клиента: {question}\n"
"Сначала подробно опишите рассуждения шаг за шагом, а затем дайте краткий связный ответ."
)
return prompt
# Генерация ответов и измерение времени
def generate(question: str):
prompt = build_prompt(question)
results = {}
for name, pipe in models.items():
start = time.time()
out = pipe(
prompt,
max_length=400,
do_sample=True,
top_p=0.9,
temperature=0.7
)[0]['generated_text']
elapsed = round(time.time() - start, 2)
# Извлечение итогового ответа после 'Ответ:' или последней строки
if 'Ответ:' in out:
answer = out.split('Ответ:')[-1].strip()
else:
answer = out.strip().split('\n')[-1]
results[name] = {'answer': answer, 'time': elapsed}
return results
# Форматируем данные для Gradio интерфейса
def format_outputs(question: str):
res = generate(question)
return (
res['ruDialoGPT-small']['answer'], f"{res['ruDialoGPT-small']['time']}s",
res['ruGPT3-small']['answer'], f"{res['ruGPT3-small']['time']}s",
res['rut5-small-chitchat']['answer'], f"{res['rut5-small-chitchat']['time']}s"
)
# Графический интерфейс Gradio
with gr.Blocks() as demo:
gr.Markdown('## Ответы на клиентские обращения с CoT на трёх моделях и таймингом')
txt = gr.Textbox(
label='Описание проблемы клиента',
placeholder='Например: "Почему я не могу снять деньги с карты?"',
lines=2
)
btn = gr.Button('Сгенерировать ответ')
out1 = gr.Textbox(label='ruDialoGPT-small Ответ')
t1 = gr.Textbox(label='ruDialoGPT-small Время')
out2 = gr.Textbox(label='ruGPT3-small Ответ')
t2 = gr.Textbox(label='ruGPT3-small Время')
out3 = gr.Textbox(label='rut5-small-chitchat Ответ')
t3 = gr.Textbox(label='rut5-small-chitchat Время')
btn.click(format_outputs, inputs=[txt], outputs=[out1, t1, out2, t2, out3, t3])
demo.launch()