Spaces:
Sleeping
Sleeping
import gradio as gr | |
import time | |
from transformers import pipeline | |
from datasets import load_dataset | |
# Инициализация трёх бесплатных русскоязычных моделей | |
models = { | |
'ruDialoGPT-small': pipeline( | |
'text-generation', | |
model='t-bank-ai/ruDialoGPT-small', | |
tokenizer='t-bank-ai/ruDialoGPT-small', | |
device=-1 | |
), | |
'ruGPT3-small': pipeline( | |
'text-generation', | |
model='ai-forever/rugpt3small_based_on_gpt2', | |
tokenizer='ai-forever/rugpt3small_based_on_gpt2', | |
device=-1 | |
), | |
'rut5-small-chitchat': pipeline( | |
'text-generation', | |
model='cointegrated/rut5-small-chitchat', | |
tokenizer='cointegrated/rut5-small-chitchat', | |
device=-1 | |
) | |
} | |
# Стриминг основного банковского датасета чтобы не загружать всё сразу | |
bank_stream = load_dataset( | |
'ai-lab/MBD', | |
split='train', | |
streaming=True | |
) | |
# Используем явно колонку 'dialogs' из описания датасета | |
# Берём первые два примера для few-shot | |
examples = [] | |
for record in bank_stream: | |
if 'dialogs' in record: | |
examples.append(record['dialogs']) | |
elif 'dialog_embeddings' in record: | |
examples.append(record['dialog_embeddings']) | |
if len(examples) == 2: | |
break | |
if len(examples) < 2: | |
raise ValueError('Не удалось получить два примера dialog из MBD') | |
# Системная инструкция для CoT | |
system_instruction = ( | |
"Вы — банковский ассистент. Ваша задача — корректно и вежливо отвечать на запросы клиентов банка," | |
" рассказывать о причинах и способах решения их проблем с банковскими услугами." | |
) | |
# Функция построения CoT промпта с few-shot примерами | |
def build_prompt(question: str) -> str: | |
few_shot = '\n\n'.join(f"Пример диалога:\n{ex}" for ex in examples) | |
prompt = ( | |
f"{system_instruction}\n\n" | |
f"{few_shot}\n\n" | |
f"Вопрос клиента: {question}\n" | |
"Сначала подробно опишите рассуждения шаг за шагом, а затем дайте краткий связный ответ." | |
) | |
return prompt | |
# Генерация ответов и измерение времени | |
def generate(question: str): | |
prompt = build_prompt(question) | |
results = {} | |
for name, pipe in models.items(): | |
start = time.time() | |
out = pipe( | |
prompt, | |
max_length=400, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.7 | |
)[0]['generated_text'] | |
elapsed = round(time.time() - start, 2) | |
# Извлечение итогового ответа после 'Ответ:' или последней строки | |
if 'Ответ:' in out: | |
answer = out.split('Ответ:')[-1].strip() | |
else: | |
answer = out.strip().split('\n')[-1] | |
results[name] = {'answer': answer, 'time': elapsed} | |
return results | |
# Форматируем данные для Gradio интерфейса | |
def format_outputs(question: str): | |
res = generate(question) | |
return ( | |
res['ruDialoGPT-small']['answer'], f"{res['ruDialoGPT-small']['time']}s", | |
res['ruGPT3-small']['answer'], f"{res['ruGPT3-small']['time']}s", | |
res['rut5-small-chitchat']['answer'], f"{res['rut5-small-chitchat']['time']}s" | |
) | |
# Графический интерфейс Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown('## Ответы на клиентские обращения с CoT на трёх моделях и таймингом') | |
txt = gr.Textbox( | |
label='Описание проблемы клиента', | |
placeholder='Например: "Почему я не могу снять деньги с карты?"', | |
lines=2 | |
) | |
btn = gr.Button('Сгенерировать ответ') | |
out1 = gr.Textbox(label='ruDialoGPT-small Ответ') | |
t1 = gr.Textbox(label='ruDialoGPT-small Время') | |
out2 = gr.Textbox(label='ruGPT3-small Ответ') | |
t2 = gr.Textbox(label='ruGPT3-small Время') | |
out3 = gr.Textbox(label='rut5-small-chitchat Ответ') | |
t3 = gr.Textbox(label='rut5-small-chitchat Время') | |
btn.click(format_outputs, inputs=[txt], outputs=[out1, t1, out2, t2, out3, t3]) | |
demo.launch() | |