Spaces:
Sleeping
Sleeping
File size: 3,607 Bytes
57645e8 57878bd a191742 af5aa20 bd826f0 a191742 884ac0c a191742 d5e56ca a191742 d0ca3ae a191742 1382bb2 a191742 1382bb2 a191742 af5aa20 a191742 1382bb2 a191742 1382bb2 a191742 1382bb2 a191742 1382bb2 a191742 1382bb2 5dee7eb a191742 5dee7eb 1e10719 90cd180 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import gradio as gr
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
MODEL_CONFIGS = {
"GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
"ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
"DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models = {}
for label, name in MODEL_CONFIGS.items():
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name)
model.to(device)
model.eval()
models[label] = (tokenizer, model)
# Загрузка датасета
load_dataset("ZhenDOS/alpha_bank_data", split="train")
def cot_prompt_1(text):
return f"Клиент задал вопрос: {text}\nПодумай шаг за шагом и объясни, как бы ты ответил на это обращение от лица банка."
def cot_prompt_2(text):
return f"Вопрос клиента: {text}\nРазложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями."
def generate_all_responses(question):
results = {}
for model_name, (tokenizer, model) in models.items():
results[model_name] = {}
for i, prompt_func in enumerate([cot_prompt_1, cot_prompt_2], start=1):
prompt = prompt_func(question)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id
)
end_time = time.time()
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.replace(prompt, "").strip()
duration = round(end_time - start_time, 2)
results[model_name][f"CoT Промпт {i}"] = {
"response": response,
"time": f"{duration} сек."
}
return results
def display_responses(question):
all_responses = generate_all_responses(question)
output = ""
for model_name, prompts in all_responses.items():
output += f"\n### Модель: {model_name}\n"
for prompt_label, content in prompts.items():
output += f"\n**{prompt_label}** ({content['time']}):\n{content['response']}\n"
return output.strip()
demo = gr.Interface(
fn=display_responses,
inputs=gr.Textbox(lines=4, label="Введите клиентский вопрос"),
outputs=gr.Markdown(label="Ответы от разных моделей"),
title="Alpha Bank Assistant — сравнение моделей",
description="Сравнение CoT-ответов от GigaChat, ChatGPT и DeepSeek-подобных моделей на обращение клиента.",
examples=[
"Как восстановить доступ в мобильный банк?",
"Почему с меня списали комиссию за обслуживание карты?",
"Какие условия по потребительскому кредиту?",
]
)
if __name__ == "__main__":
demo.launch()
|