File size: 4,728 Bytes
57645e8
66fabfa
7e06069
66fabfa
1267d48
39ebd04
7e06069
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a191742
7e06069
 
 
 
 
 
 
 
 
 
 
 
 
39ebd04
 
 
 
 
7e06069
39ebd04
7e06069
66fabfa
39ebd04
7e06069
39ebd04
 
 
 
 
 
 
 
 
 
 
66fabfa
7e06069
66fabfa
39ebd04
 
66fabfa
39ebd04
 
 
66fabfa
 
 
 
39ebd04
66fabfa
39ebd04
66fabfa
 
 
 
7e06069
39ebd04
 
 
 
 
 
 
66fabfa
7e06069
66fabfa
 
 
 
 
 
 
7e06069
66fabfa
39ebd04
66fabfa
 
 
39ebd04
 
66fabfa
 
 
1267d48
39ebd04
7e06069
 
 
 
 
39ebd04
66fabfa
 
39ebd04
 
 
 
7e06069
 
 
 
 
90cd180
e07efc3
961a138
8fab447
1267d48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import time
from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

# Инициализация трёх бесплатных русскоязычных моделей
models = {}

# 1) ruDialoGPT-small
models['ruDialoGPT-small'] = pipeline(
    'text-generation',
    model='t-bank-ai/ruDialoGPT-small',
    tokenizer='t-bank-ai/ruDialoGPT-small',
    device=-1
)

# 2) ruGPT3-small
models['ruGPT3-small'] = pipeline(
    'text-generation',
    model='ai-forever/rugpt3small_based_on_gpt2',
    tokenizer='ai-forever/rugpt3small_based_on_gpt2',
    device=-1
)

# 3) rut5-small-chitchat (T5 requires text2text и slow tokenizer)
t5_tokenizer = AutoTokenizer.from_pretrained(
    'cointegrated/rut5-small-chitchat',
    use_fast=False
)
models['rut5-small-chitchat'] = pipeline(
    'text2text-generation',
    model='cointegrated/rut5-small-chitchat',
    tokenizer=t5_tokenizer,
    device=-1
)

# Загрузка "мини" банковского датасета для few-shot (стриминг)
bank_data_stream = load_dataset(
    'ai-lab/MBD-mini',
    split='train',
    streaming=True
)
# Определяем колонку с диалогами по ключам
first_record = next(iter(bank_data_stream))
col = next((c for c in first_record.keys() if 'dialog' in c.lower() or 'диалог' in c.lower()), None)
if col is None:
    raise ValueError('Не найдена колонка с диалогами в MBD-mini')
# Собираем два few-shot примера
examples = []
for rec in bank_data_stream:
    examples.append(rec[col])
    if len(examples) == 2:
        break

# Системная инструкция для CoT
system_instruction = (
    "Вы — банковский ассистент. Ваша задача — корректно и вежливо отвечать на запросы клиентов банка,"
    " рассказывать о причинах и способах решения их проблем с банковскими услугами."
)

# Построение CoT-промпта с few-shot

def build_prompt(question: str) -> str:
    few_shot_text = '\n\n'.join(f"Пример диалога:\n{ex}" for ex in examples)
    prompt = (
        f"{system_instruction}\n\n"
        f"{few_shot_text}\n\n"
        f"Вопрос клиента: {question}\n"
        "Сначала подробно опишите рассуждения шаг за шагом, а затем дайте краткий связный ответ."
    )
    return prompt

# Генерация ответов и измерение времени

def generate(question: str):
    prompt = build_prompt(question)
    results = {}
    for name, pipe in models.items():
        start = time.time()
        # для T5 используем text2text, для других text-generation
        out = pipe(
            prompt,
            max_length=400,
            do_sample=True,
            top_p=0.9,
            temperature=0.7
        )[0]['generated_text']
        elapsed = round(time.time() - start, 2)
        # Извлечь финальный ответ
        if 'Ответ:' in out:
            answer = out.split('Ответ:')[-1].strip()
        else:
            answer = out.strip().split('\n')[-1]
        results[name] = {'answer': answer, 'time': elapsed}
    return results

# Подготовка вывода для Gradio

def format_outputs(question: str):
    res = generate(question)
    return (
        res['ruDialoGPT-small']['answer'], f"{res['ruDialoGPT-small']['time']}s",
        res['ruGPT3-small']['answer'], f"{res['ruGPT3-small']['time']}s",
        res['rut5-small-chitchat']['answer'], f"{res['rut5-small-chitchat']['time']}s"
    )

# Интерфейс Gradio
with gr.Blocks() as demo:
    gr.Markdown('## Клиентские обращения: CoT на трёх моделях с MBD-mini и тайминг')
    txt = gr.Textbox(
        label='Опишите проблему клиента',
        placeholder='Например: "Почему я не могу снять деньги с карты?"',
        lines=2
    )
    btn = gr.Button('Сгенерировать ответ')
    out1 = gr.Textbox(label='ruDialoGPT-small Ответ')
    t1 = gr.Textbox(label='ruDialoGPT-small Время')
    out2 = gr.Textbox(label='ruGPT3-small Ответ')
    t2 = gr.Textbox(label='ruGPT3-small Время')
    out3 = gr.Textbox(label='rut5-small-chitchat Ответ')
    t3 = gr.Textbox(label='rut5-small-chitchat Время')
    btn.click(
        format_outputs,
        inputs=[txt],
        outputs=[out1, t1, out2, t2, out3, t3]
    )
    demo.launch()