File size: 3,607 Bytes
57645e8
57878bd
a191742
 
af5aa20
bd826f0
a191742
 
 
 
 
884ac0c
a191742
d5e56ca
a191742
 
 
 
 
 
 
d0ca3ae
a191742
 
1382bb2
a191742
 
1382bb2
a191742
 
af5aa20
a191742
 
 
 
 
 
 
 
 
1382bb2
a191742
1382bb2
a191742
1382bb2
 
 
a191742
1382bb2
a191742
 
 
 
 
 
 
 
1382bb2
5dee7eb
a191742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dee7eb
1e10719
90cd180
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

MODEL_CONFIGS = {
    "GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
    "ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
    "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

models = {}
for label, name in MODEL_CONFIGS.items():
    tokenizer = AutoTokenizer.from_pretrained(name)
    model = AutoModelForCausalLM.from_pretrained(name)
    model.to(device)
    model.eval()
    models[label] = (tokenizer, model)

# Загрузка датасета
load_dataset("ZhenDOS/alpha_bank_data", split="train")

def cot_prompt_1(text):
    return f"Клиент задал вопрос: {text}\nПодумай шаг за шагом и объясни, как бы ты ответил на это обращение от лица банка."

def cot_prompt_2(text):
    return f"Вопрос клиента: {text}\nРазложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями."

def generate_all_responses(question):
    results = {}
    for model_name, (tokenizer, model) in models.items():
        results[model_name] = {}
        for i, prompt_func in enumerate([cot_prompt_1, cot_prompt_2], start=1):
            prompt = prompt_func(question)
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            start_time = time.time()
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=200,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    eos_token_id=tokenizer.eos_token_id
                )
            end_time = time.time()
            response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = response.replace(prompt, "").strip()
            duration = round(end_time - start_time, 2)
            results[model_name][f"CoT Промпт {i}"] = {
                "response": response,
                "time": f"{duration} сек."
            }
    return results

def display_responses(question):
    all_responses = generate_all_responses(question)
    output = ""
    for model_name, prompts in all_responses.items():
        output += f"\n### Модель: {model_name}\n"
        for prompt_label, content in prompts.items():
            output += f"\n**{prompt_label}** ({content['time']}):\n{content['response']}\n"
    return output.strip()

demo = gr.Interface(
    fn=display_responses,
    inputs=gr.Textbox(lines=4, label="Введите клиентский вопрос"),
    outputs=gr.Markdown(label="Ответы от разных моделей"),
    title="Alpha Bank Assistant — сравнение моделей",
    description="Сравнение CoT-ответов от GigaChat, ChatGPT и DeepSeek-подобных моделей на обращение клиента.",
    examples=[
        "Как восстановить доступ в мобильный банк?",
        "Почему с меня списали комиссию за обслуживание карты?",
        "Какие условия по потребительскому кредиту?",
    ]
)

if __name__ == "__main__":
    demo.launch()