File size: 4,637 Bytes
57645e8
66fabfa
f4f65e1
66fabfa
1267d48
39ebd04
f4f65e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e06069
f4f65e1
 
 
39ebd04
 
 
f4f65e1
 
39ebd04
f4f65e1
 
 
 
 
39ebd04
 
f4f65e1
 
39ebd04
 
 
 
 
 
66fabfa
f4f65e1
66fabfa
39ebd04
f4f65e1
66fabfa
39ebd04
f4f65e1
39ebd04
66fabfa
 
 
 
39ebd04
66fabfa
39ebd04
66fabfa
 
 
 
39ebd04
 
 
 
 
 
 
66fabfa
f4f65e1
66fabfa
 
 
 
 
 
 
f4f65e1
66fabfa
39ebd04
66fabfa
 
 
39ebd04
 
66fabfa
 
f4f65e1
 
1267d48
f4f65e1
7e06069
f4f65e1
7e06069
 
 
39ebd04
66fabfa
 
39ebd04
 
 
 
f4f65e1
90cd180
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import time
from transformers import pipeline
from datasets import load_dataset

# Инициализация трёх бесплатных русскоязычных моделей
models = {
    'ruDialoGPT-small': pipeline(
        'text-generation',
        model='t-bank-ai/ruDialoGPT-small',
        tokenizer='t-bank-ai/ruDialoGPT-small',
        device=-1
    ),
    'ruGPT3-small': pipeline(
        'text-generation',
        model='ai-forever/rugpt3small_based_on_gpt2',
        tokenizer='ai-forever/rugpt3small_based_on_gpt2',
        device=-1
    ),
    'rut5-small-chitchat': pipeline(
        'text-generation',
        model='cointegrated/rut5-small-chitchat',
        tokenizer='cointegrated/rut5-small-chitchat',
        device=-1
    )
}

# Стриминг основного банковского датасета чтобы не загружать всё сразу
bank_stream = load_dataset(
    'ai-lab/MBD',
    split='train',
    streaming=True
)
# Используем явно колонку 'dialogs' из описания датасета
# Берём первые два примера для few-shot
examples = []
for record in bank_stream:
    if 'dialogs' in record:
        examples.append(record['dialogs'])
    elif 'dialog_embeddings' in record:
        examples.append(record['dialog_embeddings'])
    if len(examples) == 2:
        break
if len(examples) < 2:
    raise ValueError('Не удалось получить два примера dialog из MBD')

# Системная инструкция для CoT
system_instruction = (
    "Вы — банковский ассистент. Ваша задача — корректно и вежливо отвечать на запросы клиентов банка,"
    " рассказывать о причинах и способах решения их проблем с банковскими услугами."
)

# Функция построения CoT промпта с few-shot примерами

def build_prompt(question: str) -> str:
    few_shot = '\n\n'.join(f"Пример диалога:\n{ex}" for ex in examples)
    prompt = (
        f"{system_instruction}\n\n"
        f"{few_shot}\n\n"
        f"Вопрос клиента: {question}\n"
        "Сначала подробно опишите рассуждения шаг за шагом, а затем дайте краткий связный ответ."
    )
    return prompt

# Генерация ответов и измерение времени

def generate(question: str):
    prompt = build_prompt(question)
    results = {}
    for name, pipe in models.items():
        start = time.time()
        out = pipe(
            prompt,
            max_length=400,
            do_sample=True,
            top_p=0.9,
            temperature=0.7
        )[0]['generated_text']
        elapsed = round(time.time() - start, 2)
        # Извлечение итогового ответа после 'Ответ:' или последней строки
        if 'Ответ:' in out:
            answer = out.split('Ответ:')[-1].strip()
        else:
            answer = out.strip().split('\n')[-1]
        results[name] = {'answer': answer, 'time': elapsed}
    return results

# Форматируем данные для Gradio интерфейса

def format_outputs(question: str):
    res = generate(question)
    return (
        res['ruDialoGPT-small']['answer'], f"{res['ruDialoGPT-small']['time']}s",
        res['ruGPT3-small']['answer'], f"{res['ruGPT3-small']['time']}s",
        res['rut5-small-chitchat']['answer'], f"{res['rut5-small-chitchat']['time']}s"
    )

# Графический интерфейс Gradio

with gr.Blocks() as demo:
    gr.Markdown('## Ответы на клиентские обращения с CoT на трёх моделях и таймингом')
    txt = gr.Textbox(
        label='Описание проблемы клиента',
        placeholder='Например: "Почему я не могу снять деньги с карты?"',
        lines=2
    )
    btn = gr.Button('Сгенерировать ответ')
    out1 = gr.Textbox(label='ruDialoGPT-small Ответ')
    t1 = gr.Textbox(label='ruDialoGPT-small Время')
    out2 = gr.Textbox(label='ruGPT3-small Ответ')
    t2 = gr.Textbox(label='ruGPT3-small Время')
    out3 = gr.Textbox(label='rut5-small-chitchat Ответ')
    t3 = gr.Textbox(label='rut5-small-chitchat Время')
    btn.click(format_outputs, inputs=[txt], outputs=[out1, t1, out2, t2, out3, t3])
    demo.launch()