import gradio as gr import torch import time from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset MODEL_CONFIGS = { "GigaChat-like": "ai-forever/rugpt3large_based_on_gpt2", "ChatGPT-like": "ai-forever/rugpt3medium_based_on_gpt2", "DeepSeek-like": "ai-forever/rugpt3small_based_on_gpt2" } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") models = {} for label, name in MODEL_CONFIGS.items(): tokenizer = AutoTokenizer.from_pretrained(name) model = AutoModelForCausalLM.from_pretrained(name) model.to(device) model.eval() models[label] = (tokenizer, model) # Загрузка датасета load_dataset("ZhenDOS/alpha_bank_data", split="train") def cot_prompt_1(text): return f"Клиент задал вопрос: {text}\nПодумай шаг за шагом и объясни, как бы ты ответил на это обращение от лица банка." def cot_prompt_2(text): return f"Вопрос клиента: {text}\nРазложи на части, что именно спрашивает клиент, и предложи логичный ответ с пояснениями." def generate_all_responses(question): results = {} for model_name, (tokenizer, model) in models.items(): results[model_name] = {} for i, prompt_func in enumerate([cot_prompt_1, cot_prompt_2], start=1): prompt = prompt_func(question) inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) inputs = {k: v.to(device) for k, v in inputs.items()} start_time = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=200, do_sample=True, temperature=0.7, top_p=0.9, eos_token_id=tokenizer.eos_token_id ) end_time = time.time() response = tokenizer.decode(outputs[0], skip_special_tokens=True) response = response.replace(prompt, "").strip() duration = round(end_time - start_time, 2) results[model_name][f"CoT Промпт {i}"] = { "response": response, "time": f"{duration} сек." } return results def display_responses(question): all_responses = generate_all_responses(question) output = "" for model_name, prompts in all_responses.items(): output += f"\n### Модель: {model_name}\n" for prompt_label, content in prompts.items(): output += f"\n**{prompt_label}** ({content['time']}):\n{content['response']}\n" return output.strip() demo = gr.Interface( fn=display_responses, inputs=gr.Textbox(lines=4, label="Введите клиентский вопрос"), outputs=gr.Markdown(label="Ответы от разных моделей"), title="Alpha Bank Assistant — сравнение моделей", description="Сравнение CoT-ответов от GigaChat, ChatGPT и DeepSeek-подобных моделей на обращение клиента.", examples=[ "Как восстановить доступ в мобильный банк?", "Почему с меня списали комиссию за обслуживание карты?", "Какие условия по потребительскому кредиту?", ] ) if __name__ == "__main__": demo.launch()