Spaces:
Sleeping
Sleeping
File size: 4,637 Bytes
57645e8 66fabfa f4f65e1 66fabfa 1267d48 39ebd04 f4f65e1 7e06069 f4f65e1 39ebd04 f4f65e1 39ebd04 f4f65e1 39ebd04 f4f65e1 39ebd04 66fabfa f4f65e1 66fabfa 39ebd04 f4f65e1 66fabfa 39ebd04 f4f65e1 39ebd04 66fabfa 39ebd04 66fabfa 39ebd04 66fabfa 39ebd04 66fabfa f4f65e1 66fabfa f4f65e1 66fabfa 39ebd04 66fabfa 39ebd04 66fabfa f4f65e1 1267d48 f4f65e1 7e06069 f4f65e1 7e06069 39ebd04 66fabfa 39ebd04 f4f65e1 90cd180 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
import time
from transformers import pipeline
from datasets import load_dataset
# Инициализация трёх бесплатных русскоязычных моделей
models = {
'ruDialoGPT-small': pipeline(
'text-generation',
model='t-bank-ai/ruDialoGPT-small',
tokenizer='t-bank-ai/ruDialoGPT-small',
device=-1
),
'ruGPT3-small': pipeline(
'text-generation',
model='ai-forever/rugpt3small_based_on_gpt2',
tokenizer='ai-forever/rugpt3small_based_on_gpt2',
device=-1
),
'rut5-small-chitchat': pipeline(
'text-generation',
model='cointegrated/rut5-small-chitchat',
tokenizer='cointegrated/rut5-small-chitchat',
device=-1
)
}
# Стриминг основного банковского датасета чтобы не загружать всё сразу
bank_stream = load_dataset(
'ai-lab/MBD',
split='train',
streaming=True
)
# Используем явно колонку 'dialogs' из описания датасета
# Берём первые два примера для few-shot
examples = []
for record in bank_stream:
if 'dialogs' in record:
examples.append(record['dialogs'])
elif 'dialog_embeddings' in record:
examples.append(record['dialog_embeddings'])
if len(examples) == 2:
break
if len(examples) < 2:
raise ValueError('Не удалось получить два примера dialog из MBD')
# Системная инструкция для CoT
system_instruction = (
"Вы — банковский ассистент. Ваша задача — корректно и вежливо отвечать на запросы клиентов банка,"
" рассказывать о причинах и способах решения их проблем с банковскими услугами."
)
# Функция построения CoT промпта с few-shot примерами
def build_prompt(question: str) -> str:
few_shot = '\n\n'.join(f"Пример диалога:\n{ex}" for ex in examples)
prompt = (
f"{system_instruction}\n\n"
f"{few_shot}\n\n"
f"Вопрос клиента: {question}\n"
"Сначала подробно опишите рассуждения шаг за шагом, а затем дайте краткий связный ответ."
)
return prompt
# Генерация ответов и измерение времени
def generate(question: str):
prompt = build_prompt(question)
results = {}
for name, pipe in models.items():
start = time.time()
out = pipe(
prompt,
max_length=400,
do_sample=True,
top_p=0.9,
temperature=0.7
)[0]['generated_text']
elapsed = round(time.time() - start, 2)
# Извлечение итогового ответа после 'Ответ:' или последней строки
if 'Ответ:' in out:
answer = out.split('Ответ:')[-1].strip()
else:
answer = out.strip().split('\n')[-1]
results[name] = {'answer': answer, 'time': elapsed}
return results
# Форматируем данные для Gradio интерфейса
def format_outputs(question: str):
res = generate(question)
return (
res['ruDialoGPT-small']['answer'], f"{res['ruDialoGPT-small']['time']}s",
res['ruGPT3-small']['answer'], f"{res['ruGPT3-small']['time']}s",
res['rut5-small-chitchat']['answer'], f"{res['rut5-small-chitchat']['time']}s"
)
# Графический интерфейс Gradio
with gr.Blocks() as demo:
gr.Markdown('## Ответы на клиентские обращения с CoT на трёх моделях и таймингом')
txt = gr.Textbox(
label='Описание проблемы клиента',
placeholder='Например: "Почему я не могу снять деньги с карты?"',
lines=2
)
btn = gr.Button('Сгенерировать ответ')
out1 = gr.Textbox(label='ruDialoGPT-small Ответ')
t1 = gr.Textbox(label='ruDialoGPT-small Время')
out2 = gr.Textbox(label='ruGPT3-small Ответ')
t2 = gr.Textbox(label='ruGPT3-small Время')
out3 = gr.Textbox(label='rut5-small-chitchat Ответ')
t3 = gr.Textbox(label='rut5-small-chitchat Время')
btn.click(format_outputs, inputs=[txt], outputs=[out1, t1, out2, t2, out3, t3])
demo.launch()
|