Spaces:
Running
Running
import gradio as gr | |
import torch | |
import time | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from datasets import load_dataset | |
# 1) Публичные русскоязычные модели из RuGPT-3 | |
MODEL_CONFIGS = { | |
"GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2", | |
"ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2", | |
"DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2" | |
} | |
# 2) Устройство (GPU если есть, иначе CPU) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# 3) Загрузка моделей и токенизаторов | |
models = {} | |
for label, repo_id in MODEL_CONFIGS.items(): | |
tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
model = AutoModelForCausalLM.from_pretrained(repo_id) | |
model.to(device).eval() | |
models[label] = (tokenizer, model) | |
# 4) (По необходимости) загрузка датасета для примеров / дообучения | |
# Если не нужен — можно закомментировать | |
load_dataset("ZhenDOS/alpha_bank_data", split="train") | |
# 5) CoT-промпты | |
def cot_prompt_1(text: str) -> str: | |
return ( | |
f"Клиент задал вопрос: «{text}»\n" | |
"Подумай шаг за шагом и подробно объясни ответ от лица банка." | |
) | |
def cot_prompt_2(text: str) -> str: | |
return ( | |
f"Вопрос клиента: «{text}»\n" | |
"Разложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями." | |
) | |
# 6) Генерация ответов и замер времени | |
def generate_all_responses(question: str): | |
results = {} | |
for name, (tokenizer, model) in models.items(): | |
results[name] = {} | |
for idx, prompt_fn in enumerate([cot_prompt_1, cot_prompt_2], start=1): | |
prompt = prompt_fn(question) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
start = time.time() | |
with torch.no_grad(): | |
output_ids = model.generate( | |
**inputs, | |
max_new_tokens=200, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
latency = round(time.time() - start, 2) | |
text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# Убираем повтор промпта | |
if text.startswith(prompt): | |
text = text[len(prompt):].strip() | |
results[name][f"CoT-промпт {idx}"] = { | |
"response": text, | |
"time": f"{latency} сек." | |
} | |
return results | |
# 7) Оформление Markdown-вывода | |
def display_responses(question: str) -> str: | |
all_res = generate_all_responses(question) | |
md = [] | |
for model_name, prompts in all_res.items(): | |
md.append(f"## Модель: **{model_name}**") | |
for label, data in prompts.items(): | |
md.append(f"**{label}** ({data['time']}):\n> {data['response']}") | |
return "\n\n".join(md) | |
# 8) Интерфейс Gradio | |
demo = gr.Interface( | |
fn=display_responses, | |
inputs=gr.Textbox(lines=4, label="Введите вопрос клиента"), | |
outputs=gr.Markdown(label="Ответы трёх моделей"), | |
title="Alpha Bank Assistant — сравнение CoT-моделей", | |
description="Задайте вопрос клиентского обращения и сравните Chain-of-Thought ответы трёх русскоязычных моделей." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |