SimrusDenuvo commited on
Commit
a191742
·
verified ·
1 Parent(s): d5e56ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -123
app.py CHANGED
@@ -1,144 +1,83 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
 
 
4
  from datasets import load_dataset
5
 
6
- # Загрузка датасета
7
- dataset = load_dataset("ZhenDOS/alpha_bank_data")
 
 
 
8
 
9
- # Инициализация разных моделей
10
- model_name = "ai-forever/rugpt3large_based_on_gpt2"
11
- "GigaChat-like:" "ai-forever/rugpt3large_based_on_gpt2", # Русская модель большого размера
12
- "ChatGPT-like": "tinkoff-ai/ruDialoGPT-medium", # Диалоговая модель для русского языка
13
- "DeepSeek-like": "ai-forever/sbert_large_nlu_ru" # Русская модель для понимания текста
14
 
 
 
 
 
 
 
 
15
 
16
- # Инициализация моделей и токенизаторов
17
- models = model_name
18
- tokenizers = AutoTokenizer.from_pretrained(model_name)
19
 
20
- for model_name, model_path in MODELS.items():
21
- try:
22
- if model_name == "DeepSeek-like":
23
- # Для SBERT используем pipeline
24
- models[model_name] = pipeline("text-generation", model=model_path)
25
- else:
26
- tokenizers[model_name] = AutoTokenizer.from_pretrained(model_path)
27
- models[model_name] = AutoModelForCausalLM.from_pretrained(model_path)
28
- except Exception as e:
29
- print(f"Ошибка при загрузке модели {model_name}: {e}")
30
 
31
- # Промпты для обработки обращений
32
- PROMPTS = {
33
- "Анализ проблемы":
34
- "Проанализируй клиентское обращение и выдели основную проблему. "
35
- "Обращение: {text}\n\nПроблема:",
36
-
37
- "Формирование ответа":
38
- "Клиент обратился с проблемой: {problem}\n\n"
39
- "Сформируй вежливый и профессиональный ответ, предлагая решение. "
40
- "Используй информацию о банковских услугах. Ответ:"
41
- }
42
 
43
- def generate_with_model(prompt, model_name, max_length=150):
44
- """Генерация ответа с помощью выбранной модели"""
45
- if model_name not in models:
46
- return f"Модель {model_name} не загружена"
47
-
48
- try:
49
- if model_name == "DeepSeek-like":
50
- # Обработка через pipeline
51
- result = models[model_name](
52
- prompt,
53
- max_length=max_length,
54
- do_sample=True,
55
- temperature=0.7,
56
- top_p=0.9
57
- )
58
- return result[0]['generated_text']
59
- else:
60
- # Обработка через transformers
61
- inputs = tokenizers[model_name](prompt, return_tensors="pt", truncation=True)
62
-
63
  with torch.no_grad():
64
- outputs = models[model_name].generate(
65
  **inputs,
66
- max_new_tokens=max_length,
67
  do_sample=True,
68
  temperature=0.7,
69
  top_p=0.9,
70
- eos_token_id=tokenizers[model_name].eos_token_id
71
  )
72
-
73
- response = tokenizers[model_name].decode(outputs[0], skip_special_tokens=True)
74
- return response[len(prompt):] if response.startswith(prompt) else response
75
- except Exception as e:
76
- return f"Ошибка генерации: {str(e)}"
77
-
78
- def process_complaint(text, prompt_type):
79
- """Обработка клиентского обращения с выбранным промптом"""
80
- if prompt_type not in PROMPTS:
81
- return "Неверный тип промпта"
82
-
83
- # Получаем случайный пример из датасета, если текст не введен
84
- if not text.strip():
85
- example = dataset['train'].shuffle().select(range(1))[0]
86
- text = example['text']
87
-
88
- prompt = PROMPTS[prompt_type].format(text=text, problem="")
89
-
90
- results = {}
91
- for model_name in MODELS.keys():
92
- results[model_name] = generate_with_model(prompt, model_name)
93
-
94
  return results
95
 
96
- # Интерфейс Gradio
97
- with gr.Blocks(title="Анализ клиентских обращений Alpha Bank") as demo:
98
- gr.Markdown("## Анализ клиентских обращений Alpha Bank")
99
- gr.Markdown("Тестирование разных моделей на обработку обращений")
100
-
101
- with gr.Row():
102
- with gr.Column():
103
- text_input = gr.Textbox(
104
- label="Текст обращения",
105
- placeholder="Введите текст обращения или оставьте пустым для примера из датасета",
106
- lines=5
107
- )
108
- prompt_type = gr.Radio(
109
- list(PROMPTS.keys()),
110
- label="Тип промпта",
111
- value=list(PROMPTS.keys())[0]
112
- )
113
- submit_btn = gr.Button("Обработать")
114
-
115
- with gr.Column():
116
- outputs = []
117
- for model_name in MODELS.keys():
118
- outputs.append(
119
- gr.Textbox(
120
- label=f"{model_name}",
121
- interactive=False,
122
- lines=5
123
- )
124
- )
125
-
126
- # Примеры из датасета
127
- examples = gr.Examples(
128
- examples=[x['text'] for x in dataset['train'].select(range(3))],
129
- inputs=text_input,
130
- label="Примеры из датасета"
131
- )
132
-
133
- def process_and_display(text, prompt_type):
134
- results = process_complaint(text, prompt_type)
135
- return [results.get(model_name, "") for model_name in MODELS.keys()]
136
-
137
- submit_btn.click(
138
- fn=process_and_display,
139
- inputs=[text_input, prompt_type],
140
- outputs=outputs
141
- )
142
 
143
  if __name__ == "__main__":
144
  demo.launch()
 
1
  import gradio as gr
 
2
  import torch
3
+ import time
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from datasets import load_dataset
6
 
7
+ MODEL_CONFIGS = {
8
+ "GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
9
+ "ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
10
+ "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
11
+ }
12
 
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
14
 
15
+ models = {}
16
+ for label, name in MODEL_CONFIGS.items():
17
+ tokenizer = AutoTokenizer.from_pretrained(name)
18
+ model = AutoModelForCausalLM.from_pretrained(name)
19
+ model.to(device)
20
+ model.eval()
21
+ models[label] = (tokenizer, model)
22
 
23
+ # Загрузка датасета
24
+ load_dataset("ZhenDOS/alpha_bank_data", split="train")
 
25
 
26
+ def cot_prompt_1(text):
27
+ return f"Клиент задал вопрос: {text}\nПодумай шаг за шагом и объясни, как бы ты ответил на это обращение от лица банка."
 
 
 
 
 
 
 
 
28
 
29
+ def cot_prompt_2(text):
30
+ return f"Вопрос клиента: {text}\nРазложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями."
 
 
 
 
 
 
 
 
 
31
 
32
+ def generate_all_responses(question):
33
+ results = {}
34
+ for model_name, (tokenizer, model) in models.items():
35
+ results[model_name] = {}
36
+ for i, prompt_func in enumerate([cot_prompt_1, cot_prompt_2], start=1):
37
+ prompt = prompt_func(question)
38
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
39
+ inputs = {k: v.to(device) for k, v in inputs.items()}
40
+ start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
41
  with torch.no_grad():
42
+ outputs = model.generate(
43
  **inputs,
44
+ max_new_tokens=200,
45
  do_sample=True,
46
  temperature=0.7,
47
  top_p=0.9,
48
+ eos_token_id=tokenizer.eos_token_id
49
  )
50
+ end_time = time.time()
51
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
+ response = response.replace(prompt, "").strip()
53
+ duration = round(end_time - start_time, 2)
54
+ results[model_name][f"CoT Промпт {i}"] = {
55
+ "response": response,
56
+ "time": f"{duration} сек."
57
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  return results
59
 
60
+ def display_responses(question):
61
+ all_responses = generate_all_responses(question)
62
+ output = ""
63
+ for model_name, prompts in all_responses.items():
64
+ output += f"\n### Модель: {model_name}\n"
65
+ for prompt_label, content in prompts.items():
66
+ output += f"\n**{prompt_label}** ({content['time']}):\n{content['response']}\n"
67
+ return output.strip()
68
+
69
+ demo = gr.Interface(
70
+ fn=display_responses,
71
+ inputs=gr.Textbox(lines=4, label="Введите клиентский вопрос"),
72
+ outputs=gr.Markdown(label="Ответы от разных моделей"),
73
+ title="Alpha Bank Assistant — сравнение моделей",
74
+ description="Сравнение CoT-ответов от GigaChat, ChatGPT и DeepSeek-подобных моделей на обращение клиента.",
75
+ examples=[
76
+ "Как восстановить доступ в мобильный банк?",
77
+ "Почему с меня списали комиссию за обслуживание карты?",
78
+ "Какие условия по потребительскому кредиту?",
79
+ ]
80
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  if __name__ == "__main__":
83
  demo.launch()