SimrusDenuvo commited on
Commit
961a138
·
verified ·
1 Parent(s): e07efc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -49
app.py CHANGED
@@ -3,74 +3,68 @@ import torch
3
  import time
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- # 1) Конфигурация доступных моделей
7
  MODEL_CONFIGS = {
8
- "GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
9
- "ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
10
- "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
11
  }
12
 
13
- # 2) Выбор устройства
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- # 3) Загрузка токенизаторов и моделей один раз при старте
17
  models = {}
18
- for label, repo_id in MODEL_CONFIGS.items():
19
- try:
20
- tok = AutoTokenizer.from_pretrained(repo_id)
21
- mdl = AutoModelForCausalLM.from_pretrained(repo_id)
22
- mdl.to(device).eval()
23
- models[label] = (tok, mdl)
24
- except Exception as e:
25
- print(f"Не удалось загрузить {repo_id}: {e}")
26
 
27
- # 4) Chain-of-Thought промпты
28
- def cot_prompt_1(q): return f"Клиент: «{q}»\nШаг за шагом объясни ответ от лица банка."
29
- def cot_prompt_2(q): return f"Клиент: «{q}»\nРазбери вопрос на части и дай развёрнутый ответ."
30
 
31
- # 5) Функция генерации
32
- def generate_all_responses(question):
33
- if not question.strip():
34
- return {k: {"error": "Пустой вопрос"} for k in models}
35
  out = {}
36
  for name, (tok, mdl) in models.items():
37
  out[name] = {}
38
- for idx, prm in enumerate((cot_prompt_1, cot_prompt_2), start=1):
39
- prompt = prm(question)
40
- try:
41
- inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
42
- t0 = time.time()
43
- with torch.no_grad():
44
- ids = mdl.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7, top_p=0.9)
45
- t1 = time.time()
46
- txt = tok.decode(ids[0], skip_special_tokens=True)
47
- if txt.startswith(prompt): txt = txt[len(prompt):].strip()
48
- out[name][f"CoT-промпт {idx}"] = {
49
- "response": txt or "— пустой ответ —",
50
- "time": f"{round(t1-t0,2)} сек."
51
- }
52
- except Exception as e:
53
- out[name][f"CoT-промпт {idx}"] = {"response": f"Ошибка генерации: {e}", "time": "-"}
54
  return out
55
 
56
  # 6) Обёртка для Gradio
57
- def run_all(question):
58
- res = generate_all_responses(question)
59
- md = []
60
- for model_name, prompts in res.items():
61
- md.append(f"### 🔹 {model_name}")
62
- for label, data in prompts.items():
63
- md.append(f"**{label}** ({data['time']}):\n> {data['response']}")
64
- return "\n\n".join(md)
65
 
66
- # 7) Интерфейс Gradio с блоками
67
  with gr.Blocks() as demo:
68
  gr.Markdown("# Alpha Bank Assistant — сравнение CoT-моделей")
69
- inp = gr.Textbox(lines=3, placeholder="Введите вопрос клиента...", label="Вопрос клиента")
70
  btn = gr.Button("Сгенерировать ответы")
71
- out = gr.Markdown(label="Результаты")
72
- btn.click(fn=run_all, inputs=inp, outputs=out)
 
 
 
73
 
74
  if __name__ == "__main__":
75
  demo.launch()
76
 
 
 
3
  import time
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
+ # 1) Публичные русскоязычные модели
7
  MODEL_CONFIGS = {
8
+ "GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
9
+ "ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
10
+ "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
11
  }
12
 
13
+ # 2) Устройство
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ # 3) Загружаем модели один раз при старте
17
  models = {}
18
+ for name, repo_id in MODEL_CONFIGS.items():
19
+ tok = AutoTokenizer.from_pretrained(repo_id)
20
+ mdl = AutoModelForCausalLM.from_pretrained(repo_id)
21
+ mdl.to(device).eval()
22
+ models[name] = (tok, mdl)
 
 
 
23
 
24
+ # 4) CoT-промпты
25
+ def cot1(q): return f"Клиент: «{q}»\nШаг за шагом объясни, как ответил бы банк."
26
+ def cot2(q): return f"Клиент: «{q}»\nРазбери запрос и дай развернутый ответ."
27
 
28
+ # 5) Генерация ответов + замер времени
29
+ def generate_all(q):
 
 
30
  out = {}
31
  for name, (tok, mdl) in models.items():
32
  out[name] = {}
33
+ for idx, prm in enumerate((cot1, cot2), start=1):
34
+ prompt = prm(q)
35
+ inputs = tok(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
36
+ t0 = time.time()
37
+ with torch.no_grad():
38
+ ids = mdl.generate(**inputs, max_new_tokens=150, do_sample=True, temperature=0.7, top_p=0.9)
39
+ dt = round(time.time() - t0, 2)
40
+ resp = tok.decode(ids[0], skip_special_tokens=True)
41
+ if resp.startswith(prompt):
42
+ resp = resp[len(prompt):].strip()
43
+ out[name][f"CoT-промпт {idx}"] = f"{resp}\n⏱ {dt} сек."
 
 
 
 
 
44
  return out
45
 
46
  # 6) Обёртка для Gradio
47
+ def run_all(q):
48
+ res = generate_all(q)
49
+ # вернём 3 больших текста: сначала GigaChat-like, потом ChatGPT-like, потом DeepSeek-like
50
+ return (
51
+ "\n\n".join(f"### {k}\n\n" + "\n\n".join(v.values()) for k, v in [("GigaChat-like", res["GigaChat-like"])]),
52
+ "\n\n".join(f"### {k}\n\n" + "\n\n".join(v.values()) for k, v in [("ChatGPT-like", res["ChatGPT-like"])]),
53
+ "\n\n".join(f"### {k}\n\n" + "\n\n".join(v.values()) for k, v in [("DeepSeek-like", res["DeepSeek-like"])]),
54
+ )
55
 
56
+ # 7) Blocks-интерфейс с явным полем вывода
57
  with gr.Blocks() as demo:
58
  gr.Markdown("# Alpha Bank Assistant — сравнение CoT-моделей")
59
+ inp = gr.Textbox(label="Вопрос клиента", placeholder="Например: Как восстановить доступ в мобильный банк?", lines=3)
60
  btn = gr.Button("Сгенерировать ответы")
61
+ # вот поле вывода: три текстовых Textbox’а под кнопкой
62
+ out1 = gr.Textbox(label="GigaChat-like", lines=8)
63
+ out2 = gr.Textbox(label="ChatGPT-like", lines=8)
64
+ out3 = gr.Textbox(label="DeepSeek-like", lines=8)
65
+ btn.click(fn=run_all, inputs=inp, outputs=[out1, out2, out3])
66
 
67
  if __name__ == "__main__":
68
  demo.launch()
69
 
70
+