SimrusDenuvo commited on
Commit
1382bb2
·
verified ·
1 Parent(s): d0ca3ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -81
app.py CHANGED
@@ -1,100 +1,144 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from datasets import load_dataset
5
 
6
- # Загружаем модель и токенизатор
7
- model_name = "ai-forever/rugpt3small_based_on_gpt2"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForCausalLM.from_pretrained(model_name)
10
 
11
- # Дополнительные знания о банковских услугах
12
- BANK_KNOWLEDGE = {
13
- "Как проверить баланс карты?": [
14
- "1. Через мобильное приложение банка (раздел 'Карты' → 'Баланс')",
15
- "2. В интернет-банке личном кабинете выберите карту)",
16
- "3. По SMS (отправьте BALANCE на номер 900)",
17
- "4. В банкомате (вставьте карту и выберите 'Запрос баланса')",
18
- "5. По телефону горячей линии (8-800-100-00-00)"
19
- ],
20
- "Как восстановить утерянную карту?": [
21
- "1. Немедленно позвоните в банк по телефону 8-800-100-00-00 для блокировки карты",
22
- "2. Обратитесь в отделение банка с паспортом",
23
- "3. Заполните заявление на перевыпуск карты",
24
- "4. Новая карта будет готова через 3-5 рабочих дней"
25
- ]
26
  }
27
 
28
- def enhance_response(question, generated_response):
29
- # Если вопрос есть в наших знаниях, возвращаем структурированный ответ
30
- if question in BANK_KNOWLEDGE:
31
- return "\n".join(BANK_KNOWLEDGE[question])
32
-
33
- # Улучшаем стандартные ответы модели
34
- improvements = {
35
- "баланс": "Вы можете проверить баланс карты:\n"
36
- "1. В мобильном приложении\n"
37
- "2. Через интернет-банк\n"
38
- "3. В банкомате\n"
39
- "4. По телефону горячей линии 8-800-100-00-00",
40
- "кредит": "По вопросам кредитования вы можете:\n"
41
- "1. Оставить заявку на сайте\n"
42
- "2. Обратиться в отделение банка\n"
43
- "3. Позвонить по телефону 8-800-100-00-00",
44
- "карт": "По вопросам банковских карт:\n"
45
- "1. Обратитесь в отделение банка\n"
46
- "2. Позвоните на горячую линию\n"
47
- "3. Используйте чат в мобильном приложении"
48
- }
49
-
50
- for keyword, improved_answer in improvements.items():
51
- if keyword in question.lower():
52
- return improved_answer
53
 
54
- return generated_response
 
 
 
 
55
 
56
- def generate_response(prompt):
57
- # Генерируем ответ с помощью модели
58
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
59
 
60
- with torch.no_grad():
61
- outputs = model.generate(
62
- **inputs,
63
- max_new_tokens=150,
64
- do_sample=True,
65
- temperature=0.7,
66
- top_k=50,
67
- top_p=0.95,
68
- eos_token_id=tokenizer.eos_token_id,
69
- no_repeat_ngram_size=3,
70
- early_stopping=True
71
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
74
 
75
- # Удаляем промпт из ответа
76
- if response.startswith(prompt):
77
- response = response[len(prompt):].strip()
78
 
79
- # Улучшаем ответ
80
- enhanced_response = enhance_response(prompt, response)
 
81
 
82
- return enhanced_response
83
 
84
  # Интерфейс Gradio
85
- demo = gr.Interface(
86
- fn=generate_response,
87
- inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"),
88
- outputs=gr.Textbox(label="Ответ модели"),
89
- title="Анализ клиентских обращений — Alpha Bank Assistant",
90
- description="Получите точные ответы на вопросы о банковских услугах",
91
- examples=[
92
- "Как проверить баланс карты?",
93
- "Как восстановить утерянную карту?",
94
- "Как оформить кредитную карту?",
95
- "Какие документы нужны для открытия счета?"
96
- ]
97
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  if __name__ == "__main__":
100
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  import torch
4
  from datasets import load_dataset
5
 
6
+ # Загрузка датасета
7
+ dataset = load_dataset("ZhenDOS/alpha_bank_data")
 
 
8
 
9
+ # Инициализация разных моделей
10
+ MODELS = {
11
+ "GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2", # Русская модель большого размера
12
+ "ChatGPT-like": "tinkoff-ai/ruDialoGPT-medium", # Диалоговая модель для русского языка
13
+ "DeepSeek-like": "ai-forever/sbert_large_nlu_ru" # Русская модель для понимания текста
 
 
 
 
 
 
 
 
 
 
14
  }
15
 
16
+ # Инициализация моделей и токенизаторов
17
+ models = {}
18
+ tokenizers = {}
19
+
20
+ for model_name, model_path in MODELS.items():
21
+ try:
22
+ if model_name == "DeepSeek-like":
23
+ # Для SBERT используем pipeline
24
+ models[model_name] = pipeline("text-generation", model=model_path)
25
+ else:
26
+ tokenizers[model_name] = AutoTokenizer.from_pretrained(model_path)
27
+ models[model_name] = AutoModelForCausalLM.from_pretrained(model_path)
28
+ except Exception as e:
29
+ print(f"Ошибка при загрузке модели {model_name}: {e}")
30
+
31
+ # Промпты для обработки обращений
32
+ PROMPTS = {
33
+ "Анализ проблемы":
34
+ "Проанализируй клиентское обращение и выдели основную проблему. "
35
+ "Обращение: {text}\n\nПроблема:",
 
 
 
 
 
36
 
37
+ "Формирование ответа":
38
+ "Клиент обратился с проблемой: {problem}\n\n"
39
+ "Сформируй вежливый и профессиональный ответ, предлагая решение. "
40
+ "Используй информацию о банковских услугах. Ответ:"
41
+ }
42
 
43
+ def generate_with_model(prompt, model_name, max_length=150):
44
+ """Генерация ответа с помощью выбранной модели"""
45
+ if model_name not in models:
46
+ return f"Модель {model_name} не загружена"
47
 
48
+ try:
49
+ if model_name == "DeepSeek-like":
50
+ # Обработка через pipeline
51
+ result = models[model_name](
52
+ prompt,
53
+ max_length=max_length,
54
+ do_sample=True,
55
+ temperature=0.7,
56
+ top_p=0.9
57
+ )
58
+ return result[0]['generated_text']
59
+ else:
60
+ # Обработка через transformers
61
+ inputs = tokenizers[model_name](prompt, return_tensors="pt", truncation=True)
62
+
63
+ with torch.no_grad():
64
+ outputs = models[model_name].generate(
65
+ **inputs,
66
+ max_new_tokens=max_length,
67
+ do_sample=True,
68
+ temperature=0.7,
69
+ top_p=0.9,
70
+ eos_token_id=tokenizers[model_name].eos_token_id
71
+ )
72
+
73
+ response = tokenizers[model_name].decode(outputs[0], skip_special_tokens=True)
74
+ return response[len(prompt):] if response.startswith(prompt) else response
75
+ except Exception as e:
76
+ return f"Ошибка генерации: {str(e)}"
77
+
78
+ def process_complaint(text, prompt_type):
79
+ """Обработка клиентского обращения с выбранным промптом"""
80
+ if prompt_type not in PROMPTS:
81
+ return "Неверный тип промпта"
82
 
83
+ # Получаем случайный пример из датасета, если текст не введен
84
+ if not text.strip():
85
+ example = dataset['train'].shuffle().select(range(1))[0]
86
+ text = example['text']
87
 
88
+ prompt = PROMPTS[prompt_type].format(text=text, problem="")
 
 
89
 
90
+ results = {}
91
+ for model_name in MODELS.keys():
92
+ results[model_name] = generate_with_model(prompt, model_name)
93
 
94
+ return results
95
 
96
  # Интерфейс Gradio
97
+ with gr.Blocks(title="Анализ клиентских обращений Alpha Bank") as demo:
98
+ gr.Markdown("## Анализ клиентских обращений Alpha Bank")
99
+ gr.Markdown("Тестирование разных моделей на обработку обращений")
100
+
101
+ with gr.Row():
102
+ with gr.Column():
103
+ text_input = gr.Textbox(
104
+ label="Текст обращения",
105
+ placeholder="Введите текст обращения или оставьте пустым для примера из датасета",
106
+ lines=5
107
+ )
108
+ prompt_type = gr.Radio(
109
+ list(PROMPTS.keys()),
110
+ label="Тип промпта",
111
+ value=list(PROMPTS.keys())[0]
112
+ )
113
+ submit_btn = gr.Button("Обработать")
114
+
115
+ with gr.Column():
116
+ outputs = []
117
+ for model_name in MODELS.keys():
118
+ outputs.append(
119
+ gr.Textbox(
120
+ label=f"{model_name}",
121
+ interactive=False,
122
+ lines=5
123
+ )
124
+ )
125
+
126
+ # Примеры из датасета
127
+ examples = gr.Examples(
128
+ examples=[x['text'] for x in dataset['train'].select(range(3))],
129
+ inputs=text_input,
130
+ label="Примеры из датасета"
131
+ )
132
+
133
+ def process_and_display(text, prompt_type):
134
+ results = process_complaint(text, prompt_type)
135
+ return [results.get(model_name, "") for model_name in MODELS.keys()]
136
+
137
+ submit_btn.click(
138
+ fn=process_and_display,
139
+ inputs=[text_input, prompt_type],
140
+ outputs=outputs
141
+ )
142
 
143
  if __name__ == "__main__":
144
  demo.launch()