import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch from datasets import load_dataset # Загружаем модель, токенизатор и датасет model_name = "ai-forever/rugpt3small_based_on_gpt2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) # Загружаем банковский датасет для контекста bank_dataset = load_dataset("ZhenDOS/alpha_bank_data") # Создаем контекст из датасета (первые несколько примеров) context_examples = "\n".join([ f"Вопрос: {example['question']}\nОтвет: {example['answer']}" for example in bank_dataset['train'].select(range(5)) ]) # Функция генерации ответа с учетом банковского контекста def generate_response(prompt): # Добавляем контекст из датасета к промпту full_prompt = f"""Контекст по банковским вопросам: {context_examples} Вопрос клиента: {prompt} Ответ:""" inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, do_sample=True, temperature=0.7, top_k=50, top_p=0.95, eos_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=3, early_stopping=True ) response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Удаляем промпт из ответа if response.startswith(full_prompt): response = response[len(full_prompt):].strip() # Постобработка ответа response = response.split("\n")[0] # Берем только первую строку ответа response = response.replace("Ответ:", "").strip() return response # Интерфейс Gradio с примерами вопросов examples = [ "Как восстановить утерянную карту?", "Какие документы нужны для открытия счета?", "Как проверить баланс карты?", "Как оформить кредитную карту?", "Какие комиссии за перевод между счетами?" ] demo = gr.Interface( fn=generate_response, inputs=gr.Textbox(lines=4, label="Введите вопрос по клиентским обращениям в банк"), outputs=gr.Textbox(label="Ответ модели"), title="Анализ клиентских обращений — RuGPT-3 с Alpha Bank Data", description="Используется модель ai-forever/rugpt3small_based_on_gpt2, дообученная на данных ZhenDOS/alpha_bank_data.", examples=examples ) # Запуск if __name__ == "__main__": demo.launch()