SimrusDenuvo commited on
Commit
22193d7
·
verified ·
1 Parent(s): e78c1cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -66
app.py CHANGED
@@ -1,73 +1,21 @@
1
- from transformers import BertForSequenceClassification, BertTokenizerFast, Trainer, TrainingArguments
2
- from datasets import load_dataset
3
- import torch
4
- import pandas as pd
5
- import numpy as np
6
- import gradio as gr
7
 
8
- # Загрузка датасета ZhenDOS/alpha_bank_data
9
- dataset = load_dataset("ZhenDOS/alpha_bank_data")
 
10
 
11
- # ✔️ Загрузка базовой модели и токенайзера
12
- tokenizer = BertTokenizerFast.from_pretrained("DeepPavlov/rubert-base-cased")
13
- model = BertForSequenceClassification.from_pretrained("DeepPavlov/rubert-base-cased", num_labels=len(dataset["train"].features["label"].names))
14
 
15
- # ➕ Токенизация входных данных
16
- def tokenize_function(examples):
17
- return tokenizer(examples["text"], padding="max_length", truncation=True)
18
 
19
- tokenized_datasets = dataset.map(tokenize_function, batched=True)
 
20
 
21
- # 🏃‍♂️ Настройки обучения
22
- training_args = TrainingArguments(
23
- output_dir="./results",
24
- evaluation_strategy="epoch",
25
- learning_rate=2e-5,
26
- per_device_train_batch_size=16,
27
- per_device_eval_batch_size=64,
28
- num_train_epochs=3,
29
- weight_decay=0.01,
30
- )
31
-
32
- # 💨 Процесс обучения
33
- trainer = Trainer(
34
- model=model,
35
- args=training_args,
36
- train_dataset=tokenized_datasets["train"],
37
- eval_dataset=tokenized_datasets["validation"],
38
- )
39
-
40
- trainer.train()
41
-
42
- # 📊 Функционал для демонстрации через Gradio
43
- def classify_question(question):
44
- tokens = tokenizer(question, return_tensors="pt")
45
- outputs = model(**tokens)
46
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
47
- pred_label_idx = torch.argmax(probabilities, dim=1).item()
48
- categories = dataset["train"].features["label"].names
49
- return {
50
- "Вероятности классов": dict(zip(categories, probabilities.detach().numpy()[0])),
51
- "Прогнозируемый класс": categories[pred_label_idx],
52
- }
53
-
54
- # 🖥️ Графический интерфейс Gradio
55
- demo = gr.Interface(
56
- fn=classify_question,
57
- inputs="text",
58
- outputs=[
59
- gr.Label(label="Категории"),
60
- gr.Textbox(label="Прогнозируемый класс"),
61
- ],
62
- examples=[
63
- ["Как перевести деньги между картами?"],
64
- ["Что такое кредитная история?"],
65
- ["Почему моя карта заблокирована?"],
66
- ],
67
- title="Классификация клиентских запросов банка",
68
- description="Приложение помогает определить категорию клиентского запроса и оценить вероятность принадлежности каждого класса.",
69
- )
70
-
71
- demo.launch()
72
 
73
 
 
1
+ from transformers import pipeline
 
 
 
 
 
2
 
3
+ clf = pipeline(
4
+ task = 'sentiment-analysis',
5
+ model = 'SkolkovoInstitute/russian_toxicity_classifier')
6
 
7
+ text = ['У нас в есть убунты и текникал превью.',
8
+ 'Как минимум два малолетних дегенерата в треде, мда.']
 
9
 
10
+ def data(text):
11
+ for row in text:
12
+ yield row
13
 
14
+ for out in clf(data(text)):
15
+ print(out)
16
 
17
+ #вывод
18
+ {'label': 'neutral', 'score': 0.9872767329216003}
19
+ {'label': 'toxic', 'score': 0.985331654548645}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21