File size: 4,399 Bytes
57645e8
66fabfa
f4f65e1
52af11a
19236d7
4fb6307
19236d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fb6307
66fabfa
19236d7
f5f461a
19236d7
52af11a
19236d7
 
 
 
 
 
 
 
 
 
079a9a0
19236d7
 
 
 
52af11a
66fabfa
f5f461a
66fabfa
19236d7
f5f461a
19236d7
f5f461a
 
19236d7
f5f461a
19236d7
f5f461a
 
 
19236d7
f5f461a
19236d7
f5f461a
 
 
66fabfa
19236d7
 
 
 
 
 
66fabfa
 
19236d7
 
 
f5f461a
19236d7
f5f461a
 
19236d7
f5f461a
 
 
 
 
19236d7
f5f461a
 
 
 
 
19236d7
f5f461a
 
 
 
 
19236d7
f5f461a
 
 
 
a4bd38b
cafcd4f
52af11a
19236d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import time
from transformers import pipeline

# Инициализация моделей
models = {
    "ruGPT3-Medium": pipeline(
        "text-generation",
        model="ai-forever/rugpt3medium_based_on_gpt2",
        tokenizer="ai-forever/rugpt3medium_based_on_gpt2",
        device=-1
    ),
    "ruGPT3-Small": pipeline(
        "text-generation",
        model="ai-forever/rugpt3small_based_on_gpt2",
        tokenizer="ai-forever/rugpt3small_based_on_gpt2",
        device=-1
    ),
    "SambaLingo": pipeline(
        "text-generation",
        model="sambanovasystems/SambaLingo-Russian-Chat",
        tokenizer="sambanovasystems/SambaLingo-Russian-Chat",
        device=-1
    )
}

# Построение обычного промпта
def build_simple_prompt(user_input):
    return f"Клиент: {user_input}\nКатегория обращения:"

# Построение CoT-промпта
def build_cot_prompt(user_input):
    return (
        f"Клиент: {user_input}\n"
        "Проанализируй обращение клиента пошагово:\n"
        "1. Определи суть проблемы.\n"
        "2. Выяви возможные причины.\n"
        "3. Предложи возможные решения.\n"
        "На основе анализа, укажи категорию обращения:"
    )

# Генерация результатов
def generate_classification(user_input):
    prompt_simple = build_simple_prompt(user_input)
    prompt_cot = build_cot_prompt(user_input)

    results = {}

    for name, pipe in models.items():
        # CoT ответ
        start_cot = time.time()
        cot_output = pipe(prompt_cot, max_new_tokens=100, do_sample=True, top_p=0.9, temperature=0.7)[0]["generated_text"]
        end_cot = round(time.time() - start_cot, 2)

        # Обычный ответ
        start_simple = time.time()
        simple_output = pipe(prompt_simple, max_new_tokens=60, do_sample=True, top_p=0.9, temperature=0.7)[0]["generated_text"]
        end_simple = round(time.time() - start_simple, 2)

        results[name] = {
            "cot_answer": cot_output.strip(),
            "cot_time": end_cot,
            "simple_answer": simple_output.strip(),
            "simple_time": end_simple
        }

    return (
        results["ruGPT3-Medium"]["cot_answer"], f"{results['ruGPT3-Medium"]["cot_time"]} сек",
        results["ruGPT3-Medium"]["simple_answer"], f"{results["ruGPT3-Medium"]["simple_time"]} сек",
        results["ruGPT3-Small"]["cot_answer"], f"{results["ruGPT3-Small"]["cot_time"]} сек",
        results["ruGPT3-Small"]["simple_answer"], f"{results["ruGPT3-Small"]["simple_time"]} сек",
        results["SambaLingo"]["cot_answer"], f"{results["SambaLingo"]["cot_time"]} сек",
        results["SambaLingo"]["simple_answer"], f"{results["SambaLingo"]["simple_time"]} сек"
    )

# Gradio UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("🏦 **Банковский помощник: CoT vs. Обычный ответ (магистерская работа)**")

    inp = gr.Textbox(label="Вопрос клиента", placeholder="Например: Я не могу перевести деньги на другую карту", lines=2)
    btn = gr.Button("Сгенерировать")

    gr.Markdown("### ruGPT3-Medium")
    cot1 = gr.Textbox(label="CoT ответ")
    cot1_time = gr.Textbox(label="Время CoT")
    simple1 = gr.Textbox(label="Обычный ответ")
    simple1_time = gr.Textbox(label="Время обычного")

    gr.Markdown("### ruGPT3-Small")
    cot2 = gr.Textbox(label="CoT ответ")
    cot2_time = gr.Textbox(label="Время CoT")
    simple2 = gr.Textbox(label="Обычный ответ")
    simple2_time = gr.Textbox(label="Время обычного")

    gr.Markdown("### SambaLingo-Russian-Chat")
    cot3 = gr.Textbox(label="CoT ответ")
    cot3_time = gr.Textbox(label="Время CoT")
    simple3 = gr.Textbox(label="Обычный ответ")
    simple3_time = gr.Textbox(label="Время обычного")

    btn.click(generate_classification, inputs=[inp], outputs=[
        cot1, cot1_time, simple1, simple1_time,
        cot2, cot2_time, simple2, simple2_time,
        cot3, cot3_time, simple3, simple3_time
    ])

demo.launch()