File size: 6,030 Bytes
57645e8
1382bb2
57878bd
af5aa20
bd826f0
1382bb2
 
884ac0c
1382bb2
d5e56ca
 
1382bb2
 
d5e56ca
d0ca3ae
1382bb2
d5e56ca
 
1382bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7882653
1382bb2
 
 
 
 
af5aa20
1382bb2
 
 
 
af5aa20
1382bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af5aa20
1382bb2
 
 
 
af5aa20
1382bb2
af5aa20
1382bb2
 
 
af5aa20
1382bb2
5dee7eb
7882653
1382bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dee7eb
1e10719
90cd180
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from datasets import load_dataset

# Загрузка датасета
dataset = load_dataset("ZhenDOS/alpha_bank_data")

# Инициализация разных моделей
model_name = "ai-forever/rugpt3large_based_on_gpt2"
    "GigaChat-like:" "ai-forever/rugpt3large_based_on_gpt2",  # Русская модель большого размера
    "ChatGPT-like": "tinkoff-ai/ruDialoGPT-medium",           # Диалоговая модель для русского языка
    "DeepSeek-like": "ai-forever/sbert_large_nlu_ru"           # Русская модель для понимания текста


# Инициализация моделей и токенизаторов
models = model_name
tokenizers = AutoTokenizer.from_pretrained(model_name)

for model_name, model_path in MODELS.items():
    try:
        if model_name == "DeepSeek-like":
            # Для SBERT используем pipeline
            models[model_name] = pipeline("text-generation", model=model_path)
        else:
            tokenizers[model_name] = AutoTokenizer.from_pretrained(model_path)
            models[model_name] = AutoModelForCausalLM.from_pretrained(model_path)
    except Exception as e:
        print(f"Ошибка при загрузке модели {model_name}: {e}")

# Промпты для обработки обращений
PROMPTS = {
    "Анализ проблемы": 
        "Проанализируй клиентское обращение и выдели основную проблему. "
        "Обращение: {text}\n\nПроблема:",
    
    "Формирование ответа":
        "Клиент обратился с проблемой: {problem}\n\n"
        "Сформируй вежливый и профессиональный ответ, предлагая решение. "
        "Используй информацию о банковских услугах. Ответ:"
}

def generate_with_model(prompt, model_name, max_length=150):
    """Генерация ответа с помощью выбранной модели"""
    if model_name not in models:
        return f"Модель {model_name} не загружена"
    
    try:
        if model_name == "DeepSeek-like":
            # Обработка через pipeline
            result = models[model_name](
                prompt,
                max_length=max_length,
                do_sample=True,
                temperature=0.7,
                top_p=0.9
            )
            return result[0]['generated_text']
        else:
            # Обработка через transformers
            inputs = tokenizers[model_name](prompt, return_tensors="pt", truncation=True)
            
            with torch.no_grad():
                outputs = models[model_name].generate(
                    **inputs,
                    max_new_tokens=max_length,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    eos_token_id=tokenizers[model_name].eos_token_id
                )
            
            response = tokenizers[model_name].decode(outputs[0], skip_special_tokens=True)
            return response[len(prompt):] if response.startswith(prompt) else response
    except Exception as e:
        return f"Ошибка генерации: {str(e)}"

def process_complaint(text, prompt_type):
    """Обработка клиентского обращения с выбранным промптом"""
    if prompt_type not in PROMPTS:
        return "Неверный тип промпта"
    
    # Получаем случайный пример из датасета, если текст не введен
    if not text.strip():
        example = dataset['train'].shuffle().select(range(1))[0]
        text = example['text']
    
    prompt = PROMPTS[prompt_type].format(text=text, problem="")
    
    results = {}
    for model_name in MODELS.keys():
        results[model_name] = generate_with_model(prompt, model_name)
    
    return results

# Интерфейс Gradio
with gr.Blocks(title="Анализ клиентских обращений Alpha Bank") as demo:
    gr.Markdown("## Анализ клиентских обращений Alpha Bank")
    gr.Markdown("Тестирование разных моделей на обработку обращений")
    
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Текст обращения",
                placeholder="Введите текст обращения или оставьте пустым для примера из датасета",
                lines=5
            )
            prompt_type = gr.Radio(
                list(PROMPTS.keys()),
                label="Тип промпта",
                value=list(PROMPTS.keys())[0]
            )
            submit_btn = gr.Button("Обработать")
        
        with gr.Column():
            outputs = []
            for model_name in MODELS.keys():
                outputs.append(
                    gr.Textbox(
                        label=f"{model_name}",
                        interactive=False,
                        lines=5
                    )
                )
    
    # Примеры из датасета
    examples = gr.Examples(
        examples=[x['text'] for x in dataset['train'].select(range(3))],
        inputs=text_input,
        label="Примеры из датасета"
    )
    
    def process_and_display(text, prompt_type):
        results = process_complaint(text, prompt_type)
        return [results.get(model_name, "") for model_name in MODELS.keys()]
    
    submit_btn.click(
        fn=process_and_display,
        inputs=[text_input, prompt_type],
        outputs=outputs
    )

if __name__ == "__main__":
    demo.launch()