Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import time | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from datasets import load_dataset | |
MODEL_CONFIGS = { | |
"GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2", | |
"ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2", | |
"DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2" | |
} | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
models = {} | |
for label, name in MODEL_CONFIGS.items(): | |
tokenizer = AutoTokenizer.from_pretrained(name) | |
model = AutoModelForCausalLM.from_pretrained(name) | |
model.to(device) | |
model.eval() | |
models[label] = (tokenizer, model) | |
# Загрузка датасета | |
load_dataset("ZhenDOS/alpha_bank_data", split="train") | |
def cot_prompt_1(text): | |
return f"Клиент задал вопрос: {text}\nПодумай шаг за шагом и объясни, как бы ты ответил на это обращение от лица банка." | |
def cot_prompt_2(text): | |
return f"Вопрос клиента: {text}\nРазложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями." | |
def generate_all_responses(question): | |
results = {} | |
for model_name, (tokenizer, model) in models.items(): | |
results[model_name] = {} | |
for i, prompt_func in enumerate([cot_prompt_1, cot_prompt_2], start=1): | |
prompt = prompt_func(question) | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
start_time = time.time() | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=200, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
end_time = time.time() | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = response.replace(prompt, "").strip() | |
duration = round(end_time - start_time, 2) | |
results[model_name][f"CoT Промпт {i}"] = { | |
"response": response, | |
"time": f"{duration} сек." | |
} | |
return results | |
def display_responses(question): | |
all_responses = generate_all_responses(question) | |
output = "" | |
for model_name, prompts in all_responses.items(): | |
output += f"\n### Модель: {model_name}\n" | |
for prompt_label, content in prompts.items(): | |
output += f"\n**{prompt_label}** ({content['time']}):\n{content['response']}\n" | |
return output.strip() | |
demo = gr.Interface( | |
fn=display_responses, | |
inputs=gr.Textbox(lines=4, label="Введите клиентский вопрос"), | |
outputs=gr.Markdown(label="Ответы от разных моделей"), | |
title="Alpha Bank Assistant — сравнение моделей", | |
description="Сравнение CoT-ответов от GigaChat, ChatGPT и DeepSeek-подобных моделей на обращение клиента.", | |
examples=[ | |
"Как восстановить доступ в мобильный банк?", | |
"Почему с меня списали комиссию за обслуживание карты?", | |
"Какие условия по потребительскому кредиту?", | |
] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |