SimrusDenuvo commited on
Commit
e07efc3
·
verified ·
1 Parent(s): f826a0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -72
app.py CHANGED
@@ -2,94 +2,75 @@ import gradio as gr
2
  import torch
3
  import time
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
- from datasets import load_dataset
6
 
7
- # 1) Публичные русскоязычные модели из RuGPT-3
8
  MODEL_CONFIGS = {
9
- "GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
10
- "ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
11
- "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
12
  }
13
 
14
- # 2) Устройство (GPU если есть, иначе CPU)
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
- # 3) Загрузка моделей и токенизаторов
18
  models = {}
19
  for label, repo_id in MODEL_CONFIGS.items():
20
- tokenizer = AutoTokenizer.from_pretrained(repo_id)
21
- model = AutoModelForCausalLM.from_pretrained(repo_id)
22
- model.to(device).eval()
23
- models[label] = (tokenizer, model)
 
 
 
24
 
25
- # 4) (По необходимости) загрузка датасета для примеров / дообучения
26
- # Если не нужен можно закомментировать
27
- load_dataset("ZhenDOS/alpha_bank_data", split="train")
28
 
29
- # 5) CoT-промпты
30
- def cot_prompt_1(text: str) -> str:
31
- return (
32
- f"Клиент задал вопрос: «{text}»\n"
33
- "Подумай шаг за шагом и подробно объясни ответ от лица банка."
34
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- def cot_prompt_2(text: str) -> str:
37
- return (
38
- f"Вопрос клиента: «{text}»\n"
39
- "Разложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями."
40
- )
41
-
42
- # 6) Генерация ответов и замер времени
43
- def generate_all_responses(question: str):
44
- results = {}
45
- for name, (tokenizer, model) in models.items():
46
- results[name] = {}
47
- for idx, prompt_fn in enumerate([cot_prompt_1, cot_prompt_2], start=1):
48
- prompt = prompt_fn(question)
49
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
50
- inputs = {k: v.to(device) for k, v in inputs.items()}
51
-
52
- start = time.time()
53
- with torch.no_grad():
54
- output_ids = model.generate(
55
- **inputs,
56
- max_new_tokens=200,
57
- do_sample=True,
58
- temperature=0.7,
59
- top_p=0.9,
60
- eos_token_id=tokenizer.eos_token_id
61
- )
62
- latency = round(time.time() - start, 2)
63
-
64
- text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
65
- # Убираем повтор промпта
66
- if text.startswith(prompt):
67
- text = text[len(prompt):].strip()
68
-
69
- results[name][f"CoT-промпт {idx}"] = {
70
- "response": text,
71
- "time": f"{latency} сек."
72
- }
73
- return results
74
-
75
- # 7) Оформление Markdown-вывода
76
- def display_responses(question: str) -> str:
77
- all_res = generate_all_responses(question)
78
  md = []
79
- for model_name, prompts in all_res.items():
80
- md.append(f"## Модель: **{model_name}**")
81
  for label, data in prompts.items():
82
  md.append(f"**{label}** ({data['time']}):\n> {data['response']}")
83
  return "\n\n".join(md)
84
 
85
- # 8) Интерфейс Gradio
86
- demo = gr.Interface(
87
- fn=display_responses,
88
- inputs=gr.Textbox(lines=4, label="Введите вопрос клиента"),
89
- outputs=gr.Markdown(label="Ответы трёх моделей"),
90
- title="Alpha Bank Assistant — сравнение CoT-моделей",
91
- description="Задайте вопрос клиентского обращения и сравните Chain-of-Thought ответы трёх русскоязычных моделей."
92
- )
93
 
94
  if __name__ == "__main__":
95
  demo.launch()
 
 
2
  import torch
3
  import time
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
5
 
6
+ # 1) Конфигурация доступных моделей
7
  MODEL_CONFIGS = {
8
+ "GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
9
+ "ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
10
+ "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
11
  }
12
 
13
+ # 2) Выбор устройства
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ # 3) Загрузка токенизаторов и моделей один раз при старте
17
  models = {}
18
  for label, repo_id in MODEL_CONFIGS.items():
19
+ try:
20
+ tok = AutoTokenizer.from_pretrained(repo_id)
21
+ mdl = AutoModelForCausalLM.from_pretrained(repo_id)
22
+ mdl.to(device).eval()
23
+ models[label] = (tok, mdl)
24
+ except Exception as e:
25
+ print(f"Не удалось загрузить {repo_id}: {e}")
26
 
27
+ # 4) Chain-of-Thought промпты
28
+ def cot_prompt_1(q): return f"Клиент: «{q}»\nШаг за шагом объясни ответ от лица банка."
29
+ def cot_prompt_2(q): return f"Клиент: «{q}»\nРазбери вопрос на части и дай развёрнутый ответ."
30
 
31
+ # 5) Функция генерации
32
+ def generate_all_responses(question):
33
+ if not question.strip():
34
+ return {k: {"error": "Пустой вопрос"} for k in models}
35
+ out = {}
36
+ for name, (tok, mdl) in models.items():
37
+ out[name] = {}
38
+ for idx, prm in enumerate((cot_prompt_1, cot_prompt_2), start=1):
39
+ prompt = prm(question)
40
+ try:
41
+ inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
42
+ t0 = time.time()
43
+ with torch.no_grad():
44
+ ids = mdl.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7, top_p=0.9)
45
+ t1 = time.time()
46
+ txt = tok.decode(ids[0], skip_special_tokens=True)
47
+ if txt.startswith(prompt): txt = txt[len(prompt):].strip()
48
+ out[name][f"CoT-промпт {idx}"] = {
49
+ "response": txt or "— пустой ответ —",
50
+ "time": f"{round(t1-t0,2)} сек."
51
+ }
52
+ except Exception as e:
53
+ out[name][f"CoT-промпт {idx}"] = {"response": f"Ошибка генерации: {e}", "time": "-"}
54
+ return out
55
 
56
+ # 6) Обёртка для Gradio
57
+ def run_all(question):
58
+ res = generate_all_responses(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  md = []
60
+ for model_name, prompts in res.items():
61
+ md.append(f"### 🔹 {model_name}")
62
  for label, data in prompts.items():
63
  md.append(f"**{label}** ({data['time']}):\n> {data['response']}")
64
  return "\n\n".join(md)
65
 
66
+ # 7) Интерфейс Gradio с блоками
67
+ with gr.Blocks() as demo:
68
+ gr.Markdown("# Alpha Bank Assistant — сравнение CoT-моделей")
69
+ inp = gr.Textbox(lines=3, placeholder="Введите вопрос клиента...", label="Вопрос клиента")
70
+ btn = gr.Button("Сгенерировать ответы")
71
+ out = gr.Markdown(label="Результаты")
72
+ btn.click(fn=run_all, inputs=inp, outputs=out)
 
73
 
74
  if __name__ == "__main__":
75
  demo.launch()
76
+