chat / app.py
SimrusDenuvo's picture
Update app.py
a191742 verified
raw
history blame
3.61 kB
import gradio as gr
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
MODEL_CONFIGS = {
"GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2",
"ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2",
"DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2"
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
models = {}
for label, name in MODEL_CONFIGS.items():
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name)
model.to(device)
model.eval()
models[label] = (tokenizer, model)
# Загрузка датасета
load_dataset("ZhenDOS/alpha_bank_data", split="train")
def cot_prompt_1(text):
return f"Клиент задал вопрос: {text}\nПодумай шаг за шагом и объясни, как бы ты ответил на это обращение от лица банка."
def cot_prompt_2(text):
return f"Вопрос клиента: {text}\nРазложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями."
def generate_all_responses(question):
results = {}
for model_name, (tokenizer, model) in models.items():
results[model_name] = {}
for i, prompt_func in enumerate([cot_prompt_1, cot_prompt_2], start=1):
prompt = prompt_func(question)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
start_time = time.time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.7,
top_p=0.9,
eos_token_id=tokenizer.eos_token_id
)
end_time = time.time()
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.replace(prompt, "").strip()
duration = round(end_time - start_time, 2)
results[model_name][f"CoT Промпт {i}"] = {
"response": response,
"time": f"{duration} сек."
}
return results
def display_responses(question):
all_responses = generate_all_responses(question)
output = ""
for model_name, prompts in all_responses.items():
output += f"\n### Модель: {model_name}\n"
for prompt_label, content in prompts.items():
output += f"\n**{prompt_label}** ({content['time']}):\n{content['response']}\n"
return output.strip()
demo = gr.Interface(
fn=display_responses,
inputs=gr.Textbox(lines=4, label="Введите клиентский вопрос"),
outputs=gr.Markdown(label="Ответы от разных моделей"),
title="Alpha Bank Assistant — сравнение моделей",
description="Сравнение CoT-ответов от GigaChat, ChatGPT и DeepSeek-подобных моделей на обращение клиента.",
examples=[
"Как восстановить доступ в мобильный банк?",
"Почему с меня списали комиссию за обслуживание карты?",
"Какие условия по потребительскому кредиту?",
]
)
if __name__ == "__main__":
demo.launch()