chat / app.py
SimrusDenuvo's picture
Update app.py
74a6389 verified
raw
history blame
3.92 kB
import gradio as gr
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
# 1) Публичные русскоязычные модели из RuGPT-3
MODEL_CONFIGS = {
"GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
"ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
"DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
}
# 2) Устройство (GPU если есть, иначе CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 3) Загрузка моделей и токенизаторов
models = {}
for label, repo_id in MODEL_CONFIGS.items():
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModelForCausalLM.from_pretrained(repo_id)
model.to(device).eval()
models[label] = (tokenizer, model)
# 4) (По необходимости) загрузка датасета для примеров / дообучения
# Если не нужен — можно закомментировать
load_dataset("ZhenDOS/alpha_bank_data", split="train")
# 5) CoT-промпты
def cot_prompt_1(text: str) -> str:
return (
f"Клиент задал вопрос: «{text}»\n"
"Подумай шаг за шагом и подробно объясни ответ от лица банка."
)
def cot_prompt_2(text: str) -> str:
return (
f"Вопрос клиента: «{text}»\n"
"Разложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями."
)
# 6) Генерация ответов и замер времени
def generate_all_responses(question: str):
results = {}
for name, (tokenizer, model) in models.items():
results[name] = {}
for idx, prompt_fn in enumerate([cot_prompt_1, cot_prompt_2], start=1):
prompt = prompt_fn(question)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
start = time.time()
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id
)
latency = round(time.time() - start, 2)
text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Убираем повтор промпта
if text.startswith(prompt):
text = text[len(prompt):].strip()
results[name][f"CoT-промпт {idx}"] = {
"response": text,
"time": f"{latency} сек."
}
return results
# 7) Оформление Markdown-вывода
def display_responses(question: str) -> str:
all_res = generate_all_responses(question)
md = []
for model_name, prompts in all_res.items():
md.append(f"## Модель: **{model_name}**")
for label, data in prompts.items():
md.append(f"**{label}** ({data['time']}):\n> {data['response']}")
return "\n\n".join(md)
# 8) Интерфейс Gradio
demo = gr.Interface(
fn=display_responses,
inputs=gr.Textbox(lines=4, label="Введите вопрос клиента"),
outputs=gr.Markdown(label="Ответы трёх моделей"),
title="Alpha Bank Assistant — сравнение CoT-моделей",
description="Задайте вопрос клиентского обращения и сравните Chain-of-Thought ответы трёх русскоязычных моделей."
)
if __name__ == "__main__":
demo.launch()