from transformers import BertForSequenceClassification, BertTokenizerFast, Trainer, TrainingArguments from datasets import load_dataset import torch import pandas as pd import numpy as np import gradio as gr # ❗ Загрузка датасета ZhenDOS/alpha_bank_data dataset = load_dataset("ZhenDOS/alpha_bank_data") # ✔️ Загрузка базовой модели и токенайзера tokenizer = BertTokenizerFast.from_pretrained("DeepPavlov/rubert-base-cased") model = BertForSequenceClassification.from_pretrained("DeepPavlov/rubert-base-cased", num_labels=len(dataset["train"].features["label"].names)) # ➕ Токенизация входных данных def tokenize_function(examples): return tokenizer(examples["text"], padding="max_length", truncation=True) tokenized_datasets = dataset.map(tokenize_function, batched=True) # 🏃‍♂️ Настройки обучения training_args = TrainingArguments( output_dir="./results", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=64, num_train_epochs=3, weight_decay=0.01, ) # 💨 Процесс обучения trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["validation"], ) trainer.train() # 📊 Функционал для демонстрации через Gradio def classify_question(question): tokens = tokenizer(question, return_tensors="pt") outputs = model(**tokens) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) pred_label_idx = torch.argmax(probabilities, dim=1).item() categories = dataset["train"].features["label"].names return { "Вероятности классов": dict(zip(categories, probabilities.detach().numpy()[0])), "Прогнозируемый класс": categories[pred_label_idx], } # 🖥️ Графический интерфейс Gradio demo = gr.Interface( fn=classify_question, inputs="text", outputs=[ gr.Label(label="Категории"), gr.Textbox(label="Прогнозируемый класс"), ], examples=[ ["Как перевести деньги между картами?"], ["Что такое кредитная история?"], ["Почему моя карта заблокирована?"], ], title="Классификация клиентских запросов банка", description="Приложение помогает определить категорию клиентского запроса и оценить вероятность принадлежности каждого класса.", ) demo.launch()