SimrusDenuvo commited on
Commit
74a6389
·
verified ·
1 Parent(s): a191742

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -44
app.py CHANGED
@@ -4,42 +4,54 @@ import time
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from datasets import load_dataset
6
 
 
7
  MODEL_CONFIGS = {
8
- "GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
9
- "ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
10
- "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
11
  }
12
 
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
 
15
  models = {}
16
- for label, name in MODEL_CONFIGS.items():
17
- tokenizer = AutoTokenizer.from_pretrained(name)
18
- model = AutoModelForCausalLM.from_pretrained(name)
19
- model.to(device)
20
- model.eval()
21
  models[label] = (tokenizer, model)
22
 
23
- # Загрузка датасета
 
24
  load_dataset("ZhenDOS/alpha_bank_data", split="train")
25
 
26
- def cot_prompt_1(text):
27
- return f"Клиент задал вопрос: {text}\nПодумай шаг за шагом и объясни, как бы ты ответил на это обращение от лица банка."
 
 
 
 
28
 
29
- def cot_prompt_2(text):
30
- return f"Вопрос клиента: {text}\nРазложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями."
 
 
 
31
 
32
- def generate_all_responses(question):
 
33
  results = {}
34
- for model_name, (tokenizer, model) in models.items():
35
- results[model_name] = {}
36
- for i, prompt_func in enumerate([cot_prompt_1, cot_prompt_2], start=1):
37
- prompt = prompt_func(question)
38
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
39
  inputs = {k: v.to(device) for k, v in inputs.items()}
40
- start_time = time.time()
 
41
  with torch.no_grad():
42
- outputs = model.generate(
43
  **inputs,
44
  max_new_tokens=200,
45
  do_sample=True,
@@ -47,36 +59,36 @@ def generate_all_responses(question):
47
  top_p=0.9,
48
  eos_token_id=tokenizer.eos_token_id
49
  )
50
- end_time = time.time()
51
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
52
- response = response.replace(prompt, "").strip()
53
- duration = round(end_time - start_time, 2)
54
- results[model_name][f"CoT Промпт {i}"] = {
55
- "response": response,
56
- "time": f"{duration} сек."
 
 
 
57
  }
58
  return results
59
 
60
- def display_responses(question):
61
- all_responses = generate_all_responses(question)
62
- output = ""
63
- for model_name, prompts in all_responses.items():
64
- output += f"\n### Модель: {model_name}\n"
65
- for prompt_label, content in prompts.items():
66
- output += f"\n**{prompt_label}** ({content['time']}):\n{content['response']}\n"
67
- return output.strip()
 
68
 
 
69
  demo = gr.Interface(
70
  fn=display_responses,
71
- inputs=gr.Textbox(lines=4, label="Введите клиентский вопрос"),
72
- outputs=gr.Markdown(label="Ответы от разных моделей"),
73
- title="Alpha Bank Assistant — сравнение моделей",
74
- description="Сравнение CoT-ответов от GigaChat, ChatGPT и DeepSeek-подобных моделей на обращение клиента.",
75
- examples=[
76
- "Как восстановить доступ в мобильный банк?",
77
- "Почему с меня списали комиссию за обслуживание карты?",
78
- "Какие условия по потребительскому кредиту?",
79
- ]
80
  )
81
 
82
  if __name__ == "__main__":
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from datasets import load_dataset
6
 
7
+ # 1) Публичные русскоязычные модели из RuGPT-3
8
  MODEL_CONFIGS = {
9
+ "GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
10
+ "ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
11
+ "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
12
  }
13
 
14
+ # 2) Устройство (GPU если есть, иначе CPU)
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # 3) Загрузка моделей и токенизаторов
18
  models = {}
19
+ for label, repo_id in MODEL_CONFIGS.items():
20
+ tokenizer = AutoTokenizer.from_pretrained(repo_id)
21
+ model = AutoModelForCausalLM.from_pretrained(repo_id)
22
+ model.to(device).eval()
 
23
  models[label] = (tokenizer, model)
24
 
25
+ # 4) (По необходимости) загрузка датасета для примеров / дообучения
26
+ # Если не нужен — можно закомментировать
27
  load_dataset("ZhenDOS/alpha_bank_data", split="train")
28
 
29
+ # 5) CoT-промпты
30
+ def cot_prompt_1(text: str) -> str:
31
+ return (
32
+ f"Клиент задал вопрос: «{text}»\n"
33
+ "Подумай шаг за шагом и подробно объясни ответ от лица банка."
34
+ )
35
 
36
+ def cot_prompt_2(text: str) -> str:
37
+ return (
38
+ f"Вопрос клиента: «{text}»\n"
39
+ "Разложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями."
40
+ )
41
 
42
+ # 6) Генерация ответов и замер времени
43
+ def generate_all_responses(question: str):
44
  results = {}
45
+ for name, (tokenizer, model) in models.items():
46
+ results[name] = {}
47
+ for idx, prompt_fn in enumerate([cot_prompt_1, cot_prompt_2], start=1):
48
+ prompt = prompt_fn(question)
49
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
50
  inputs = {k: v.to(device) for k, v in inputs.items()}
51
+
52
+ start = time.time()
53
  with torch.no_grad():
54
+ output_ids = model.generate(
55
  **inputs,
56
  max_new_tokens=200,
57
  do_sample=True,
 
59
  top_p=0.9,
60
  eos_token_id=tokenizer.eos_token_id
61
  )
62
+ latency = round(time.time() - start, 2)
63
+
64
+ text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
+ # Убираем повтор промпта
66
+ if text.startswith(prompt):
67
+ text = text[len(prompt):].strip()
68
+
69
+ results[name][f"CoT-промпт {idx}"] = {
70
+ "response": text,
71
+ "time": f"{latency} сек."
72
  }
73
  return results
74
 
75
+ # 7) Оформление Markdown-вывода
76
+ def display_responses(question: str) -> str:
77
+ all_res = generate_all_responses(question)
78
+ md = []
79
+ for model_name, prompts in all_res.items():
80
+ md.append(f"## Модель: **{model_name}**")
81
+ for label, data in prompts.items():
82
+ md.append(f"**{label}** ({data['time']}):\n> {data['response']}")
83
+ return "\n\n".join(md)
84
 
85
+ # 8) Интерфейс Gradio
86
  demo = gr.Interface(
87
  fn=display_responses,
88
+ inputs=gr.Textbox(lines=4, label="Введите вопрос клиента"),
89
+ outputs=gr.Markdown(label="Ответы трёх моделей"),
90
+ title="Alpha Bank Assistant — сравнение CoT-моделей",
91
+ description="Задайте вопрос клиентского обращения и сравните Chain-of-Thought ответы трёх русскоязычных моделей."
 
 
 
 
 
92
  )
93
 
94
  if __name__ == "__main__":