Spaces:
Sleeping
Sleeping
import gradio as gr | |
import time | |
from transformers import pipeline, AutoTokenizer | |
from datasets import load_dataset | |
# Инициализация трёх бесплатных русскоязычных моделей | |
models = {} | |
# 1) ruDialoGPT-small | |
models['ruDialoGPT-small'] = pipeline( | |
'text-generation', | |
model='t-bank-ai/ruDialoGPT-small', | |
tokenizer='t-bank-ai/ruDialoGPT-small', | |
device=-1 | |
) | |
# 2) ruGPT3-small | |
models['ruGPT3-small'] = pipeline( | |
'text-generation', | |
model='ai-forever/rugpt3small_based_on_gpt2', | |
tokenizer='ai-forever/rugpt3small_based_on_gpt2', | |
device=-1 | |
) | |
# 3) rut5-small-chitchat (T5 requires text2text и slow tokenizer) | |
t5_tokenizer = AutoTokenizer.from_pretrained( | |
'cointegrated/rut5-small-chitchat', | |
use_fast=False | |
) | |
models['rut5-small-chitchat'] = pipeline( | |
'text2text-generation', | |
model='cointegrated/rut5-small-chitchat', | |
tokenizer=t5_tokenizer, | |
device=-1 | |
) | |
# Загрузка "мини" банковского датасета для few-shot (стриминг) | |
bank_data_stream = load_dataset( | |
'ai-lab/MBD-mini', | |
split='train', | |
streaming=True | |
) | |
# Определяем колонку с диалогами по ключам | |
first_record = next(iter(bank_data_stream)) | |
col = next((c for c in first_record.keys() if 'dialog' in c.lower() or 'диалог' in c.lower()), None) | |
if col is None: | |
raise ValueError('Не найдена колонка с диалогами в MBD-mini') | |
# Собираем два few-shot примера | |
examples = [] | |
for rec in bank_data_stream: | |
examples.append(rec[col]) | |
if len(examples) == 2: | |
break | |
# Системная инструкция для CoT | |
system_instruction = ( | |
"Вы — банковский ассистент. Ваша задача — корректно и вежливо отвечать на запросы клиентов банка," | |
" рассказывать о причинах и способах решения их проблем с банковскими услугами." | |
) | |
# Построение CoT-промпта с few-shot | |
def build_prompt(question: str) -> str: | |
few_shot_text = '\n\n'.join(f"Пример диалога:\n{ex}" for ex in examples) | |
prompt = ( | |
f"{system_instruction}\n\n" | |
f"{few_shot_text}\n\n" | |
f"Вопрос клиента: {question}\n" | |
"Сначала подробно опишите рассуждения шаг за шагом, а затем дайте краткий связный ответ." | |
) | |
return prompt | |
# Генерация ответов и измерение времени | |
def generate(question: str): | |
prompt = build_prompt(question) | |
results = {} | |
for name, pipe in models.items(): | |
start = time.time() | |
# для T5 используем text2text, для других text-generation | |
out = pipe( | |
prompt, | |
max_length=400, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.7 | |
)[0]['generated_text'] | |
elapsed = round(time.time() - start, 2) | |
# Извлечь финальный ответ | |
if 'Ответ:' in out: | |
answer = out.split('Ответ:')[-1].strip() | |
else: | |
answer = out.strip().split('\n')[-1] | |
results[name] = {'answer': answer, 'time': elapsed} | |
return results | |
# Подготовка вывода для Gradio | |
def format_outputs(question: str): | |
res = generate(question) | |
return ( | |
res['ruDialoGPT-small']['answer'], f"{res['ruDialoGPT-small']['time']}s", | |
res['ruGPT3-small']['answer'], f"{res['ruGPT3-small']['time']}s", | |
res['rut5-small-chitchat']['answer'], f"{res['rut5-small-chitchat']['time']}s" | |
) | |
# Интерфейс Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown('## Клиентские обращения: CoT на трёх моделях с MBD-mini и тайминг') | |
txt = gr.Textbox( | |
label='Опишите проблему клиента', | |
placeholder='Например: "Почему я не могу снять деньги с карты?"', | |
lines=2 | |
) | |
btn = gr.Button('Сгенерировать ответ') | |
out1 = gr.Textbox(label='ruDialoGPT-small Ответ') | |
t1 = gr.Textbox(label='ruDialoGPT-small Время') | |
out2 = gr.Textbox(label='ruGPT3-small Ответ') | |
t2 = gr.Textbox(label='ruGPT3-small Время') | |
out3 = gr.Textbox(label='rut5-small-chitchat Ответ') | |
t3 = gr.Textbox(label='rut5-small-chitchat Время') | |
btn.click( | |
format_outputs, | |
inputs=[txt], | |
outputs=[out1, t1, out2, t2, out3, t3] | |
) | |
demo.launch() | |