SimrusDenuvo commited on
Commit
e78c1cb
·
verified ·
1 Parent(s): e2fc154

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -38
app.py CHANGED
@@ -1,45 +1,73 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
2
  import torch
 
 
3
  import gradio as gr
4
 
5
- # Название модели
6
- model_name = "sberbank-ai/rugpt3medium_based_on_gpt2"
7
-
8
- # Загрузка модели и токенизатора
9
- tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- model = AutoModelForCausalLM.from_pretrained(model_name)
11
-
12
- # Функция для генерации ответа
13
- def generate_response(prompt):
14
- instruction = f"Ответь кратко и по существу на вопрос:\n{prompt.strip()}\nОтвет:"
15
- input_ids = tokenizer.encode(instruction, return_tensors="pt")
16
-
17
- # Параметры генерации для уменьшения времени отклика
18
- output = model.generate(
19
- input_ids,
20
- max_new_tokens=50, # Уменьшение числа токенов для более короткого ответа
21
- do_sample=True,
22
- top_k=50,
23
- top_p=0.95,
24
- temperature=0.7, # Уменьшение случайности
25
- pad_token_id=tokenizer.eos_token_id,
26
- eos_token_id=tokenizer.eos_token_id
27
- )
28
-
29
- # Декодирование ответа и удаление части промпта
30
- response = tokenizer.decode(output[0], skip_special_tokens=True)
31
- return response.replace(instruction, "").strip()
32
-
33
- # Интерфейс Gradio
34
- iface = gr.Interface(
35
- fn=generate_response,
36
- inputs=gr.Textbox(label="Введите ваш вопрос"),
37
- outputs=gr.Textbox(label="Ответ модели"),
38
- title="Ответ от ruGPT-3 Medium",
39
- description="Генерация ответа с помощью модели Sberbank ruGPT-3 Medium"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
 
42
- # Запуск интерфейса
43
- iface.launch()
44
 
45
 
 
1
+ from transformers import BertForSequenceClassification, BertTokenizerFast, Trainer, TrainingArguments
2
+ from datasets import load_dataset
3
  import torch
4
+ import pandas as pd
5
+ import numpy as np
6
  import gradio as gr
7
 
8
+ # Загрузка датасета ZhenDOS/alpha_bank_data
9
+ dataset = load_dataset("ZhenDOS/alpha_bank_data")
10
+
11
+ # ✔️ Загрузка базовой модели и токенайзера
12
+ tokenizer = BertTokenizerFast.from_pretrained("DeepPavlov/rubert-base-cased")
13
+ model = BertForSequenceClassification.from_pretrained("DeepPavlov/rubert-base-cased", num_labels=len(dataset["train"].features["label"].names))
14
+
15
+ # Токенизация входных данных
16
+ def tokenize_function(examples):
17
+ return tokenizer(examples["text"], padding="max_length", truncation=True)
18
+
19
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
20
+
21
+ # 🏃‍♂️ Настройки обучения
22
+ training_args = TrainingArguments(
23
+ output_dir="./results",
24
+ evaluation_strategy="epoch",
25
+ learning_rate=2e-5,
26
+ per_device_train_batch_size=16,
27
+ per_device_eval_batch_size=64,
28
+ num_train_epochs=3,
29
+ weight_decay=0.01,
30
+ )
31
+
32
+ # 💨 Процесс обучения
33
+ trainer = Trainer(
34
+ model=model,
35
+ args=training_args,
36
+ train_dataset=tokenized_datasets["train"],
37
+ eval_dataset=tokenized_datasets["validation"],
38
+ )
39
+
40
+ trainer.train()
41
+
42
+ # 📊 Функционал для демонстрации через Gradio
43
+ def classify_question(question):
44
+ tokens = tokenizer(question, return_tensors="pt")
45
+ outputs = model(**tokens)
46
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
47
+ pred_label_idx = torch.argmax(probabilities, dim=1).item()
48
+ categories = dataset["train"].features["label"].names
49
+ return {
50
+ "Вероятности классов": dict(zip(categories, probabilities.detach().numpy()[0])),
51
+ "Прогнозируемый класс": categories[pred_label_idx],
52
+ }
53
+
54
+ # 🖥️ Графический интерфейс Gradio
55
+ demo = gr.Interface(
56
+ fn=classify_question,
57
+ inputs="text",
58
+ outputs=[
59
+ gr.Label(label="Категории"),
60
+ gr.Textbox(label="Прогнозируемый класс"),
61
+ ],
62
+ examples=[
63
+ ["Как перевести деньги между картами?"],
64
+ ["Что такое кредитная история?"],
65
+ ["Почему моя карта заблокирована?"],
66
+ ],
67
+ title="Классификация клиентских запросов банка",
68
+ description="Приложение помогает определить категорию клиентского запроса и оценить вероятность принадлежности каждого класса.",
69
  )
70
 
71
+ demo.launch()
 
72
 
73